/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package setup import ( "context" "encoding/base64" "fmt" "net/url" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) const constellationKekID = "Constellation" func TestMain(m *testing.M) { goleak.VerifyTestMain(m, // https://github.com/census-instrumentation/opencensus-go/issues/1262 goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), ) } func TestGetStore(t *testing.T) { testCases := map[string]struct { uri string wantErr bool }{ "no store": { uri: NoStoreURI, wantErr: false, }, "aws s3": { uri: fmt.Sprintf(AWSS3URI, ""), wantErr: true, }, "azure blob": { uri: fmt.Sprintf(AzureBlobURI, "", ""), wantErr: true, }, "gcp storage": { uri: fmt.Sprintf(GCPStorageURI, "", ""), wantErr: true, }, "unknown store": { uri: "storage://unknown", wantErr: true, }, "invalid scheme": { uri: ClusterKMSURI, wantErr: true, }, "not a url": { uri: ":/123", wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) _, err := getStore(context.Background(), tc.uri) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) } }) } } func TestGetKMS(t *testing.T) { testCases := map[string]struct { uri string wantErr bool }{ "cluster kms": { uri: fmt.Sprintf(ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("key")), base64.URLEncoding.EncodeToString([]byte("salt"))), wantErr: false, }, "aws kms": { uri: fmt.Sprintf(AWSKMSURI, "", ""), wantErr: true, }, "azure kms": { uri: fmt.Sprintf(AzureKMSURI, "", "", ""), wantErr: true, }, "azure hsm": { uri: fmt.Sprintf(AzureHSMURI, "", ""), wantErr: true, }, "gcp kms": { uri: fmt.Sprintf(GCPKMSURI, "", "", "", "", ""), wantErr: true, }, "unknown kms": { uri: "kms://unknown", wantErr: true, }, "invalid scheme": { uri: NoStoreURI, wantErr: true, }, "not a url": { uri: ":/123", wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) kms, err := getKMS(context.Background(), tc.uri, nil) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) assert.NotNil(kms) } }) } } func TestSetUpKMS(t *testing.T) { assert := assert.New(t) kms, err := KMS(context.Background(), "storage://unknown", "kms://unknown") assert.Error(err) assert.Nil(kms) masterSecret := MasterSecret{Key: []byte("key"), Salt: []byte("salt")} kms, err = KMS(context.Background(), "storage://no-store", masterSecret.EncodeToURI()) assert.NoError(err) assert.NotNil(kms) } func TestGetAWSKMSConfig(t *testing.T) { assert := assert.New(t) require := require.New(t) policy := "{keyPolicy: keyPolicy}" escapedPolicy := url.QueryEscape(policy) kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID)) uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy, kekID)) require.NoError(err) policyProducer, rKekID, err := getAWSKMSConfig(uri) require.NoError(err) keyPolicy, err := policyProducer.CreateKeyPolicy("") require.NoError(err) assert.Equal(policy, keyPolicy) assert.Equal(constellationKekID, rKekID) } func TestGetAzureBlobConfig(t *testing.T) { assert := assert.New(t) require := require.New(t) connStr := "DefaultEndpointsProtocol=https;AccountName=test;AccountKey=Q29uc3RlbGxhdGlvbg==;EndpointSuffix=core.windows.net" escapedConnStr := url.QueryEscape(connStr) container := "test" uri, err := url.Parse(fmt.Sprintf(AzureBlobURI, container, escapedConnStr)) require.NoError(err) rContainer, rConnStr, err := getAzureBlobConfig(uri) require.NoError(err) assert.Equal(container, rContainer) assert.Equal(connStr, rConnStr) } func TestGetGCPKMSConfig(t *testing.T) { assert := assert.New(t) require := require.New(t) project := "test-project" location := "global" keyRing := "test-ring" protectionLvl := "2" kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID)) uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl, kekID)) require.NoError(err) rProject, rLocation, rKeyRing, rProtectionLvl, rKekID, err := getGCPKMSConfig(uri) require.NoError(err) assert.Equal(project, rProject) assert.Equal(location, rLocation) assert.Equal(keyRing, rKeyRing) assert.Equal(int32(2), rProtectionLvl) assert.Equal(constellationKekID, rKekID) uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid", kekID)) require.NoError(err) _, _, _, _, _, err = getGCPKMSConfig(uri) assert.Error(err) } func TestGetClusterKMSConfig(t *testing.T) { assert := assert.New(t) require := require.New(t) expectedSalt := []byte{ 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf, } masterSecretIn := MasterSecret{Key: []byte("key"), Salt: expectedSalt} uri, err := url.Parse(masterSecretIn.EncodeToURI()) require.NoError(err) masterSecretOut, err := getClusterKMSConfig(uri) assert.NoError(err) assert.Equal(expectedSalt, masterSecretOut.Salt) } func TestGetConfig(t *testing.T) { const testURI = "test://config?name=test-name&data=test-data&value=test-value" testCases := map[string]struct { uri string keys []string wantErr bool }{ "success": { uri: testURI, keys: []string{"name", "data", "value"}, wantErr: false, }, "less keys than capture groups": { uri: testURI, keys: []string{"name", "data"}, wantErr: false, }, "invalid regex": { uri: testURI, keys: []string{"name", "data", "test-value"}, wantErr: true, }, "missing value": { uri: "test://config?name=test-name&data=test-data&value", keys: []string{"name", "data", "value"}, wantErr: true, }, "more keys than expected": { uri: testURI, keys: []string{"name", "data", "value", "anotherValue"}, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) uri, err := url.Parse(tc.uri) require.NoError(err) res, err := getConfig(uri.Query(), tc.keys) if tc.wantErr { assert.Error(err) assert.Len(res, len(tc.keys)) } else { assert.NoError(err) require.Len(res, len(tc.keys)) for i := range tc.keys { assert.NotEmpty(res[i]) } } }) } }