/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package versionsapi

import (
	"bytes"
	"context"
	"encoding/json"
	"io"
	"net/http"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m)
}

func TestValidate(t *testing.T) {
	testCases := map[string]struct {
		listFunc     func() *List
		overrideFunc func(list *List)
		wantErr      bool
	}{
		"valid major list": {
			listFunc: majorList,
		},
		"valid minor list": {
			listFunc: minorList,
		},
		"invalid stream": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Stream = "invalid"
			},
			wantErr: true,
		},
		"invalid granularity": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Granularity = "invalid"
			},
			wantErr: true,
		},
		"invalid kind": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Kind = "invalid"
			},
			wantErr: true,
		},
		"base ver is not semantic version": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Base = "invalid"
			},
			wantErr: true,
		},
		"base ver does not reflect major granularity": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Base = "v1.0"
			},
			wantErr: true,
		},
		"base ver does not reflect minor granularity": {
			listFunc: minorList,
			overrideFunc: func(list *List) {
				list.Base = "v1"
			},
			wantErr: true,
		},
		"version in list is not semantic version": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Versions[0] = "invalid"
			},
			wantErr: true,
		},
		"version in list is not sub version of base": {
			listFunc: majorList,
			overrideFunc: func(list *List) {
				list.Versions[0] = "v2.1"
			},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			list := tc.listFunc()
			if tc.overrideFunc != nil {
				tc.overrideFunc(list)
			}
			err := list.Validate()
			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
		})
	}
}

func TestList(t *testing.T) {
	majorListJSON, err := json.Marshal(majorList())
	require.NoError(t, err)
	minorListJSON, err := json.Marshal(minorList())
	require.NoError(t, err)
	inconsistentList := majorList()
	inconsistentList.Base = "v2"
	inconsistentListJSON, err := json.Marshal(inconsistentList)
	require.NoError(t, err)
	client := newTestClient(func(req *http.Request) *http.Response {
		switch req.URL.Path {
		case "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v1/image.json":
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBuffer(majorListJSON)),
				Header:     make(http.Header),
			}
		case "/constellation/v1/ref/test-ref/stream/nightly/versions/minor/v1.1/image.json":
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBuffer(minorListJSON)),
				Header:     make(http.Header),
			}
		case "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v1/500.json": // 500 error
			return &http.Response{
				StatusCode: http.StatusInternalServerError,
				Body:       io.NopCloser(bytes.NewBufferString("Server Error.")),
				Header:     make(http.Header),
			}
		case "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v1/nojson.json": // invalid format
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBufferString("not json")),
				Header:     make(http.Header),
			}
		case "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v2/image.json": // inconsistent list
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBuffer(inconsistentListJSON)),
				Header:     make(http.Header),
			}
		case "/constellation/v1/ref/test-ref/stream/nightly/versions/major/v3/image.json": // does not match requested version
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBuffer(minorListJSON)),
				Header:     make(http.Header),
			}
		}
		return &http.Response{
			StatusCode: http.StatusNotFound,
			Body:       io.NopCloser(bytes.NewBufferString("Not found.")),
			Header:     make(http.Header),
		}
	})

	testCases := map[string]struct {
		ref, stream, granularity, base, kind string
		overrideFile                         string
		wantList                             List
		wantErr                              bool
	}{
		"major list fetched remotely": {
			wantList: *majorList(),
		},
		"minor list fetched remotely": {
			granularity: "minor",
			base:        "v1.1",
			wantList:    *minorList(),
		},
		"list does not exist": {
			stream:  "unknown",
			wantErr: true,
		},
		"unexpected error code": {
			kind:    "500",
			wantErr: true,
		},
		"invalid json returned": {
			kind:    "nojson",
			wantErr: true,
		},
		"invalid list returned": {
			base:    "v2",
			wantErr: true,
		},
		"response does not match request": {
			base:    "v3",
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			ref := "test-ref"
			stream := "nightly"
			granularity := "major"
			base := "v1"
			kind := "image"
			if tc.stream != "" {
				stream = tc.stream
			}
			if tc.granularity != "" {
				granularity = tc.granularity
			}
			if tc.base != "" {
				base = tc.base
			}
			if tc.kind != "" {
				kind = tc.kind
			}

			fetcher := &Fetcher{
				httpc: client,
			}
			list, err := fetcher.list(context.Background(), ref, stream, granularity, base, kind)
			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(tc.wantList, *list)
		})
	}
}

// roundTripFunc .
type roundTripFunc func(req *http.Request) *http.Response

// RoundTrip .
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	return f(req), nil
}

// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
	return &http.Client{
		Transport: fn,
	}
}

func majorList() *List {
	return &List{
		Ref:         "test-ref",
		Stream:      "nightly",
		Granularity: "major",
		Base:        "v1",
		Kind:        "image",
		Versions: []string{
			"v1.0", "v1.1", "v1.2",
		},
	}
}

func minorList() *List {
	return &List{
		Ref:         "test-ref",
		Stream:      "nightly",
		Granularity: "minor",
		Base:        "v1.1",
		Kind:        "image",
		Versions: []string{
			"v1.1.0", "v1.1.1", "v1.1.2",
		},
	}
}

func TestIsValidRef(t *testing.T) {
	testCases := map[string]bool{
		"feat/foo":            false,
		"feat-foo":            true,
		"feat$foo":            false,
		"3234":                true,
		"feat foo":            false,
		"refs-heads-feat-foo": false,
		"":                    false,
	}

	for ref, want := range testCases {
		t.Run(ref, func(t *testing.T) {
			assert := assert.New(t)
			assert.Equal(want, IsValidRef(ref))
		})
	}
}

func TestCanonicalRef(t *testing.T) {
	testCases := map[string]string{
		"feat/foo": "feat-foo",
		"feat-foo": "feat-foo",
		"feat$foo": "feat-foo",
		"3234":     "3234",
		"feat foo": "feat-foo",
	}

	for ref, want := range testCases {
		t.Run(ref, func(t *testing.T) {
			assert := assert.New(t)
			assert.Equal(want, CanonicalRef(ref))
		})
	}
}

func TestIsValidStream(t *testing.T) {
	testCases := []struct {
		branch string
		stream string
		want   bool
	}{
		{branch: "-", stream: "stable", want: true},
		{branch: "-", stream: "debug", want: true},
		{branch: "-", stream: "nightly", want: false},
		{branch: "-", stream: "console", want: true},
		{branch: "main", stream: "stable", want: false},
		{branch: "main", stream: "debug", want: true},
		{branch: "main", stream: "nightly", want: true},
		{branch: "main", stream: "console", want: true},
		{branch: "foo-branch", stream: "nightly", want: true},
		{branch: "foo-branch", stream: "console", want: true},
		{branch: "foo-branch", stream: "debug", want: true},
		{branch: "foo-branch", stream: "stable", want: false},
	}

	for _, tc := range testCases {
		t.Run(tc.branch+"+"+tc.stream, func(t *testing.T) {
			assert := assert.New(t)

			assert.Equal(tc.want, IsValidStream(tc.branch, tc.stream))
		})
	}
}