/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package cmd import ( "bytes" "context" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/json" "errors" "net" "strconv" "strings" "testing" "github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constellation/state" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" "github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/verify/verifyproto" tpmProto "github.com/google/go-tpm-tools/proto/tpm" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/codes" rpcStatus "google.golang.org/grpc/status" ) func TestVerify(t *testing.T) { zeroBase64 := base64.StdEncoding.EncodeToString([]byte("00000000000000000000000000000000")) someErr := errors.New("failed") testCases := map[string]struct { provider cloudprovider.Provider protoClient *stubVerifyClient nodeEndpointFlag string clusterIDFlag string stateFile *state.State wantEndpoint string skipConfigCreation bool wantErr bool }{ "gcp": { provider: cloudprovider.GCP, nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantEndpoint: "192.0.2.1:1234", }, "azure": { provider: cloudprovider.Azure, nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.Azure), wantEndpoint: "192.0.2.1:1234", }, "default port": { provider: cloudprovider.GCP, nodeEndpointFlag: "192.0.2.1", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), }, "endpoint not set": { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: func() *state.State { s := defaultStateFile(cloudprovider.GCP) s.Infrastructure.ClusterEndpoint = "" return s }(), wantErr: true, }, "endpoint from state file": { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: func() *state.State { s := defaultStateFile(cloudprovider.GCP) s.Infrastructure.ClusterEndpoint = "192.0.2.1" return s }(), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), }, "override endpoint from details file": { provider: cloudprovider.GCP, nodeEndpointFlag: "192.0.2.2:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: func() *state.State { s := defaultStateFile(cloudprovider.GCP) s.Infrastructure.ClusterEndpoint = "192.0.2.1" return s }(), wantEndpoint: "192.0.2.2:1234", }, "invalid endpoint": { provider: cloudprovider.GCP, nodeEndpointFlag: ":::::", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantErr: true, }, "neither owner id nor cluster id set": { provider: cloudprovider.GCP, nodeEndpointFlag: "192.0.2.1:1234", stateFile: func() *state.State { s := defaultStateFile(cloudprovider.GCP) s.ClusterValues.OwnerID = "" s.ClusterValues.ClusterID = "" return s }(), protoClient: &stubVerifyClient{}, wantErr: true, }, "use owner id from state file": { provider: cloudprovider.GCP, nodeEndpointFlag: "192.0.2.1:1234", protoClient: &stubVerifyClient{}, stateFile: func() *state.State { s := defaultStateFile(cloudprovider.GCP) s.ClusterValues.OwnerID = zeroBase64 return s }(), wantEndpoint: "192.0.2.1:1234", }, "config file not existing": { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, nodeEndpointFlag: "192.0.2.1:1234", stateFile: defaultStateFile(cloudprovider.GCP), skipConfigCreation: true, wantErr: true, }, "error protoClient GetState": { provider: cloudprovider.Azure, nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")}, stateFile: defaultStateFile(cloudprovider.Azure), wantErr: true, }, "error protoClient GetState not rpc": { provider: cloudprovider.Azure, nodeEndpointFlag: "192.0.2.1:1234", clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: someErr}, stateFile: defaultStateFile(cloudprovider.Azure), wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) cmd := NewVerifyCmd() out := &bytes.Buffer{} cmd.SetErr(out) fileHandler := file.NewHandler(afero.NewMemMapFs()) if !tc.skipConfigCreation { cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg)) } require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) v := &verifyCmd{ fileHandler: fileHandler, log: logger.NewTest(t), flags: verifyFlags{ clusterID: tc.clusterIDFlag, endpoint: tc.nodeEndpointFlag, output: "raw", }, } err := v.verify(cmd, tc.protoClient, stubAttestationFetcher{}) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) assert.Contains(out.String(), "OK") assert.Equal(tc.wantEndpoint, tc.protoClient.endpoint) } }) } } func TestFormatDefault(t *testing.T) { testCases := map[string]struct { doc []byte attCfg config.AttestationCfg wantErr bool }{ "invalid doc": { doc: []byte("invalid"), attCfg: &config.AzureSEVSNP{}, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { _, err := formatDefault(context.Background(), tc.doc, tc.attCfg, logger.NewTest(t)) if tc.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) } }) } } func TestVerifyClient(t *testing.T) { testCases := map[string]struct { attestationDoc atls.FakeAttestationDoc nonce []byte attestationErr error wantErr bool }{ "success": { attestationDoc: atls.FakeAttestationDoc{ UserData: []byte(constants.ConstellationVerifyServiceUserData), Nonce: []byte("nonce"), }, nonce: []byte("nonce"), }, "attestation error": { attestationDoc: atls.FakeAttestationDoc{ UserData: []byte(constants.ConstellationVerifyServiceUserData), Nonce: []byte("nonce"), }, nonce: []byte("nonce"), attestationErr: errors.New("error"), wantErr: true, }, "user data does not match": { attestationDoc: atls.FakeAttestationDoc{ UserData: []byte("wrong user data"), Nonce: []byte("nonce"), }, nonce: []byte("nonce"), wantErr: true, }, "nonce does not match": { attestationDoc: atls.FakeAttestationDoc{ UserData: []byte(constants.ConstellationVerifyServiceUserData), Nonce: []byte("wrong nonce"), }, nonce: []byte("nonce"), wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) attestation, err := json.Marshal(tc.attestationDoc) require.NoError(err) verifyAPI := &stubVerifyAPI{ attestation: &verifyproto.GetAttestationResponse{Attestation: attestation}, attestationErr: tc.attestationErr, } netDialer := testdialer.NewBufconnDialer() dialer := dialer.New(nil, nil, netDialer) verifyServer := grpc.NewServer() verifyproto.RegisterAPIServer(verifyServer, verifyAPI) addr := net.JoinHostPort("192.0.2.1", strconv.Itoa(constants.VerifyServiceNodePortGRPC)) listener := netDialer.GetListener(addr) go verifyServer.Serve(listener) defer verifyServer.GracefulStop() verifier := &constellationVerifier{dialer: dialer, log: logger.NewTest(t)} request := &verifyproto.GetAttestationRequest{ Nonce: tc.nonce, } _, err = verifier.Verify(context.Background(), addr, request, atls.NewFakeValidator(variant.Dummy{})) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) } }) } } type stubVerifyClient struct { verifyErr error endpoint string } func (c *stubVerifyClient) Verify(_ context.Context, endpoint string, _ *verifyproto.GetAttestationRequest, _ atls.Validator) ([]byte, error) { c.endpoint = endpoint return nil, c.verifyErr } type stubVerifyAPI struct { attestation *verifyproto.GetAttestationResponse attestationErr error verifyproto.UnimplementedAPIServer } func (a stubVerifyAPI) GetAttestation(context.Context, *verifyproto.GetAttestationRequest) (*verifyproto.GetAttestationResponse, error) { return a.attestation, a.attestationErr } func TestAddPortIfMissing(t *testing.T) { testCases := map[string]struct { endpoint string defaultPort int wantResult string wantErr bool }{ "ip and port": { endpoint: "192.0.2.1:2", defaultPort: 3, wantResult: "192.0.2.1:2", }, "hostname and port": { endpoint: "foo:2", defaultPort: 3, wantResult: "foo:2", }, "ip": { endpoint: "192.0.2.1", defaultPort: 3, wantResult: "192.0.2.1:3", }, "hostname": { endpoint: "foo", defaultPort: 3, wantResult: "foo:3", }, "empty endpoint": { endpoint: "", defaultPort: 3, wantErr: true, }, "invalid endpoint": { endpoint: "foo:2:2", defaultPort: 3, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) res, err := addPortIfMissing(tc.endpoint, tc.defaultPort) if tc.wantErr { assert.Error(err) return } require.NoError(err) assert.Equal(tc.wantResult, res) }) } } func TestParseQuotes(t *testing.T) { testCases := map[string]struct { quotes []*tpmProto.Quote expectedPCRs measurements.M wantOutput string wantErr bool }{ "parse quotes in order": { quotes: []*tpmProto.Quote{ { Pcrs: &tpmProto.PCRs{ Hash: tpmProto.HashAlgo_SHA256, Pcrs: map[uint32][]byte{ 0: {0x00}, 1: {0x01}, }, }, }, }, expectedPCRs: measurements.M{ 0: measurements.WithAllBytes(0x00, measurements.Enforce, 1), 1: measurements.WithAllBytes(0x01, measurements.WarnOnly, 1), }, wantOutput: "\tQuote:\n\t\tPCR 0 (Strict: true):\n\t\t\tExpected:\t00\n\t\t\tActual:\t\t00\n\t\tPCR 1 (Strict: false):\n\t\t\tExpected:\t01\n\t\t\tActual:\t\t01\n", }, "additional quotes are skipped": { quotes: []*tpmProto.Quote{ { Pcrs: &tpmProto.PCRs{ Hash: tpmProto.HashAlgo_SHA256, Pcrs: map[uint32][]byte{ 0: {0x00}, 1: {0x01}, 2: {0x02}, 3: {0x03}, }, }, }, }, expectedPCRs: measurements.M{ 0: measurements.WithAllBytes(0x00, measurements.Enforce, 1), 1: measurements.WithAllBytes(0x01, measurements.WarnOnly, 1), }, wantOutput: "\tQuote:\n\t\tPCR 0 (Strict: true):\n\t\t\tExpected:\t00\n\t\t\tActual:\t\t00\n\t\tPCR 1 (Strict: false):\n\t\t\tExpected:\t01\n\t\t\tActual:\t\t01\n", }, "missing quotes error": { quotes: []*tpmProto.Quote{ { Pcrs: &tpmProto.PCRs{ Hash: tpmProto.HashAlgo_SHA256, Pcrs: map[uint32][]byte{ 0: {0x00}, }, }, }, }, expectedPCRs: measurements.M{ 0: measurements.WithAllBytes(0x00, measurements.Enforce, 1), 1: measurements.WithAllBytes(0x01, measurements.WarnOnly, 1), }, wantErr: true, }, "no quotes error": { quotes: []*tpmProto.Quote{}, expectedPCRs: measurements.M{ 0: measurements.WithAllBytes(0x00, measurements.Enforce, 1), 1: measurements.WithAllBytes(0x01, measurements.WarnOnly, 1), }, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) b := &strings.Builder{} err := parseQuotes(b, tc.quotes, tc.expectedPCRs) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) assert.Equal(tc.wantOutput, b.String()) } }) } } func TestValidatorUpdateInitPCRs(t *testing.T) { zero := measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength) one := measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength) one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) oneHash := sha256.Sum256(one.Expected[:]) pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...)) newTestPCRs := func() measurements.M { return measurements.M{ 0: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 1: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 2: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 3: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 4: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 5: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 6: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 7: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 8: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 9: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 10: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 11: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 12: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 13: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 14: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 15: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 16: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), 17: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), 18: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), 19: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), 20: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), 21: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), 22: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), 23: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), } } testCases := map[string]struct { config config.AttestationCfg ownerID string clusterID string wantErr bool }{ "gcp update owner ID": { config: &config.GCPSEVES{ Measurements: newTestPCRs(), }, ownerID: one64, }, "gcp update cluster ID": { config: &config.GCPSEVES{ Measurements: newTestPCRs(), }, clusterID: one64, }, "gcp update both": { config: &config.GCPSEVES{ Measurements: newTestPCRs(), }, ownerID: one64, clusterID: one64, }, "azure update owner ID": { config: &config.AzureSEVSNP{ Measurements: newTestPCRs(), }, ownerID: one64, }, "azure update cluster ID": { config: &config.AzureSEVSNP{ Measurements: newTestPCRs(), }, clusterID: one64, }, "azure update both": { config: &config.AzureSEVSNP{ Measurements: newTestPCRs(), }, ownerID: one64, clusterID: one64, }, "owner ID and cluster ID empty": { config: &config.AzureSEVSNP{ Measurements: newTestPCRs(), }, }, "invalid encoding": { config: &config.GCPSEVES{ Measurements: newTestPCRs(), }, ownerID: "invalid", wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) err := updateInitMeasurements(tc.config, tc.ownerID, tc.clusterID) if tc.wantErr { assert.Error(err) return } require.NoError(t, err) m := tc.config.GetMeasurements() for i := 0; i < len(m); i++ { switch { case i == int(measurements.PCRIndexClusterID) && tc.clusterID == "": // should be deleted _, ok := m[uint32(i)] assert.False(ok) case i == int(measurements.PCRIndexClusterID): pcr, ok := m[uint32(i)] assert.True(ok) assert.Equal(pcrZeroUpdatedOne[:], pcr.Expected) case i == int(measurements.PCRIndexOwnerID) && tc.ownerID == "": // should be deleted _, ok := m[uint32(i)] assert.False(ok) case i == int(measurements.PCRIndexOwnerID): pcr, ok := m[uint32(i)] assert.True(ok) assert.Equal(pcrZeroUpdatedOne[:], pcr.Expected) default: if i >= 17 && i <= 22 { assert.Equal(one, m[uint32(i)]) } else { assert.Equal(zero, m[uint32(i)]) } } } }) } } func TestValidatorUpdateInitMeasurementsTDX(t *testing.T) { zero := measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength) one := measurements.WithAllBytes(0x11, true, measurements.TDXMeasurementLength) one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) oneHash := sha512.Sum384(one.Expected[:]) tdxZeroUpdatedOne := sha512.Sum384(append(zero.Expected[:], oneHash[:]...)) newTestTDXMeasurements := func() measurements.M { return measurements.M{ 0: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), 1: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), 2: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), 3: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), 4: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), } } testCases := map[string]struct { measurements measurements.M clusterID string wantErr bool }{ "QEMUT TDX update update cluster ID": { measurements: newTestTDXMeasurements(), clusterID: one64, }, "cluster ID empty": { measurements: newTestTDXMeasurements(), }, "invalid encoding": { measurements: newTestTDXMeasurements(), clusterID: "invalid", wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) cfg := &config.QEMUTDX{Measurements: tc.measurements} err := updateInitMeasurements(cfg, "", tc.clusterID) if tc.wantErr { assert.Error(err) return } assert.NoError(err) for i := 0; i < len(tc.measurements); i++ { switch { case i == measurements.TDXIndexClusterID && tc.clusterID == "": // should be deleted _, ok := cfg.Measurements[uint32(i)] assert.False(ok) case i == measurements.TDXIndexClusterID: pcr, ok := cfg.Measurements[uint32(i)] assert.True(ok) assert.Equal(tdxZeroUpdatedOne[:], pcr.Expected) default: assert.Equal(zero, cfg.Measurements[uint32(i)]) } } }) } }