/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package image

import (
	"bytes"
	"context"
	"io"
	"net/http"
	"os"
	"strings"
	"testing"

	"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
	"github.com/edgelesssys/constellation/v2/internal/config"
	"github.com/spf13/afero"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m)
}

func TestGetReference(t *testing.T) {
	testCases := map[string]struct {
		info          *imageInfo
		csp, variant  string
		wantReference string
		wantErr       bool
	}{
		"reference exists": {
			info:          &imageInfo{AWS: map[string]string{"someVariant": "someReference"}},
			csp:           "aws",
			variant:       "someVariant",
			wantReference: "someReference",
		},
		"csp does not exist": {
			info:    &imageInfo{AWS: map[string]string{"someVariant": "someReference"}},
			csp:     "nonExistingCSP",
			variant: "someVariant",
			wantErr: true,
		},
		"variant does not exist": {
			info:    &imageInfo{AWS: map[string]string{"someVariant": "someReference"}},
			csp:     "aws",
			variant: "nonExistingVariant",
			wantErr: true,
		},
		"info is nil": {
			info:    nil,
			csp:     "aws",
			variant: "someVariant",
			wantErr: true,
		},
		"csp is nil": {
			info:    &imageInfo{AWS: nil},
			csp:     "aws",
			variant: "someVariant",
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			reference, err := tc.info.getReference(tc.csp, tc.variant)

			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(tc.wantReference, reference)
		})
	}
}

func TestGetReferenceOnNil(t *testing.T) {
	assert := assert.New(t)

	var lut *imageInfo
	_, err := lut.getReference("someCSP", "someVariant")
	assert.Error(err)
}

func TestVariant(t *testing.T) {
	testCases := map[string]struct {
		csp         cloudprovider.Provider
		config      *config.Config
		wantVariant string
		wantErr     bool
	}{
		"AWS region": {
			csp: cloudprovider.AWS,
			config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
				AWS: &config.AWSConfig{Region: "someRegion"},
			}},
			wantVariant: "someRegion",
		},
		"Azure cvm": {
			csp: cloudprovider.Azure,
			config: &config.Config{
				Image: "someImage", Provider: config.ProviderConfig{
					Azure: &config.AzureConfig{ConfidentialVM: func() *bool { b := true; return &b }()},
				},
			},
			wantVariant: "cvm",
		},
		"Azure trustedlaunch": {
			csp: cloudprovider.Azure,
			config: &config.Config{
				Image: "someImage", Provider: config.ProviderConfig{
					Azure: &config.AzureConfig{ConfidentialVM: func() *bool { b := false; return &b }()},
				},
			},
			wantVariant: "trustedlaunch",
		},
		"GCP": {
			csp: cloudprovider.GCP,
			config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
				GCP: &config.GCPConfig{},
			}},
			wantVariant: "sev-es",
		},
		"QEMU": {
			csp: cloudprovider.QEMU,
			config: &config.Config{Image: "someImage", Provider: config.ProviderConfig{
				QEMU: &config.QEMUConfig{},
			}},
			wantVariant: "default",
		},
		"invalid": {
			csp:     cloudprovider.Provider(9999),
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			vari, err := variant(tc.csp, tc.config)
			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(tc.wantVariant, vari)
		})
	}
}

func TestFetchReference(t *testing.T) {
	imageVersionUID := "someImageVersionUID"
	client := newTestClient(func(req *http.Request) *http.Response {
		if strings.HasSuffix(req.URL.String(), "/constellation/v1/images/someImageVersionUID.json") {
			return &http.Response{
				StatusCode: http.StatusOK,
				Body:       io.NopCloser(bytes.NewBufferString(lut)),
				Header:     make(http.Header),
			}
		}
		return &http.Response{
			StatusCode: http.StatusNotFound,
			Body:       io.NopCloser(bytes.NewBufferString("Not found.")),
			Header:     make(http.Header),
		}
	})

	testCases := map[string]struct {
		config        *config.Config
		overrideFile  string
		wantReference string
		wantErr       bool
	}{
		"reference fetched remotely": {
			config: &config.Config{Image: imageVersionUID, Provider: config.ProviderConfig{
				QEMU: &config.QEMUConfig{},
			}},
			wantReference: "someReference",
		},
		"reference fetched locally": {
			config: &config.Config{Image: imageVersionUID, Provider: config.ProviderConfig{
				QEMU: &config.QEMUConfig{},
			}},
			overrideFile:  `{"qemu":{"default":"localOverrideReference"}}`,
			wantReference: "localOverrideReference",
		},
		"lut is invalid": {
			config: &config.Config{Image: imageVersionUID, Provider: config.ProviderConfig{
				QEMU: &config.QEMUConfig{},
			}},
			overrideFile: `{`,
			wantErr:      true,
		},
		"image version does not exist": {
			config: &config.Config{Image: "nonExistingImageVersionUID", Provider: config.ProviderConfig{
				QEMU: &config.QEMUConfig{},
			}},
			wantErr: true,
		},
		"invalid config": {
			config:  &config.Config{},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			fetcher := &Fetcher{
				httpc: client,
				fs:    newImageVersionStubFs(t, imageVersionUID, tc.overrideFile),
			}
			reference, err := fetcher.FetchReference(context.Background(), tc.config)
			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(tc.wantReference, reference)
		})
	}
}

func must(t *testing.T, err error) {
	t.Helper()
	if err != nil {
		t.Fatal(err)
	}
}

// roundTripFunc .
type roundTripFunc func(req *http.Request) *http.Response

// RoundTrip .
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	return f(req), nil
}

// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
	return &http.Client{
		Transport: fn,
	}
}

func newImageVersionStubFs(t *testing.T, imageVersionUID string, overrideFile string) *afero.Afero {
	fs := afero.NewMemMapFs()
	if overrideFile != "" {
		must(t, afero.WriteFile(fs, imageVersionUID+".json", []byte(overrideFile), os.ModePerm))
	}
	return &afero.Afero{Fs: fs}
}

const lut = `{"qemu":{"default":"someReference"}}`