Refactor init/recovery to use kms URI

So far the masterSecret was sent to the initial bootstrapper
on init/recovery. With this commit this information is encoded
in the kmsURI that is sent during init.
For recover, the communication with the recoveryserver is
changed. Before a streaming gRPC call was used to
exchanges UUID for measurementSecret and state disk key.
Now a standard gRPC is made that includes the same kmsURI &
storageURI that are sent during init.
This commit is contained in:
Otto Bittner 2023-01-16 11:19:03 +01:00
parent 0e71322e2e
commit 9a1f52e94e
35 changed files with 466 additions and 623 deletions

View file

@ -56,13 +56,14 @@ type KMSClient struct {
awsClient ClientAPI
policyProducer KeyPolicyProducer
storage kmsInterface.Storage
kekID string
}
// New creates and initializes a new KMSClient for AWS.
//
// The parameter client needs to be initialized with valid AWS credentials (https://aws.github.io/aws-sdk-go-v2/docs/getting-started).
// If storage is nil, the default MemMapStorage is used.
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterface.Storage, kekID string, optFns ...func(*awsconfig.LoadOptions) error) (*KMSClient, error) {
if store == nil {
store = storage.NewMemMapStorage()
}
@ -77,6 +78,7 @@ func New(ctx context.Context, policyProducer KeyPolicyProducer, store kmsInterfa
awsClient: client,
policyProducer: policyProducer,
storage: store,
kekID: kekID,
}, nil
}
@ -206,9 +208,9 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK returns the DEK for dekID and kekID from the KMS.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
// The KEK should be identified by its alias. The alias always has the same scheme: 'alias/<kekId>'
kekID = "alias/" + kekID
kekID := "alias/" + c.kekID
// If a key for keyID exists in the storage, decrypt the key using the KEK.
dek, err := c.decryptDEKFromStorage(ctx, kekID, keyID)

View file

@ -253,22 +253,23 @@ func TestAWSKMSClient(t *testing.T) {
awsClient := &fakeAWSClient{kekPool: make(map[string][]byte, 2)}
client := &KMSClient{
awsClient: awsClient,
policyProducer: &stubKeyPolicyProducer{},
storage: storage.NewMemMapStorage(),
}
awsClient.keyIDCount = -1
testKEK1ID := "testKEK1"
testKEK1 := []byte("test KEK")
testKEK2ID := "testKEK2"
testKEK2 := []byte("more test KEK")
client := &KMSClient{
awsClient: awsClient,
policyProducer: &stubKeyPolicyProducer{},
storage: storage.NewMemMapStorage(),
kekID: testKEK1ID,
}
awsClient.keyIDCount = -1
ctx := context.Background()
// try to get a DEK before setting the KEK
_, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
_, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.Error(err)
assert.ErrorIs(err, kmsInterface.ErrKEKUnknown)
@ -285,16 +286,16 @@ func TestAWSKMSClient(t *testing.T) {
assert.Equal(testKEK2, awsClient.kekPool[strconv.Itoa(awsClient.keyIDCount)])
// test GetDEK method
dek1, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
dek1, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.NoError(err)
dek2, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength)
dek2, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
assert.NoError(err)
// make sure that GetDEK is idempotent
dek1Copy, err := client.GetDEK(ctx, testKEK1ID, "volume01", config.SymmetricKeyLength)
dek1Copy, err := client.GetDEK(ctx, "volume01", config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(dek1, dek1Copy)
dek2Copy, err := client.GetDEK(ctx, testKEK2ID, "volume02", config.SymmetricKeyLength)
dek2Copy, err := client.GetDEK(ctx, "volume02", config.SymmetricKeyLength)
assert.NoError(err)
assert.Equal(dek2, dek2Copy)
}

View file

@ -47,6 +47,7 @@ type kmsClientAPI interface {
type KMSClient struct {
client kmsClientAPI
storage kms.Storage
kekID string
}
// Opts are optional settings for AKV clients.
@ -57,7 +58,7 @@ type Opts struct {
}
// New initializes a KMS client for Azure Key Vault.
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, opts *Opts) (*KMSClient, error) {
func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms.Storage, kekID string, opts *Opts) (*KMSClient, error) {
if opts == nil {
opts = &Opts{}
}
@ -80,7 +81,7 @@ func New(ctx context.Context, vaultName string, vaultType VaultSuffix, store kms
if store == nil {
store = storage.NewMemMapStorage()
}
return &KMSClient{client: client, storage: store}, nil
return &KMSClient{client: client, storage: store, kekID: kekID}, nil
}
// CreateKEK saves a new Key Encryption Key using Azure Key Vault.
@ -111,8 +112,8 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK decrypts a DEK from storage.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
kek, err := c.getKEK(ctx, kekID)
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
kek, err := c.getKEK(ctx, c.kekID)
if err != nil {
return nil, fmt.Errorf("loading KEK from key vault: %w", err)
}

View file

@ -155,7 +155,7 @@ func TestKMSGetDEK(t *testing.T) {
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {

View file

@ -38,10 +38,11 @@ type HSMClient struct {
client hsmClientAPI
storage kms.Storage
vaultURL string
kekID string
}
// NewHSM initializes a KMS client for Azure manged HSM Key Vault.
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts) (*HSMClient, error) {
func NewHSM(ctx context.Context, vaultName string, store kms.Storage, kekID string, opts *Opts) (*HSMClient, error) {
if opts == nil {
opts = &Opts{}
}
@ -72,6 +73,7 @@ func NewHSM(ctx context.Context, vaultName string, store kms.Storage, opts *Opts
client: client,
credentials: cred,
storage: store,
kekID: kekID,
}, nil
}
@ -114,7 +116,7 @@ func (c *HSMClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK loads an encrypted DEK from storage and unwraps it using an HSM-backed key.
func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekSize int) ([]byte, error) {
func (c *HSMClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
encryptedDEK, err := c.storage.Get(ctx, keyID)
if err != nil {
if !errors.Is(err, storage.ErrDEKUnset) {
@ -126,7 +128,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
if err != nil {
return nil, fmt.Errorf("key generation: %w", err)
}
if err := c.putDEK(ctx, kekID, keyID, newDEK); err != nil {
if err := c.putDEK(ctx, c.kekID, keyID, newDEK); err != nil {
return nil, fmt.Errorf("creating new DEK: %w", err)
}
@ -137,7 +139,7 @@ func (c *HSMClient) GetDEK(ctx context.Context, kekID string, keyID string, dekS
Algorithm: to.Ptr(azkeys.JSONWebKeyEncryptionAlgorithmA256KW),
Value: encryptedDEK,
}
res, err := c.client.UnwrapKey(ctx, kekID, "", params, &azkeys.UnwrapKeyOptions{})
res, err := c.client.UnwrapKey(ctx, c.kekID, "", params, &azkeys.UnwrapKeyOptions{})
if err != nil {
return nil, fmt.Errorf("unwrapping key: %w", err)
}

View file

@ -165,7 +165,7 @@ func TestHSMGetNewDEK(t *testing.T) {
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {
@ -208,7 +208,7 @@ func TestHSMGetExistingDEK(t *testing.T) {
storage: storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", keyID, len(testKey))
dek, err := client.GetDEK(context.Background(), keyID, len(testKey))
if tc.wantErr {
assert.Error(err)
} else {

View file

@ -20,8 +20,15 @@ type KMS struct {
}
// New creates a new ClusterKMS.
func New(salt []byte) *KMS {
return &KMS{salt: salt}
func New(key []byte, salt []byte) (*KMS, error) {
if len(key) == 0 {
return nil, errors.New("missing master key")
}
if len(salt) == 0 {
return nil, errors.New("missing salt")
}
return &KMS{masterKey: key, salt: salt}, nil
}
// CreateKEK sets the ClusterKMS masterKey.
@ -31,7 +38,7 @@ func (c *KMS) CreateKEK(ctx context.Context, keyID string, kek []byte) error {
}
// GetDEK derives a key from the KMS masterKey.
func (c *KMS) GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error) {
func (c *KMS) GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error) {
if len(c.masterKey) == 0 {
return nil, errors.New("master key not set for Constellation KMS")
}

View file

@ -24,19 +24,12 @@ func TestMain(m *testing.M) {
func TestClusterKMS(t *testing.T) {
testVector := testvector.HKDF0xFF
assert := assert.New(t)
kms := New(testVector.Salt)
key, err := kms.GetDEK(context.Background(), "", "key-1", 32)
assert.Error(err)
assert.Nil(key)
err = kms.CreateKEK(context.Background(), "", testVector.Secret)
assert.NoError(err)
assert.Equal(testVector.Secret, kms.masterKey)
require := require.New(t)
kms, err := New(testVector.Secret, testVector.Salt)
require.NoError(err)
keyLower, err := kms.GetDEK(
context.Background(),
"",
strings.ToLower(testVector.InfoPrefix+testVector.Info),
int(testVector.Length),
)
@ -46,12 +39,11 @@ func TestClusterKMS(t *testing.T) {
// output of the KMS should be case sensitive
keyUpper, err := kms.GetDEK(
context.Background(),
"",
strings.ToUpper(testVector.InfoPrefix+testVector.Info),
int(testVector.Length),
)
assert.NoError(err)
assert.NotEqual(key, keyUpper)
assert.NotEqual(keyLower, keyUpper)
}
func TestVectorsHKDF(t *testing.T) {
@ -61,6 +53,7 @@ func TestVectorsHKDF(t *testing.T) {
dekID string
dekSize uint
wantKey []byte
wantErr bool
}{
"rfc Test Case 1": {
kek: testvector.HKDFrfc1.Secret,
@ -82,6 +75,7 @@ func TestVectorsHKDF(t *testing.T) {
dekID: testvector.HKDFrfc3.Info,
dekSize: testvector.HKDFrfc3.Length,
wantKey: testvector.HKDFrfc3.Output,
wantErr: true,
},
"HKDF zero": {
kek: testvector.HKDFZero.Secret,
@ -104,10 +98,15 @@ func TestVectorsHKDF(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
kms := New(tc.salt)
kms, err := New(tc.kek, tc.salt)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
require.NoError(kms.CreateKEK(context.Background(), "", tc.kek))
out, err := kms.GetDEK(context.Background(), "", tc.dekID, int(tc.dekSize))
out, err := kms.GetDEK(context.Background(), tc.dekID, int(tc.dekSize))
require.NoError(err)
assert.Equal(tc.wantKey, out)
})

View file

@ -50,11 +50,12 @@ type KMSClient struct {
waitBackoffLimit int
storage kmsInterface.Storage
protectionLevel kmspb.ProtectionLevel
kekID string
opts []gax.CallOption
}
// New initializes a KMS client for Google Cloud Platform.
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, opts ...gax.CallOption) (*KMSClient, error) {
func New(ctx context.Context, projectID, locationID, keyRingID string, store kmsInterface.Storage, protectionLvl kmspb.ProtectionLevel, kekID string, opts ...gax.CallOption) (*KMSClient, error) {
if store == nil {
store = storage.NewMemMapStorage()
}
@ -71,6 +72,7 @@ func New(ctx context.Context, projectID, locationID, keyRingID string, store kms
waitBackoffLimit: 10,
storage: store,
protectionLevel: protectionLvl,
kekID: kekID,
opts: opts,
}
@ -108,7 +110,7 @@ func (c *KMSClient) CreateKEK(ctx context.Context, keyID string, key []byte) err
}
// GetDEK fetches an encrypted Data Encryption Key from storage and decrypts it using a KEK stored in Google's KMS.
func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int) ([]byte, error) {
func (c *KMSClient) GetDEK(ctx context.Context, keyID string, dekSize int) ([]byte, error) {
client, err := c.newClient(ctx)
if err != nil {
return nil, err
@ -126,11 +128,11 @@ func (c *KMSClient) GetDEK(ctx context.Context, kekID, keyID string, dekSize int
if err != nil {
return nil, fmt.Errorf("key generation: %w", err)
}
return newDEK, c.putDEK(ctx, client, kekID, keyID, newDEK)
return newDEK, c.putDEK(ctx, client, c.kekID, keyID, newDEK)
}
request := &kmspb.DecryptRequest{
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, kekID),
Name: fmt.Sprintf("projects/%s/locations/%s/keyRings/%s/cryptoKeys/%s", c.projectID, c.locationID, c.keyRingID, c.kekID),
Ciphertext: encryptedDEK,
}

View file

@ -321,7 +321,7 @@ func TestGetDEK(t *testing.T) {
storage: tc.storage,
}
dek, err := client.GetDEK(context.Background(), "test-key", "volume-01", 32)
dek, err := client.GetDEK(context.Background(), "volume-01", 32)
if tc.wantErr {
assert.Error(err)
} else {

View file

@ -17,7 +17,7 @@ type CloudKMS interface {
CreateKEK(ctx context.Context, keyID string, kek []byte) error
// GetDEK returns the DEK for dekID and kekID from the KMS.
// If the DEK does not exist, a new one is created and saved to storage.
GetDEK(ctx context.Context, kekID string, dekID string, dekSize int) ([]byte, error)
GetDEK(ctx context.Context, dekID string, dekSize int) ([]byte, error)
}
// Storage provides an abstract interface for the storage backend used for DEKs.

View file

@ -24,10 +24,10 @@ import (
// Well known endpoints for KMS services.
const (
AWSKMSURI = "kms://aws?keyPolicy=%s"
AzureKMSURI = "kms://azure-kms?name=%s&type=%s"
AzureHSMURI = "kms://azure-hsm?name=%s"
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s"
AWSKMSURI = "kms://aws?keyPolicy=%s&kekID=%s"
AzureKMSURI = "kms://azure-kms?name=%s&type=%s&kekID=%s"
AzureHSMURI = "kms://azure-hsm?name=%s&kekID=%s"
GCPKMSURI = "kms://gcp?project=%s&location=%s&keyRing=%s&protectionLvl=%s&kekID=%s"
ClusterKMSURI = "kms://cluster-kms?key=%s&salt=%s"
AWSS3URI = "storage://aws?bucket=%s"
AzureBlobURI = "storage://azure?container=%s&connectionString=%s"
@ -35,6 +35,21 @@ const (
NoStoreURI = "storage://no-store"
)
// MasterSecret holds the master key and salt for deriving keys.
type MasterSecret struct {
Key []byte `json:"key"`
Salt []byte `json:"salt"`
}
// EncodeToURI returns an URI encoding the master secret.
func (m *MasterSecret) EncodeToURI() string {
return fmt.Sprintf(
ClusterKMSURI,
base64.URLEncoding.EncodeToString(m.Key),
base64.URLEncoding.EncodeToString(m.Salt),
)
}
// KMSInformation about an existing KMS.
type KMSInformation struct {
KMSURI string
@ -104,39 +119,39 @@ func getKMS(ctx context.Context, kmsURI string, store kms.Storage) (kms.CloudKMS
switch uri.Host {
case "aws":
poliyProducer, err := getAWSKMSConfig(uri)
poliyProducer, kekID, err := getAWSKMSConfig(uri)
if err != nil {
return nil, err
}
return aws.New(ctx, poliyProducer, store)
return aws.New(ctx, poliyProducer, store, kekID)
case "azure-kms":
vaultName, vaultType, err := getAzureKMSConfig(uri)
vaultName, vaultType, kekID, err := getAzureKMSConfig(uri)
if err != nil {
return nil, err
}
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, nil)
return azure.New(ctx, vaultName, azure.VaultSuffix(vaultType), store, kekID, nil)
case "azure-hsm":
vaultName, err := getAzureHSMConfig(uri)
vaultName, kekID, err := getAzureHSMConfig(uri)
if err != nil {
return nil, err
}
return azure.NewHSM(ctx, vaultName, store, nil)
return azure.NewHSM(ctx, vaultName, store, kekID, nil)
case "gcp":
project, location, keyRing, protectionLvl, err := getGCPKMSConfig(uri)
project, location, keyRing, protectionLvl, kekID, err := getGCPKMSConfig(uri)
if err != nil {
return nil, err
}
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl))
return gcp.New(ctx, project, location, keyRing, store, kmspb.ProtectionLevel(protectionLvl), kekID)
case "cluster-kms":
salt, err := getClusterKMSConfig(uri)
masterSecret, err := getClusterKMSConfig(uri)
if err != nil {
return nil, err
}
return cluster.New(salt), nil
return cluster.New(masterSecret.Key, masterSecret.Salt)
default:
return nil, fmt.Errorf("unknown KMS type: %s", uri.Host)
@ -156,22 +171,56 @@ func getAWSS3Config(uri *url.URL) (string, error) {
return r[0], err
}
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, error) {
r, err := getConfig(uri.Query(), []string{"keyPolicy"})
func getAWSKMSConfig(uri *url.URL) (*defaultPolicyProducer, string, error) {
r, err := getConfig(uri.Query(), []string{"keyPolicy", "kekID"})
if err != nil {
return nil, err
return nil, "", err
}
return &defaultPolicyProducer{policy: r[0]}, err
if len(r) != 2 {
return nil, "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return nil, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return &defaultPolicyProducer{policy: r[0]}, string(kekID), err
}
func getAzureKMSConfig(uri *url.URL) (string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "type"})
return r[0], r[1], err
func getAzureKMSConfig(uri *url.URL) (string, string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "type", "kekID"})
if err != nil {
return "", "", "", fmt.Errorf("getting config: %w", err)
}
if len(r) != 3 {
return "", "", "", fmt.Errorf("expected 3 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[2])
if err != nil {
return "", "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], r[1], string(kekID), err
}
func getAzureHSMConfig(uri *url.URL) (string, error) {
r, err := getConfig(uri.Query(), []string{"name"})
return r[0], err
func getAzureHSMConfig(uri *url.URL) (string, string, error) {
r, err := getConfig(uri.Query(), []string{"name", "kekID"})
if err != nil {
return "", "", fmt.Errorf("getting config: %w", err)
}
if len(r) != 2 {
return "", "", fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
kekID, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return "", "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], string(kekID), err
}
func getAzureBlobConfig(uri *url.URL) (string, string, error) {
@ -182,16 +231,26 @@ func getAzureBlobConfig(uri *url.URL) (string, string, error) {
return r[0], r[1], nil
}
func getGCPKMSConfig(uri *url.URL) (string, string, string, int32, error) {
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl"})
func getGCPKMSConfig(uri *url.URL) (project string, location string, keyRing string, protectionLvl int32, kekID string, err error) {
r, err := getConfig(uri.Query(), []string{"project", "location", "keyRing", "protectionLvl", "kekID"})
if err != nil {
return "", "", "", 0, err
return "", "", "", 0, "", err
}
protectionLvl, err := strconv.ParseInt(r[3], 10, 32)
if len(r) != 5 {
return "", "", "", 0, "", fmt.Errorf("expected 5 KmsURI args, got %d", len(r))
}
kekIDByte, err := base64.URLEncoding.DecodeString(r[4])
if err != nil {
return "", "", "", 0, err
return "", "", "", 0, "", fmt.Errorf("parsing kekID from kmsUri: %w", err)
}
return r[0], r[1], r[2], int32(protectionLvl), nil
protectionLvl32, err := strconv.ParseInt(r[3], 10, 32)
if err != nil {
return "", "", "", 0, "", err
}
return r[0], r[1], r[2], int32(protectionLvl32), string(kekIDByte), nil
}
func getGCPStorageConfig(uri *url.URL) (string, string, error) {
@ -199,12 +258,26 @@ func getGCPStorageConfig(uri *url.URL) (string, string, error) {
return r[0], r[1], err
}
func getClusterKMSConfig(uri *url.URL) ([]byte, error) {
r, err := getConfig(uri.Query(), []string{"salt"})
func getClusterKMSConfig(uri *url.URL) (MasterSecret, error) {
r, err := getConfig(uri.Query(), []string{"key", "salt"})
if err != nil {
return nil, err
return MasterSecret{}, err
}
return base64.URLEncoding.DecodeString(r[0])
if len(r) != 2 {
return MasterSecret{}, fmt.Errorf("expected 2 KmsURI args, got %d", len(r))
}
key, err := base64.URLEncoding.DecodeString(r[0])
if err != nil {
return MasterSecret{}, fmt.Errorf("parsing key from kmsUri: %w", err)
}
salt, err := base64.URLEncoding.DecodeString(r[1])
if err != nil {
return MasterSecret{}, fmt.Errorf("parsing salt from kmsUri: %w", err)
}
return MasterSecret{Key: key, Salt: salt}, nil
}
// getConfig parses url query values, returning a map of the requested values.

View file

@ -18,6 +18,8 @@ import (
"go.uber.org/goleak"
)
const constellationKekID = "Constellation"
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
@ -80,23 +82,23 @@ func TestGetKMS(t *testing.T) {
wantErr bool
}{
"cluster kms": {
uri: fmt.Sprintf("%s?salt=%s", ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("salt"))),
uri: fmt.Sprintf(ClusterKMSURI, base64.URLEncoding.EncodeToString([]byte("key")), base64.URLEncoding.EncodeToString([]byte("salt"))),
wantErr: false,
},
"aws kms": {
uri: fmt.Sprintf(AWSKMSURI, ""),
uri: fmt.Sprintf(AWSKMSURI, "", ""),
wantErr: true,
},
"azure kms": {
uri: fmt.Sprintf(AzureKMSURI, "", ""),
uri: fmt.Sprintf(AzureKMSURI, "", "", ""),
wantErr: true,
},
"azure hsm": {
uri: fmt.Sprintf(AzureHSMURI, ""),
uri: fmt.Sprintf(AzureHSMURI, "", ""),
wantErr: true,
},
"gcp kms": {
uri: fmt.Sprintf(GCPKMSURI, "", "", "", ""),
uri: fmt.Sprintf(GCPKMSURI, "", "", "", "", ""),
wantErr: true,
},
"unknown kms": {
@ -135,7 +137,8 @@ func TestSetUpKMS(t *testing.T) {
assert.Error(err)
assert.Nil(kms)
kms, err = KMS(context.Background(), "storage://no-store", "kms://cluster-kms?salt="+base64.URLEncoding.EncodeToString([]byte("salt")))
masterSecret := MasterSecret{Key: []byte("key"), Salt: []byte("salt")}
kms, err = KMS(context.Background(), "storage://no-store", masterSecret.EncodeToURI())
assert.NoError(err)
assert.NotNil(kms)
}
@ -146,13 +149,15 @@ func TestGetAWSKMSConfig(t *testing.T) {
policy := "{keyPolicy: keyPolicy}"
escapedPolicy := url.QueryEscape(policy)
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy))
kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
uri, err := url.Parse(fmt.Sprintf(AWSKMSURI, escapedPolicy, kekID))
require.NoError(err)
policyProducer, err := getAWSKMSConfig(uri)
policyProducer, rKekID, err := getAWSKMSConfig(uri)
require.NoError(err)
keyPolicy, err := policyProducer.CreateKeyPolicy("")
require.NoError(err)
assert.Equal(policy, keyPolicy)
assert.Equal(constellationKekID, rKekID)
}
func TestGetAzureBlobConfig(t *testing.T) {
@ -178,18 +183,20 @@ func TestGetGCPKMSConfig(t *testing.T) {
location := "global"
keyRing := "test-ring"
protectionLvl := "2"
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl))
kekID := base64.URLEncoding.EncodeToString([]byte(constellationKekID))
uri, err := url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, protectionLvl, kekID))
require.NoError(err)
rProject, rLocation, rKeyRing, rProtectionLvl, err := getGCPKMSConfig(uri)
rProject, rLocation, rKeyRing, rProtectionLvl, rKekID, err := getGCPKMSConfig(uri)
require.NoError(err)
assert.Equal(project, rProject)
assert.Equal(location, rLocation)
assert.Equal(keyRing, rKeyRing)
assert.Equal(int32(2), rProtectionLvl)
assert.Equal(constellationKekID, rKekID)
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid"))
uri, err = url.Parse(fmt.Sprintf(GCPKMSURI, project, location, keyRing, "invalid", kekID))
require.NoError(err)
_, _, _, _, err = getGCPKMSConfig(uri)
_, _, _, _, _, err = getGCPKMSConfig(uri)
assert.Error(err)
}
@ -202,12 +209,13 @@ func TestGetClusterKMSConfig(t *testing.T) {
0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf,
}
uri, err := url.Parse(ClusterKMSURI + "?salt=" + base64.URLEncoding.EncodeToString(expectedSalt))
masterSecretIn := MasterSecret{Key: []byte("key"), Salt: expectedSalt}
uri, err := url.Parse(masterSecretIn.EncodeToURI())
require.NoError(err)
salt, err := getClusterKMSConfig(uri)
masterSecretOut, err := getClusterKMSConfig(uri)
assert.NoError(err)
assert.Equal(expectedSalt, salt)
assert.Equal(expectedSalt, masterSecretOut.Salt)
}
func TestGetConfig(t *testing.T) {