/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package crypto import ( "testing" "github.com/edgelesssys/constellation/v2/internal/crypto/testvector" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m) } func TestDeriveKey(t *testing.T) { assert := assert.New(t) require := require.New(t) key, err := DeriveKey([]byte("secret"), []byte("salt"), nil, 32) assert.NoError(err) assert.Len(key, 32) key1, err := DeriveKey([]byte("secret"), []byte("salt"), []byte("first"), 32) require.NoError(err) key2, err := DeriveKey([]byte("secret"), []byte("salt"), []byte("first"), 32) require.NoError(err) assert.Equal(key1, key2) key3, err := DeriveKey([]byte("secret"), []byte("salt"), []byte("second"), 32) require.NoError(err) assert.NotEqual(key1, key3) zeroInput := testvector.HKDFZero out, err := DeriveKey(zeroInput.Secret, zeroInput.Salt, []byte(zeroInput.InfoPrefix+zeroInput.Info), zeroInput.Length) require.NoError(err) assert.Equal(zeroInput.Output, out) fInput := testvector.HKDF0xFF out, err = DeriveKey(fInput.Secret, fInput.Salt, []byte(fInput.InfoPrefix+fInput.Info), fInput.Length) require.NoError(err) assert.Equal(fInput.Output, out) } func TestVectorsHKDF(t *testing.T) { testCases := map[string]struct { secret []byte salt []byte info []byte length uint wantKey []byte }{ "rfc Test Case 1": { secret: testvector.HKDFrfc1.Secret, salt: testvector.HKDFrfc1.Salt, info: []byte(testvector.HKDFrfc1.Info), length: testvector.HKDFrfc1.Length, wantKey: testvector.HKDFrfc1.Output, }, "rfc Test Case 2": { secret: testvector.HKDFrfc2.Secret, salt: testvector.HKDFrfc2.Salt, info: []byte(testvector.HKDFrfc2.Info), length: testvector.HKDFrfc2.Length, wantKey: testvector.HKDFrfc2.Output, }, "rfc Test Case 3": { secret: testvector.HKDFrfc3.Secret, salt: testvector.HKDFrfc3.Salt, info: []byte(testvector.HKDFrfc3.Info), length: testvector.HKDFrfc3.Length, wantKey: testvector.HKDFrfc3.Output, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) out, err := DeriveKey(tc.secret, tc.salt, tc.info, tc.length) require.NoError(err) assert.Equal(tc.wantKey, out) }) } } func TestGenerateCertificateSerialNumber(t *testing.T) { assert := assert.New(t) require := require.New(t) s1, err := GenerateCertificateSerialNumber() require.NoError(err) s2, err := GenerateCertificateSerialNumber() require.NoError(err) assert.NotEqual(s1, s2) } func TestGenerateRandomBytes(t *testing.T) { assert := assert.New(t) require := require.New(t) n1, err := GenerateRandomBytes(32) require.NoError(err) assert.Len(n1, 32) n2, err := GenerateRandomBytes(32) require.NoError(err) assert.Equal(len(n1), len(n2)) assert.NotEqual(n1, n2) n3, err := GenerateRandomBytes(16) require.NoError(err) assert.Len(n3, 16) }