From e130188ecd3f64dace712efd54d5a93d4e347cb0 Mon Sep 17 00:00:00 2001 From: Nils Hanke Date: Wed, 22 Mar 2023 14:56:51 +0100 Subject: [PATCH] cli: add verify support for TDX --- cli/internal/cloudcmd/validators.go | 73 ++++++++++++++++++------ cli/internal/cloudcmd/validators_test.go | 69 ++++++++++++++++++++++ 2 files changed, 124 insertions(+), 18 deletions(-) diff --git a/cli/internal/cloudcmd/validators.go b/cli/internal/cloudcmd/validators.go index 47dc12bb1..0450fed71 100644 --- a/cli/internal/cloudcmd/validators.go +++ b/cli/internal/cloudcmd/validators.go @@ -8,6 +8,7 @@ package cloudcmd import ( "crypto/sha256" + "crypto/sha512" "encoding/base64" "encoding/hex" "fmt" @@ -16,6 +17,7 @@ import ( "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/config" + "github.com/edgelesssys/constellation/v2/internal/variant" "github.com/spf13/cobra" ) @@ -27,33 +29,56 @@ func NewValidator(cmd *cobra.Command, config config.AttestationCfg, log debugLog // UpdateInitMeasurements sets the owner and cluster measurement values. func UpdateInitMeasurements(config config.AttestationCfg, ownerID, clusterID string) error { m := config.GetMeasurements() - if err := updateMeasurement(m, uint32(measurements.PCRIndexOwnerID), ownerID); err != nil { - return err + + switch config.GetVariant() { + case variant.AWSNitroTPM{}, variant.AzureTrustedLaunch{}, variant.AzureSEVSNP{}, variant.GCPSEVES{}, variant.QEMUVTPM{}: + if err := updateMeasurementTPM(m, uint32(measurements.PCRIndexOwnerID), ownerID); err != nil { + return err + } + return updateMeasurementTPM(m, uint32(measurements.PCRIndexClusterID), clusterID) + case variant.QEMUTDX{}: + // Measuring ownerID is currently not implemented for Constellation + // Since adding support for measuring ownerID to TDX would require additional code changes, + // the current implementation does not support it, but can be changed if we decide to add support in the future + return updateMeasurementTDX(m, uint32(measurements.TDXIndexClusterID), clusterID) + default: + return fmt.Errorf("UpdateInitMeasurements: unknown attestation variant") } - return updateMeasurement(m, uint32(measurements.PCRIndexClusterID), clusterID) } -// updateMeasurement adds a new entry to the measurements of v, or removes the key if the input is an empty string. -// -// When adding, the input is first decoded from hex or base64. -// We then calculate the expected measurement by hashing the input using SHA256, -// appending expected measurement for initialization, and then hashing once more. -func updateMeasurement(m measurements.M, measurementIdx uint32, encoded string) error { +func updateMeasurementTDX(m measurements.M, measurementIdx uint32, encoded string) error { if encoded == "" { delete(m, measurementIdx) return nil } - - // decode from hex or base64 - decoded, err := hex.DecodeString(encoded) + decoded, err := decodeMeasurement(encoded) if err != nil { - hexErr := err - decoded, err = base64.StdEncoding.DecodeString(encoded) - if err != nil { - return fmt.Errorf("input [%s] could neither be hex decoded (%w) nor base64 decoded (%w)", encoded, hexErr, err) - } + return err } - // new_measurement_value := hash(old_pcr_value || data_to_extend) + + // new_measurement_value := hash(old_measurement_value || data_to_extend) + // Since we use the DG.MR.RTMR.EXTEND call to extend the register, data_to_extend is the hash of our input + hashedInput := sha512.Sum384(decoded) + oldExpected := m[measurementIdx].Expected + expectedMeasurementSum := sha512.Sum384(append(oldExpected[:], hashedInput[:]...)) + m[measurementIdx] = measurements.Measurement{ + Expected: expectedMeasurementSum[:], + ValidationOpt: m[measurementIdx].ValidationOpt, + } + return nil +} + +func updateMeasurementTPM(m measurements.M, measurementIdx uint32, encoded string) error { + if encoded == "" { + delete(m, measurementIdx) + return nil + } + decoded, err := decodeMeasurement(encoded) + if err != nil { + return err + } + + // new_pcr_value := hash(old_pcr_value || data_to_extend) // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input hashedInput := sha256.Sum256(decoded) oldExpected := m[measurementIdx].Expected @@ -65,6 +90,18 @@ func updateMeasurement(m measurements.M, measurementIdx uint32, encoded string) return nil } +func decodeMeasurement(encoded string) ([]byte, error) { + decoded, err := hex.DecodeString(encoded) + if err != nil { + hexErr := err + decoded, err = base64.StdEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("input [%s] could neither be hex decoded (%w) nor base64 decoded (%w)", encoded, hexErr, err) + } + } + return decoded, nil +} + // warnLogger implements logging of warnings for validators. type warnLogger struct { cmd *cobra.Command diff --git a/cli/internal/cloudcmd/validators_test.go b/cli/internal/cloudcmd/validators_test.go index 62b4a297e..980a8326b 100644 --- a/cli/internal/cloudcmd/validators_test.go +++ b/cli/internal/cloudcmd/validators_test.go @@ -8,6 +8,7 @@ package cloudcmd import ( "crypto/sha256" + "crypto/sha512" "encoding/base64" "testing" @@ -155,3 +156,71 @@ func TestValidatorUpdateInitPCRs(t *testing.T) { }) } } + +func TestValidatorUpdateInitMeasurementsTDX(t *testing.T) { + zero := measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength) + one := measurements.WithAllBytes(0x11, true, measurements.TDXMeasurementLength) + one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) + oneHash := sha512.Sum384(one.Expected[:]) + tdxZeroUpdatedOne := sha512.Sum384(append(zero.Expected[:], oneHash[:]...)) + newTestTDXMeasurements := func() measurements.M { + return measurements.M{ + 0: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 1: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 2: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 3: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + 4: measurements.WithAllBytes(0x00, true, measurements.TDXMeasurementLength), + } + } + + testCases := map[string]struct { + measurements measurements.M + clusterID string + wantErr bool + }{ + "QEMUT TDX update update cluster ID": { + measurements: newTestTDXMeasurements(), + clusterID: one64, + }, + "cluster ID empty": { + measurements: newTestTDXMeasurements(), + }, + "invalid encoding": { + measurements: newTestTDXMeasurements(), + clusterID: "invalid", + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + cfg := &config.QEMUTDX{Measurements: tc.measurements} + + err := UpdateInitMeasurements(cfg, "", tc.clusterID) + + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + for i := 0; i < len(tc.measurements); i++ { + switch { + case i == measurements.TDXIndexClusterID && tc.clusterID == "": + // should be deleted + _, ok := cfg.Measurements[uint32(i)] + assert.False(ok) + + case i == measurements.TDXIndexClusterID: + pcr, ok := cfg.Measurements[uint32(i)] + assert.True(ok) + assert.Equal(tdxZeroUpdatedOne[:], pcr.Expected) + + default: + assert.Equal(zero, cfg.Measurements[uint32(i)]) + } + } + }) + } +}