/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package setup

import (
	"context"
	"encoding/base64"
	"fmt"
	"net/url"
	"testing"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m,
		// https://github.com/census-instrumentation/opencensus-go/issues/1262
		goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
	)
}

func TestGetStore(t *testing.T) {
	testCases := map[string]struct {
		uri     string
		wantErr bool
	}{
		"no store": {
			uri:     NoStoreURI,
			wantErr: false,
		},
		"aws s3": {
			uri:     fmt.Sprintf(AWSS3URI, ""),
			wantErr: true,
		},
		"azure blob": {
			uri:     fmt.Sprintf(AzureBlobURI, "", ""),
			wantErr: true,
		},
		"gcp storage": {
			uri:     fmt.Sprintf(GCPStorageURI, "", ""),
			wantErr: true,
		},
		"unknown store": {
			uri:     "storage://unknown",
			wantErr: true,
		},
		"invalid scheme": {
			uri:     ClusterKMSURI,
			wantErr: true,
		},
		"not a url": {
			uri:     ":/123",
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			_, err := getStore(context.Background(), tc.uri)
			if tc.wantErr {
				assert.Error(err)
			} else {
				assert.NoError(err)
			}
		})
	}
}

func TestGetKMS(t *testing.T) {
	testCases := map[string]struct {
		uri     string
		wantErr bool
	}{
		"cluster kms": {
			uri:     fmt.Sprintf("%s?salt=%s", ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("salt"))),
			wantErr: false,
		},
		"aws kms": {
			uri:     fmt.Sprintf(AWSKMSURI, ""),
			wantErr: true,
		},
		"azure kms": {
			uri:     fmt.Sprintf(AzureKMSURI, "", ""),
			wantErr: true,
		},
		"azure hsm": {
			uri:     fmt.Sprintf(AzureHSMURI, ""),
			wantErr: true,
		},
		"gcp kms": {
			uri:     fmt.Sprintf(GCPKMSURI, "", "", "", ""),
			wantErr: true,
		},
		"unknown kms": {
			uri:     "kms://unknown",
			wantErr: true,
		},
		"invalid scheme": {
			uri:     NoStoreURI,
			wantErr: true,
		},
		"not a url": {
			uri:     ":/123",
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			kms, err := getKMS(context.Background(), tc.uri, nil)
			if tc.wantErr {
				assert.Error(err)
			} else {
				assert.NoError(err)
				assert.NotNil(kms)
			}
		})
	}
}

func TestSetUpKMS(t *testing.T) {
	assert := assert.New(t)

	kms, err := KMS(context.Background(), "storage://unknown", "kms://unknown")
	assert.Error(err)
	assert.Nil(kms)

	kms, err = KMS(context.Background(), "storage://no-store", "kms://cluster-kms?salt="+base64.URLEncoding.EncodeToString([]byte("salt")))
	assert.NoError(err)
	assert.NotNil(kms)
}

func TestGetAWSKMSConfig(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	policy := "{keyPolicy: keyPolicy}"
	escapedPolicy := url.QueryEscape(policy)
	uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy))
	require.NoError(err)
	policyProducer, err := getAWSKMSConfig(uri)
	require.NoError(err)
	keyPolicy, err := policyProducer.CreateKeyPolicy("")
	require.NoError(err)
	assert.Equal(policy, keyPolicy)
}

func TestGetAzureBlobConfig(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	connStr := "DefaultEndpointsProtocol=https;AccountName=test;AccountKey=Q29uc3RlbGxhdGlvbg==;EndpointSuffix=core.windows.net"
	escapedConnStr := url.QueryEscape(connStr)
	container := "test"
	uri, err := url.Parse(fmt.Sprintf(AzureBlobURI, container, escapedConnStr))
	require.NoError(err)
	rContainer, rConnStr, err := getAzureBlobConfig(uri)
	require.NoError(err)
	assert.Equal(container, rContainer)
	assert.Equal(connStr, rConnStr)
}

func TestGetGCPKMSConfig(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	project := "test-project"
	location := "global"
	keyRing := "test-ring"
	protectionLvl := "2"
	uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl))
	require.NoError(err)
	rProject, rLocation, rKeyRing, rProtectionLvl, err := getGCPKMSConfig(uri)
	require.NoError(err)
	assert.Equal(project, rProject)
	assert.Equal(location, rLocation)
	assert.Equal(keyRing, rKeyRing)
	assert.Equal(2, rProtectionLvl)

	uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid"))
	require.NoError(err)
	_, _, _, _, err = getGCPKMSConfig(uri)
	assert.Error(err)
}

func TestGetClusterKMSConfig(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	expectedSalt := []byte{
		0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
		0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
	}

	uri, err := url.Parse(ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(expectedSalt))
	require.NoError(err)

	salt, err := getClusterKMSConfig(uri)
	assert.NoError(err)
	assert.Equal(expectedSalt, salt)
}

func TestGetConfig(t *testing.T) {
	const testURI = "test://config?name=test-name&data=test-data&value=test-value"

	testCases := map[string]struct {
		uri     string
		keys    []string
		wantErr bool
	}{
		"success": {
			uri:     testURI,
			keys:    []string{"name", "data", "value"},
			wantErr: false,
		},
		"less keys than capture groups": {
			uri:     testURI,
			keys:    []string{"name", "data"},
			wantErr: false,
		},
		"invalid regex": {
			uri:     testURI,
			keys:    []string{"name", "data", "test-value"},
			wantErr: true,
		},
		"missing value": {
			uri:     "test://config?name=test-name&data=test-data&value",
			keys:    []string{"name", "data", "value"},
			wantErr: true,
		},
		"more keys than expected": {
			uri:     testURI,
			keys:    []string{"name", "data", "value", "anotherValue"},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			uri, err := url.Parse(tc.uri)
			require.NoError(err)

			res, err := getConfig(uri.Query(), tc.keys)
			if tc.wantErr {
				assert.Error(err)
				assert.Len(res, len(tc.keys))
			} else {
				assert.NoError(err)
				require.Len(res, len(tc.keys))
				for i := range tc.keys {
					assert.NotEmpty(res[i])
				}
			}
		})
	}
}