From f8001efbc0d00d5a89ba59ce48e6a34453831c14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Thu, 24 Nov 2022 10:57:58 +0100 Subject: [PATCH] Refactor enforced/expected PCRs (#553) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Merge enforced and expected measurements * Update measurement generation to new format * Write expected measurements hex encoded by default * Allow hex or base64 encoded expected measurements * Allow hex or base64 encoded clusterID * Allow security upgrades to warnOnly flag * Upload signed measurements in JSON format * Fetch measurements either from JSON or YAML * Use yaml.v3 instead of yaml.v2 * Error on invalid enforced selection * Add placeholder measurements to config * Update e2e test to new measurement format Signed-off-by: Daniel Weiße --- .../actions/constellation_create/action.yml | 6 +- .../actions/constellation_measure/action.yml | 41 +- bootstrapper/cmd/bootstrapper/main.go | 33 +- bootstrapper/initproto/init.proto | 6 +- .../internal/kubernetes/kubernetes.go | 59 +- cli/internal/cloudcmd/upgrade.go | 7 + cli/internal/cloudcmd/upgrade_test.go | 38 +- cli/internal/cloudcmd/validators.go | 93 ++- cli/internal/cloudcmd/validators_test.go | 151 ++--- cli/internal/cmd/configfetchmeasurements.go | 4 +- .../cmd/configfetchmeasurements_test.go | 16 +- cli/internal/cmd/init.go | 8 +- cli/internal/cmd/init_test.go | 68 +-- cli/internal/cmd/upgradeplan.go | 4 +- cli/internal/cmd/upgradeplan_test.go | 10 +- .../join-service/templates/configmap.yaml | 1 - .../charts/join-service/values.schema.json | 8 +- cli/internal/helm/loader.go | 16 +- cli/internal/helm/loader_test.go | 46 +- .../join-service/templates/configmap.yaml | 3 +- .../join-service/templates/configmap.yaml | 3 +- .../join-service/templates/configmap.yaml | 3 +- disk-mapper/cmd/main.go | 2 +- go.mod | 2 +- hack/go.mod | 2 +- hack/pcr-reader/main.go | 63 +- hack/pcr-reader/main_test.go | 97 +--- hack/qemu-metadata-api/server/server_test.go | 8 +- internal/attestation/aws/validator.go | 3 +- internal/attestation/azure/snp/validator.go | 3 +- .../azure/trustedlaunch/trustedlaunch_test.go | 2 +- .../azure/trustedlaunch/validator.go | 3 +- internal/attestation/gcp/validator.go | 3 +- .../attestation/measurements/measurements.go | 304 +++++++--- .../measurements/measurements_test.go | 538 +++++++++++++----- internal/attestation/qemu/validator.go | 3 +- internal/attestation/vtpm/attestation.go | 37 +- internal/attestation/vtpm/attestation_test.go | 71 +-- internal/attestation/vtpm/initialize.go | 16 +- internal/attestation/vtpm/initialize_test.go | 5 +- internal/config/config.go | 44 +- internal/config/config_doc.go | 28 +- internal/config/config_test.go | 16 +- internal/config/validation.go | 34 ++ internal/watcher/validator.go | 42 +- internal/watcher/validator_test.go | 31 +- 46 files changed, 1180 insertions(+), 801 deletions(-) diff --git a/.github/actions/constellation_create/action.yml b/.github/actions/constellation_create/action.yml index 8816e43fb..221e8f5ce 100644 --- a/.github/actions/constellation_create/action.yml +++ b/.github/actions/constellation_create/action.yml @@ -75,14 +75,14 @@ runs: (.provider | select(. | has(\"azure\")).azure.resourceGroup) = \"${{ inputs.azureResourceGroup }}\" | (.provider | select(. | has(\"azure\")).azure.appClientID) = \"${{ inputs.azureClientID }}\" | (.provider | select(. | has(\"azure\")).azure.clientSecretValue) = \"${{ inputs.azureClientSecret }}\" | - (.provider | select(. | has(\"azure\")).azure.enforcedMeasurements) = [15]" \ + (.provider | select(. | has(\"azure\")).azure.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}}" \ constellation-conf.yaml yq eval -i \ "(.provider | select(. | has(\"gcp\")).gcp.project) = \"${{ inputs.gcpProject }}\" | (.provider | select(. | has(\"gcp\")).gcp.region) = \"europe-west3\" | (.provider | select(. | has(\"gcp\")).gcp.zone) = \"europe-west3-b\" | - (.provider | select(. | has(\"gcp\")).gcp.enforcedMeasurements) = [15] | + (.provider | select(. | has(\"gcp\")).gcp.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}} | (.provider | select(. | has(\"gcp\")).gcp.serviceAccountKeyPath) = \"serviceAccountKey.json\"" \ constellation-conf.yaml @@ -91,7 +91,7 @@ runs: (.provider | select(. | has(\"aws\")).aws.zone) = \"eu-central-1a\" | (.provider | select(. | has(\"aws\")).aws.iamProfileControlPlane) = \"e2e_test_control_plane_instance_profile\" | (.provider | select(. | has(\"aws\")).aws.iamProfileWorkerNodes) = \"e2e_test_worker_node_instance_profile\" | - (.provider | select(. | has(\"aws\")).aws.enforcedMeasurements) = [15]" \ + (.provider | select(. | has(\"aws\")).aws.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}}" \ constellation-conf.yaml if [ ${{ inputs.kubernetesVersion != '' }} = true ]; then diff --git a/.github/actions/constellation_measure/action.yml b/.github/actions/constellation_measure/action.yml index 8716e7cdd..f703b6af0 100644 --- a/.github/actions/constellation_measure/action.yml +++ b/.github/actions/constellation_measure/action.yml @@ -51,16 +51,35 @@ runs: run: | KUBECONFIG="$PWD/constellation-admin.conf" kubectl rollout status ds/verification-service -n kube-system --timeout=3m CONSTELL_IP=$(jq -r ".ip" constellation-id.json) - pcr-reader --constell-ip ${CONSTELL_IP} -format yaml > measurements.yaml + pcr-reader --constell-ip ${CONSTELL_IP} -format json > measurements.json case $CSP in azure) - yq e 'del(.[0,6,10,11,12,13,14,15,16,17,18,19,20,21,22,23])' -i measurements.yaml + yq e 'del(.[0,6,10,16,17,18,19,20,21,22,23])' -I 0 -o json -i measurements.json + yq '.4.warnOnly = false | + .8.warnOnly = false | + .9.warnOnly = false | + .11.warnOnly = false | + .12.warnOnly = false | + .13.warnOnly = false | + .15.warnOnly = false | + .15.expected = "0000000000000000000000000000000000000000000000000000000000000000"' \ + -I 0 -o json -i measurements.json ;; gcp) - yq e 'del(.[11,12,13,14,15,16,17,18,19,20,21,22,23])' -i measurements.yaml + yq e 'del(.[16,17,18,19,20,21,22,23])' -I 0 -o json -i measurements.json + yq '.0.warnOnly = false | + .4.warnOnly = false | + .8.warnOnly = false | + .9.warnOnly = false | + .11.warnOnly = false | + .12.warnOnly = false | + .13.warnOnly = false | + .15.warnOnly = false | + .15.expected = "0000000000000000000000000000000000000000000000000000000000000000"' \ + -I 0 -o json -i measurements.json ;; esac - cat measurements.yaml + cat measurements.json shell: bash env: CSP: ${{ inputs.cloudProvider }} @@ -81,14 +100,14 @@ runs: run: | echo "$COSIGN_PUBLIC_KEY" > cosign.pub # Enabling experimental mode also publishes signature to Rekor - COSIGN_EXPERIMENTAL=1 cosign sign-blob --key env://COSIGN_PRIVATE_KEY measurements.yaml > measurements.yaml.sig + COSIGN_EXPERIMENTAL=1 cosign sign-blob --key env://COSIGN_PRIVATE_KEY measurements.json > measurements.json.sig # Verify - As documentation & check # Local Signature (input: artifact, key, signature) - cosign verify-blob --key cosign.pub --signature measurements.yaml.sig measurements.yaml + cosign verify-blob --key cosign.pub --signature measurements.json.sig measurements.json # Transparency Log Signature (input: artifact, key) - uuid=$(rekor-cli search --artifact measurements.yaml | tail -n 1) + uuid=$(rekor-cli search --artifact measurements.json | tail -n 1) sig=$(rekor-cli get --uuid=$uuid --format=json | jq -r .Body.HashedRekordObj.signature.content) - cosign verify-blob --key cosign.pub --signature <(echo $sig) measurements.yaml + cosign verify-blob --key cosign.pub --signature <(echo $sig) measurements.json shell: bash env: COSIGN_PUBLIC_KEY: ${{ inputs.cosignPublicKey }} @@ -100,9 +119,9 @@ runs: run: | IMAGE=$(yq e ".provider.${CSP}.image" constellation-conf.yaml) S3_PATH=s3://${PUBLIC_BUCKET_NAME}/${IMAGE,,} - aws s3 cp measurements.yaml ${S3_PATH}/measurements.yaml - if test -f measurements.yaml.sig; then - aws s3 cp measurements.yaml.sig ${S3_PATH}/measurements.yaml.sig + aws s3 cp measurements.json ${S3_PATH}/measurements.json + if test -f measurements.json.sig; then + aws s3 cp measurements.json.sig ${S3_PATH}/measurements.json.sig fi shell: bash env: diff --git a/bootstrapper/cmd/bootstrapper/main.go b/bootstrapper/cmd/bootstrapper/main.go index 877556422..080c5ebba 100644 --- a/bootstrapper/cmd/bootstrapper/main.go +++ b/bootstrapper/cmd/bootstrapper/main.go @@ -8,7 +8,6 @@ package main import ( "context" - "encoding/json" "flag" "io" "os" @@ -80,14 +79,10 @@ func main() { switch cloudprovider.FromString(os.Getenv(constellationCSP)) { case cloudprovider.AWS: - pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.AWSPCRSelection) + measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AWSPCRSelection) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } - pcrsJSON, err := json.Marshal(pcrs) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs") - } issuer = initserver.NewIssuerWrapper(aws.NewIssuer(), vmtype.Unknown, nil) @@ -104,13 +99,13 @@ func main() { clusterInitJoiner = kubernetes.New( "aws", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), - metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, + metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, ) openTPM = vtpm.OpenVTPM fs = afero.NewOsFs() case cloudprovider.GCP: - pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.GCPPCRSelection) + measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.GCPPCRSelection) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } @@ -129,20 +124,16 @@ func main() { } metadataAPI = metadata - pcrsJSON, err := json.Marshal(pcrs) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs") - } clusterInitJoiner = kubernetes.New( "gcp", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), - metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, + metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, ) openTPM = vtpm.OpenVTPM fs = afero.NewOsFs() log.Infof("Added load balancer IP to routing table") case cloudprovider.Azure: - pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.AzurePCRSelection) + measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AzurePCRSelection) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } @@ -163,20 +154,16 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger") } metadataAPI = metadata - pcrsJSON, err := json.Marshal(pcrs) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs") - } clusterInitJoiner = kubernetes.New( "azure", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), - metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, + metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, ) openTPM = vtpm.OpenVTPM fs = afero.NewOsFs() case cloudprovider.QEMU: - pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.QEMUPCRSelection) + measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.QEMUPCRSelection) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } @@ -185,13 +172,9 @@ func main() { cloudLogger = qemucloud.NewLogger() metadata := qemucloud.New() - pcrsJSON, err := json.Marshal(pcrs) - if err != nil { - log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs") - } clusterInitJoiner = kubernetes.New( "qemu", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), - metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, + metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, ) metadataAPI = metadata diff --git a/bootstrapper/initproto/init.proto b/bootstrapper/initproto/init.proto index 39739c21e..e916b463b 100644 --- a/bootstrapper/initproto/init.proto +++ b/bootstrapper/initproto/init.proto @@ -27,9 +27,9 @@ message InitRequest { } message InitResponse { - bytes kubeconfig = 1; - bytes owner_id = 2; - bytes cluster_id = 3; + bytes kubeconfig = 1; + bytes owner_id = 2; + bytes cluster_id = 3; } message KubernetesComponent { diff --git a/bootstrapper/internal/kubernetes/kubernetes.go b/bootstrapper/internal/kubernetes/kubernetes.go index 77cacff72..4f4aefa99 100644 --- a/bootstrapper/internal/kubernetes/kubernetes.go +++ b/bootstrapper/internal/kubernetes/kubernetes.go @@ -20,6 +20,7 @@ import ( "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/k8sapi" kubewaiter "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/kubeWaiter" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/cloud/azureshared" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" @@ -52,33 +53,33 @@ type kubeAPIWaiter interface { // KubeWrapper implements Cluster interface. type KubeWrapper struct { - cloudProvider string - clusterUtil clusterUtil - helmClient helmClient - kubeAPIWaiter kubeAPIWaiter - configProvider configurationProvider - client k8sapi.Client - kubeconfigReader configReader - providerMetadata ProviderMetadata - initialMeasurementsJSON []byte - getIPAddr func() (string, error) + cloudProvider string + clusterUtil clusterUtil + helmClient helmClient + kubeAPIWaiter kubeAPIWaiter + configProvider configurationProvider + client k8sapi.Client + kubeconfigReader configReader + providerMetadata ProviderMetadata + initialMeasurements measurements.M + getIPAddr func() (string, error) } // New creates a new KubeWrapper with real values. func New(cloudProvider string, clusterUtil clusterUtil, configProvider configurationProvider, client k8sapi.Client, - providerMetadata ProviderMetadata, initialMeasurementsJSON []byte, helmClient helmClient, kubeAPIWaiter kubeAPIWaiter, + providerMetadata ProviderMetadata, measurements measurements.M, helmClient helmClient, kubeAPIWaiter kubeAPIWaiter, ) *KubeWrapper { return &KubeWrapper{ - cloudProvider: cloudProvider, - clusterUtil: clusterUtil, - helmClient: helmClient, - kubeAPIWaiter: kubeAPIWaiter, - configProvider: configProvider, - client: client, - kubeconfigReader: &KubeconfigReader{fs: afero.Afero{Fs: afero.NewOsFs()}}, - providerMetadata: providerMetadata, - initialMeasurementsJSON: initialMeasurementsJSON, - getIPAddr: getIPAddr, + cloudProvider: cloudProvider, + clusterUtil: clusterUtil, + helmClient: helmClient, + kubeAPIWaiter: kubeAPIWaiter, + configProvider: configProvider, + client: client, + kubeconfigReader: &KubeconfigReader{fs: afero.Afero{Fs: afero.NewOsFs()}}, + providerMetadata: providerMetadata, + initialMeasurements: measurements, + getIPAddr: getIPAddr, } } @@ -187,7 +188,21 @@ func (k *KubeWrapper) InitCluster( } else { controlPlaneIP = controlPlaneEndpoint } - serviceConfig := constellationServicesConfig{k.initialMeasurementsJSON, idKeyDigest, measurementSalt, subnetworkPodCIDR, cloudServiceAccountURI, controlPlaneIP} + if err := k.initialMeasurements.SetEnforced(enforcedPCRs); err != nil { + return nil, err + } + measurementsJSON, err := json.Marshal(k.initialMeasurements) + if err != nil { + return nil, fmt.Errorf("marshaling initial measurements: %w", err) + } + serviceConfig := constellationServicesConfig{ + initialMeasurementsJSON: measurementsJSON, + idkeydigest: idKeyDigest, + measurementSalt: measurementSalt, + subnetworkPodCIDR: subnetworkPodCIDR, + cloudServiceAccountURI: cloudServiceAccountURI, + loadBalancerIP: controlPlaneIP, + } extraVals, err := k.setupExtraVals(ctx, serviceConfig) if err != nil { return nil, fmt.Errorf("setting up extraVals: %w", err) diff --git a/cli/internal/cloudcmd/upgrade.go b/cli/internal/cloudcmd/upgrade.go index f6afc1101..9ac998040 100644 --- a/cli/internal/cloudcmd/upgrade.go +++ b/cli/internal/cloudcmd/upgrade.go @@ -112,6 +112,13 @@ func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measu return nil } + // don't allow potential security downgrades by setting the warnOnly flag to true + for k, newM := range newMeasurements { + if currentM, ok := currentMeasurements[k]; ok && !currentM.WarnOnly && newM.WarnOnly { + return fmt.Errorf("setting enforced measurement %d to warn only: not allowed", k) + } + } + // backup of previous measurements existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename] diff --git a/cli/internal/cloudcmd/upgrade_test.go b/cli/internal/cloudcmd/upgrade_test.go index 905ee1f5b..13469ad88 100644 --- a/cli/internal/cloudcmd/upgrade_test.go +++ b/cli/internal/cloudcmd/upgrade_test.go @@ -33,12 +33,12 @@ func TestUpdateMeasurements(t *testing.T) { updater: &stubMeasurementsUpdater{ oldMeasurements: &corev1.ConfigMap{ Data: map[string]string{ - constants.MeasurementsFilename: `{"0":"AAAAAA=="}`, + constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`, }, }, }, newMeasurements: measurements.M{ - 0: []byte("1"), + 0: measurements.WithAllBytes(0xBB, false), }, wantUpdate: true, }, @@ -46,14 +46,40 @@ func TestUpdateMeasurements(t *testing.T) { updater: &stubMeasurementsUpdater{ oldMeasurements: &corev1.ConfigMap{ Data: map[string]string{ - constants.MeasurementsFilename: `{"0":"MQ=="}`, + constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`, }, }, }, newMeasurements: measurements.M{ - 0: []byte("1"), + 0: measurements.WithAllBytes(0xAA, false), }, }, + "trying to set warnOnly to true results in error": { + updater: &stubMeasurementsUpdater{ + oldMeasurements: &corev1.ConfigMap{ + Data: map[string]string{ + constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`, + }, + }, + }, + newMeasurements: measurements.M{ + 0: measurements.WithAllBytes(0xAA, true), + }, + wantErr: true, + }, + "setting warnOnly to false is allowed": { + updater: &stubMeasurementsUpdater{ + oldMeasurements: &corev1.ConfigMap{ + Data: map[string]string{ + constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":true}}`, + }, + }, + }, + newMeasurements: measurements.M{ + 0: measurements.WithAllBytes(0xAA, false), + }, + wantUpdate: true, + }, "getCurrent error": { updater: &stubMeasurementsUpdater{getErr: someErr}, wantErr: true, @@ -62,7 +88,7 @@ func TestUpdateMeasurements(t *testing.T) { updater: &stubMeasurementsUpdater{ oldMeasurements: &corev1.ConfigMap{ Data: map[string]string{ - constants.MeasurementsFilename: `{"0":"AAAAAA=="}`, + constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`, }, }, updateErr: someErr, @@ -82,7 +108,7 @@ func TestUpdateMeasurements(t *testing.T) { err := upgrader.updateMeasurements(context.Background(), tc.newMeasurements) if tc.wantErr { - assert.ErrorIs(err, someErr) + assert.Error(err) return } diff --git a/cli/internal/cloudcmd/validators.go b/cli/internal/cloudcmd/validators.go index 0019f1a6b..38b7fbc60 100644 --- a/cli/internal/cloudcmd/validators.go +++ b/cli/internal/cloudcmd/validators.go @@ -20,17 +20,16 @@ import ( "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" - "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/spf13/cobra" + "go.uber.org/multierr" ) // Validator validates Platform Configuration Registers (PCRs). type Validator struct { provider cloudprovider.Provider pcrs measurements.M - enforcedPCRs []uint32 idkeydigest []byte enforceIDKeyDigest bool azureCVM bool @@ -65,41 +64,44 @@ func NewValidator(provider cloudprovider.Provider, conf *config.Config) (*Valida // UpdateInitPCRs sets the owner and cluster PCR values. func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error { - if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil { + if err := v.updatePCR(uint32(measurements.PCRIndexOwnerID), ownerID); err != nil { return err } - return v.updatePCR(uint32(vtpm.PCRIndexClusterID), clusterID) + return v.updatePCR(uint32(measurements.PCRIndexClusterID), clusterID) } -// updatePCR adds a new entry to the pcr map of v, or removes the key if the input is an empty string. +// updatePCR adds a new entry to the measurements of v, or removes the key if the input is an empty string. // -// When adding, the input is first decoded from base64. +// When adding, the input is first decoded from hex or base64. // We then calculate the expected PCR by hashing the input using SHA256, // appending expected PCR for initialization, and then hashing once more. func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error { if encoded == "" { delete(v.pcrs, pcrIndex) - - // remove enforced PCR if it exists - for i, enforcedIdx := range v.enforcedPCRs { - if enforcedIdx == pcrIndex { - v.enforcedPCRs[i] = v.enforcedPCRs[len(v.enforcedPCRs)-1] - v.enforcedPCRs = v.enforcedPCRs[:len(v.enforcedPCRs)-1] - break - } - } - return nil } - decoded, err := base64.StdEncoding.DecodeString(encoded) + + // decode from hex or base64 + decoded, err := hex.DecodeString(encoded) if err != nil { - return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err) + hexErr := err + decoded, err = base64.StdEncoding.DecodeString(encoded) + if err != nil { + return multierr.Append( + fmt.Errorf("input [%s] is not hex encoded: %w", encoded, hexErr), + fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err), + ) + } } // new_pcr_value := hash(old_pcr_value || data_to_extend) // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input hashedInput := sha256.Sum256(decoded) - expectedPcr := sha256.Sum256(append(v.pcrs[pcrIndex], hashedInput[:]...)) - v.pcrs[pcrIndex] = expectedPcr[:] + oldExpected := v.pcrs[pcrIndex].Expected + expectedPcr := sha256.Sum256(append(oldExpected[:], hashedInput[:]...)) + v.pcrs[pcrIndex] = measurements.Measurement{ + Expected: expectedPcr, + WarnOnly: v.pcrs[pcrIndex].WarnOnly, + } return nil } @@ -107,35 +109,27 @@ func (v *Validator) setPCRs(config *config.Config) error { switch v.provider { case cloudprovider.AWS: awsPCRs := config.Provider.AWS.Measurements - enforcedPCRs := config.Provider.AWS.EnforcedMeasurements - if err := v.checkPCRs(awsPCRs, enforcedPCRs); err != nil { - return err + if len(awsPCRs) == 0 { + return errors.New("no expected measurement provided") } - v.enforcedPCRs = enforcedPCRs v.pcrs = awsPCRs case cloudprovider.Azure: azurePCRs := config.Provider.Azure.Measurements - enforcedPCRs := config.Provider.Azure.EnforcedMeasurements - if err := v.checkPCRs(azurePCRs, enforcedPCRs); err != nil { - return err + if len(azurePCRs) == 0 { + return errors.New("no expected measurement provided") } - v.enforcedPCRs = enforcedPCRs v.pcrs = azurePCRs case cloudprovider.GCP: gcpPCRs := config.Provider.GCP.Measurements - enforcedPCRs := config.Provider.GCP.EnforcedMeasurements - if err := v.checkPCRs(gcpPCRs, enforcedPCRs); err != nil { - return err + if len(gcpPCRs) == 0 { + return errors.New("no expected measurement provided") } - v.enforcedPCRs = enforcedPCRs v.pcrs = gcpPCRs case cloudprovider.QEMU: qemuPCRs := config.Provider.QEMU.Measurements - enforcedPCRs := config.Provider.QEMU.EnforcedMeasurements - if err := v.checkPCRs(qemuPCRs, enforcedPCRs); err != nil { - return err + if len(qemuPCRs) == 0 { + return errors.New("no expected measurement provided") } - v.enforcedPCRs = enforcedPCRs v.pcrs = qemuPCRs } return nil @@ -156,37 +150,20 @@ func (v *Validator) updateValidator(cmd *cobra.Command) { log := warnLogger{cmd: cmd} switch v.provider { case cloudprovider.GCP: - v.validator = gcp.NewValidator(v.pcrs, v.enforcedPCRs, log) + v.validator = gcp.NewValidator(v.pcrs, log) case cloudprovider.Azure: if v.azureCVM { - v.validator = snp.NewValidator(v.pcrs, v.enforcedPCRs, v.idkeydigest, v.enforceIDKeyDigest, log) + v.validator = snp.NewValidator(v.pcrs, v.idkeydigest, v.enforceIDKeyDigest, log) } else { - v.validator = trustedlaunch.NewValidator(v.pcrs, v.enforcedPCRs, log) + v.validator = trustedlaunch.NewValidator(v.pcrs, log) } case cloudprovider.AWS: - v.validator = aws.NewValidator(v.pcrs, v.enforcedPCRs, log) + v.validator = aws.NewValidator(v.pcrs, log) case cloudprovider.QEMU: - v.validator = qemu.NewValidator(v.pcrs, v.enforcedPCRs, log) + v.validator = qemu.NewValidator(v.pcrs, log) } } -func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error { - if len(pcrs) == 0 { - return errors.New("no PCR values provided") - } - for k, v := range pcrs { - if len(v) != 32 { - return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v)) - } - } - for _, v := range enforcedPCRs { - if _, ok := pcrs[v]; !ok { - return fmt.Errorf("bad config: PCR[%d] is enforced, but no expected measurement is provided", v) - } - } - return nil -} - // warnLogger implements logging of warnings for validators. type warnLogger struct { cmd *cobra.Command diff --git a/cli/internal/cloudcmd/validators_test.go b/cli/internal/cloudcmd/validators_test.go index 654c684bb..978266514 100644 --- a/cli/internal/cloudcmd/validators_test.go +++ b/cli/internal/cloudcmd/validators_test.go @@ -7,9 +7,9 @@ SPDX-License-Identifier: AGPL-3.0-only package cloudcmd import ( - "bytes" "crypto/sha256" "encoding/base64" + "encoding/hex" "testing" "github.com/edgelesssys/constellation/v2/internal/atls" @@ -18,7 +18,6 @@ import ( "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" - "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/spf13/cobra" @@ -27,12 +26,12 @@ import ( func TestNewValidator(t *testing.T) { testPCRs := measurements.M{ - 0: measurements.PCRWithAllBytes(0x00), - 1: measurements.PCRWithAllBytes(0xFF), - 2: measurements.PCRWithAllBytes(0x00), - 3: measurements.PCRWithAllBytes(0xFF), - 4: measurements.PCRWithAllBytes(0x00), - 5: measurements.PCRWithAllBytes(0x00), + 0: measurements.WithAllBytes(0x00, false), + 1: measurements.WithAllBytes(0xFF, false), + 2: measurements.WithAllBytes(0x00, false), + 3: measurements.WithAllBytes(0xFF, false), + 4: measurements.WithAllBytes(0x00, false), + 5: measurements.WithAllBytes(0x00, false), } testCases := map[string]struct { @@ -67,13 +66,6 @@ func TestNewValidator(t *testing.T) { pcrs: measurements.M{}, wantErr: true, }, - "invalid pcr length": { - provider: cloudprovider.GCP, - pcrs: measurements.M{ - 0: bytes.Repeat([]byte{0x00}, 31), - }, - wantErr: true, - }, "unknown provider": { provider: cloudprovider.Unknown, pcrs: testPCRs, @@ -126,19 +118,19 @@ func TestNewValidator(t *testing.T) { func TestValidatorV(t *testing.T) { newTestPCRs := func() measurements.M { return measurements.M{ - 0: measurements.PCRWithAllBytes(0x00), - 1: measurements.PCRWithAllBytes(0x00), - 2: measurements.PCRWithAllBytes(0x00), - 3: measurements.PCRWithAllBytes(0x00), - 4: measurements.PCRWithAllBytes(0x00), - 5: measurements.PCRWithAllBytes(0x00), - 6: measurements.PCRWithAllBytes(0x00), - 7: measurements.PCRWithAllBytes(0x00), - 8: measurements.PCRWithAllBytes(0x00), - 9: measurements.PCRWithAllBytes(0x00), - 10: measurements.PCRWithAllBytes(0x00), - 11: measurements.PCRWithAllBytes(0x00), - 12: measurements.PCRWithAllBytes(0x00), + 0: measurements.WithAllBytes(0x00, true), + 1: measurements.WithAllBytes(0x00, true), + 2: measurements.WithAllBytes(0x00, true), + 3: measurements.WithAllBytes(0x00, true), + 4: measurements.WithAllBytes(0x00, true), + 5: measurements.WithAllBytes(0x00, true), + 6: measurements.WithAllBytes(0x00, true), + 7: measurements.WithAllBytes(0x00, true), + 8: measurements.WithAllBytes(0x00, true), + 9: measurements.WithAllBytes(0x00, true), + 10: measurements.WithAllBytes(0x00, true), + 11: measurements.WithAllBytes(0x00, true), + 12: measurements.WithAllBytes(0x00, true), } } @@ -151,23 +143,23 @@ func TestValidatorV(t *testing.T) { "gcp": { provider: cloudprovider.GCP, pcrs: newTestPCRs(), - wantVs: gcp.NewValidator(newTestPCRs(), nil, nil), + wantVs: gcp.NewValidator(newTestPCRs(), nil), }, "azure cvm": { provider: cloudprovider.Azure, pcrs: newTestPCRs(), - wantVs: snp.NewValidator(newTestPCRs(), nil, nil, false, nil), + wantVs: snp.NewValidator(newTestPCRs(), nil, false, nil), azureCVM: true, }, "azure trusted launch": { provider: cloudprovider.Azure, pcrs: newTestPCRs(), - wantVs: trustedlaunch.NewValidator(newTestPCRs(), nil, nil), + wantVs: trustedlaunch.NewValidator(newTestPCRs(), nil), }, "qemu": { provider: cloudprovider.QEMU, pcrs: newTestPCRs(), - wantVs: qemu.NewValidator(newTestPCRs(), nil, nil), + wantVs: qemu.NewValidator(newTestPCRs(), nil), }, } @@ -185,37 +177,37 @@ func TestValidatorV(t *testing.T) { } func TestValidatorUpdateInitPCRs(t *testing.T) { - zero := []byte("00000000000000000000000000000000") - one := []byte("11111111111111111111111111111111") - one64 := base64.StdEncoding.EncodeToString(one) - oneHash := sha256.Sum256(one) - pcrZeroUpdatedOne := sha256.Sum256(append(zero, oneHash[:]...)) - newTestPCRs := func() map[uint32][]byte { - return map[uint32][]byte{ - 0: zero, - 1: zero, - 2: zero, - 3: zero, - 4: zero, - 5: zero, - 6: zero, - 7: zero, - 8: zero, - 9: zero, - 10: zero, - 11: zero, - 12: zero, - 13: zero, - 14: zero, - 15: zero, - 16: zero, - 17: one, - 18: one, - 19: one, - 20: one, - 21: one, - 22: one, - 23: zero, + zero := measurements.WithAllBytes(0x00, true) + one := measurements.WithAllBytes(0x11, true) + one64 := base64.StdEncoding.EncodeToString(one.Expected[:]) + oneHash := sha256.Sum256(one.Expected[:]) + pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...)) + newTestPCRs := func() measurements.M { + return measurements.M{ + 0: measurements.WithAllBytes(0x00, true), + 1: measurements.WithAllBytes(0x00, true), + 2: measurements.WithAllBytes(0x00, true), + 3: measurements.WithAllBytes(0x00, true), + 4: measurements.WithAllBytes(0x00, true), + 5: measurements.WithAllBytes(0x00, true), + 6: measurements.WithAllBytes(0x00, true), + 7: measurements.WithAllBytes(0x00, true), + 8: measurements.WithAllBytes(0x00, true), + 9: measurements.WithAllBytes(0x00, true), + 10: measurements.WithAllBytes(0x00, true), + 11: measurements.WithAllBytes(0x00, true), + 12: measurements.WithAllBytes(0x00, true), + 13: measurements.WithAllBytes(0x00, true), + 14: measurements.WithAllBytes(0x00, true), + 15: measurements.WithAllBytes(0x00, true), + 16: measurements.WithAllBytes(0x00, true), + 17: measurements.WithAllBytes(0x11, true), + 18: measurements.WithAllBytes(0x11, true), + 19: measurements.WithAllBytes(0x11, true), + 20: measurements.WithAllBytes(0x11, true), + 21: measurements.WithAllBytes(0x11, true), + 22: measurements.WithAllBytes(0x11, true), + 23: measurements.WithAllBytes(0x00, true), } } @@ -285,25 +277,25 @@ func TestValidatorUpdateInitPCRs(t *testing.T) { assert.NoError(err) for i := 0; i < len(tc.pcrs); i++ { switch { - case i == int(vtpm.PCRIndexClusterID) && tc.clusterID == "": + case i == int(measurements.PCRIndexClusterID) && tc.clusterID == "": // should be deleted _, ok := validators.pcrs[uint32(i)] assert.False(ok) - case i == int(vtpm.PCRIndexClusterID): + case i == int(measurements.PCRIndexClusterID): pcr, ok := validators.pcrs[uint32(i)] assert.True(ok) - assert.Equal(pcrZeroUpdatedOne[:], pcr) + assert.Equal(pcrZeroUpdatedOne, pcr.Expected) - case i == int(vtpm.PCRIndexOwnerID) && tc.ownerID == "": + case i == int(measurements.PCRIndexOwnerID) && tc.ownerID == "": // should be deleted _, ok := validators.pcrs[uint32(i)] assert.False(ok) - case i == int(vtpm.PCRIndexOwnerID): + case i == int(measurements.PCRIndexOwnerID): pcr, ok := validators.pcrs[uint32(i)] assert.True(ok) - assert.Equal(pcrZeroUpdatedOne[:], pcr) + assert.Equal(pcrZeroUpdatedOne, pcr.Expected) default: if i >= 17 && i <= 22 { @@ -320,8 +312,8 @@ func TestValidatorUpdateInitPCRs(t *testing.T) { func TestUpdatePCR(t *testing.T) { emptyMap := measurements.M{} defaultMap := measurements.M{ - 0: measurements.PCRWithAllBytes(0xAA), - 1: measurements.PCRWithAllBytes(0xBB), + 0: measurements.WithAllBytes(0xAA, false), + 1: measurements.WithAllBytes(0xBB, false), } testCases := map[string]struct { @@ -359,6 +351,20 @@ func TestUpdatePCR(t *testing.T) { wantEntries: len(defaultMap) + 1, wantErr: false, }, + "hex input, empty map": { + pcrMap: emptyMap, + pcrIndex: 10, + encoded: hex.EncodeToString([]byte("Constellation")), + wantEntries: 1, + wantErr: false, + }, + "hex input, default map": { + pcrMap: defaultMap, + pcrIndex: 10, + encoded: hex.EncodeToString([]byte("Constellation")), + wantEntries: len(defaultMap) + 1, + wantErr: false, + }, "unencoded input, empty map": { pcrMap: emptyMap, pcrIndex: 10, @@ -403,9 +409,6 @@ func TestUpdatePCR(t *testing.T) { assert.NoError(err) } assert.Len(pcrs, tc.wantEntries) - for _, v := range pcrs { - assert.Len(v, 32) - } }) } } diff --git a/cli/internal/cmd/configfetchmeasurements.go b/cli/internal/cmd/configfetchmeasurements.go index e60c2d08d..ea344757b 100644 --- a/cli/internal/cmd/configfetchmeasurements.go +++ b/cli/internal/cmd/configfetchmeasurements.go @@ -137,7 +137,7 @@ func (f *fetchMeasurementsFlags) updateURLs(ctx context.Context, conf *config.Co if f.measurementsURL == nil { // TODO(AB#2644): resolve image version to reference - parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.yaml") + parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.json") if err != nil { return err } @@ -145,7 +145,7 @@ func (f *fetchMeasurementsFlags) updateURLs(ctx context.Context, conf *config.Co } if f.signatureURL == nil { - parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.yaml.sig") + parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.json.sig") if err != nil { return err } diff --git a/cli/internal/cmd/configfetchmeasurements_test.go b/cli/internal/cmd/configfetchmeasurements_test.go index aa5645f22..2b56559d9 100644 --- a/cli/internal/cmd/configfetchmeasurements_test.go +++ b/cli/internal/cmd/configfetchmeasurements_test.go @@ -109,17 +109,17 @@ func TestUpdateURLs(t *testing.T) { }, }, flags: &fetchMeasurementsFlags{}, - wantMeasurementsURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.yaml", - wantMeasurementsSigURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.yaml.sig", + wantMeasurementsURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.json", + wantMeasurementsSigURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.json.sig", }, "both set by user": { conf: &config.Config{}, flags: &fetchMeasurementsFlags{ - measurementsURL: urlMustParse("get.my/measurements.yaml"), - signatureURL: urlMustParse("get.my/measurements.yaml.sig"), + measurementsURL: urlMustParse("get.my/measurements.json"), + signatureURL: urlMustParse("get.my/measurements.json.sig"), }, - wantMeasurementsURL: "get.my/measurements.yaml", - wantMeasurementsSigURL: "get.my/measurements.yaml.sig", + wantMeasurementsURL: "get.my/measurements.json", + wantMeasurementsSigURL: "get.my/measurements.json.sig", }, } @@ -164,14 +164,14 @@ func TestConfigFetchMeasurements(t *testing.T) { signature := "MEUCIFdJ5dH6HDywxQWTUh9Bw77wMrq0mNCUjMQGYP+6QsVmAiEAmazj/L7rFGA4/Gz8y+kI5h5E5cDgc3brihvXBKF6qZA=" client := newTestClient(func(req *http.Request) *http.Response { - if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.yaml" { + if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.json" { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString(measurements)), Header: make(http.Header), } } - if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.yaml.sig" { + if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.json.sig" { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(bytes.NewBufferString(signature)), diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index 196c16165..12fc9f02c 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -8,7 +8,7 @@ package cmd import ( "context" - "encoding/base64" + "encoding/hex" "fmt" "io" "net" @@ -134,7 +134,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator KubernetesVersion: conf.KubernetesVersion, KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToProto(), HelmDeployments: helmDeployments, - EnforcedPcrs: conf.GetEnforcedPCRs(), + EnforcedPcrs: conf.EnforcedPCRs(), EnforceIdkeydigest: conf.EnforcesIDKeyDigest(), ConformanceMode: flags.conformance, } @@ -190,8 +190,8 @@ func (d *initDoer) Do(ctx context.Context) error { func writeOutput(idFile clusterid.File, resp *initproto.InitResponse, wr io.Writer, fileHandler file.Handler) error { fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n") - ownerID := base64.StdEncoding.EncodeToString(resp.OwnerId) - clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) + ownerID := hex.EncodeToString(resp.OwnerId) + clusterID := hex.EncodeToString(resp.ClusterId) tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0) // writeRow(tw, "Constellation cluster's owner identifier", ownerID) diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 10629626e..61f6ddefd 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -9,7 +9,7 @@ package cmd import ( "bytes" "context" - "encoding/base64" + "encoding/hex" "encoding/json" "errors" "net" @@ -101,16 +101,6 @@ func TestInitialize(t *testing.T) { initServerAPI: &stubInitServer{initErr: someErr}, wantErr: true, }, - "fail missing enforced PCR": { - provider: cloudprovider.GCP, - idFile: &clusterid.File{IP: "192.0.2.1"}, - configMutator: func(c *config.Config) { - c.Provider.GCP.EnforcedMeasurements = append(c.Provider.GCP.EnforcedMeasurements, 10) - }, - serviceAccKey: gcpServiceAccKey, - initServerAPI: &stubInitServer{initResp: testInitResp}, - wantErr: true, - }, } for name, tc := range testCases { @@ -174,7 +164,7 @@ func TestInitialize(t *testing.T) { } require.NoError(err) // assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) - assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("clusterID"))) + assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID"))) var secret masterSecret assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret)) assert.NotEmpty(secret.Key) @@ -192,8 +182,8 @@ func TestWriteOutput(t *testing.T) { Kubeconfig: []byte("kubeconfig"), } - ownerID := base64.StdEncoding.EncodeToString(resp.OwnerId) - clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) + ownerID := hex.EncodeToString(resp.OwnerId) + clusterID := hex.EncodeToString(resp.ClusterId) expectedIDFile := clusterid.File{ ClusterID: clusterID, @@ -361,11 +351,11 @@ func TestAttestation(t *testing.T) { issuer := &testIssuer{ Getter: oid.QEMU{}, - pcrs: measurements.M{ - 0: measurements.PCRWithAllBytes(0xFF), - 1: measurements.PCRWithAllBytes(0xFF), - 2: measurements.PCRWithAllBytes(0xFF), - 3: measurements.PCRWithAllBytes(0xFF), + pcrs: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0xFF}, 32), + 1: bytes.Repeat([]byte{0xFF}, 32), + 2: bytes.Repeat([]byte{0xFF}, 32), + 3: bytes.Repeat([]byte{0xFF}, 32), }, } serverCreds := atlscredentials.New(issuer, nil) @@ -390,13 +380,13 @@ func TestAttestation(t *testing.T) { cfg := config.Default() cfg.Image = "image" cfg.RemoveProviderExcept(cloudprovider.QEMU) - cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00) - cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11) - cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22) - cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33) - cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44) - cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99) - cfg.Provider.QEMU.Measurements[12] = measurements.PCRWithAllBytes(0xcc) + cfg.Provider.QEMU.Measurements[0] = measurements.WithAllBytes(0x00, false) + cfg.Provider.QEMU.Measurements[1] = measurements.WithAllBytes(0x11, false) + cfg.Provider.QEMU.Measurements[2] = measurements.WithAllBytes(0x22, false) + cfg.Provider.QEMU.Measurements[3] = measurements.WithAllBytes(0x33, false) + cfg.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false) + cfg.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x99, false) + cfg.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone)) ctx := context.Background() @@ -418,14 +408,14 @@ type testValidator struct { func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { var attestation struct { UserData []byte - PCRs measurements.M + PCRs map[uint32][]byte } if err := json.Unmarshal(attDoc, &attestation); err != nil { return nil, err } for k, pcr := range v.pcrs { - if !bytes.Equal(attestation.PCRs[k], pcr) { + if !bytes.Equal(attestation.PCRs[k], pcr.Expected[:]) { return nil, errors.New("invalid PCR value") } } @@ -434,14 +424,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { type testIssuer struct { oid.Getter - pcrs measurements.M + pcrs map[uint32][]byte } func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { return json.Marshal( struct { UserData []byte - PCRs measurements.M + PCRs map[uint32][]byte }{ UserData: userData, PCRs: i.pcrs, @@ -474,21 +464,21 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs conf.Provider.Azure.ResourceGroup = "test-resource-group" conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab" conf.Provider.Azure.ClientSecretValue = "test-client-secret" - conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44) - conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11) - conf.Provider.Azure.Measurements[12] = measurements.PCRWithAllBytes(0xcc) + conf.Provider.Azure.Measurements[4] = measurements.WithAllBytes(0x44, false) + conf.Provider.Azure.Measurements[9] = measurements.WithAllBytes(0x11, false) + conf.Provider.Azure.Measurements[12] = measurements.WithAllBytes(0xcc, false) case cloudprovider.GCP: conf.Provider.GCP.Region = "test-region" conf.Provider.GCP.Project = "test-project" conf.Provider.GCP.Zone = "test-zone" conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path" - conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44) - conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11) - conf.Provider.GCP.Measurements[12] = measurements.PCRWithAllBytes(0xcc) + conf.Provider.GCP.Measurements[4] = measurements.WithAllBytes(0x44, false) + conf.Provider.GCP.Measurements[9] = measurements.WithAllBytes(0x11, false) + conf.Provider.GCP.Measurements[12] = measurements.WithAllBytes(0xcc, false) case cloudprovider.QEMU: - conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44) - conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11) - conf.Provider.QEMU.Measurements[12] = measurements.PCRWithAllBytes(0xcc) + conf.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false) + conf.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x11, false) + conf.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false) } conf.RemoveProviderExcept(csp) diff --git a/cli/internal/cmd/upgradeplan.go b/cli/internal/cmd/upgradeplan.go index a59cce489..320e646bd 100644 --- a/cli/internal/cmd/upgradeplan.go +++ b/cli/internal/cmd/upgradeplan.go @@ -181,12 +181,12 @@ func getCompatibleImages(csp cloudprovider.Provider, currentVersion string, imag // getCompatibleImageMeasurements retrieves the expected measurements for each image. func getCompatibleImageMeasurements(ctx context.Context, cmd *cobra.Command, client *http.Client, rekor rekorVerifier, pubK []byte, images map[string]config.UpgradeConfig) error { for idx, img := range images { - measurementsURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.yaml") + measurementsURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.json") if err != nil { return err } - signatureURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.yaml.sig") + signatureURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.json.sig") if err != nil { return err } diff --git a/cli/internal/cmd/upgradeplan_test.go b/cli/internal/cmd/upgradeplan_test.go index 2c9bd4910..ea3051a87 100644 --- a/cli/internal/cmd/upgradeplan_test.go +++ b/cli/internal/cmd/upgradeplan_test.go @@ -25,7 +25,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/mod/semver" - "gopkg.in/yaml.v2" + "gopkg.in/yaml.v3" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" ) @@ -248,14 +248,14 @@ func TestGetCompatibleImageMeasurements(t *testing.T) { } client := newTestClient(func(req *http.Request) *http.Response { - if strings.HasSuffix(req.URL.String(), "/measurements.yaml") { + if strings.HasSuffix(req.URL.String(), "/measurements.json") { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")), Header: make(http.Header), } } - if strings.HasSuffix(req.URL.String(), "/measurements.yaml.sig") { + if strings.HasSuffix(req.URL.String(), "/measurements.json.sig") { return &http.Response{ StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")), @@ -470,14 +470,14 @@ func TestUpgradePlan(t *testing.T) { Header: make(http.Header), } } - if strings.HasSuffix(req.URL.String(), "/measurements.yaml") { + if strings.HasSuffix(req.URL.String(), "/measurements.json") { return &http.Response{ StatusCode: tc.measurementsFetchStatus, Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")), Header: make(http.Header), } } - if strings.HasSuffix(req.URL.String(), "/measurements.yaml.sig") { + if strings.HasSuffix(req.URL.String(), "/measurements.json.sig") { return &http.Response{ StatusCode: tc.measurementsFetchStatus, Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")), diff --git a/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/templates/configmap.yaml b/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/templates/configmap.yaml index 7f0493084..41867f260 100644 --- a/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/templates/configmap.yaml +++ b/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/templates/configmap.yaml @@ -5,7 +5,6 @@ metadata: namespace: {{ .Release.Namespace }} data: # mustToJson is required so the json-strings passed from go are of type string in the rendered yaml. - enforcedPCRs: {{ .Values.enforcedPCRs | mustToJson }} measurements: {{ .Values.measurements | mustToJson }} {{- if eq .Values.csp "Azure" }} # ConfigMap.data is of type map[string]string. quote will not quote a quoted string. diff --git a/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/values.schema.json b/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/values.schema.json index 81c5baab7..a941a66e7 100644 --- a/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/values.schema.json +++ b/cli/internal/helm/charts/edgeless/constellation-services/charts/join-service/values.schema.json @@ -5,15 +5,10 @@ "description": "CSP to which the chart is deployed.", "enum": ["Azure", "GCP", "AWS", "QEMU"] }, - "enforcedPCRs": { - "description": "JSON-string to describe the enforced PCRs.", - "type": "string", - "examples": ["[1, 15]"] - }, "measurements": { "description": "JSON-string to describe the expected measurements.", "type": "string", - "examples": ["{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"] + "examples": ["{'1':{'expected':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','warnOnly':true},'15':{'expected':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=','warnOnly':true}}"] }, "enforceIdKeyDigest": { "description": "Whether or not idkeydigest should be enforced during attestation on azure.", @@ -37,7 +32,6 @@ }, "required": [ "csp", - "enforcedPCRs", "measurements", "measurementSalt", "image" diff --git a/cli/internal/helm/loader.go b/cli/internal/helm/loader.go index 121afadfd..420578dae 100644 --- a/cli/internal/helm/loader.go +++ b/cli/internal/helm/loader.go @@ -76,9 +76,7 @@ func New(csp cloudprovider.Provider, k8sVersion versions.ValidK8sVersion) *Chart // Load the embedded helm charts. func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSecret, salt []byte) ([]byte, error) { - csp := config.GetProvider() - - ciliumRelease, err := i.loadCilium(csp, conformanceMode) + ciliumRelease, err := i.loadCilium(config.GetProvider(), conformanceMode) if err != nil { return nil, fmt.Errorf("loading cilium: %w", err) } @@ -88,7 +86,7 @@ func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSe return nil, fmt.Errorf("loading cilium: %w", err) } - operatorRelease, err := i.loadOperators(csp) + operatorRelease, err := i.loadOperators(config.GetProvider()) if err != nil { return nil, fmt.Errorf("loading operators: %w", err) } @@ -350,11 +348,6 @@ func (i *ChartLoader) loadConstellationServicesHelper(config *config.Config, mas return nil, nil, fmt.Errorf("loading constellation-services chart: %w", err) } - enforcedPCRsJSON, err := json.Marshal(config.GetEnforcedPCRs()) - if err != nil { - return nil, nil, fmt.Errorf("marshaling enforcedPCRs: %w", err) - } - csp := config.GetProvider() values := map[string]any{ "global": map[string]any{ @@ -374,9 +367,8 @@ func (i *ChartLoader) loadConstellationServicesHelper(config *config.Config, mas "measurementsFilename": constants.MeasurementsFilename, }, "join-service": map[string]any{ - "csp": csp.String(), - "enforcedPCRs": string(enforcedPCRsJSON), - "image": i.joinServiceImage, + "csp": csp.String(), + "image": i.joinServiceImage, }, "ccm": map[string]any{ "csp": csp.String(), diff --git a/cli/internal/helm/loader_test.go b/cli/internal/helm/loader_test.go index 8b10bf8cf..909e00c0b 100644 --- a/cli/internal/helm/loader_test.go +++ b/cli/internal/helm/loader_test.go @@ -15,6 +15,7 @@ import ( "path" "testing" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/deploy/helm" @@ -56,8 +57,7 @@ func TestConstellationServices(t *testing.T) { }{ "GCP": { config: &config.Config{Provider: config.ProviderConfig{GCP: &config.GCPConfig{ - DeployCSIDriver: func() *bool { b := true; return &b }(), - EnforcedMeasurements: []uint32{1, 11}, + DeployCSIDriver: func() *bool { b := true; return &b }(), }}}, enforceIDKeyDigest: false, valuesModifier: prepareGCPValues, @@ -65,9 +65,8 @@ func TestConstellationServices(t *testing.T) { }, "Azure": { config: &config.Config{Provider: config.ProviderConfig{Azure: &config.AzureConfig{ - DeployCSIDriver: func() *bool { b := true; return &b }(), - EnforcedMeasurements: []uint32{1, 11}, - EnforceIDKeyDigest: func() *bool { b := true; return &b }(), + DeployCSIDriver: func() *bool { b := true; return &b }(), + EnforceIDKeyDigest: func() *bool { b := true; return &b }(), }}}, enforceIDKeyDigest: true, valuesModifier: prepareAzureValues, @@ -75,9 +74,7 @@ func TestConstellationServices(t *testing.T) { cnmImage: "cnmImageForAzure", }, "QEMU": { - config: &config.Config{Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{ - EnforcedMeasurements: []uint32{1, 11}, - }}}, + config: &config.Config{Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}}}, enforceIDKeyDigest: false, valuesModifier: prepareQEMUValues, }, @@ -88,7 +85,14 @@ func TestConstellationServices(t *testing.T) { assert := assert.New(t) require := require.New(t) - chartLoader := ChartLoader{joinServiceImage: "joinServiceImage", kmsImage: "kmsImage", ccmImage: tc.ccmImage, cnmImage: tc.cnmImage, autoscalerImage: "autoscalerImage", verificationServiceImage: "verificationImage"} + chartLoader := ChartLoader{ + joinServiceImage: "joinServiceImage", + kmsImage: "kmsImage", + ccmImage: tc.ccmImage, + cnmImage: tc.cnmImage, + autoscalerImage: "autoscalerImage", + verificationServiceImage: "verificationImage", + } chart, values, err := chartLoader.loadConstellationServicesHelper(tc.config, []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) require.NoError(err) @@ -197,7 +201,15 @@ func prepareGCPValues(values map[string]any) error { if !ok { return errors.New("missing 'join-service' key") } - joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" + + m := measurements.M{ + 1: measurements.WithAllBytes(0xAA, false), + } + mJSON, err := json.Marshal(m) + if err != nil { + return err + } + joinVals["measurements"] = string(mJSON) joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" ccmVals, ok := values["ccm"].(map[string]any) @@ -269,7 +281,12 @@ func prepareAzureValues(values map[string]any) error { return errors.New("missing 'join-service' key") } joinVals["idkeydigest"] = "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad" - joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" + m := measurements.M{1: measurements.WithAllBytes(0xAA, false)} + mJSON, err := json.Marshal(m) + if err != nil { + return err + } + joinVals["measurements"] = string(mJSON) joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" ccmVals, ok := values["ccm"].(map[string]any) @@ -311,7 +328,12 @@ func prepareQEMUValues(values map[string]any) error { if !ok { return errors.New("missing 'join-service' key") } - joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" + m := measurements.M{1: measurements.WithAllBytes(0xAA, false)} + mJSON, err := json.Marshal(m) + if err != nil { + return err + } + joinVals["measurements"] = string(mJSON) joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" verificationVals, ok := values["verification-service"].(map[string]any) diff --git a/cli/internal/helm/testdata/Azure/constellation-services/charts/join-service/templates/configmap.yaml b/cli/internal/helm/testdata/Azure/constellation-services/charts/join-service/templates/configmap.yaml index a6770a452..ce16be564 100644 --- a/cli/internal/helm/testdata/Azure/constellation-services/charts/join-service/templates/configmap.yaml +++ b/cli/internal/helm/testdata/Azure/constellation-services/charts/join-service/templates/configmap.yaml @@ -4,8 +4,7 @@ metadata: name: join-config namespace: testNamespace data: - enforcedPCRs: "[1,11]" - measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" + measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}" enforceIdKeyDigest: "true" idkeydigest: "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad" binaryData: diff --git a/cli/internal/helm/testdata/GCP/constellation-services/charts/join-service/templates/configmap.yaml b/cli/internal/helm/testdata/GCP/constellation-services/charts/join-service/templates/configmap.yaml index 889dfe8f6..4c445457a 100644 --- a/cli/internal/helm/testdata/GCP/constellation-services/charts/join-service/templates/configmap.yaml +++ b/cli/internal/helm/testdata/GCP/constellation-services/charts/join-service/templates/configmap.yaml @@ -4,7 +4,6 @@ metadata: name: join-config namespace: testNamespace data: - enforcedPCRs: "[1,11]" - measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" + measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}" binaryData: measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA diff --git a/cli/internal/helm/testdata/QEMU/constellation-services/charts/join-service/templates/configmap.yaml b/cli/internal/helm/testdata/QEMU/constellation-services/charts/join-service/templates/configmap.yaml index 889dfe8f6..4c445457a 100644 --- a/cli/internal/helm/testdata/QEMU/constellation-services/charts/join-service/templates/configmap.yaml +++ b/cli/internal/helm/testdata/QEMU/constellation-services/charts/join-service/templates/configmap.yaml @@ -4,7 +4,6 @@ metadata: name: join-config namespace: testNamespace data: - enforcedPCRs: "[1,11]" - measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" + measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}" binaryData: measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA diff --git a/disk-mapper/cmd/main.go b/disk-mapper/cmd/main.go index 56f0ae2c5..b1e84a8e6 100644 --- a/disk-mapper/cmd/main.go +++ b/disk-mapper/cmd/main.go @@ -171,7 +171,7 @@ func main() { // We can use this to calculate the PCRs of the image locally. func exportPCRs() error { // get TPM state - pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, tpmClient.FullPcrSel(tpm2.AlgSHA256)) + pcrs, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, tpmClient.FullPcrSel(tpm2.AlgSHA256)) if err != nil { return err } diff --git a/go.mod b/go.mod index 9525bcab2..6d341c312 100644 --- a/go.mod +++ b/go.mod @@ -92,7 +92,6 @@ require ( google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 google.golang.org/grpc v1.51.0 google.golang.org/protobuf v1.28.1 - gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.1 helm.sh/helm v2.17.0+incompatible helm.sh/helm/v3 v3.10.2 @@ -119,6 +118,7 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.1 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect golang.org/x/text v0.4.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect ) require ( diff --git a/hack/go.mod b/hack/go.mod index 093f0dc50..2bd384618 100644 --- a/hack/go.mod +++ b/hack/go.mod @@ -41,7 +41,6 @@ require ( github.com/go-git/go-git/v5 v5.4.2 github.com/google/go-tpm-tools v0.3.9 github.com/google/uuid v1.3.0 - github.com/spf13/afero v1.9.3 github.com/spf13/cobra v1.6.1 github.com/stretchr/testify v1.8.1 go.uber.org/goleak v1.2.0 @@ -189,6 +188,7 @@ require ( github.com/sigstore/rekor v1.0.1 // indirect github.com/sigstore/sigstore v1.4.5 // indirect github.com/sirupsen/logrus v1.9.0 // indirect + github.com/spf13/afero v1.9.3 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/talos-systems/talos/pkg/machinery v1.2.7 // indirect github.com/tent/canonical-json-go v0.0.0-20130607151641-96e4ba3a7613 // indirect diff --git a/hack/pcr-reader/main.go b/hack/pcr-reader/main.go index 25a032c55..656f857b3 100644 --- a/hack/pcr-reader/main.go +++ b/hack/pcr-reader/main.go @@ -24,24 +24,24 @@ import ( "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/verify/verifyproto" - "github.com/spf13/afero" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "gopkg.in/yaml.v3" ) -var ( - coordIP = flag.String("constell-ip", "", "Public IP of the Constellation") - port = flag.String("constell-port", strconv.Itoa(constants.VerifyServiceNodePortGRPC), "NodePort of the Constellation's verification service") - export = flag.String("o", "", "Write PCRs, formatted as Go code, to file") - format = flag.String("format", "json", "Output format: json, yaml (default json)") - quiet = flag.Bool("q", false, "Set to disable output") - timeout = flag.Duration("timeout", 2*time.Minute, "Wait this duration for the verification service to become available") -) - func main() { + coordIP := flag.String("constell-ip", "", "Public IP of the Constellation") + port := flag.String("constell-port", strconv.Itoa(constants.VerifyServiceNodePortGRPC), "NodePort of the Constellation's verification service") + format := flag.String("format", "json", "Output format: json, yaml (default json)") + quiet := flag.Bool("q", false, "Set to disable output") + timeout := flag.Duration("timeout", 2*time.Minute, "Wait this duration for the verification service to become available") flag.Parse() + if *coordIP == "" || *port == "" { + flag.Usage() + os.Exit(1) + } + addr := net.JoinHostPort(*coordIP, *port) ctx, cancel := context.WithTimeout(context.Background(), *timeout) defer cancel() @@ -51,18 +51,13 @@ func main() { log.Fatal(err) } - pcrs, err := validatePCRAttDoc(attDocRaw) + measurements, err := validatePCRAttDoc(attDocRaw) if err != nil { log.Fatal(err) } if !*quiet { - if err := printPCRs(os.Stdout, pcrs, *format); err != nil { - log.Fatal(err) - } - } - if *export != "" { - if err := exportToFile(*export, pcrs, &afero.Afero{Fs: afero.NewOsFs()}); err != nil { + if err := printPCRs(os.Stdout, measurements, *format); err != nil { log.Fatal(err) } } @@ -104,16 +99,23 @@ func validatePCRAttDoc(attDocRaw []byte) (measurements.M, error) { if err != nil { return nil, err } + + m := measurements.M{} for idx, pcr := range attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs { if len(pcr) != 32 { return nil, fmt.Errorf("incomplete PCR at index: %d", idx) } + + m[idx] = measurements.Measurement{ + Expected: *(*[32]byte)(pcr), + WarnOnly: true, + } } - return attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, nil + return m, nil } -// printPCRs formates and prints PCRs to the given writer. -// format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'. +// printPCRs formats and prints PCRs to the given writer. +// format can be one of 'json' or 'yaml'. If it doesn't match defaults to 'json'. func printPCRs(w io.Writer, pcrs measurements.M, format string) error { switch format { case "json": @@ -142,24 +144,3 @@ func printPCRsJSON(w io.Writer, pcrs measurements.M) error { fmt.Fprintf(w, "%s", string(pcrJSON)) return nil } - -// exportToFile writes pcrs to a file, formatted to be valid Go code. -// Validity of the PCR map is not checked, and should be handled by the caller. -func exportToFile(path string, pcrs measurements.M, fs *afero.Afero) error { - goCode := `package pcrs - -var pcrs = map[uint32][]byte{%s -} -` - pcrsFormatted := "" - for i := 0; i < len(pcrs); i++ { - pcrHex := fmt.Sprintf("%#02X", pcrs[uint32(i)][0]) - for j := 1; j < len(pcrs[uint32(i)]); j++ { - pcrHex = fmt.Sprintf("%s, %#02X", pcrHex, pcrs[uint32(i)][j]) - } - - pcrsFormatted = pcrsFormatted + fmt.Sprintf("\n\t%d: {%s},", i, pcrHex) - } - - return fs.WriteFile(path, []byte(fmt.Sprintf(goCode, pcrsFormatted)), 0o644) -} diff --git a/hack/pcr-reader/main_test.go b/hack/pcr-reader/main_test.go index 07d7e3a61..49a801809 100644 --- a/hack/pcr-reader/main_test.go +++ b/hack/pcr-reader/main_test.go @@ -8,7 +8,7 @@ package main import ( "bytes" - "encoding/base64" + "encoding/hex" "encoding/json" "fmt" "testing" @@ -17,67 +17,13 @@ import ( "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/tpm" - "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/goleak" ) func TestMain(m *testing.M) { - goleak.VerifyTestMain(m, - // https://github.com/census-instrumentation/opencensus-go/issues/1262 - goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), - ) -} - -func TestExportToFile(t *testing.T) { - testCases := map[string]struct { - pcrs measurements.M - fs *afero.Afero - wantErr bool - }{ - "file not writeable": { - pcrs: measurements.M{ - 0: {0x1, 0x2, 0x3}, - 1: {0x1, 0x2, 0x3}, - 2: {0x1, 0x2, 0x3}, - }, - fs: &afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())}, - wantErr: true, - }, - "file writeable": { - pcrs: measurements.M{ - 0: {0x1, 0x2, 0x3}, - 1: {0x1, 0x2, 0x3}, - 2: {0x1, 0x2, 0x3}, - }, - fs: &afero.Afero{Fs: afero.NewMemMapFs()}, - wantErr: false, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - require := require.New(t) - - path := "test-file" - err := exportToFile(path, tc.pcrs, tc.fs) - if tc.wantErr { - assert.Error(err) - } else { - assert.NoError(err) - content, err := tc.fs.ReadFile(path) - require.NoError(err) - - for _, pcr := range tc.pcrs { - for _, register := range pcr { - assert.Contains(string(content), fmt.Sprintf("%#02X", register)) - } - } - } - }) - } + goleak.VerifyTestMain(m) } func TestValidatePCRAttDoc(t *testing.T) { @@ -106,7 +52,7 @@ func TestValidatePCRAttDoc(t *testing.T) { { Pcrs: &tpm.PCRs{ Hash: tpm.HashAlgo_SHA256, - Pcrs: measurements.M{ + Pcrs: map[uint32][]byte{ 0: {0x1, 0x2, 0x3}, }, }, @@ -123,8 +69,8 @@ func TestValidatePCRAttDoc(t *testing.T) { { Pcrs: &tpm.PCRs{ Hash: tpm.HashAlgo_SHA256, - Pcrs: measurements.M{ - 0: measurements.PCRWithAllBytes(0xAA), + Pcrs: map[uint32][]byte{ + 0: bytes.Repeat([]byte{0xAA}, 32), }, }, }, @@ -150,7 +96,10 @@ func TestValidatePCRAttDoc(t *testing.T) { require.NoError(json.Unmarshal(tc.attDocRaw, &attDoc)) qIdx, err := vtpm.GetSHA256QuoteIndex(attDoc.Attestation.Quotes) require.NoError(err) - assert.EqualValues(attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, pcrs) + + for pcrIdx, pcrVal := range pcrs { + assert.Equal(pcrVal.Expected[:], attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs[pcrIdx]) + } } }) } @@ -164,31 +113,15 @@ func mustMarshalAttDoc(t *testing.T, attDoc vtpm.AttestationDocument) []byte { func TestPrintPCRs(t *testing.T) { testCases := map[string]struct { - pcrs measurements.M format string }{ "json": { - pcrs: measurements.M{ - 0: {0x1, 0x2, 0x3}, - 1: {0x1, 0x2, 0x3}, - 2: {0x1, 0x2, 0x3}, - }, format: "json", }, "empty format": { - pcrs: measurements.M{ - 0: {0x1, 0x2, 0x3}, - 1: {0x1, 0x2, 0x3}, - 2: {0x1, 0x2, 0x3}, - }, format: "", }, "yaml": { - pcrs: measurements.M{ - 0: {0x1, 0x2, 0x3}, - 1: {0x1, 0x2, 0x3}, - 2: {0x1, 0x2, 0x3}, - }, format: "yaml", }, } @@ -197,13 +130,19 @@ func TestPrintPCRs(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) + pcrs := measurements.M{ + 0: measurements.WithAllBytes(0xAA, true), + 1: measurements.WithAllBytes(0xBB, true), + 2: measurements.WithAllBytes(0xCC, true), + } + var out bytes.Buffer - err := printPCRs(&out, tc.pcrs, tc.format) + err := printPCRs(&out, pcrs, tc.format) assert.NoError(err) - for idx, pcr := range tc.pcrs { + for idx, pcr := range pcrs { assert.Contains(out.String(), fmt.Sprintf("%d", idx)) - assert.Contains(out.String(), base64.StdEncoding.EncodeToString(pcr)) + assert.Contains(out.String(), hex.EncodeToString(pcr.Expected[:])) } }) } diff --git a/hack/qemu-metadata-api/server/server_test.go b/hack/qemu-metadata-api/server/server_test.go index 969b1c1fc..0f5bb9b0c 100644 --- a/hack/qemu-metadata-api/server/server_test.go +++ b/hack/qemu-metadata-api/server/server_test.go @@ -307,12 +307,12 @@ func TestExportPCRs(t *testing.T) { remoteAddr: "192.0.100.1:1234", connect: defaultConnect, method: http.MethodPost, - message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), + message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}), }, "incorrect method": { remoteAddr: "192.0.100.1:1234", connect: defaultConnect, - message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), + message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}), method: http.MethodGet, wantErr: true, }, @@ -321,7 +321,7 @@ func TestExportPCRs(t *testing.T) { connect: &stubConnect{ getNetworkErr: errors.New("error"), }, - message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), + message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}), method: http.MethodPost, wantErr: true, }, @@ -336,7 +336,7 @@ func TestExportPCRs(t *testing.T) { remoteAddr: "localhost", connect: defaultConnect, method: http.MethodPost, - message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), + message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}), wantErr: true, }, } diff --git a/internal/attestation/aws/validator.go b/internal/attestation/aws/validator.go index 1dca029b1..b8d182869 100644 --- a/internal/attestation/aws/validator.go +++ b/internal/attestation/aws/validator.go @@ -29,11 +29,10 @@ type Validator struct { } // NewValidator create a new Validator structure and returns it. -func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { +func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator { v := &Validator{} v.Validator = vtpm.NewValidator( pcrs, - enforcedPCRs, getTrustedKey, v.tpmEnabled, vtpm.VerifyPKCS1v15, diff --git a/internal/attestation/azure/snp/validator.go b/internal/attestation/azure/snp/validator.go index bdb4d4965..3046f1172 100644 --- a/internal/attestation/azure/snp/validator.go +++ b/internal/attestation/azure/snp/validator.go @@ -42,11 +42,10 @@ type Validator struct { } // NewValidator initializes a new Azure validator with the provided PCR values. -func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator { +func NewValidator(pcrs measurements.M, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator { return &Validator{ Validator: vtpm.NewValidator( pcrs, - enforcedPCRs, getTrustedKey(&azureInstanceInfo{}, idKeyDigest, enforceIDKeyDigest, log), validateCVM, vtpm.VerifyPKCS1v15, diff --git a/internal/attestation/azure/trustedlaunch/trustedlaunch_test.go b/internal/attestation/azure/trustedlaunch/trustedlaunch_test.go index 83e6003eb..c6426523f 100644 --- a/internal/attestation/azure/trustedlaunch/trustedlaunch_test.go +++ b/internal/attestation/azure/trustedlaunch/trustedlaunch_test.go @@ -189,7 +189,7 @@ func TestGetAttestationCert(t *testing.T) { } require.NoError(err) - validator := NewValidator(measurements.M{}, []uint32{}, nil) + validator := NewValidator(measurements.M{}, nil) cert, err := x509.ParseCertificate(rootCert.Raw) require.NoError(err) roots := x509.NewCertPool() diff --git a/internal/attestation/azure/trustedlaunch/validator.go b/internal/attestation/azure/trustedlaunch/validator.go index 802f2f771..f5b4e0d5c 100644 --- a/internal/attestation/azure/trustedlaunch/validator.go +++ b/internal/attestation/azure/trustedlaunch/validator.go @@ -33,13 +33,12 @@ type Validator struct { } // NewValidator initializes a new Azure validator with the provided PCR values. -func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { +func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator { rootPool := x509.NewCertPool() rootPool.AddCert(ameRoot) v := &Validator{roots: rootPool} v.Validator = vtpm.NewValidator( pcrs, - enforcedPCRs, v.verifyAttestationKey, validateVM, vtpm.VerifyPKCS1v15, diff --git a/internal/attestation/gcp/validator.go b/internal/attestation/gcp/validator.go index 017274972..85b4c5c38 100644 --- a/internal/attestation/gcp/validator.go +++ b/internal/attestation/gcp/validator.go @@ -35,11 +35,10 @@ type Validator struct { } // NewValidator initializes a new GCP validator with the provided PCR values. -func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { +func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator { return &Validator{ Validator: vtpm.NewValidator( pcrs, - enforcedPCRs, trustedKeyFromGCEAPI(newInstanceClient), gceNonHostInfoEvent, vtpm.VerifyPKCS1v15, diff --git a/internal/attestation/measurements/measurements.go b/internal/attestation/measurements/measurements.go index 34b6aa948..b10d3c7d1 100644 --- a/internal/attestation/measurements/measurements.go +++ b/internal/attestation/measurements/measurements.go @@ -12,61 +12,31 @@ import ( "crypto/sha256" "encoding/base64" "encoding/hex" + "encoding/json" "fmt" "io" "net/http" "net/url" - "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/sigstore" - "gopkg.in/yaml.v2" + "github.com/google/go-tpm/tpmutil" + "go.uber.org/multierr" + "gopkg.in/yaml.v3" +) + +const ( + // PCRIndexClusterID is a PCR we extend to mark the node as initialized. + // The value used to extend is a random generated 32 Byte value. + PCRIndexClusterID = tpmutil.Handle(15) + // PCRIndexOwnerID is a PCR we extend to mark the node as initialized. + // The value used to extend is derived from Constellation's master key. + // TODO: move to stable, non-debug PCR before use. + PCRIndexOwnerID = tpmutil.Handle(16) ) // M are Platform Configuration Register (PCR) values that make up the Measurements. -type M map[uint32][]byte - -// PCRWithAllBytes returns a PCR value where all 32 bytes are set to b. -func PCRWithAllBytes(b byte) []byte { - return bytes.Repeat([]byte{b}, 32) -} - -// DefaultsFor provides the default measurements for given cloud provider. -func DefaultsFor(provider cloudprovider.Provider) M { - switch provider { - case cloudprovider.AWS: - return M{ - 8: PCRWithAllBytes(0x00), - 11: PCRWithAllBytes(0x00), - 13: PCRWithAllBytes(0x00), - uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00), - } - case cloudprovider.Azure: - return M{ - 8: PCRWithAllBytes(0x00), - 11: PCRWithAllBytes(0x00), - 13: PCRWithAllBytes(0x00), - uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00), - } - case cloudprovider.GCP: - return M{ - 0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF}, - 8: PCRWithAllBytes(0x00), - 11: PCRWithAllBytes(0x00), - 13: PCRWithAllBytes(0x00), - uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00), - } - case cloudprovider.QEMU: - return M{ - 8: PCRWithAllBytes(0x00), - 11: PCRWithAllBytes(0x00), - 13: PCRWithAllBytes(0x00), - uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00), - } - default: - return nil - } -} +type M map[uint32]Measurement // FetchAndVerify fetches measurement and signature files via provided URLs, // using client for download. The publicKey is used to verify the measurements. @@ -83,8 +53,14 @@ func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurement if err := sigstore.VerifySignature(measurements, signature, publicKey); err != nil { return "", err } - if err := yaml.NewDecoder(bytes.NewReader(measurements)).Decode(&m); err != nil { - return "", err + + if err := json.Unmarshal(measurements, m); err != nil { + if yamlErr := yaml.Unmarshal(measurements, m); yamlErr != nil { + return "", multierr.Append( + err, + fmt.Errorf("trying yaml format: %w", yamlErr), + ) + } } shaHash := sha256.Sum256(measurements) @@ -94,58 +70,231 @@ func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurement // CopyFrom copies over all values from other. Overwriting existing values, // but keeping not specified values untouched. -func (m M) CopyFrom(other M) { +func (m *M) CopyFrom(other M) { for idx := range other { - m[idx] = other[idx] + (*m)[idx] = other[idx] } } // EqualTo tests whether the provided other Measurements are equal to these // measurements. -func (m M) EqualTo(other M) bool { - if len(m) != len(other) { +func (m *M) EqualTo(other M) bool { + if len(*m) != len(other) { return false } - for k, v := range m { - if !bytes.Equal(v, other[k]) { + for k, v := range *m { + otherExpected := other[k].Expected + if !bytes.Equal(v.Expected[:], otherExpected[:]) { + return false + } + if v.WarnOnly != other[k].WarnOnly { return false } } return true } -// MarshalYAML overwrites the default behaviour of writing out []byte not as -// single bytes, but as a single base64 encoded string. -func (m M) MarshalYAML() (any, error) { - base64Map := make(map[uint32]string) - - for key, value := range m { - base64Map[key] = base64.StdEncoding.EncodeToString(value[:]) +// GetEnforced returns a list of all enforced Measurements, +// i.e. all Measurements that are not marked as WarnOnly. +func (m *M) GetEnforced() []uint32 { + var enforced []uint32 + for idx, measurement := range *m { + if !measurement.WarnOnly { + enforced = append(enforced, idx) + } } - - return base64Map, nil + return enforced } -// UnmarshalYAML overwrites the default behaviour of reading []byte not as -// single bytes, but as a single base64 encoded string. -func (m *M) UnmarshalYAML(unmarshal func(any) error) error { - base64Map := make(map[uint32]string) - err := unmarshal(base64Map) - if err != nil { - return err +// SetEnforced sets the WarnOnly flag to true for all Measurements +// that are NOT included in the provided list of enforced measurements. +func (m *M) SetEnforced(enforced []uint32) error { + newM := make(M) + + // set all measurements to warn only + for idx, measurement := range *m { + newM[idx] = Measurement{ + Expected: measurement.Expected, + WarnOnly: true, + } } - *m = make(M) - for key, value := range base64Map { - measurement, err := base64.StdEncoding.DecodeString(value) - if err != nil { - return err + // set enforced measurements from list + for _, idx := range enforced { + measurement, ok := newM[idx] + if !ok { + return fmt.Errorf("measurement %d not in list, but set to enforced", idx) } - (*m)[key] = measurement + measurement.WarnOnly = false + newM[idx] = measurement + } + + *m = newM + return nil +} + +// Measurement wraps expected PCR value and whether it is enforced. +type Measurement struct { + // Expected measurement value. + Expected [32]byte `json:"expected" yaml:"expected"` + // WarnOnly if set to true, a mismatching measurement will only result in a warning. + WarnOnly bool `json:"warnOnly" yaml:"warnOnly"` +} + +// UnmarshalJSON reads a Measurement either as json object, +// or as a simple hex or base64 encoded string. +func (m *Measurement) UnmarshalJSON(b []byte) error { + var eM encodedMeasurement + if err := json.Unmarshal(b, &eM); err != nil { + // Unmarshalling failed, Measurement might be in legacy format, + // meaning a simple string instead of Measurement struct. + // TODO: remove with v2.4.0 + if legacyErr := json.Unmarshal(b, &eM.Expected); legacyErr != nil { + return multierr.Append( + err, + fmt.Errorf("trying legacy format: %w", legacyErr), + ) + } + } + + if err := m.unmarshal(eM); err != nil { + return fmt.Errorf("unmarshalling json: %w", err) } return nil } +// MarshalJSON writes out a Measurement with Expected encoded as a hex string. +func (m Measurement) MarshalJSON() ([]byte, error) { + return json.Marshal(encodedMeasurement{ + Expected: hex.EncodeToString(m.Expected[:]), + WarnOnly: m.WarnOnly, + }) +} + +// UnmarshalYAML reads a Measurement either as yaml object, +// or as a simple hex or base64 encoded string. +func (m *Measurement) UnmarshalYAML(unmarshal func(any) error) error { + var eM encodedMeasurement + if err := unmarshal(&eM); err != nil { + // Unmarshalling failed, Measurement might be in legacy format, + // meaning a simple string instead of Measurement struct. + // TODO: remove with v2.4.0 + if legacyErr := unmarshal(&eM.Expected); legacyErr != nil { + return multierr.Append( + err, + fmt.Errorf("trying legacy format: %w", legacyErr), + ) + } + } + + if err := m.unmarshal(eM); err != nil { + return fmt.Errorf("unmarshalling yaml: %w", err) + } + return nil +} + +// MarshalYAML writes out a Measurement with Expected encoded as a hex string. +func (m Measurement) MarshalYAML() (any, error) { + return encodedMeasurement{ + Expected: hex.EncodeToString(m.Expected[:]), + WarnOnly: m.WarnOnly, + }, nil +} + +// unmarshal parses a hex or base64 encoded Measurement. +func (m *Measurement) unmarshal(eM encodedMeasurement) error { + expected, err := hex.DecodeString(eM.Expected) + if err != nil { + // expected value might be in base64 legacy format + // TODO: Remove with v2.4.0 + hexErr := err + expected, err = base64.StdEncoding.DecodeString(eM.Expected) + if err != nil { + return multierr.Append( + fmt.Errorf("invalid measurement: not a hex string %w", hexErr), + fmt.Errorf("not a base64 string: %w", err), + ) + } + } + + if len(expected) != 32 { + return fmt.Errorf("invalid measurement: invalid length: %d", len(expected)) + } + + m.Expected = *(*[32]byte)(expected) + m.WarnOnly = eM.WarnOnly + + return nil +} + +// WithAllBytes returns a measurement value where all 32 bytes are set to b. +func WithAllBytes(b byte, warnOnly bool) Measurement { + return Measurement{ + Expected: *(*[32]byte)(bytes.Repeat([]byte{b}, 32)), + WarnOnly: warnOnly, + } +} + +// DefaultsFor provides the default measurements for given cloud provider. +func DefaultsFor(provider cloudprovider.Provider) M { + switch provider { + case cloudprovider.AWS: + return M{ + 4: PlaceHolderMeasurement(), + 8: WithAllBytes(0x00, false), + 9: PlaceHolderMeasurement(), + 11: WithAllBytes(0x00, false), + 12: PlaceHolderMeasurement(), + 13: WithAllBytes(0x00, false), + uint32(PCRIndexClusterID): WithAllBytes(0x00, false), + } + case cloudprovider.Azure: + return M{ + 4: PlaceHolderMeasurement(), + 8: WithAllBytes(0x00, false), + 9: PlaceHolderMeasurement(), + 11: WithAllBytes(0x00, false), + 12: PlaceHolderMeasurement(), + 13: WithAllBytes(0x00, false), + uint32(PCRIndexClusterID): WithAllBytes(0x00, false), + } + case cloudprovider.GCP: + return M{ + 0: { + Expected: [32]byte{0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF}, + WarnOnly: false, + }, + 4: PlaceHolderMeasurement(), + 8: WithAllBytes(0x00, false), + 9: PlaceHolderMeasurement(), + 11: WithAllBytes(0x00, false), + 12: PlaceHolderMeasurement(), + 13: WithAllBytes(0x00, false), + uint32(PCRIndexClusterID): WithAllBytes(0x00, false), + } + case cloudprovider.QEMU: + return M{ + 4: PlaceHolderMeasurement(), + 8: WithAllBytes(0x00, false), + 9: PlaceHolderMeasurement(), + 11: WithAllBytes(0x00, false), + 12: PlaceHolderMeasurement(), + 13: WithAllBytes(0x00, false), + uint32(PCRIndexClusterID): WithAllBytes(0x00, false), + } + default: + return nil + } +} + +// PlaceHolderMeasurement returns a measurement with placeholder values for Expected. +func PlaceHolderMeasurement() Measurement { + return Measurement{ + Expected: *(*[32]byte)(bytes.Repeat([]byte{0x12, 0x34}, 16)), + WarnOnly: false, + } +} + func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([]byte, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), http.NoBody) if err != nil { @@ -166,3 +315,8 @@ func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([ } return content, nil } + +type encodedMeasurement struct { + Expected string `json:"expected" yaml:"expected"` + WarnOnly bool `json:"warnOnly" yaml:"warnOnly"` +} diff --git a/internal/attestation/measurements/measurements_test.go b/internal/attestation/measurements/measurements_test.go index 2a7930ac7..07b0d7ef5 100644 --- a/internal/attestation/measurements/measurements_test.go +++ b/internal/attestation/measurements/measurements_test.go @@ -8,7 +8,7 @@ package measurements import ( "context" - "errors" + "encoding/json" "io" "net/http" "net/url" @@ -17,32 +17,29 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" ) -func TestMarshalYAML(t *testing.T) { +func TestMarshal(t *testing.T) { testCases := map[string]struct { - measurements M - wantBase64Map map[uint32]string + m Measurement + wantYAML string + wantJSON string }{ - "valid measurements": { - measurements: M{ - 2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, - 3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, - }, - wantBase64Map: map[uint32]string{ - 2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", - 3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=", + "measurement": { + m: Measurement{ + Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, }, + wantYAML: "expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35\"\nwarnOnly: false", + wantJSON: `{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35","warnOnly":false}`, }, - "omit bytes": { - measurements: M{ - 2: []byte{}, - 3: []byte{1, 2, 3, 4}, - }, - wantBase64Map: map[uint32]string{ - 2: "", - 3: "AQIDBA==", + "warn only": { + m: Measurement{ + Expected: [32]byte{1, 2, 3, 4}, // implicitly padded with 0s + WarnOnly: true, }, + wantYAML: "expected: \"0102030400000000000000000000000000000000000000000000000000000000\"\nwarnOnly: true", + wantJSON: `{"expected":"0102030400000000000000000000000000000000000000000000000000000000","warnOnly":true}`, }, } @@ -51,63 +48,99 @@ func TestMarshalYAML(t *testing.T) { assert := assert.New(t) require := require.New(t) - base64Map, err := tc.measurements.MarshalYAML() - require.NoError(err) + { + // YAML + yaml, err := yaml.Marshal(tc.m) + require.NoError(err) - assert.Equal(tc.wantBase64Map, base64Map) + assert.YAMLEq(tc.wantYAML, string(yaml)) + } + + { + // JSON + json, err := json.Marshal(tc.m) + require.NoError(err) + + assert.JSONEq(tc.wantJSON, string(json)) + } }) } } -func TestUnmarshalYAML(t *testing.T) { +func TestUnmarshal(t *testing.T) { testCases := map[string]struct { - inputBase64Map map[uint32]string - forceUnmarshalError bool - wantMeasurements M - wantErr bool + inputYAML string + inputJSON string + wantMeasurements M + wantErr bool }{ - "valid measurements": { - inputBase64Map: map[uint32]string{ - 2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", - 3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=", - }, + "valid measurements base64": { + inputYAML: "2:\n expected: \"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=\"\n3:\n expected: \"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=\"", + inputJSON: `{"2":{"expected":"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU="},"3":{"expected":"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754="}}`, wantMeasurements: M{ - 2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, - 3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, + 2: { + Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, + }, + 3: { + Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, + }, + }, + }, + "valid measurements hex": { + inputYAML: "2:\n expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35\"\n3:\n expected: \"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463fe607f2488ddf4f1006ef9e\"", + inputJSON: `{"2":{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35"},"3":{"expected":"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463fe607f2488ddf4f1006ef9e"}}`, + wantMeasurements: M{ + 2: { + Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, + }, + 3: { + Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, + }, }, }, "empty bytes": { - inputBase64Map: map[uint32]string{ - 2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", - 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", - }, + inputYAML: "2:\n expected: \"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"", + inputJSON: `{"2":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}}`, wantMeasurements: M{ - 2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - 3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + 2: { + Expected: [32]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, + 3: { + Expected: [32]byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + }, }, }, "invalid base64": { - inputBase64Map: map[uint32]string{ - 2: "This is not base64", - 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", - }, - wantMeasurements: M{ - 2: []byte{}, - 3: []byte{1, 2, 3, 4}, - }, - wantErr: true, + inputYAML: "2:\n expected: \"This is not base64\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"", + inputJSON: `{"2":{"expected":"This is not base64"},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}}`, + wantErr: true, }, - "simulated unmarshal error": { - inputBase64Map: map[uint32]string{ - 2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", - 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", - }, - forceUnmarshalError: true, + "legacy format": { + inputYAML: "2: \"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=\"\n3: \"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=\"", + inputJSON: `{"2":"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=","3":"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754="}`, wantMeasurements: M{ - 2: []byte{}, - 3: []byte{1, 2, 3, 4}, + 2: { + Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, + }, + 3: { + Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, + }, }, - wantErr: true, + }, + "invalid length hex": { + inputYAML: "2:\n expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef\"\n3:\n expected: \"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463f\"", + inputJSON: `{"2":{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef"},"3":{"expected":"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463f"}}`, + wantErr: true, + }, + "invalid length base64": { + inputYAML: "2:\n expected: \"AA==\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\"", + inputJSON: `{"2":{"expected":"AA=="},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}}`, + wantErr: true, + }, + "invalid format": { + inputYAML: "1:\n expected:\n someKey: 12\n anotherKey: 34", + inputJSON: `{"1":{"expected":{"someKey":12,"anotherKey":34}}}`, + wantErr: true, }, } @@ -116,24 +149,30 @@ func TestUnmarshalYAML(t *testing.T) { assert := assert.New(t) require := require.New(t) - var m M - err := m.UnmarshalYAML(func(i any) error { - if base64Map, ok := i.(map[uint32]string); ok { - for key, value := range tc.inputBase64Map { - base64Map[key] = value - } - } - if tc.forceUnmarshalError { - return errors.New("unmarshal error") - } - return nil - }) + { + // YAML + var m M + err := yaml.Unmarshal([]byte(tc.inputYAML), &m) - if tc.wantErr { - assert.Error(err) - } else { - require.NoError(err) - assert.Equal(tc.wantMeasurements, m) + if tc.wantErr { + assert.Error(err, "yaml.Unmarshal should have failed") + } else { + require.NoError(err, "yaml.Unmarshal failed") + assert.Equal(tc.wantMeasurements, m) + } + } + + { + // JSON + var m M + err := json.Unmarshal([]byte(tc.inputJSON), &m) + + if tc.wantErr { + assert.Error(err, "json.Unmarshal should have failed") + } else { + require.NoError(err, "json.Unmarshal failed") + assert.Equal(tc.wantMeasurements, m) + } } }) } @@ -148,48 +187,48 @@ func TestMeasurementsCopyFrom(t *testing.T) { "add to empty": { current: M{}, newMeasurements: M{ - 1: PCRWithAllBytes(0x00), - 2: PCRWithAllBytes(0x01), - 3: PCRWithAllBytes(0x02), + 1: WithAllBytes(0x00, true), + 2: WithAllBytes(0x01, true), + 3: WithAllBytes(0x02, true), }, wantMeasurements: M{ - 1: PCRWithAllBytes(0x00), - 2: PCRWithAllBytes(0x01), - 3: PCRWithAllBytes(0x02), + 1: WithAllBytes(0x00, true), + 2: WithAllBytes(0x01, true), + 3: WithAllBytes(0x02, true), }, }, "keep existing": { current: M{ - 4: PCRWithAllBytes(0x01), - 5: PCRWithAllBytes(0x02), + 4: WithAllBytes(0x01, false), + 5: WithAllBytes(0x02, true), }, newMeasurements: M{ - 1: PCRWithAllBytes(0x00), - 2: PCRWithAllBytes(0x01), - 3: PCRWithAllBytes(0x02), + 1: WithAllBytes(0x00, true), + 2: WithAllBytes(0x01, true), + 3: WithAllBytes(0x02, true), }, wantMeasurements: M{ - 1: PCRWithAllBytes(0x00), - 2: PCRWithAllBytes(0x01), - 3: PCRWithAllBytes(0x02), - 4: PCRWithAllBytes(0x01), - 5: PCRWithAllBytes(0x02), + 1: WithAllBytes(0x00, true), + 2: WithAllBytes(0x01, true), + 3: WithAllBytes(0x02, true), + 4: WithAllBytes(0x01, false), + 5: WithAllBytes(0x02, true), }, }, "overwrite existing": { current: M{ - 2: PCRWithAllBytes(0x04), - 3: PCRWithAllBytes(0x05), + 2: WithAllBytes(0x04, false), + 3: WithAllBytes(0x05, false), }, newMeasurements: M{ - 1: PCRWithAllBytes(0x00), - 2: PCRWithAllBytes(0x01), - 3: PCRWithAllBytes(0x02), + 1: WithAllBytes(0x00, true), + 2: WithAllBytes(0x01, true), + 3: WithAllBytes(0x02, true), }, wantMeasurements: M{ - 1: PCRWithAllBytes(0x00), - 2: PCRWithAllBytes(0x01), - 3: PCRWithAllBytes(0x02), + 1: WithAllBytes(0x00, true), + 2: WithAllBytes(0x01, true), + 3: WithAllBytes(0x02, true), }, }, } @@ -224,6 +263,22 @@ func urlMustParse(raw string) *url.URL { } func TestMeasurementsFetchAndVerify(t *testing.T) { + // Cosign private key used to sign the measurements. + // Generated with: cosign generate-key-pair + // Password left empty. + // + // -----BEGIN ENCRYPTED COSIGN PRIVATE KEY----- + // eyJrZGYiOnsibmFtZSI6InNjcnlwdCIsInBhcmFtcyI6eyJOIjozMjc2OCwiciI6 + // OCwicCI6MX0sInNhbHQiOiJlRHVYMWRQMGtIWVRnK0xkbjcxM0tjbFVJaU92eFVX + // VXgvNi9BbitFVk5BPSJ9LCJjaXBoZXIiOnsibmFtZSI6Im5hY2wvc2VjcmV0Ym94 + // Iiwibm9uY2UiOiJwaWhLL2txNmFXa2hqSVVHR3RVUzhTVkdHTDNIWWp4TCJ9LCJj + // aXBoZXJ0ZXh0Ijoidm81SHVWRVFWcUZ2WFlQTTVPaTVaWHM5a255bndZU2dvcyth + // VklIeHcrOGFPamNZNEtvVjVmL3lHRHR0K3BHV2toanJPR1FLOWdBbmtsazFpQ0c5 + // a2czUXpPQTZsU2JRaHgvZlowRVRZQ0hLeElncEdPRVRyTDlDenZDemhPZXVSOXJ6 + // TDcvRjBBVy9vUDVqZXR3dmJMNmQxOEhjck9kWE8yVmYxY2w0YzNLZjVRcnFSZzlN + // dlRxQWFsNXJCNHNpY1JaMVhpUUJjb0YwNHc9PSJ9 + // -----END ENCRYPTED COSIGN PRIVATE KEY----- + testCases := map[string]struct { measurements string measurementsStatus int @@ -237,44 +292,66 @@ func TestMeasurementsFetchAndVerify(t *testing.T) { "simple": { measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurementsStatus: http.StatusOK, - signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", + signature: "MEUCIQDcHS2bLls7OrLHpQKuiFGXhPrTcehPDwgVyERHl4V02wIgeIxK4J9oJpXWRBjokbog2lgifRXuJK8ljlAID26MbHk=", signatureStatus: http.StatusOK, - publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), wantMeasurements: M{ - 0: PCRWithAllBytes(0x00), + 0: WithAllBytes(0x00, false), }, wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87", }, - "404 measurements": { - measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", - measurementsStatus: http.StatusNotFound, - signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", + "json measurements": { + measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`, + measurementsStatus: http.StatusOK, + signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=", signatureStatus: http.StatusOK, - publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), + wantMeasurements: M{ + 0: WithAllBytes(0x00, false), + }, + wantSHA: "1da09758c89537946496358f80b892e508563fcbbc695c90b6c16bf158e69c11", + }, + "yaml measurements": { + measurements: "0:\n expected: \"0000000000000000000000000000000000000000000000000000000000000000\"\n warnOnly: false\n", + measurementsStatus: http.StatusOK, + signature: "MEUCIFzQdwBS92aJjY0bcIag1uQRl42lUSBmmjEvO0tM/N0ZAiEAvuWaP744qYMw5uEmc7BY4mm4Ij3TEqAWFgxNhFkckp4=", + signatureStatus: http.StatusOK, + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), + wantMeasurements: M{ + 0: WithAllBytes(0x00, false), + }, + wantSHA: "c651cd419fd536c63cfc5349ad44da140a09987465e31192660059d383413807", + }, + "404 measurements": { + measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`, + measurementsStatus: http.StatusNotFound, + signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=", + signatureStatus: http.StatusOK, + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), wantError: true, }, "404 signature": { - measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", + measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`, measurementsStatus: http.StatusOK, - signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", + signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=", signatureStatus: http.StatusNotFound, - publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), wantError: true, }, "broken signature": { - measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", + measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`, measurementsStatus: http.StatusOK, - signature: "AAAAAAs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", + signature: "AAAAAAAA3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=", signatureStatus: http.StatusOK, - publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), wantError: true, }, - "not yaml": { + "not yaml or json": { measurements: "This is some content to be signed!\n", measurementsStatus: http.StatusOK, - signature: "MEUCIQDzMN3yaiO9sxLGAaSA9YD8rLwzvOaZKWa/bzkcjImUFAIgXLLGzClYUd1dGbuEiY3O/g/eiwQYlyxqLQalxjFmz+8=", + signature: "MEUCIQCGA/lSu5qCJgNNvgMaTKJ9rj6vQMecUDaQo3ukaiAfUgIgWoxXRoDKLY9naN7YgxokM7r2fwnyYk3M2WKJJO1g6yo=", signatureStatus: http.StatusOK, - publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAElWUhon39eAqzEC+/GP03oY4/MQg+\ngCDlEzkuOCybCHf+q766bve799L7Y5y5oRsHY1MrUCUwYF/tL7Sg7EYMsA==\n-----END PUBLIC KEY-----"), + publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"), wantError: true, }, } @@ -322,30 +399,196 @@ func TestMeasurementsFetchAndVerify(t *testing.T) { } } -func TestPCRWithAllBytes(t *testing.T) { +func TestGetEnforced(t *testing.T) { testCases := map[string]struct { - b byte - wantPCR []byte + input M + want map[uint32]struct{} }{ - "0x00": { - b: 0x00, - wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + "only warnings": { + input: M{ + 0: WithAllBytes(0x00, true), + 1: WithAllBytes(0x01, true), + }, + want: map[uint32]struct{}{}, }, - "0x01": { - b: 0x01, - wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + "all enforced": { + input: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, false), + }, + want: map[uint32]struct{}{ + 0: {}, + 1: {}, + }, }, - "0xFF": { - b: 0xFF, - wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + "mixed": { + input: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, true), + 2: WithAllBytes(0x02, false), + }, + want: map[uint32]struct{}{ + 0: {}, + 2: {}, + }, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) - pcr := PCRWithAllBytes(tc.b) - assert.Equal(tc.wantPCR, pcr) + + got := tc.input.GetEnforced() + enforced := map[uint32]struct{}{} + for _, id := range got { + enforced[id] = struct{}{} + } + assert.Equal(tc.want, enforced) + }) + } +} + +func TestSetEnforced(t *testing.T) { + testCases := map[string]struct { + input M + enforced []uint32 + wantM M + wantErr bool + }{ + "no enforced measurements": { + input: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, false), + }, + enforced: []uint32{}, + wantM: M{ + 0: WithAllBytes(0x00, true), + 1: WithAllBytes(0x01, true), + }, + }, + "all enforced measurements": { + input: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, false), + }, + enforced: []uint32{0, 1}, + wantM: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, false), + }, + }, + "mixed": { + input: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, false), + 2: WithAllBytes(0x02, false), + 3: WithAllBytes(0x03, false), + }, + enforced: []uint32{0, 2}, + wantM: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, true), + 2: WithAllBytes(0x02, false), + 3: WithAllBytes(0x03, true), + }, + }, + "warn only to enforced": { + input: M{ + 0: WithAllBytes(0x00, true), + 1: WithAllBytes(0x01, true), + }, + enforced: []uint32{0, 1}, + wantM: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0x01, false), + }, + }, + "more enforced than measurements": { + input: M{ + 0: WithAllBytes(0x00, true), + 1: WithAllBytes(0x01, true), + }, + enforced: []uint32{0, 1, 2}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + err := tc.input.SetEnforced(tc.enforced) + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + assert.True(tc.input.EqualTo(tc.wantM)) + }) + } +} + +func TestWithAllBytes(t *testing.T) { + testCases := map[string]struct { + b byte + warnOnly bool + wantMeasurement Measurement + }{ + "0x00 warnOnly": { + b: 0x00, + warnOnly: true, + wantMeasurement: Measurement{ + Expected: [32]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + WarnOnly: true, + }, + }, + "0x00": { + b: 0x00, + warnOnly: false, + wantMeasurement: Measurement{ + Expected: [32]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + WarnOnly: false, + }, + }, + "0x01 warnOnly": { + b: 0x01, + warnOnly: true, + wantMeasurement: Measurement{ + Expected: [32]byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + WarnOnly: true, + }, + }, + "0x01": { + b: 0x01, + warnOnly: false, + wantMeasurement: Measurement{ + Expected: [32]byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, + WarnOnly: false, + }, + }, + "0xFF warnOnly": { + b: 0xFF, + warnOnly: true, + wantMeasurement: Measurement{ + Expected: [32]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + WarnOnly: true, + }, + }, + "0xFF": { + b: 0xFF, + warnOnly: false, + wantMeasurement: Measurement{ + Expected: [32]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + WarnOnly: false, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + measurement := WithAllBytes(tc.b, tc.warnOnly) + assert.Equal(tc.wantMeasurement, measurement) }) } } @@ -358,33 +601,44 @@ func TestEqualTo(t *testing.T) { }{ "same values": { given: M{ - 0: PCRWithAllBytes(0x00), - 1: PCRWithAllBytes(0xFF), + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0xFF, false), }, other: M{ - 0: PCRWithAllBytes(0x00), - 1: PCRWithAllBytes(0xFF), + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0xFF, false), }, wantEqual: true, }, "different number of elements": { given: M{ - 0: PCRWithAllBytes(0x00), - 1: PCRWithAllBytes(0xFF), + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0xFF, false), }, other: M{ - 0: PCRWithAllBytes(0x00), + 0: WithAllBytes(0x00, false), }, wantEqual: false, }, "different values": { given: M{ - 0: PCRWithAllBytes(0x00), - 1: PCRWithAllBytes(0xFF), + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0xFF, false), }, other: M{ - 0: PCRWithAllBytes(0xFF), - 1: PCRWithAllBytes(0x00), + 0: WithAllBytes(0xFF, false), + 1: WithAllBytes(0x00, false), + }, + wantEqual: false, + }, + "different warn settings": { + given: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0xFF, false), + }, + other: M{ + 0: WithAllBytes(0x00, false), + 1: WithAllBytes(0xFF, true), }, wantEqual: false, }, diff --git a/internal/attestation/qemu/validator.go b/internal/attestation/qemu/validator.go index c1ac719aa..1db26b2e5 100644 --- a/internal/attestation/qemu/validator.go +++ b/internal/attestation/qemu/validator.go @@ -22,11 +22,10 @@ type Validator struct { } // NewValidator initializes a new QEMU validator with the provided PCR values. -func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { +func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator { return &Validator{ Validator: vtpm.NewValidator( pcrs, - enforcedPCRs, unconditionalTrust, func(attestation vtpm.AttestationDocument) error { return nil }, vtpm.VerifyPKCS1v15, diff --git a/internal/attestation/vtpm/attestation.go b/internal/attestation/vtpm/attestation.go index c79a5a7c6..59425e470 100644 --- a/internal/attestation/vtpm/attestation.go +++ b/internal/attestation/vtpm/attestation.go @@ -16,6 +16,7 @@ import ( "fmt" "io" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" tpmClient "github.com/google/go-tpm-tools/client" "github.com/google/go-tpm-tools/proto/attest" tpmProto "github.com/google/go-tpm-tools/proto/tpm" @@ -144,8 +145,7 @@ func (i *Issuer) Issue(userData []byte, nonce []byte) ([]byte, error) { // Validator handles validation of TPM based attestation. type Validator struct { - expectedPCRs map[uint32][]byte - enforcedPCRs map[uint32]struct{} + expected measurements.M getTrustedKey GetTPMTrustedAttestationPublicKey validateCVM ValidateCVM verifyUserData VerifyUserData @@ -154,18 +154,11 @@ type Validator struct { } // NewValidator returns a new Validator. -func NewValidator(expectedPCRs map[uint32][]byte, enforcedPCRs []uint32, getTrustedKey GetTPMTrustedAttestationPublicKey, +func NewValidator(expected measurements.M, getTrustedKey GetTPMTrustedAttestationPublicKey, validateCVM ValidateCVM, verifyUserData VerifyUserData, log AttestationLogger, ) *Validator { - // Convert the enforced PCR list to a map for convenient and fast lookup - enforcedMap := make(map[uint32]struct{}) - for _, pcr := range enforcedPCRs { - enforcedMap[pcr] = struct{}{} - } - return &Validator{ - expectedPCRs: expectedPCRs, - enforcedPCRs: enforcedMap, + expected: expected, getTrustedKey: getTrustedKey, validateCVM: validateCVM, verifyUserData: verifyUserData, @@ -212,9 +205,9 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte) ([]byte, error) { if err != nil { return nil, err } - for idx, pcr := range v.expectedPCRs { - if !bytes.Equal(pcr, attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) { - if _, ok := v.enforcedPCRs[idx]; ok { + for idx, pcr := range v.expected { + if !bytes.Equal(pcr.Expected[:], attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) { + if !pcr.WarnOnly { return nil, fmt.Errorf("untrusted PCR value at PCR index %d", idx) } if v.log != nil { @@ -263,8 +256,8 @@ func VerifyPKCS1v15(pub crypto.PublicKey, hash crypto.Hash, hashed, sig []byte) return rsa.VerifyPKCS1v15(key, hash, hashed, sig) } -// GetSelectedPCRs returns a map of the selected PCR hashes. -func GetSelectedPCRs(open TPMOpenFunc, selection tpm2.PCRSelection) (map[uint32][]byte, error) { +// GetSelectedMeasurements returns a map of Measurments for the PCRs in selection. +func GetSelectedMeasurements(open TPMOpenFunc, selection tpm2.PCRSelection) (measurements.M, error) { tpm, err := open() if err != nil { return nil, err @@ -276,5 +269,15 @@ func GetSelectedPCRs(open TPMOpenFunc, selection tpm2.PCRSelection) (map[uint32] return nil, err } - return pcrList.Pcrs, nil + m := make(measurements.M) + for i, pcr := range pcrList.Pcrs { + if len(pcr) != 32 { + return nil, fmt.Errorf("invalid measurement: invalid length: %d", len(pcr)) + } + m[i] = measurements.Measurement{ + Expected: *(*[32]byte)(pcr), + } + } + + return m, nil } diff --git a/internal/attestation/vtpm/attestation_test.go b/internal/attestation/vtpm/attestation_test.go index 93f1d62e6..d5c62c894 100644 --- a/internal/attestation/vtpm/attestation_test.go +++ b/internal/attestation/vtpm/attestation_test.go @@ -14,6 +14,7 @@ import ( "io" "testing" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" tpmsim "github.com/edgelesssys/constellation/v2/internal/attestation/simulator" tpmclient "github.com/google/go-tpm-tools/client" "github.com/google/go-tpm-tools/proto/attest" @@ -64,14 +65,14 @@ func TestValidate(t *testing.T) { return pubArea.Key() } - testExpectedPCRs := map[uint32][]byte{ - 0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - 1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + testExpectedPCRs := measurements.M{ + 0: measurements.WithAllBytes(0x00, true), + 1: measurements.WithAllBytes(0x00, true), } warnLog := &testAttestationLogger{} issuer := NewIssuer(newSimTPMWithEventLog, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo) - validator := NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog) + validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog) nonce := []byte{1, 2, 3, 4} challenge := []byte("Constellation") @@ -89,18 +90,28 @@ func TestValidate(t *testing.T) { require.NoError(err) require.Equal(challenge, out) - enforcedPCRs := []uint32{0, 1} - expectedPCRs := map[uint32][]byte{ - 0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - 1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, - 2: {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}, - 3: {0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40}, - 4: {0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60}, - 5: {0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80}, + expectedPCRs := measurements.M{ + 0: measurements.WithAllBytes(0x00, true), + 1: measurements.WithAllBytes(0x00, true), + 2: measurements.Measurement{ + Expected: [32]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}, + WarnOnly: true, + }, + 3: measurements.Measurement{ + Expected: [32]byte{0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40}, + WarnOnly: true, + }, + 4: measurements.Measurement{ + Expected: [32]byte{0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60}, + WarnOnly: true, + }, + 5: measurements.Measurement{ + Expected: [32]byte{0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80}, + WarnOnly: true, + }, } warningValidator := NewValidator( expectedPCRs, - enforcedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, @@ -109,7 +120,7 @@ func TestValidate(t *testing.T) { out, err = warningValidator.Validate(attDocRaw, nonce) require.NoError(err) assert.Equal(t, challenge, out) - assert.Len(t, warnLog.warnings, len(expectedPCRs)-len(enforcedPCRs)) + assert.Len(t, warnLog.warnings, 4) testCases := map[string]struct { validator *Validator @@ -118,13 +129,13 @@ func TestValidate(t *testing.T) { wantErr bool }{ "invalid nonce": { - validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), + validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), attDoc: mustMarshalAttestation(attDoc, require), nonce: []byte{4, 3, 2, 1}, wantErr: true, }, "invalid signature": { - validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), + validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), attDoc: mustMarshalAttestation(AttestationDocument{ Attestation: attDoc.Attestation, InstanceInfo: attDoc.InstanceInfo, @@ -137,7 +148,6 @@ func TestValidate(t *testing.T) { "untrusted attestation public key": { validator: NewValidator( testExpectedPCRs, - []uint32{0, 1}, func(akPub, instanceInfo []byte) (crypto.PublicKey, error) { return nil, errors.New("untrusted") }, @@ -149,7 +159,6 @@ func TestValidate(t *testing.T) { "not a CVM": { validator: NewValidator( testExpectedPCRs, - []uint32{0, 1}, fakeGetTrustedKey, func(attestation AttestationDocument) error { return errors.New("untrusted") @@ -161,10 +170,12 @@ func TestValidate(t *testing.T) { }, "untrusted PCRs": { validator: NewValidator( - map[uint32][]byte{ - 0: {0xFF}, + measurements.M{ + 0: measurements.Measurement{ + Expected: [32]byte{0xFF}, + WarnOnly: false, + }, }, - []uint32{0}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), @@ -173,7 +184,7 @@ func TestValidate(t *testing.T) { wantErr: true, }, "no sha256 quote": { - validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), + validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), attDoc: mustMarshalAttestation(AttestationDocument{ Attestation: &attest.Attestation{ AkPub: attDoc.Attestation.AkPub, @@ -191,7 +202,7 @@ func TestValidate(t *testing.T) { wantErr: true, }, "invalid attestation document": { - validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), + validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), attDoc: []byte("invalid attestation"), nonce: nonce, wantErr: true, @@ -350,7 +361,7 @@ func TestGetSHA256QuoteIndex(t *testing.T) { } } -func TestGetSelectedPCRs(t *testing.T) { +func TestGetSelectedMeasurements(t *testing.T) { testCases := map[string]struct { openFunc TPMOpenFunc pcrSelection tpm2.PCRSelection @@ -386,17 +397,13 @@ func TestGetSelectedPCRs(t *testing.T) { require := require.New(t) assert := assert.New(t) - pcrs, err := GetSelectedPCRs(tc.openFunc, tc.pcrSelection) + pcrs, err := GetSelectedMeasurements(tc.openFunc, tc.pcrSelection) if tc.wantErr { assert.Error(err) - } else { - require.NoError(err) - - assert.Equal(len(pcrs), len(tc.pcrSelection.PCRs)) - for _, pcr := range pcrs { - assert.Len(pcr, 32) - } + return } + require.NoError(err) + assert.Len(pcrs, len(tc.pcrSelection.PCRs)) }) } } diff --git a/internal/attestation/vtpm/initialize.go b/internal/attestation/vtpm/initialize.go index 171f73685..427f23647 100644 --- a/internal/attestation/vtpm/initialize.go +++ b/internal/attestation/vtpm/initialize.go @@ -9,18 +9,8 @@ package vtpm import ( "errors" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/google/go-tpm/tpm2" - "github.com/google/go-tpm/tpmutil" -) - -const ( - // PCRIndexClusterID is a PCR we extend to mark the node as initialized. - // The value used to extend is a random generated 32 Byte value. - PCRIndexClusterID = tpmutil.Handle(15) - // PCRIndexOwnerID is a PCR we extend to mark the node as initialized. - // The value used to extend is derived from Constellation's master key. - // TODO: move to stable, non-debug PCR before use. - PCRIndexOwnerID = tpmutil.Handle(16) ) // MarkNodeAsBootstrapped marks a node as initialized by extending PCRs. @@ -32,7 +22,7 @@ func MarkNodeAsBootstrapped(openTPM TPMOpenFunc, clusterID []byte) error { defer tpm.Close() // clusterID is used to uniquely identify this running instance of Constellation - return tpm2.PCREvent(tpm, PCRIndexClusterID, clusterID) + return tpm2.PCREvent(tpm, measurements.PCRIndexClusterID, clusterID) } // IsNodeBootstrapped checks if a node is already bootstrapped by reading PCRs. @@ -43,7 +33,7 @@ func IsNodeBootstrapped(openTPM TPMOpenFunc) (bool, error) { } defer tpm.Close() - idxClusterID := int(PCRIndexClusterID) + idxClusterID := int(measurements.PCRIndexClusterID) pcrs, err := tpm2.ReadPCRs(tpm, tpm2.PCRSelection{ Hash: tpm2.AlgSHA256, PCRs: []int{idxClusterID}, diff --git a/internal/attestation/vtpm/initialize_test.go b/internal/attestation/vtpm/initialize_test.go index 69b39580a..b9b2a8496 100644 --- a/internal/attestation/vtpm/initialize_test.go +++ b/internal/attestation/vtpm/initialize_test.go @@ -11,6 +11,7 @@ import ( "io" "testing" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/simulator" "github.com/google/go-tpm-tools/client" "github.com/google/go-tpm/tpm2" @@ -45,7 +46,7 @@ func TestMarkNodeAsBootstrapped(t *testing.T) { require.NoError(err) for i := range pcrs { - assert.NotEqual(pcrs[i].Pcrs[uint32(PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(PCRIndexClusterID)]) + assert.NotEqual(pcrs[i].Pcrs[uint32(measurements.PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(measurements.PCRIndexClusterID)]) } } @@ -76,7 +77,7 @@ func TestIsNodeInitialized(t *testing.T) { require.NoError(err) defer tpm.Close() if tc.pcrValueClusterID != nil { - require.NoError(tpm2.PCREvent(tpm, PCRIndexClusterID, tc.pcrValueClusterID)) + require.NoError(tpm2.PCREvent(tpm, measurements.PCRIndexClusterID, tc.pcrValueClusterID)) } initialized, err := IsNodeBootstrapped(func() (io.ReadWriteCloser, error) { return &simTPMNOPCloser{tpm}, nil diff --git a/internal/config/config.go b/internal/config/config.go index 9a82244e7..fb7910a90 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -138,10 +138,7 @@ type AWSConfig struct { IAMProfileWorkerNodes string `yaml:"iamProfileWorkerNodes" validate:"required"` // description: | // Expected VM measurements. - Measurements Measurements `yaml:"measurements"` - // description: | - // List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning. - EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"` + Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"` } // AzureConfig are Azure specific configuration values used by the CLI. @@ -190,10 +187,7 @@ type AzureConfig struct { EnforceIDKeyDigest *bool `yaml:"enforceIdKeyDigest" validate:"required"` // description: | // Expected confidential VM measurements. - Measurements Measurements `yaml:"measurements"` - // description: | - // List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning. - EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"` + Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"` } // GCPConfig are GCP specific configuration values used by the CLI. @@ -221,10 +215,7 @@ type GCPConfig struct { DeployCSIDriver *bool `yaml:"deployCSIDriver" validate:"required"` // description: | // Expected confidential VM measurements. - Measurements Measurements `yaml:"measurements"` - // description: | - // List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning. - EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"` + Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"` } // QEMUConfig holds config information for QEMU based Constellation deployments. @@ -255,10 +246,7 @@ type QEMUConfig struct { Firmware string `yaml:"firmware"` // description: | // Measurement used to enable measured boot. - Measurements Measurements `yaml:"measurements"` - // description: | - // List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning. - EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"` + Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"` } // Default returns a struct with the default config. @@ -276,7 +264,6 @@ func Default() *Config { IAMProfileControlPlane: "", IAMProfileWorkerNodes: "", Measurements: measurements.DefaultsFor(cloudprovider.AWS), - EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15}, }, Azure: &AzureConfig{ SubscriptionID: "", @@ -292,7 +279,6 @@ func Default() *Config { ConfidentialVM: func() *bool { b := true; return &b }(), SecureBoot: func() *bool { b := false; return &b }(), Measurements: measurements.DefaultsFor(cloudprovider.Azure), - EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15}, }, GCP: &GCPConfig{ Project: "", @@ -303,7 +289,6 @@ func Default() *Config { StateDiskType: "pd-ssd", DeployCSIDriver: func() *bool { b := true; return &b }(), Measurements: measurements.DefaultsFor(cloudprovider.GCP), - EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15}, }, QEMU: &QEMUConfig{ ImageFormat: "raw", @@ -314,7 +299,6 @@ func Default() *Config { LibvirtContainerImage: versions.LibvirtImage, NVRAM: "production", Measurements: measurements.DefaultsFor(cloudprovider.QEMU), - EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15}, }, }, KubernetesVersion: string(versions.Default), @@ -446,18 +430,18 @@ func (c *Config) EnforcesIDKeyDigest() bool { return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest } -// GetEnforcedPCRs returns the list of enforced PCRs for the configured cloud provider. -func (c *Config) GetEnforcedPCRs() []uint32 { +// EnforcedPCRs returns the list of enforced PCRs for the configured cloud provider. +func (c *Config) EnforcedPCRs() []uint32 { provider := c.GetProvider() switch provider { case cloudprovider.AWS: - return c.Provider.AWS.EnforcedMeasurements + return c.Provider.AWS.Measurements.GetEnforced() case cloudprovider.Azure: - return c.Provider.Azure.EnforcedMeasurements + return c.Provider.Azure.Measurements.GetEnforced() case cloudprovider.GCP: - return c.Provider.GCP.EnforcedMeasurements + return c.Provider.GCP.Measurements.GetEnforced() case cloudprovider.QEMU: - return c.Provider.QEMU.EnforcedMeasurements + return c.Provider.QEMU.Measurements.GetEnforced() default: return nil } @@ -499,6 +483,14 @@ func (c *Config) Validate() error { return err } + if err := validate.RegisterTranslation("no_placeholders", trans, registerContainsPlaceholderError, translateContainsPlaceholderError); err != nil { + return err + } + + if err := validate.RegisterValidation("no_placeholders", validateNoPlaceholder); err != nil { + return err + } + if err := validate.RegisterValidation("safe_image", validateImage); err != nil { return err } diff --git a/internal/config/config_doc.go b/internal/config/config_doc.go index 19521340e..ba8b2307c 100644 --- a/internal/config/config_doc.go +++ b/internal/config/config_doc.go @@ -157,7 +157,7 @@ func init() { FieldName: "aws", }, } - AWSConfigDoc.Fields = make([]encoder.Doc, 8) + AWSConfigDoc.Fields = make([]encoder.Doc, 7) AWSConfigDoc.Fields[0].Name = "region" AWSConfigDoc.Fields[0].Type = "string" AWSConfigDoc.Fields[0].Note = "" @@ -193,11 +193,6 @@ func init() { AWSConfigDoc.Fields[6].Note = "" AWSConfigDoc.Fields[6].Description = "Expected VM measurements." AWSConfigDoc.Fields[6].Comments[encoder.LineComment] = "Expected VM measurements." - AWSConfigDoc.Fields[7].Name = "enforcedMeasurements" - AWSConfigDoc.Fields[7].Type = "[]uint32" - AWSConfigDoc.Fields[7].Note = "" - AWSConfigDoc.Fields[7].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." - AWSConfigDoc.Fields[7].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." AzureConfigDoc.Type = "AzureConfig" AzureConfigDoc.Comments[encoder.LineComment] = "AzureConfig are Azure specific configuration values used by the CLI." @@ -208,7 +203,7 @@ func init() { FieldName: "azure", }, } - AzureConfigDoc.Fields = make([]encoder.Doc, 16) + AzureConfigDoc.Fields = make([]encoder.Doc, 15) AzureConfigDoc.Fields[0].Name = "subscription" AzureConfigDoc.Fields[0].Type = "string" AzureConfigDoc.Fields[0].Note = "" @@ -284,11 +279,6 @@ func init() { AzureConfigDoc.Fields[14].Note = "" AzureConfigDoc.Fields[14].Description = "Expected confidential VM measurements." AzureConfigDoc.Fields[14].Comments[encoder.LineComment] = "Expected confidential VM measurements." - AzureConfigDoc.Fields[15].Name = "enforcedMeasurements" - AzureConfigDoc.Fields[15].Type = "[]uint32" - AzureConfigDoc.Fields[15].Note = "" - AzureConfigDoc.Fields[15].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." - AzureConfigDoc.Fields[15].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." GCPConfigDoc.Type = "GCPConfig" GCPConfigDoc.Comments[encoder.LineComment] = "GCPConfig are GCP specific configuration values used by the CLI." @@ -299,7 +289,7 @@ func init() { FieldName: "gcp", }, } - GCPConfigDoc.Fields = make([]encoder.Doc, 9) + GCPConfigDoc.Fields = make([]encoder.Doc, 8) GCPConfigDoc.Fields[0].Name = "project" GCPConfigDoc.Fields[0].Type = "string" GCPConfigDoc.Fields[0].Note = "" @@ -340,11 +330,6 @@ func init() { GCPConfigDoc.Fields[7].Note = "" GCPConfigDoc.Fields[7].Description = "Expected confidential VM measurements." GCPConfigDoc.Fields[7].Comments[encoder.LineComment] = "Expected confidential VM measurements." - GCPConfigDoc.Fields[8].Name = "enforcedMeasurements" - GCPConfigDoc.Fields[8].Type = "[]uint32" - GCPConfigDoc.Fields[8].Note = "" - GCPConfigDoc.Fields[8].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." - GCPConfigDoc.Fields[8].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." QEMUConfigDoc.Type = "QEMUConfig" QEMUConfigDoc.Comments[encoder.LineComment] = "QEMUConfig holds config information for QEMU based Constellation deployments." @@ -355,7 +340,7 @@ func init() { FieldName: "qemu", }, } - QEMUConfigDoc.Fields = make([]encoder.Doc, 10) + QEMUConfigDoc.Fields = make([]encoder.Doc, 9) QEMUConfigDoc.Fields[0].Name = "imageFormat" QEMUConfigDoc.Fields[0].Type = "string" QEMUConfigDoc.Fields[0].Note = "" @@ -401,11 +386,6 @@ func init() { QEMUConfigDoc.Fields[8].Note = "" QEMUConfigDoc.Fields[8].Description = "Measurement used to enable measured boot." QEMUConfigDoc.Fields[8].Comments[encoder.LineComment] = "Measurement used to enable measured boot." - QEMUConfigDoc.Fields[9].Name = "enforcedMeasurements" - QEMUConfigDoc.Fields[9].Type = "[]uint32" - QEMUConfigDoc.Fields[9].Note = "" - QEMUConfigDoc.Fields[9].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." - QEMUConfigDoc.Fields[9].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning." } func (_ Config) Doc() *encoder.Doc { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 2ca777ee2..1b8ad12ff 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -128,6 +128,7 @@ func TestNewWithDefaultOptions(t *testing.T) { c.Provider.Azure.ResourceGroup = "test" c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity" c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e" + c.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)} return c }(), envToSet: map[string]string{ @@ -147,6 +148,7 @@ func TestNewWithDefaultOptions(t *testing.T) { c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well. c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity" c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e" + c.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)} return c }(), envToSet: map[string]string{ @@ -182,9 +184,9 @@ func TestNewWithDefaultOptions(t *testing.T) { } func TestValidate(t *testing.T) { - const defaultErrCount = 17 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default - const azErrCount = 8 - const gcpErrCount = 5 + const defaultErrCount = 21 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default + const azErrCount = 9 + const gcpErrCount = 6 testCases := map[string]struct { cnf *Config @@ -240,6 +242,7 @@ func TestValidate(t *testing.T) { az.ClientSecretValue = "test-client-secret" cnf.Provider = ProviderConfig{} cnf.Provider.Azure = az + cnf.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)} return cnf }(), }, @@ -265,6 +268,7 @@ func TestValidate(t *testing.T) { gcp.ServiceAccountKeyPath = "test-key-path" cnf.Provider = ProviderConfig{} cnf.Provider.GCP = gcp + cnf.Provider.GCP.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)} return cnf }(), }, @@ -364,9 +368,9 @@ func TestConfigGeneratedDocsFresh(t *testing.T) { func TestConfig_UpdateMeasurements(t *testing.T) { assert := assert.New(t) newMeasurements := measurements.M{ - 1: measurements.PCRWithAllBytes(0x00), - 2: measurements.PCRWithAllBytes(0x01), - 3: measurements.PCRWithAllBytes(0x02), + 1: measurements.WithAllBytes(0x00, false), + 2: measurements.WithAllBytes(0x01, false), + 3: measurements.WithAllBytes(0x02, false), } { // AWS diff --git a/internal/config/validation.go b/internal/config/validation.go index c062bab4a..0633559d2 100644 --- a/internal/config/validation.go +++ b/internal/config/validation.go @@ -7,9 +7,11 @@ SPDX-License-Identifier: AGPL-3.0-only package config import ( + "bytes" "fmt" "strings" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config/instancetypes" "github.com/edgelesssys/constellation/v2/internal/versions" @@ -223,3 +225,35 @@ func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator. return t } + +func validateNoPlaceholder(fl validator.FieldLevel) bool { + return len(getPlaceholderEntries(fl.Field().Interface().(Measurements))) == 0 +} + +func registerContainsPlaceholderError(ut ut.Translator) error { + return ut.Add("no_placeholders", "{0} placeholder values (repeated 1234...)", true) +} + +func translateContainsPlaceholderError(ut ut.Translator, fe validator.FieldError) string { + placeholders := getPlaceholderEntries(fe.Value().(Measurements)) + msg := fmt.Sprintf("Measurements %v contain", placeholders) + if len(placeholders) == 1 { + msg = fmt.Sprintf("Measurement %v contains", placeholders) + } + + t, _ := ut.T("no_placeholders", msg) + return t +} + +func getPlaceholderEntries(m Measurements) []uint32 { + var placeholders []uint32 + placeholder := measurements.PlaceHolderMeasurement() + + for idx, measurement := range m { + if bytes.Equal(measurement.Expected[:], placeholder.Expected[:]) { + placeholders = append(placeholders, idx) + } + } + + return placeholders +} diff --git a/internal/watcher/validator.go b/internal/watcher/validator.go index ffd49db13..06bcafc48 100644 --- a/internal/watcher/validator.go +++ b/internal/watcher/validator.go @@ -9,7 +9,9 @@ package watcher import ( "encoding/asn1" "encoding/hex" + "errors" "fmt" + "os" "path/filepath" "strconv" "sync" @@ -19,6 +21,7 @@ import ( "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -42,26 +45,26 @@ func NewValidator(log *logger.Logger, csp string, fileHandler file.Handler, azur var newValidator newValidatorFunc switch cloudprovider.FromString(csp) { case cloudprovider.AWS: - newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { - return aws.NewValidator(m, e, log) + newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator { + return aws.NewValidator(m, log) } case cloudprovider.Azure: if azureCVM { - newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { - return snp.NewValidator(m, e, idkeydigest, enforceIdKeyDigest, log) + newValidator = func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { + return snp.NewValidator(m, idkeydigest, enforceIdKeyDigest, log) } } else { - newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { - return trustedlaunch.NewValidator(m, e, log) + newValidator = func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { + return trustedlaunch.NewValidator(m, log) } } case cloudprovider.GCP: - newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { - return gcp.NewValidator(m, e, log) + newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator { + return gcp.NewValidator(m, log) } case cloudprovider.QEMU: - newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { - return qemu.NewValidator(m, e, log) + newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator { + return qemu.NewValidator(m, log) } default: return nil, fmt.Errorf("unknown cloud service provider: %q", csp) @@ -100,17 +103,24 @@ func (u *Updatable) Update() error { u.log.Infof("Updating expected measurements") - var measurements map[uint32][]byte + var measurements measurements.M if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), &measurements); err != nil { return err } - u.log.Debugf("New measurements: %v", measurements) + u.log.Debugf("New measurements: %+v", measurements) + // handle legacy measurement format, where expected measurements and enforced measurements were stored in separate data structures + // TODO: remove with v2.4.0 var enforced []uint32 - if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err != nil { + if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err == nil { + u.log.Debugf("Detected legacy format. Loading enforced PCRs...") + if err := measurements.SetEnforced(enforced); err != nil { + return err + } + u.log.Debugf("Merged measurements with enforced values: %+v", measurements) + } else if !errors.Is(err, os.ErrNotExist) { return err } - u.log.Debugf("Enforced PCRs: %v", enforced) var idkeydigest []byte var enforceIDKeyDigest bool @@ -138,9 +148,9 @@ func (u *Updatable) Update() error { u.log.Debugf("New idkeydigest: %x", idkeydigest) } - u.Validator = u.newValidator(measurements, enforced, idkeydigest, enforceIDKeyDigest, u.log) + u.Validator = u.newValidator(measurements, idkeydigest, enforceIDKeyDigest, u.log) return nil } -type newValidatorFunc func(measurements map[uint32][]byte, enforcedPCRs []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator +type newValidatorFunc func(measurements measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator diff --git a/internal/watcher/validator_test.go b/internal/watcher/validator_test.go index 2cb1b91aa..58badd528 100644 --- a/internal/watcher/validator_test.go +++ b/internal/watcher/validator_test.go @@ -20,6 +20,7 @@ import ( "testing" "github.com/edgelesssys/constellation/v2/internal/atls" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/logger" @@ -117,7 +118,7 @@ func TestUpdate(t *testing.T) { require := require.New(t) oid := fakeOID{1, 3, 9900, 1} - newValidator := func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { + newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { return fakeValidator{fakeOID: oid} } handler := file.NewHandler(afero.NewMemMapFs()) @@ -135,14 +136,7 @@ func TestUpdate(t *testing.T) { // write measurement config require.NoError(handler.WriteJSON( filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), - map[uint32][]byte{ - 11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, - }, - file.OptNone, - )) - require.NoError(handler.WriteJSON( - filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), - []uint32{11}, + measurements.M{11: measurements.WithAllBytes(0x00, false)}, )) require.NoError(handler.Write( filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), @@ -189,6 +183,23 @@ func TestUpdate(t *testing.T) { defer resp.Body.Close() } assert.Error(err) + + // update should work for legacy measurement format + // TODO: remove with v2.4.0 + require.NoError(handler.WriteJSON( + filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), + map[uint32][]byte{ + 11: bytes.Repeat([]byte{0x0}, 32), + 12: bytes.Repeat([]byte{0x1}, 32), + }, + file.OptOverwrite, + )) + require.NoError(handler.WriteJSON( + filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), + []uint32{11}, + )) + + assert.NoError(validator.Update()) } func TestUpdateConcurrency(t *testing.T) { @@ -199,7 +210,7 @@ func TestUpdateConcurrency(t *testing.T) { validator := &Updatable{ log: logger.NewTest(t), fileHandler: handler, - newValidator: func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { + newValidator: func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}} }, }