From 781ac85711a1edd3b543b5b8bbb97ce37f7000d7 Mon Sep 17 00:00:00 2001 From: Moritz Sanft <58110325+msanft@users.noreply.github.com> Date: Tue, 5 Dec 2023 12:28:40 +0100 Subject: [PATCH] cli: move `cloudcmd/validators` to `cmd` (#2679) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * cli: refactor `cloudcmd/validators` Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * make struct fields private Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * use errors.New Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * make struct fields private in usage * fix casing --------- Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> --- cli/internal/cloudcmd/BUILD.bazel | 7 - cli/internal/cloudcmd/validators.go | 125 ------------- cli/internal/cloudcmd/validators_test.go | 226 ----------------------- cli/internal/cmd/apply.go | 19 +- cli/internal/cmd/recover.go | 4 +- cli/internal/cmd/verify.go | 90 ++++++++- cli/internal/cmd/verify_test.go | 209 +++++++++++++++++++++ 7 files changed, 316 insertions(+), 364 deletions(-) delete mode 100644 cli/internal/cloudcmd/validators.go delete mode 100644 cli/internal/cloudcmd/validators_test.go diff --git a/cli/internal/cloudcmd/BUILD.bazel b/cli/internal/cloudcmd/BUILD.bazel index e03cacfe3..41aff9ad9 100644 --- a/cli/internal/cloudcmd/BUILD.bazel +++ b/cli/internal/cloudcmd/BUILD.bazel @@ -14,16 +14,12 @@ go_library( "terminate.go", "tfplan.go", "tfvars.go", - "validators.go", ], importpath = "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd", visibility = ["//cli:__subpackages__"], deps = [ "//cli/internal/libvirt", "//cli/internal/terraform", - "//internal/atls", - "//internal/attestation/choose", - "//internal/attestation/measurements", "//internal/attestation/variant", "//internal/cloud/azureshared", "//internal/cloud/cloudprovider", @@ -36,7 +32,6 @@ go_library( "//internal/maa", "//internal/role", "//internal/state", - "@com_github_spf13_cobra//:cobra", ], ) @@ -50,12 +45,10 @@ go_test( "terminate_test.go", "tfplan_test.go", "tfvars_test.go", - "validators_test.go", ], embed = [":cloudcmd"], deps = [ "//cli/internal/terraform", - "//internal/attestation/measurements", "//internal/attestation/variant", "//internal/cloud/cloudprovider", "//internal/cloud/gcpshared", diff --git a/cli/internal/cloudcmd/validators.go b/cli/internal/cloudcmd/validators.go deleted file mode 100644 index a5989af45..000000000 --- a/cli/internal/cloudcmd/validators.go +++ /dev/null @@ -1,125 +0,0 @@ -/* -Copyright (c) Edgeless Systems GmbH - -SPDX-License-Identifier: AGPL-3.0-only -*/ - -package cloudcmd - -import ( - "crypto/sha256" - "crypto/sha512" - "encoding/base64" - "encoding/hex" - "fmt" - - "github.com/edgelesssys/constellation/v2/internal/atls" - "github.com/edgelesssys/constellation/v2/internal/attestation/choose" - "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" - "github.com/edgelesssys/constellation/v2/internal/attestation/variant" - "github.com/edgelesssys/constellation/v2/internal/config" - "github.com/spf13/cobra" -) - -// NewValidator creates a new Validator. -func NewValidator(cmd *cobra.Command, config config.AttestationCfg, log debugLog) (atls.Validator, error) { - return choose.Validator(config, WarnLogger{Cmd: cmd, Log: log}) -} - -// UpdateInitMeasurements sets the owner and cluster measurement values. -func UpdateInitMeasurements(config config.AttestationCfg, ownerID, clusterID string) error { - m := config.GetMeasurements() - - switch config.GetVariant() { - case variant.AWSNitroTPM{}, variant.AWSSEVSNP{}, variant.AzureTrustedLaunch{}, variant.AzureSEVSNP{}, variant.GCPSEVES{}, variant.QEMUVTPM{}: - if err := updateMeasurementTPM(m, uint32(measurements.PCRIndexOwnerID), ownerID); err != nil { - return err - } - return updateMeasurementTPM(m, uint32(measurements.PCRIndexClusterID), clusterID) - case variant.QEMUTDX{}: - // Measuring ownerID is currently not implemented for Constellation - // Since adding support for measuring ownerID to TDX would require additional code changes, - // the current implementation does not support it, but can be changed if we decide to add support in the future - return updateMeasurementTDX(m, uint32(measurements.TDXIndexClusterID), clusterID) - default: - return fmt.Errorf("selecting attestation variant: unknown attestation variant") - } -} - -func updateMeasurementTDX(m measurements.M, measurementIdx uint32, encoded string) error { - if encoded == "" { - delete(m, measurementIdx) - return nil - } - decoded, err := decodeMeasurement(encoded) - if err != nil { - return err - } - - // new_measurement_value := hash(old_measurement_value || data_to_extend) - // Since we use the DG.MR.RTMR.EXTEND call to extend the register, data_to_extend is the hash of our input - hashedInput := sha512.Sum384(decoded) - oldExpected := m[measurementIdx].Expected - expectedMeasurementSum := sha512.Sum384(append(oldExpected[:], hashedInput[:]...)) - m[measurementIdx] = measurements.Measurement{ - Expected: expectedMeasurementSum[:], - ValidationOpt: m[measurementIdx].ValidationOpt, - } - return nil -} - -func updateMeasurementTPM(m measurements.M, measurementIdx uint32, encoded string) error { - if encoded == "" { - delete(m, measurementIdx) - return nil - } - decoded, err := decodeMeasurement(encoded) - if err != nil { - return err - } - - // new_pcr_value := hash(old_pcr_value || data_to_extend) - // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input - hashedInput := sha256.Sum256(decoded) - oldExpected := m[measurementIdx].Expected - expectedMeasurement := sha256.Sum256(append(oldExpected[:], hashedInput[:]...)) - m[measurementIdx] = measurements.Measurement{ - Expected: expectedMeasurement[:], - ValidationOpt: m[measurementIdx].ValidationOpt, - } - return nil -} - -func decodeMeasurement(encoded string) ([]byte, error) { - decoded, err := hex.DecodeString(encoded) - if err != nil { - hexErr := err - decoded, err = base64.StdEncoding.DecodeString(encoded) - if err != nil { - return nil, fmt.Errorf("input [%s] could neither be hex decoded (%w) nor base64 decoded (%w)", encoded, hexErr, err) - } - } - return decoded, nil -} - -// WarnLogger implements logging of warnings for validators. -type WarnLogger struct { - Cmd *cobra.Command - Log debugLog -} - -// Infof messages are reduced to debug messages, since we don't want -// the extra info when using the CLI without setting the debug flag. -func (wl WarnLogger) Infof(fmtStr string, args ...any) { - wl.Log.Debugf(fmtStr, args...) -} - -// Warnf prints a formatted warning from the validator. -func (wl WarnLogger) Warnf(fmtStr string, args ...any) { - wl.Cmd.PrintErrf("Warning: %s\n", fmt.Sprintf(fmtStr, args...)) -} - -type debugLog interface { - Debugf(format string, args ...any) - Sync() -} diff --git a/cli/internal/cloudcmd/validators_test.go b/cli/internal/cloudcmd/validators_test.go deleted file mode 100644 index 980a8326b..000000000 --- a/cli/internal/cloudcmd/validators_test.go +++ /dev/null @@ -1,226 +0,0 @@ -/* -Copyright (c) Edgeless Systems GmbH - -SPDX-License-Identifier: AGPL-3.0-only -*/ - -package cloudcmd - -import ( - "crypto/sha256" - "crypto/sha512" - "encoding/base64" - "testing" - - "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" - "github.com/edgelesssys/constellation/v2/internal/config" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestValidatorUpdateInitPCRs(t *testing.T) { - zero := measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength) - one := measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength) - one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) - oneHash := sha256.Sum256(one.Expected[:]) - pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...)) - newTestPCRs := func() measurements.M { - return measurements.M{ - 0: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 1: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 2: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 3: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 4: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 5: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 6: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 7: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 8: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 9: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 10: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 11: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 12: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 13: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 14: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 15: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 16: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - 17: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), - 18: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), - 19: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), - 20: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), - 21: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), - 22: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), - 23: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), - } - } - - testCases := map[string]struct { - config config.AttestationCfg - ownerID string - clusterID string - wantErr bool - }{ - "gcp update owner ID": { - config: &config.GCPSEVES{ - Measurements: newTestPCRs(), - }, - ownerID: one64, - }, - "gcp update cluster ID": { - config: &config.GCPSEVES{ - Measurements: newTestPCRs(), - }, - clusterID: one64, - }, - "gcp update both": { - config: &config.GCPSEVES{ - Measurements: newTestPCRs(), - }, - ownerID: one64, - clusterID: one64, - }, - "azure update owner ID": { - config: &config.AzureSEVSNP{ - Measurements: newTestPCRs(), - }, - ownerID: one64, - }, - "azure update cluster ID": { - config: &config.AzureSEVSNP{ - Measurements: newTestPCRs(), - }, - clusterID: one64, - }, - "azure update both": { - config: &config.AzureSEVSNP{ - Measurements: newTestPCRs(), - }, - ownerID: one64, - clusterID: one64, - }, - "owner ID and cluster ID empty": { - config: &config.AzureSEVSNP{ - Measurements: newTestPCRs(), - }, - }, - "invalid encoding": { - config: &config.GCPSEVES{ - Measurements: newTestPCRs(), - }, - ownerID: "invalid", - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - err := UpdateInitMeasurements(tc.config, tc.ownerID, tc.clusterID) - - if tc.wantErr { - assert.Error(err) - return - } - require.NoError(t, err) - m := tc.config.GetMeasurements() - for i := 0; i < len(m); i++ { - switch { - case i == int(measurements.PCRIndexClusterID) && tc.clusterID == "": - // should be deleted - _, ok := m[uint32(i)] - assert.False(ok) - - case i == int(measurements.PCRIndexClusterID): - pcr, ok := m[uint32(i)] - assert.True(ok) - assert.Equal(pcrZeroUpdatedOne[:], pcr.Expected) - - case i == int(measurements.PCRIndexOwnerID) && tc.ownerID == "": - // should be deleted - _, ok := m[uint32(i)] - assert.False(ok) - - case i == int(measurements.PCRIndexOwnerID): - pcr, ok := m[uint32(i)] - assert.True(ok) - assert.Equal(pcrZeroUpdatedOne[:], pcr.Expected) - - default: - if i >= 17 && i <= 22 { - assert.Equal(one, m[uint32(i)]) - } else { - assert.Equal(zero, m[uint32(i)]) - } - } - } - }) - } -} - -func TestValidatorUpdateInitMeasurementsTDX(t *testing.T) { - zero := measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength) - one := measurements.WithAllBytes(0x11, true, measurements.TDXMeasurementLength) - one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) - oneHash := sha512.Sum384(one.Expected[:]) - tdxZeroUpdatedOne := sha512.Sum384(append(zero.Expected[:], oneHash[:]...)) - newTestTDXMeasurements := func() measurements.M { - return measurements.M{ - 0: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), - 1: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), - 2: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), - 3: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), - 4: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), - } - } - - testCases := map[string]struct { - measurements measurements.M - clusterID string - wantErr bool - }{ - "QEMUT TDX update update cluster ID": { - measurements: newTestTDXMeasurements(), - clusterID: one64, - }, - "cluster ID empty": { - measurements: newTestTDXMeasurements(), - }, - "invalid encoding": { - measurements: newTestTDXMeasurements(), - clusterID: "invalid", - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - cfg := &config.QEMUTDX{Measurements: tc.measurements} - - err := UpdateInitMeasurements(cfg, "", tc.clusterID) - - if tc.wantErr { - assert.Error(err) - return - } - assert.NoError(err) - for i := 0; i < len(tc.measurements); i++ { - switch { - case i == measurements.TDXIndexClusterID && tc.clusterID == "": - // should be deleted - _, ok := cfg.Measurements[uint32(i)] - assert.False(ok) - - case i == measurements.TDXIndexClusterID: - pcr, ok := cfg.Measurements[uint32(i)] - assert.True(ok) - assert.Equal(tdxZeroUpdatedOne[:], pcr.Expected) - - default: - assert.Equal(zero, cfg.Measurements[uint32(i)]) - } - } - }) - } -} diff --git a/cli/internal/cmd/apply.go b/cli/internal/cmd/apply.go index 3c63898db..cb7a810d6 100644 --- a/cli/internal/cmd/apply.go +++ b/cli/internal/cmd/apply.go @@ -260,7 +260,7 @@ func runApply(cmd *cobra.Command, _ []string) error { fileHandler: fileHandler, flags: flags, log: log, - wLog: &cloudcmd.WarnLogger{Cmd: cmd, Log: log}, + wLog: &warnLogger{cmd: cmd, log: log}, spinner: spinner, merger: &kubeconfigMerger{log: log}, newHelmClient: newHelmClient, @@ -786,3 +786,20 @@ func skipPhasesCompletion(_ *cobra.Command, _ []string, toComplete string) ([]st return suggestions, cobra.ShellCompDirectiveNoFileComp } + +// warnLogger implements logging of warnings for validators. +type warnLogger struct { + cmd *cobra.Command + log debugLog +} + +// Infof messages are reduced to debug messages, since we don't want +// the extra info when using the CLI without setting the debug flag. +func (wl warnLogger) Infof(fmtStr string, args ...any) { + wl.log.Debugf(fmtStr, args...) +} + +// Warnf prints a formatted warning from the validator. +func (wl warnLogger) Warnf(fmtStr string, args ...any) { + wl.cmd.PrintErrf("Warning: %s\n", fmt.Sprintf(fmtStr, args...)) +} diff --git a/cli/internal/cmd/recover.go b/cli/internal/cmd/recover.go index 735f97524..1ba8ff9d4 100644 --- a/cli/internal/cmd/recover.go +++ b/cli/internal/cmd/recover.go @@ -15,10 +15,10 @@ import ( "sync" "time" - "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/atls" + "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -132,7 +132,7 @@ func (r *recoverCmd) recover( } r.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant()) - validator, err := cloudcmd.NewValidator(cmd, conf.GetAttestationConfig(), r.log) + validator, err := choose.Validator(conf.GetAttestationConfig(), warnLogger{cmd: cmd, log: r.log}) if err != nil { return fmt.Errorf("creating new validator: %w", err) } diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 8c0b465e0..99068ec0c 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -9,7 +9,10 @@ package cmd import ( "bytes" "context" + "crypto/sha256" + "crypto/sha512" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -20,11 +23,12 @@ import ( tpmProto "github.com/google/go-tpm-tools/proto/tpm" - "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/atls" + "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/snp" + "github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" @@ -169,12 +173,12 @@ func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factor c.log.Debugf("Updating expected PCRs") attConfig := conf.GetAttestationConfig() - if err := cloudcmd.UpdateInitMeasurements(attConfig, ownerID, clusterID); err != nil { + if err := updateInitMeasurements(attConfig, ownerID, clusterID); err != nil { return fmt.Errorf("updating expected PCRs: %w", err) } c.log.Debugf("Creating aTLS Validator for %s", conf.GetAttestationConfig().GetVariant()) - validator, err := cloudcmd.NewValidator(cmd, attConfig, c.log) + validator, err := choose.Validator(attConfig, warnLogger{cmd: cmd, log: c.log}) if err != nil { return fmt.Errorf("creating aTLS validator: %w", err) } @@ -456,3 +460,83 @@ func addPortIfMissing(endpoint string, defaultPort int) (string, error) { return "", err } + +// UpdateInitMeasurements sets the owner and cluster measurement values in the attestation config depending on the +// attestation variant. +func updateInitMeasurements(config config.AttestationCfg, ownerID, clusterID string) error { + m := config.GetMeasurements() + + switch config.GetVariant() { + case variant.AWSNitroTPM{}, variant.AWSSEVSNP{}, variant.AzureTrustedLaunch{}, variant.AzureSEVSNP{}, variant.GCPSEVES{}, variant.QEMUVTPM{}: + if err := updateMeasurementTPM(m, uint32(measurements.PCRIndexOwnerID), ownerID); err != nil { + return err + } + return updateMeasurementTPM(m, uint32(measurements.PCRIndexClusterID), clusterID) + case variant.QEMUTDX{}: + // Measuring ownerID is currently not implemented for Constellation + // Since adding support for measuring ownerID to TDX would require additional code changes, + // the current implementation does not support it, but can be changed if we decide to add support in the future + return updateMeasurementTDX(m, uint32(measurements.TDXIndexClusterID), clusterID) + default: + return errors.New("selecting attestation variant: unknown attestation variant") + } +} + +// updateMeasurementTDX updates the TDX measurement value in the attestation config for the given measurement index. +func updateMeasurementTDX(m measurements.M, measurementIdx uint32, encoded string) error { + if encoded == "" { + delete(m, measurementIdx) + return nil + } + decoded, err := decodeMeasurement(encoded) + if err != nil { + return err + } + + // new_measurement_value := hash(old_measurement_value || data_to_extend) + // Since we use the DG.MR.RTMR.EXTEND call to extend the register, data_to_extend is the hash of our input + hashedInput := sha512.Sum384(decoded) + oldExpected := m[measurementIdx].Expected + expectedMeasurementSum := sha512.Sum384(append(oldExpected[:], hashedInput[:]...)) + m[measurementIdx] = measurements.Measurement{ + Expected: expectedMeasurementSum[:], + ValidationOpt: m[measurementIdx].ValidationOpt, + } + return nil +} + +// updateMeasurementTPM updates the TPM measurement value in the attestation config for the given measurement index. +func updateMeasurementTPM(m measurements.M, measurementIdx uint32, encoded string) error { + if encoded == "" { + delete(m, measurementIdx) + return nil + } + decoded, err := decodeMeasurement(encoded) + if err != nil { + return err + } + + // new_pcr_value := hash(old_pcr_value || data_to_extend) + // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input + hashedInput := sha256.Sum256(decoded) + oldExpected := m[measurementIdx].Expected + expectedMeasurement := sha256.Sum256(append(oldExpected[:], hashedInput[:]...)) + m[measurementIdx] = measurements.Measurement{ + Expected: expectedMeasurement[:], + ValidationOpt: m[measurementIdx].ValidationOpt, + } + return nil +} + +// decodeMeasurement is a utility function that decodes the given string as hex or base64. +func decodeMeasurement(encoded string) ([]byte, error) { + decoded, err := hex.DecodeString(encoded) + if err != nil { + hexErr := err + decoded, err = base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("input [%s] could neither be hex decoded (%w) nor base64 decoded (%w)", encoded, hexErr, err) + } + } + return decoded, nil +} diff --git a/cli/internal/cmd/verify_test.go b/cli/internal/cmd/verify_test.go index 7896eb567..2719154e6 100644 --- a/cli/internal/cmd/verify_test.go +++ b/cli/internal/cmd/verify_test.go @@ -9,6 +9,8 @@ package cmd import ( "bytes" "context" + "crypto/sha256" + "crypto/sha512" "encoding/base64" "encoding/json" "errors" @@ -512,3 +514,210 @@ func TestParseQuotes(t *testing.T) { }) } } + +func TestValidatorUpdateInitPCRs(t *testing.T) { + zero := measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength) + one := measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength) + one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) + oneHash := sha256.Sum256(one.Expected[:]) + pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...)) + newTestPCRs := func() measurements.M { + return measurements.M{ + 0: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 1: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 2: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 3: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 4: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 5: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 6: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 7: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 8: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 9: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 10: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 11: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 12: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 13: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 14: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 15: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 16: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + 17: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), + 18: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), + 19: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), + 20: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), + 21: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), + 22: measurements.WithAllBytes(0x11, measurements.WarnOnly, measurements.PCRMeasurementLength), + 23: measurements.WithAllBytes(0x00, measurements.WarnOnly, measurements.PCRMeasurementLength), + } + } + + testCases := map[string]struct { + config config.AttestationCfg + ownerID string + clusterID string + wantErr bool + }{ + "gcp update owner ID": { + config: &config.GCPSEVES{ + Measurements: newTestPCRs(), + }, + ownerID: one64, + }, + "gcp update cluster ID": { + config: &config.GCPSEVES{ + Measurements: newTestPCRs(), + }, + clusterID: one64, + }, + "gcp update both": { + config: &config.GCPSEVES{ + Measurements: newTestPCRs(), + }, + ownerID: one64, + clusterID: one64, + }, + "azure update owner ID": { + config: &config.AzureSEVSNP{ + Measurements: newTestPCRs(), + }, + ownerID: one64, + }, + "azure update cluster ID": { + config: &config.AzureSEVSNP{ + Measurements: newTestPCRs(), + }, + clusterID: one64, + }, + "azure update both": { + config: &config.AzureSEVSNP{ + Measurements: newTestPCRs(), + }, + ownerID: one64, + clusterID: one64, + }, + "owner ID and cluster ID empty": { + config: &config.AzureSEVSNP{ + Measurements: newTestPCRs(), + }, + }, + "invalid encoding": { + config: &config.GCPSEVES{ + Measurements: newTestPCRs(), + }, + ownerID: "invalid", + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + err := updateInitMeasurements(tc.config, tc.ownerID, tc.clusterID) + + if tc.wantErr { + assert.Error(err) + return + } + require.NoError(t, err) + m := tc.config.GetMeasurements() + for i := 0; i < len(m); i++ { + switch { + case i == int(measurements.PCRIndexClusterID) && tc.clusterID == "": + // should be deleted + _, ok := m[uint32(i)] + assert.False(ok) + + case i == int(measurements.PCRIndexClusterID): + pcr, ok := m[uint32(i)] + assert.True(ok) + assert.Equal(pcrZeroUpdatedOne[:], pcr.Expected) + + case i == int(measurements.PCRIndexOwnerID) && tc.ownerID == "": + // should be deleted + _, ok := m[uint32(i)] + assert.False(ok) + + case i == int(measurements.PCRIndexOwnerID): + pcr, ok := m[uint32(i)] + assert.True(ok) + assert.Equal(pcrZeroUpdatedOne[:], pcr.Expected) + + default: + if i >= 17 && i <= 22 { + assert.Equal(one, m[uint32(i)]) + } else { + assert.Equal(zero, m[uint32(i)]) + } + } + } + }) + } +} + +func TestValidatorUpdateInitMeasurementsTDX(t *testing.T) { + zero := measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength) + one := measurements.WithAllBytes(0x11, true, measurements.TDXMeasurementLength) + one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) + oneHash := sha512.Sum384(one.Expected[:]) + tdxZeroUpdatedOne := sha512.Sum384(append(zero.Expected[:], oneHash[:]...)) + newTestTDXMeasurements := func() measurements.M { + return measurements.M{ + 0: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 1: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 2: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 3: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 4: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + } + } + + testCases := map[string]struct { + measurements measurements.M + clusterID string + wantErr bool + }{ + "QEMUT TDX update update cluster ID": { + measurements: newTestTDXMeasurements(), + clusterID: one64, + }, + "cluster ID empty": { + measurements: newTestTDXMeasurements(), + }, + "invalid encoding": { + measurements: newTestTDXMeasurements(), + clusterID: "invalid", + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + cfg := &config.QEMUTDX{Measurements: tc.measurements} + + err := updateInitMeasurements(cfg, "", tc.clusterID) + + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + for i := 0; i < len(tc.measurements); i++ { + switch { + case i == measurements.TDXIndexClusterID && tc.clusterID == "": + // should be deleted + _, ok := cfg.Measurements[uint32(i)] + assert.False(ok) + + case i == measurements.TDXIndexClusterID: + pcr, ok := cfg.Measurements[uint32(i)] + assert.True(ok) + assert.Equal(tdxZeroUpdatedOne[:], pcr.Expected) + + default: + assert.Equal(zero, cfg.Measurements[uint32(i)]) + } + } + }) + } +}