Refactor enforced/expected PCRs (#553)

* Merge enforced and expected measurements

* Update measurement generation to new format

* Write expected measurements hex encoded by default

* Allow hex or base64 encoded expected measurements

* Allow hex or base64 encoded clusterID

* Allow security upgrades to warnOnly flag

* Upload signed measurements in JSON format

* Fetch measurements either from JSON or YAML

* Use yaml.v3 instead of yaml.v2

* Error on invalid enforced selection

* Add placeholder measurements to config

* Update e2e test to new measurement format

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-11-24 10:57:58 +01:00 committed by GitHub
parent 8ce954e012
commit f8001efbc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1180 additions and 801 deletions

View file

@ -75,14 +75,14 @@ runs:
(.provider | select(. | has(\"azure\")).azure.resourceGroup) = \"${{ inputs.azureResourceGroup }}\" | (.provider | select(. | has(\"azure\")).azure.resourceGroup) = \"${{ inputs.azureResourceGroup }}\" |
(.provider | select(. | has(\"azure\")).azure.appClientID) = \"${{ inputs.azureClientID }}\" | (.provider | select(. | has(\"azure\")).azure.appClientID) = \"${{ inputs.azureClientID }}\" |
(.provider | select(. | has(\"azure\")).azure.clientSecretValue) = \"${{ inputs.azureClientSecret }}\" | (.provider | select(. | has(\"azure\")).azure.clientSecretValue) = \"${{ inputs.azureClientSecret }}\" |
(.provider | select(. | has(\"azure\")).azure.enforcedMeasurements) = [15]" \ (.provider | select(. | has(\"azure\")).azure.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}}" \
constellation-conf.yaml constellation-conf.yaml
yq eval -i \ yq eval -i \
"(.provider | select(. | has(\"gcp\")).gcp.project) = \"${{ inputs.gcpProject }}\" | "(.provider | select(. | has(\"gcp\")).gcp.project) = \"${{ inputs.gcpProject }}\" |
(.provider | select(. | has(\"gcp\")).gcp.region) = \"europe-west3\" | (.provider | select(. | has(\"gcp\")).gcp.region) = \"europe-west3\" |
(.provider | select(. | has(\"gcp\")).gcp.zone) = \"europe-west3-b\" | (.provider | select(. | has(\"gcp\")).gcp.zone) = \"europe-west3-b\" |
(.provider | select(. | has(\"gcp\")).gcp.enforcedMeasurements) = [15] | (.provider | select(. | has(\"gcp\")).gcp.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}} |
(.provider | select(. | has(\"gcp\")).gcp.serviceAccountKeyPath) = \"serviceAccountKey.json\"" \ (.provider | select(. | has(\"gcp\")).gcp.serviceAccountKeyPath) = \"serviceAccountKey.json\"" \
constellation-conf.yaml constellation-conf.yaml
@ -91,7 +91,7 @@ runs:
(.provider | select(. | has(\"aws\")).aws.zone) = \"eu-central-1a\" | (.provider | select(. | has(\"aws\")).aws.zone) = \"eu-central-1a\" |
(.provider | select(. | has(\"aws\")).aws.iamProfileControlPlane) = \"e2e_test_control_plane_instance_profile\" | (.provider | select(. | has(\"aws\")).aws.iamProfileControlPlane) = \"e2e_test_control_plane_instance_profile\" |
(.provider | select(. | has(\"aws\")).aws.iamProfileWorkerNodes) = \"e2e_test_worker_node_instance_profile\" | (.provider | select(. | has(\"aws\")).aws.iamProfileWorkerNodes) = \"e2e_test_worker_node_instance_profile\" |
(.provider | select(. | has(\"aws\")).aws.enforcedMeasurements) = [15]" \ (.provider | select(. | has(\"aws\")).aws.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}}" \
constellation-conf.yaml constellation-conf.yaml
if [ ${{ inputs.kubernetesVersion != '' }} = true ]; then if [ ${{ inputs.kubernetesVersion != '' }} = true ]; then

View file

@ -51,16 +51,35 @@ runs:
run: | run: |
KUBECONFIG="$PWD/constellation-admin.conf" kubectl rollout status ds/verification-service -n kube-system --timeout=3m KUBECONFIG="$PWD/constellation-admin.conf" kubectl rollout status ds/verification-service -n kube-system --timeout=3m
CONSTELL_IP=$(jq -r ".ip" constellation-id.json) CONSTELL_IP=$(jq -r ".ip" constellation-id.json)
pcr-reader --constell-ip ${CONSTELL_IP} -format yaml > measurements.yaml pcr-reader --constell-ip ${CONSTELL_IP} -format json > measurements.json
case $CSP in case $CSP in
azure) azure)
yq e 'del(.[0,6,10,11,12,13,14,15,16,17,18,19,20,21,22,23])' -i measurements.yaml yq e 'del(.[0,6,10,16,17,18,19,20,21,22,23])' -I 0 -o json -i measurements.json
yq '.4.warnOnly = false |
.8.warnOnly = false |
.9.warnOnly = false |
.11.warnOnly = false |
.12.warnOnly = false |
.13.warnOnly = false |
.15.warnOnly = false |
.15.expected = "0000000000000000000000000000000000000000000000000000000000000000"' \
-I 0 -o json -i measurements.json
;; ;;
gcp) gcp)
yq e 'del(.[11,12,13,14,15,16,17,18,19,20,21,22,23])' -i measurements.yaml yq e 'del(.[16,17,18,19,20,21,22,23])' -I 0 -o json -i measurements.json
yq '.0.warnOnly = false |
.4.warnOnly = false |
.8.warnOnly = false |
.9.warnOnly = false |
.11.warnOnly = false |
.12.warnOnly = false |
.13.warnOnly = false |
.15.warnOnly = false |
.15.expected = "0000000000000000000000000000000000000000000000000000000000000000"' \
-I 0 -o json -i measurements.json
;; ;;
esac esac
cat measurements.yaml cat measurements.json
shell: bash shell: bash
env: env:
CSP: ${{ inputs.cloudProvider }} CSP: ${{ inputs.cloudProvider }}
@ -81,14 +100,14 @@ runs:
run: | run: |
echo "$COSIGN_PUBLIC_KEY" > cosign.pub echo "$COSIGN_PUBLIC_KEY" > cosign.pub
# Enabling experimental mode also publishes signature to Rekor # Enabling experimental mode also publishes signature to Rekor
COSIGN_EXPERIMENTAL=1 cosign sign-blob --key env://COSIGN_PRIVATE_KEY measurements.yaml > measurements.yaml.sig COSIGN_EXPERIMENTAL=1 cosign sign-blob --key env://COSIGN_PRIVATE_KEY measurements.json > measurements.json.sig
# Verify - As documentation & check # Verify - As documentation & check
# Local Signature (input: artifact, key, signature) # Local Signature (input: artifact, key, signature)
cosign verify-blob --key cosign.pub --signature measurements.yaml.sig measurements.yaml cosign verify-blob --key cosign.pub --signature measurements.json.sig measurements.json
# Transparency Log Signature (input: artifact, key) # Transparency Log Signature (input: artifact, key)
uuid=$(rekor-cli search --artifact measurements.yaml | tail -n 1) uuid=$(rekor-cli search --artifact measurements.json | tail -n 1)
sig=$(rekor-cli get --uuid=$uuid --format=json | jq -r .Body.HashedRekordObj.signature.content) sig=$(rekor-cli get --uuid=$uuid --format=json | jq -r .Body.HashedRekordObj.signature.content)
cosign verify-blob --key cosign.pub --signature <(echo $sig) measurements.yaml cosign verify-blob --key cosign.pub --signature <(echo $sig) measurements.json
shell: bash shell: bash
env: env:
COSIGN_PUBLIC_KEY: ${{ inputs.cosignPublicKey }} COSIGN_PUBLIC_KEY: ${{ inputs.cosignPublicKey }}
@ -100,9 +119,9 @@ runs:
run: | run: |
IMAGE=$(yq e ".provider.${CSP}.image" constellation-conf.yaml) IMAGE=$(yq e ".provider.${CSP}.image" constellation-conf.yaml)
S3_PATH=s3://${PUBLIC_BUCKET_NAME}/${IMAGE,,} S3_PATH=s3://${PUBLIC_BUCKET_NAME}/${IMAGE,,}
aws s3 cp measurements.yaml ${S3_PATH}/measurements.yaml aws s3 cp measurements.json ${S3_PATH}/measurements.json
if test -f measurements.yaml.sig; then if test -f measurements.json.sig; then
aws s3 cp measurements.yaml.sig ${S3_PATH}/measurements.yaml.sig aws s3 cp measurements.json.sig ${S3_PATH}/measurements.json.sig
fi fi
shell: bash shell: bash
env: env:

View file

@ -8,7 +8,6 @@ package main
import ( import (
"context" "context"
"encoding/json"
"flag" "flag"
"io" "io"
"os" "os"
@ -80,14 +79,10 @@ func main() {
switch cloudprovider.FromString(os.Getenv(constellationCSP)) { switch cloudprovider.FromString(os.Getenv(constellationCSP)) {
case cloudprovider.AWS: case cloudprovider.AWS:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.AWSPCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AWSPCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
issuer = initserver.NewIssuerWrapper(aws.NewIssuer(), vmtype.Unknown, nil) issuer = initserver.NewIssuerWrapper(aws.NewIssuer(), vmtype.Unknown, nil)
@ -104,13 +99,13 @@ func main() {
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"aws", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "aws", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
openTPM = vtpm.OpenVTPM openTPM = vtpm.OpenVTPM
fs = afero.NewOsFs() fs = afero.NewOsFs()
case cloudprovider.GCP: case cloudprovider.GCP:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.GCPPCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.GCPPCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
@ -129,20 +124,16 @@ func main() {
} }
metadataAPI = metadata metadataAPI = metadata
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"gcp", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "gcp", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
openTPM = vtpm.OpenVTPM openTPM = vtpm.OpenVTPM
fs = afero.NewOsFs() fs = afero.NewOsFs()
log.Infof("Added load balancer IP to routing table") log.Infof("Added load balancer IP to routing table")
case cloudprovider.Azure: case cloudprovider.Azure:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.AzurePCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AzurePCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
@ -163,20 +154,16 @@ func main() {
log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger") log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger")
} }
metadataAPI = metadata metadataAPI = metadata
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"azure", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "azure", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
openTPM = vtpm.OpenVTPM openTPM = vtpm.OpenVTPM
fs = afero.NewOsFs() fs = afero.NewOsFs()
case cloudprovider.QEMU: case cloudprovider.QEMU:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.QEMUPCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.QEMUPCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
@ -185,13 +172,9 @@ func main() {
cloudLogger = qemucloud.NewLogger() cloudLogger = qemucloud.NewLogger()
metadata := qemucloud.New() metadata := qemucloud.New()
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"qemu", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "qemu", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
metadataAPI = metadata metadataAPI = metadata

View file

@ -27,9 +27,9 @@ message InitRequest {
} }
message InitResponse { message InitResponse {
bytes kubeconfig = 1; bytes kubeconfig = 1;
bytes owner_id = 2; bytes owner_id = 2;
bytes cluster_id = 3; bytes cluster_id = 3;
} }
message KubernetesComponent { message KubernetesComponent {

View file

@ -20,6 +20,7 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/k8sapi" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/k8sapi"
kubewaiter "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/kubeWaiter" kubewaiter "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/kubeWaiter"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared" "github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
@ -52,33 +53,33 @@ type kubeAPIWaiter interface {
// KubeWrapper implements Cluster interface. // KubeWrapper implements Cluster interface.
type KubeWrapper struct { type KubeWrapper struct {
cloudProvider string cloudProvider string
clusterUtil clusterUtil clusterUtil clusterUtil
helmClient helmClient helmClient helmClient
kubeAPIWaiter kubeAPIWaiter kubeAPIWaiter kubeAPIWaiter
configProvider configurationProvider configProvider configurationProvider
client k8sapi.Client client k8sapi.Client
kubeconfigReader configReader kubeconfigReader configReader
providerMetadata ProviderMetadata providerMetadata ProviderMetadata
initialMeasurementsJSON []byte initialMeasurements measurements.M
getIPAddr func() (string, error) getIPAddr func() (string, error)
} }
// New creates a new KubeWrapper with real values. // New creates a new KubeWrapper with real values.
func New(cloudProvider string, clusterUtil clusterUtil, configProvider configurationProvider, client k8sapi.Client, func New(cloudProvider string, clusterUtil clusterUtil, configProvider configurationProvider, client k8sapi.Client,
providerMetadata ProviderMetadata, initialMeasurementsJSON []byte, helmClient helmClient, kubeAPIWaiter kubeAPIWaiter, providerMetadata ProviderMetadata, measurements measurements.M, helmClient helmClient, kubeAPIWaiter kubeAPIWaiter,
) *KubeWrapper { ) *KubeWrapper {
return &KubeWrapper{ return &KubeWrapper{
cloudProvider: cloudProvider, cloudProvider: cloudProvider,
clusterUtil: clusterUtil, clusterUtil: clusterUtil,
helmClient: helmClient, helmClient: helmClient,
kubeAPIWaiter: kubeAPIWaiter, kubeAPIWaiter: kubeAPIWaiter,
configProvider: configProvider, configProvider: configProvider,
client: client, client: client,
kubeconfigReader: &KubeconfigReader{fs: afero.Afero{Fs: afero.NewOsFs()}}, kubeconfigReader: &KubeconfigReader{fs: afero.Afero{Fs: afero.NewOsFs()}},
providerMetadata: providerMetadata, providerMetadata: providerMetadata,
initialMeasurementsJSON: initialMeasurementsJSON, initialMeasurements: measurements,
getIPAddr: getIPAddr, getIPAddr: getIPAddr,
} }
} }
@ -187,7 +188,21 @@ func (k *KubeWrapper) InitCluster(
} else { } else {
controlPlaneIP = controlPlaneEndpoint controlPlaneIP = controlPlaneEndpoint
} }
serviceConfig := constellationServicesConfig{k.initialMeasurementsJSON, idKeyDigest, measurementSalt, subnetworkPodCIDR, cloudServiceAccountURI, controlPlaneIP} if err := k.initialMeasurements.SetEnforced(enforcedPCRs); err != nil {
return nil, err
}
measurementsJSON, err := json.Marshal(k.initialMeasurements)
if err != nil {
return nil, fmt.Errorf("marshaling initial measurements: %w", err)
}
serviceConfig := constellationServicesConfig{
initialMeasurementsJSON: measurementsJSON,
idkeydigest: idKeyDigest,
measurementSalt: measurementSalt,
subnetworkPodCIDR: subnetworkPodCIDR,
cloudServiceAccountURI: cloudServiceAccountURI,
loadBalancerIP: controlPlaneIP,
}
extraVals, err := k.setupExtraVals(ctx, serviceConfig) extraVals, err := k.setupExtraVals(ctx, serviceConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("setting up extraVals: %w", err) return nil, fmt.Errorf("setting up extraVals: %w", err)

View file

@ -112,6 +112,13 @@ func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measu
return nil return nil
} }
// don't allow potential security downgrades by setting the warnOnly flag to true
for k, newM := range newMeasurements {
if currentM, ok := currentMeasurements[k]; ok && !currentM.WarnOnly && newM.WarnOnly {
return fmt.Errorf("setting enforced measurement %d to warn only: not allowed", k)
}
}
// backup of previous measurements // backup of previous measurements
existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename] existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename]

View file

@ -33,12 +33,12 @@ func TestUpdateMeasurements(t *testing.T) {
updater: &stubMeasurementsUpdater{ updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{ oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{ Data: map[string]string{
constants.MeasurementsFilename: `{"0":"AAAAAA=="}`, constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
}, },
}, },
}, },
newMeasurements: measurements.M{ newMeasurements: measurements.M{
0: []byte("1"), 0: measurements.WithAllBytes(0xBB, false),
}, },
wantUpdate: true, wantUpdate: true,
}, },
@ -46,14 +46,40 @@ func TestUpdateMeasurements(t *testing.T) {
updater: &stubMeasurementsUpdater{ updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{ oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{ Data: map[string]string{
constants.MeasurementsFilename: `{"0":"MQ=="}`, constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
}, },
}, },
}, },
newMeasurements: measurements.M{ newMeasurements: measurements.M{
0: []byte("1"), 0: measurements.WithAllBytes(0xAA, false),
}, },
}, },
"trying to set warnOnly to true results in error": {
updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{
constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
},
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, true),
},
wantErr: true,
},
"setting warnOnly to false is allowed": {
updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{
constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":true}}`,
},
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, false),
},
wantUpdate: true,
},
"getCurrent error": { "getCurrent error": {
updater: &stubMeasurementsUpdater{getErr: someErr}, updater: &stubMeasurementsUpdater{getErr: someErr},
wantErr: true, wantErr: true,
@ -62,7 +88,7 @@ func TestUpdateMeasurements(t *testing.T) {
updater: &stubMeasurementsUpdater{ updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{ oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{ Data: map[string]string{
constants.MeasurementsFilename: `{"0":"AAAAAA=="}`, constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
}, },
}, },
updateErr: someErr, updateErr: someErr,
@ -82,7 +108,7 @@ func TestUpdateMeasurements(t *testing.T) {
err := upgrader.updateMeasurements(context.Background(), tc.newMeasurements) err := upgrader.updateMeasurements(context.Background(), tc.newMeasurements)
if tc.wantErr { if tc.wantErr {
assert.ErrorIs(err, someErr) assert.Error(err)
return return
} }

View file

@ -20,17 +20,16 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.uber.org/multierr"
) )
// Validator validates Platform Configuration Registers (PCRs). // Validator validates Platform Configuration Registers (PCRs).
type Validator struct { type Validator struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs measurements.M pcrs measurements.M
enforcedPCRs []uint32
idkeydigest []byte idkeydigest []byte
enforceIDKeyDigest bool enforceIDKeyDigest bool
azureCVM bool azureCVM bool
@ -65,41 +64,44 @@ func NewValidator(provider cloudprovider.Provider, conf *config.Config) (*Valida
// UpdateInitPCRs sets the owner and cluster PCR values. // UpdateInitPCRs sets the owner and cluster PCR values.
func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error { func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error {
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil { if err := v.updatePCR(uint32(measurements.PCRIndexOwnerID), ownerID); err != nil {
return err return err
} }
return v.updatePCR(uint32(vtpm.PCRIndexClusterID), clusterID) return v.updatePCR(uint32(measurements.PCRIndexClusterID), clusterID)
} }
// updatePCR adds a new entry to the pcr map of v, or removes the key if the input is an empty string. // updatePCR adds a new entry to the measurements of v, or removes the key if the input is an empty string.
// //
// When adding, the input is first decoded from base64. // When adding, the input is first decoded from hex or base64.
// We then calculate the expected PCR by hashing the input using SHA256, // We then calculate the expected PCR by hashing the input using SHA256,
// appending expected PCR for initialization, and then hashing once more. // appending expected PCR for initialization, and then hashing once more.
func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error { func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error {
if encoded == "" { if encoded == "" {
delete(v.pcrs, pcrIndex) delete(v.pcrs, pcrIndex)
// remove enforced PCR if it exists
for i, enforcedIdx := range v.enforcedPCRs {
if enforcedIdx == pcrIndex {
v.enforcedPCRs[i] = v.enforcedPCRs[len(v.enforcedPCRs)-1]
v.enforcedPCRs = v.enforcedPCRs[:len(v.enforcedPCRs)-1]
break
}
}
return nil return nil
} }
decoded, err := base64.StdEncoding.DecodeString(encoded)
// decode from hex or base64
decoded, err := hex.DecodeString(encoded)
if err != nil { if err != nil {
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err) hexErr := err
decoded, err = base64.StdEncoding.DecodeString(encoded)
if err != nil {
return multierr.Append(
fmt.Errorf("input [%s] is not hex encoded: %w", encoded, hexErr),
fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err),
)
}
} }
// new_pcr_value := hash(old_pcr_value || data_to_extend) // new_pcr_value := hash(old_pcr_value || data_to_extend)
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
hashedInput := sha256.Sum256(decoded) hashedInput := sha256.Sum256(decoded)
expectedPcr := sha256.Sum256(append(v.pcrs[pcrIndex], hashedInput[:]...)) oldExpected := v.pcrs[pcrIndex].Expected
v.pcrs[pcrIndex] = expectedPcr[:] expectedPcr := sha256.Sum256(append(oldExpected[:], hashedInput[:]...))
v.pcrs[pcrIndex] = measurements.Measurement{
Expected: expectedPcr,
WarnOnly: v.pcrs[pcrIndex].WarnOnly,
}
return nil return nil
} }
@ -107,35 +109,27 @@ func (v *Validator) setPCRs(config *config.Config) error {
switch v.provider { switch v.provider {
case cloudprovider.AWS: case cloudprovider.AWS:
awsPCRs := config.Provider.AWS.Measurements awsPCRs := config.Provider.AWS.Measurements
enforcedPCRs := config.Provider.AWS.EnforcedMeasurements if len(awsPCRs) == 0 {
if err := v.checkPCRs(awsPCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = awsPCRs v.pcrs = awsPCRs
case cloudprovider.Azure: case cloudprovider.Azure:
azurePCRs := config.Provider.Azure.Measurements azurePCRs := config.Provider.Azure.Measurements
enforcedPCRs := config.Provider.Azure.EnforcedMeasurements if len(azurePCRs) == 0 {
if err := v.checkPCRs(azurePCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = azurePCRs v.pcrs = azurePCRs
case cloudprovider.GCP: case cloudprovider.GCP:
gcpPCRs := config.Provider.GCP.Measurements gcpPCRs := config.Provider.GCP.Measurements
enforcedPCRs := config.Provider.GCP.EnforcedMeasurements if len(gcpPCRs) == 0 {
if err := v.checkPCRs(gcpPCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = gcpPCRs v.pcrs = gcpPCRs
case cloudprovider.QEMU: case cloudprovider.QEMU:
qemuPCRs := config.Provider.QEMU.Measurements qemuPCRs := config.Provider.QEMU.Measurements
enforcedPCRs := config.Provider.QEMU.EnforcedMeasurements if len(qemuPCRs) == 0 {
if err := v.checkPCRs(qemuPCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = qemuPCRs v.pcrs = qemuPCRs
} }
return nil return nil
@ -156,37 +150,20 @@ func (v *Validator) updateValidator(cmd *cobra.Command) {
log := warnLogger{cmd: cmd} log := warnLogger{cmd: cmd}
switch v.provider { switch v.provider {
case cloudprovider.GCP: case cloudprovider.GCP:
v.validator = gcp.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = gcp.NewValidator(v.pcrs, log)
case cloudprovider.Azure: case cloudprovider.Azure:
if v.azureCVM { if v.azureCVM {
v.validator = snp.NewValidator(v.pcrs, v.enforcedPCRs, v.idkeydigest, v.enforceIDKeyDigest, log) v.validator = snp.NewValidator(v.pcrs, v.idkeydigest, v.enforceIDKeyDigest, log)
} else { } else {
v.validator = trustedlaunch.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = trustedlaunch.NewValidator(v.pcrs, log)
} }
case cloudprovider.AWS: case cloudprovider.AWS:
v.validator = aws.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = aws.NewValidator(v.pcrs, log)
case cloudprovider.QEMU: case cloudprovider.QEMU:
v.validator = qemu.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = qemu.NewValidator(v.pcrs, log)
} }
} }
func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error {
if len(pcrs) == 0 {
return errors.New("no PCR values provided")
}
for k, v := range pcrs {
if len(v) != 32 {
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
}
}
for _, v := range enforcedPCRs {
if _, ok := pcrs[v]; !ok {
return fmt.Errorf("bad config: PCR[%d] is enforced, but no expected measurement is provided", v)
}
}
return nil
}
// warnLogger implements logging of warnings for validators. // warnLogger implements logging of warnings for validators.
type warnLogger struct { type warnLogger struct {
cmd *cobra.Command cmd *cobra.Command

View file

@ -7,9 +7,9 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd package cloudcmd
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
@ -18,7 +18,6 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -27,12 +26,12 @@ import (
func TestNewValidator(t *testing.T) { func TestNewValidator(t *testing.T) {
testPCRs := measurements.M{ testPCRs := measurements.M{
0: measurements.PCRWithAllBytes(0x00), 0: measurements.WithAllBytes(0x00, false),
1: measurements.PCRWithAllBytes(0xFF), 1: measurements.WithAllBytes(0xFF, false),
2: measurements.PCRWithAllBytes(0x00), 2: measurements.WithAllBytes(0x00, false),
3: measurements.PCRWithAllBytes(0xFF), 3: measurements.WithAllBytes(0xFF, false),
4: measurements.PCRWithAllBytes(0x00), 4: measurements.WithAllBytes(0x00, false),
5: measurements.PCRWithAllBytes(0x00), 5: measurements.WithAllBytes(0x00, false),
} }
testCases := map[string]struct { testCases := map[string]struct {
@ -67,13 +66,6 @@ func TestNewValidator(t *testing.T) {
pcrs: measurements.M{}, pcrs: measurements.M{},
wantErr: true, wantErr: true,
}, },
"invalid pcr length": {
provider: cloudprovider.GCP,
pcrs: measurements.M{
0: bytes.Repeat([]byte{0x00}, 31),
},
wantErr: true,
},
"unknown provider": { "unknown provider": {
provider: cloudprovider.Unknown, provider: cloudprovider.Unknown,
pcrs: testPCRs, pcrs: testPCRs,
@ -126,19 +118,19 @@ func TestNewValidator(t *testing.T) {
func TestValidatorV(t *testing.T) { func TestValidatorV(t *testing.T) {
newTestPCRs := func() measurements.M { newTestPCRs := func() measurements.M {
return measurements.M{ return measurements.M{
0: measurements.PCRWithAllBytes(0x00), 0: measurements.WithAllBytes(0x00, true),
1: measurements.PCRWithAllBytes(0x00), 1: measurements.WithAllBytes(0x00, true),
2: measurements.PCRWithAllBytes(0x00), 2: measurements.WithAllBytes(0x00, true),
3: measurements.PCRWithAllBytes(0x00), 3: measurements.WithAllBytes(0x00, true),
4: measurements.PCRWithAllBytes(0x00), 4: measurements.WithAllBytes(0x00, true),
5: measurements.PCRWithAllBytes(0x00), 5: measurements.WithAllBytes(0x00, true),
6: measurements.PCRWithAllBytes(0x00), 6: measurements.WithAllBytes(0x00, true),
7: measurements.PCRWithAllBytes(0x00), 7: measurements.WithAllBytes(0x00, true),
8: measurements.PCRWithAllBytes(0x00), 8: measurements.WithAllBytes(0x00, true),
9: measurements.PCRWithAllBytes(0x00), 9: measurements.WithAllBytes(0x00, true),
10: measurements.PCRWithAllBytes(0x00), 10: measurements.WithAllBytes(0x00, true),
11: measurements.PCRWithAllBytes(0x00), 11: measurements.WithAllBytes(0x00, true),
12: measurements.PCRWithAllBytes(0x00), 12: measurements.WithAllBytes(0x00, true),
} }
} }
@ -151,23 +143,23 @@ func TestValidatorV(t *testing.T) {
"gcp": { "gcp": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: gcp.NewValidator(newTestPCRs(), nil, nil), wantVs: gcp.NewValidator(newTestPCRs(), nil),
}, },
"azure cvm": { "azure cvm": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: snp.NewValidator(newTestPCRs(), nil, nil, false, nil), wantVs: snp.NewValidator(newTestPCRs(), nil, false, nil),
azureCVM: true, azureCVM: true,
}, },
"azure trusted launch": { "azure trusted launch": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: trustedlaunch.NewValidator(newTestPCRs(), nil, nil), wantVs: trustedlaunch.NewValidator(newTestPCRs(), nil),
}, },
"qemu": { "qemu": {
provider: cloudprovider.QEMU, provider: cloudprovider.QEMU,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: qemu.NewValidator(newTestPCRs(), nil, nil), wantVs: qemu.NewValidator(newTestPCRs(), nil),
}, },
} }
@ -185,37 +177,37 @@ func TestValidatorV(t *testing.T) {
} }
func TestValidatorUpdateInitPCRs(t *testing.T) { func TestValidatorUpdateInitPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000") zero := measurements.WithAllBytes(0x00, true)
one := []byte("11111111111111111111111111111111") one := measurements.WithAllBytes(0x11, true)
one64 := base64.StdEncoding.EncodeToString(one) one64 := base64.StdEncoding.EncodeToString(one.Expected[:])
oneHash := sha256.Sum256(one) oneHash := sha256.Sum256(one.Expected[:])
pcrZeroUpdatedOne := sha256.Sum256(append(zero, oneHash[:]...)) pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...))
newTestPCRs := func() map[uint32][]byte { newTestPCRs := func() measurements.M {
return map[uint32][]byte{ return measurements.M{
0: zero, 0: measurements.WithAllBytes(0x00, true),
1: zero, 1: measurements.WithAllBytes(0x00, true),
2: zero, 2: measurements.WithAllBytes(0x00, true),
3: zero, 3: measurements.WithAllBytes(0x00, true),
4: zero, 4: measurements.WithAllBytes(0x00, true),
5: zero, 5: measurements.WithAllBytes(0x00, true),
6: zero, 6: measurements.WithAllBytes(0x00, true),
7: zero, 7: measurements.WithAllBytes(0x00, true),
8: zero, 8: measurements.WithAllBytes(0x00, true),
9: zero, 9: measurements.WithAllBytes(0x00, true),
10: zero, 10: measurements.WithAllBytes(0x00, true),
11: zero, 11: measurements.WithAllBytes(0x00, true),
12: zero, 12: measurements.WithAllBytes(0x00, true),
13: zero, 13: measurements.WithAllBytes(0x00, true),
14: zero, 14: measurements.WithAllBytes(0x00, true),
15: zero, 15: measurements.WithAllBytes(0x00, true),
16: zero, 16: measurements.WithAllBytes(0x00, true),
17: one, 17: measurements.WithAllBytes(0x11, true),
18: one, 18: measurements.WithAllBytes(0x11, true),
19: one, 19: measurements.WithAllBytes(0x11, true),
20: one, 20: measurements.WithAllBytes(0x11, true),
21: one, 21: measurements.WithAllBytes(0x11, true),
22: one, 22: measurements.WithAllBytes(0x11, true),
23: zero, 23: measurements.WithAllBytes(0x00, true),
} }
} }
@ -285,25 +277,25 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
assert.NoError(err) assert.NoError(err)
for i := 0; i < len(tc.pcrs); i++ { for i := 0; i < len(tc.pcrs); i++ {
switch { switch {
case i == int(vtpm.PCRIndexClusterID) && tc.clusterID == "": case i == int(measurements.PCRIndexClusterID) && tc.clusterID == "":
// should be deleted // should be deleted
_, ok := validators.pcrs[uint32(i)] _, ok := validators.pcrs[uint32(i)]
assert.False(ok) assert.False(ok)
case i == int(vtpm.PCRIndexClusterID): case i == int(measurements.PCRIndexClusterID):
pcr, ok := validators.pcrs[uint32(i)] pcr, ok := validators.pcrs[uint32(i)]
assert.True(ok) assert.True(ok)
assert.Equal(pcrZeroUpdatedOne[:], pcr) assert.Equal(pcrZeroUpdatedOne, pcr.Expected)
case i == int(vtpm.PCRIndexOwnerID) && tc.ownerID == "": case i == int(measurements.PCRIndexOwnerID) && tc.ownerID == "":
// should be deleted // should be deleted
_, ok := validators.pcrs[uint32(i)] _, ok := validators.pcrs[uint32(i)]
assert.False(ok) assert.False(ok)
case i == int(vtpm.PCRIndexOwnerID): case i == int(measurements.PCRIndexOwnerID):
pcr, ok := validators.pcrs[uint32(i)] pcr, ok := validators.pcrs[uint32(i)]
assert.True(ok) assert.True(ok)
assert.Equal(pcrZeroUpdatedOne[:], pcr) assert.Equal(pcrZeroUpdatedOne, pcr.Expected)
default: default:
if i >= 17 && i <= 22 { if i >= 17 && i <= 22 {
@ -320,8 +312,8 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
func TestUpdatePCR(t *testing.T) { func TestUpdatePCR(t *testing.T) {
emptyMap := measurements.M{} emptyMap := measurements.M{}
defaultMap := measurements.M{ defaultMap := measurements.M{
0: measurements.PCRWithAllBytes(0xAA), 0: measurements.WithAllBytes(0xAA, false),
1: measurements.PCRWithAllBytes(0xBB), 1: measurements.WithAllBytes(0xBB, false),
} }
testCases := map[string]struct { testCases := map[string]struct {
@ -359,6 +351,20 @@ func TestUpdatePCR(t *testing.T) {
wantEntries: len(defaultMap) + 1, wantEntries: len(defaultMap) + 1,
wantErr: false, wantErr: false,
}, },
"hex input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: hex.EncodeToString([]byte("Constellation")),
wantEntries: 1,
wantErr: false,
},
"hex input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: hex.EncodeToString([]byte("Constellation")),
wantEntries: len(defaultMap) + 1,
wantErr: false,
},
"unencoded input, empty map": { "unencoded input, empty map": {
pcrMap: emptyMap, pcrMap: emptyMap,
pcrIndex: 10, pcrIndex: 10,
@ -403,9 +409,6 @@ func TestUpdatePCR(t *testing.T) {
assert.NoError(err) assert.NoError(err)
} }
assert.Len(pcrs, tc.wantEntries) assert.Len(pcrs, tc.wantEntries)
for _, v := range pcrs {
assert.Len(v, 32)
}
}) })
} }
} }

View file

@ -137,7 +137,7 @@ func (f *fetchMeasurementsFlags) updateURLs(ctx context.Context, conf *config.Co
if f.measurementsURL == nil { if f.measurementsURL == nil {
// TODO(AB#2644): resolve image version to reference // TODO(AB#2644): resolve image version to reference
parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.yaml") parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.json")
if err != nil { if err != nil {
return err return err
} }
@ -145,7 +145,7 @@ func (f *fetchMeasurementsFlags) updateURLs(ctx context.Context, conf *config.Co
} }
if f.signatureURL == nil { if f.signatureURL == nil {
parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.yaml.sig") parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.json.sig")
if err != nil { if err != nil {
return err return err
} }

View file

@ -109,17 +109,17 @@ func TestUpdateURLs(t *testing.T) {
}, },
}, },
flags: &fetchMeasurementsFlags{}, flags: &fetchMeasurementsFlags{},
wantMeasurementsURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.yaml", wantMeasurementsURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.json",
wantMeasurementsSigURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.yaml.sig", wantMeasurementsSigURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.json.sig",
}, },
"both set by user": { "both set by user": {
conf: &config.Config{}, conf: &config.Config{},
flags: &fetchMeasurementsFlags{ flags: &fetchMeasurementsFlags{
measurementsURL: urlMustParse("get.my/measurements.yaml"), measurementsURL: urlMustParse("get.my/measurements.json"),
signatureURL: urlMustParse("get.my/measurements.yaml.sig"), signatureURL: urlMustParse("get.my/measurements.json.sig"),
}, },
wantMeasurementsURL: "get.my/measurements.yaml", wantMeasurementsURL: "get.my/measurements.json",
wantMeasurementsSigURL: "get.my/measurements.yaml.sig", wantMeasurementsSigURL: "get.my/measurements.json.sig",
}, },
} }
@ -164,14 +164,14 @@ func TestConfigFetchMeasurements(t *testing.T) {
signature := "MEUCIFdJ5dH6HDywxQWTUh9Bw77wMrq0mNCUjMQGYP+6QsVmAiEAmazj/L7rFGA4/Gz8y+kI5h5E5cDgc3brihvXBKF6qZA=" signature := "MEUCIFdJ5dH6HDywxQWTUh9Bw77wMrq0mNCUjMQGYP+6QsVmAiEAmazj/L7rFGA4/Gz8y+kI5h5E5cDgc3brihvXBKF6qZA="
client := newTestClient(func(req *http.Request) *http.Response { client := newTestClient(func(req *http.Request) *http.Response {
if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.yaml" { if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.json" {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(measurements)), Body: io.NopCloser(bytes.NewBufferString(measurements)),
Header: make(http.Header), Header: make(http.Header),
} }
} }
if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.yaml.sig" { if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.json.sig" {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(signature)), Body: io.NopCloser(bytes.NewBufferString(signature)),

View file

@ -8,7 +8,7 @@ package cmd
import ( import (
"context" "context"
"encoding/base64" "encoding/hex"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -134,7 +134,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
KubernetesVersion: conf.KubernetesVersion, KubernetesVersion: conf.KubernetesVersion,
KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToProto(), KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToProto(),
HelmDeployments: helmDeployments, HelmDeployments: helmDeployments,
EnforcedPcrs: conf.GetEnforcedPCRs(), EnforcedPcrs: conf.EnforcedPCRs(),
EnforceIdkeydigest: conf.EnforcesIDKeyDigest(), EnforceIdkeydigest: conf.EnforcesIDKeyDigest(),
ConformanceMode: flags.conformance, ConformanceMode: flags.conformance,
} }
@ -190,8 +190,8 @@ func (d *initDoer) Do(ctx context.Context) error {
func writeOutput(idFile clusterid.File, resp *initproto.InitResponse, wr io.Writer, fileHandler file.Handler) error { func writeOutput(idFile clusterid.File, resp *initproto.InitResponse, wr io.Writer, fileHandler file.Handler) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n") fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
ownerID := base64.StdEncoding.EncodeToString(resp.OwnerId) ownerID := hex.EncodeToString(resp.OwnerId)
clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) clusterID := hex.EncodeToString(resp.ClusterId)
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0) tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
// writeRow(tw, "Constellation cluster's owner identifier", ownerID) // writeRow(tw, "Constellation cluster's owner identifier", ownerID)

View file

@ -9,7 +9,7 @@ package cmd
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -101,16 +101,6 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{initErr: someErr}, initServerAPI: &stubInitServer{initErr: someErr},
wantErr: true, wantErr: true,
}, },
"fail missing enforced PCR": {
provider: cloudprovider.GCP,
idFile: &clusterid.File{IP: "192.0.2.1"},
configMutator: func(c *config.Config) {
c.Provider.GCP.EnforcedMeasurements = append(c.Provider.GCP.EnforcedMeasurements, 10)
},
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp},
wantErr: true,
},
} }
for name, tc := range testCases { for name, tc := range testCases {
@ -174,7 +164,7 @@ func TestInitialize(t *testing.T) {
} }
require.NoError(err) require.NoError(err)
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) // assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("clusterID"))) assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
var secret masterSecret var secret masterSecret
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret)) assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
assert.NotEmpty(secret.Key) assert.NotEmpty(secret.Key)
@ -192,8 +182,8 @@ func TestWriteOutput(t *testing.T) {
Kubeconfig: []byte("kubeconfig"), Kubeconfig: []byte("kubeconfig"),
} }
ownerID := base64.StdEncoding.EncodeToString(resp.OwnerId) ownerID := hex.EncodeToString(resp.OwnerId)
clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) clusterID := hex.EncodeToString(resp.ClusterId)
expectedIDFile := clusterid.File{ expectedIDFile := clusterid.File{
ClusterID: clusterID, ClusterID: clusterID,
@ -361,11 +351,11 @@ func TestAttestation(t *testing.T) {
issuer := &testIssuer{ issuer := &testIssuer{
Getter: oid.QEMU{}, Getter: oid.QEMU{},
pcrs: measurements.M{ pcrs: map[uint32][]byte{
0: measurements.PCRWithAllBytes(0xFF), 0: bytes.Repeat([]byte{0xFF}, 32),
1: measurements.PCRWithAllBytes(0xFF), 1: bytes.Repeat([]byte{0xFF}, 32),
2: measurements.PCRWithAllBytes(0xFF), 2: bytes.Repeat([]byte{0xFF}, 32),
3: measurements.PCRWithAllBytes(0xFF), 3: bytes.Repeat([]byte{0xFF}, 32),
}, },
} }
serverCreds := atlscredentials.New(issuer, nil) serverCreds := atlscredentials.New(issuer, nil)
@ -390,13 +380,13 @@ func TestAttestation(t *testing.T) {
cfg := config.Default() cfg := config.Default()
cfg.Image = "image" cfg.Image = "image"
cfg.RemoveProviderExcept(cloudprovider.QEMU) cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00) cfg.Provider.QEMU.Measurements[0] = measurements.WithAllBytes(0x00, false)
cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11) cfg.Provider.QEMU.Measurements[1] = measurements.WithAllBytes(0x11, false)
cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22) cfg.Provider.QEMU.Measurements[2] = measurements.WithAllBytes(0x22, false)
cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33) cfg.Provider.QEMU.Measurements[3] = measurements.WithAllBytes(0x33, false)
cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44) cfg.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false)
cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99) cfg.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x99, false)
cfg.Provider.QEMU.Measurements[12] = measurements.PCRWithAllBytes(0xcc) cfg.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone)) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background() ctx := context.Background()
@ -418,14 +408,14 @@ type testValidator struct {
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var attestation struct { var attestation struct {
UserData []byte UserData []byte
PCRs measurements.M PCRs map[uint32][]byte
} }
if err := json.Unmarshal(attDoc, &attestation); err != nil { if err := json.Unmarshal(attDoc, &attestation); err != nil {
return nil, err return nil, err
} }
for k, pcr := range v.pcrs { for k, pcr := range v.pcrs {
if !bytes.Equal(attestation.PCRs[k], pcr) { if !bytes.Equal(attestation.PCRs[k], pcr.Expected[:]) {
return nil, errors.New("invalid PCR value") return nil, errors.New("invalid PCR value")
} }
} }
@ -434,14 +424,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
type testIssuer struct { type testIssuer struct {
oid.Getter oid.Getter
pcrs measurements.M pcrs map[uint32][]byte
} }
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal( return json.Marshal(
struct { struct {
UserData []byte UserData []byte
PCRs measurements.M PCRs map[uint32][]byte
}{ }{
UserData: userData, UserData: userData,
PCRs: i.pcrs, PCRs: i.pcrs,
@ -474,21 +464,21 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.ResourceGroup = "test-resource-group" conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab" conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
conf.Provider.Azure.ClientSecretValue = "test-client-secret" conf.Provider.Azure.ClientSecretValue = "test-client-secret"
conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44) conf.Provider.Azure.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11) conf.Provider.Azure.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.Azure.Measurements[12] = measurements.PCRWithAllBytes(0xcc) conf.Provider.Azure.Measurements[12] = measurements.WithAllBytes(0xcc, false)
case cloudprovider.GCP: case cloudprovider.GCP:
conf.Provider.GCP.Region = "test-region" conf.Provider.GCP.Region = "test-region"
conf.Provider.GCP.Project = "test-project" conf.Provider.GCP.Project = "test-project"
conf.Provider.GCP.Zone = "test-zone" conf.Provider.GCP.Zone = "test-zone"
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path" conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44) conf.Provider.GCP.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11) conf.Provider.GCP.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.GCP.Measurements[12] = measurements.PCRWithAllBytes(0xcc) conf.Provider.GCP.Measurements[12] = measurements.WithAllBytes(0xcc, false)
case cloudprovider.QEMU: case cloudprovider.QEMU:
conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44) conf.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11) conf.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.QEMU.Measurements[12] = measurements.PCRWithAllBytes(0xcc) conf.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false)
} }
conf.RemoveProviderExcept(csp) conf.RemoveProviderExcept(csp)

View file

@ -181,12 +181,12 @@ func getCompatibleImages(csp cloudprovider.Provider, currentVersion string, imag
// getCompatibleImageMeasurements retrieves the expected measurements for each image. // getCompatibleImageMeasurements retrieves the expected measurements for each image.
func getCompatibleImageMeasurements(ctx context.Context, cmd *cobra.Command, client *http.Client, rekor rekorVerifier, pubK []byte, images map[string]config.UpgradeConfig) error { func getCompatibleImageMeasurements(ctx context.Context, cmd *cobra.Command, client *http.Client, rekor rekorVerifier, pubK []byte, images map[string]config.UpgradeConfig) error {
for idx, img := range images { for idx, img := range images {
measurementsURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.yaml") measurementsURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.json")
if err != nil { if err != nil {
return err return err
} }
signatureURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.yaml.sig") signatureURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.json.sig")
if err != nil { if err != nil {
return err return err
} }

View file

@ -25,7 +25,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
) )
@ -248,14 +248,14 @@ func TestGetCompatibleImageMeasurements(t *testing.T) {
} }
client := newTestClient(func(req *http.Request) *http.Response { client := newTestClient(func(req *http.Request) *http.Response {
if strings.HasSuffix(req.URL.String(), "/measurements.yaml") { if strings.HasSuffix(req.URL.String(), "/measurements.json") {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")), Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")),
Header: make(http.Header), Header: make(http.Header),
} }
} }
if strings.HasSuffix(req.URL.String(), "/measurements.yaml.sig") { if strings.HasSuffix(req.URL.String(), "/measurements.json.sig") {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")), Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")),
@ -470,14 +470,14 @@ func TestUpgradePlan(t *testing.T) {
Header: make(http.Header), Header: make(http.Header),
} }
} }
if strings.HasSuffix(req.URL.String(), "/measurements.yaml") { if strings.HasSuffix(req.URL.String(), "/measurements.json") {
return &http.Response{ return &http.Response{
StatusCode: tc.measurementsFetchStatus, StatusCode: tc.measurementsFetchStatus,
Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")), Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")),
Header: make(http.Header), Header: make(http.Header),
} }
} }
if strings.HasSuffix(req.URL.String(), "/measurements.yaml.sig") { if strings.HasSuffix(req.URL.String(), "/measurements.json.sig") {
return &http.Response{ return &http.Response{
StatusCode: tc.measurementsFetchStatus, StatusCode: tc.measurementsFetchStatus,
Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")), Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")),

View file

@ -5,7 +5,6 @@ metadata:
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
data: data:
# mustToJson is required so the json-strings passed from go are of type string in the rendered yaml. # mustToJson is required so the json-strings passed from go are of type string in the rendered yaml.
enforcedPCRs: {{ .Values.enforcedPCRs | mustToJson }}
measurements: {{ .Values.measurements | mustToJson }} measurements: {{ .Values.measurements | mustToJson }}
{{- if eq .Values.csp "Azure" }} {{- if eq .Values.csp "Azure" }}
# ConfigMap.data is of type map[string]string. quote will not quote a quoted string. # ConfigMap.data is of type map[string]string. quote will not quote a quoted string.

View file

@ -5,15 +5,10 @@
"description": "CSP to which the chart is deployed.", "description": "CSP to which the chart is deployed.",
"enum": ["Azure", "GCP", "AWS", "QEMU"] "enum": ["Azure", "GCP", "AWS", "QEMU"]
}, },
"enforcedPCRs": {
"description": "JSON-string to describe the enforced PCRs.",
"type": "string",
"examples": ["[1, 15]"]
},
"measurements": { "measurements": {
"description": "JSON-string to describe the expected measurements.", "description": "JSON-string to describe the expected measurements.",
"type": "string", "type": "string",
"examples": ["{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"] "examples": ["{'1':{'expected':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','warnOnly':true},'15':{'expected':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=','warnOnly':true}}"]
}, },
"enforceIdKeyDigest": { "enforceIdKeyDigest": {
"description": "Whether or not idkeydigest should be enforced during attestation on azure.", "description": "Whether or not idkeydigest should be enforced during attestation on azure.",
@ -37,7 +32,6 @@
}, },
"required": [ "required": [
"csp", "csp",
"enforcedPCRs",
"measurements", "measurements",
"measurementSalt", "measurementSalt",
"image" "image"

View file

@ -76,9 +76,7 @@ func New(csp cloudprovider.Provider, k8sVersion versions.ValidK8sVersion) *Chart
// Load the embedded helm charts. // Load the embedded helm charts.
func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSecret, salt []byte) ([]byte, error) { func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSecret, salt []byte) ([]byte, error) {
csp := config.GetProvider() ciliumRelease, err := i.loadCilium(config.GetProvider(), conformanceMode)
ciliumRelease, err := i.loadCilium(csp, conformanceMode)
if err != nil { if err != nil {
return nil, fmt.Errorf("loading cilium: %w", err) return nil, fmt.Errorf("loading cilium: %w", err)
} }
@ -88,7 +86,7 @@ func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSe
return nil, fmt.Errorf("loading cilium: %w", err) return nil, fmt.Errorf("loading cilium: %w", err)
} }
operatorRelease, err := i.loadOperators(csp) operatorRelease, err := i.loadOperators(config.GetProvider())
if err != nil { if err != nil {
return nil, fmt.Errorf("loading operators: %w", err) return nil, fmt.Errorf("loading operators: %w", err)
} }
@ -350,11 +348,6 @@ func (i *ChartLoader) loadConstellationServicesHelper(config *config.Config, mas
return nil, nil, fmt.Errorf("loading constellation-services chart: %w", err) return nil, nil, fmt.Errorf("loading constellation-services chart: %w", err)
} }
enforcedPCRsJSON, err := json.Marshal(config.GetEnforcedPCRs())
if err != nil {
return nil, nil, fmt.Errorf("marshaling enforcedPCRs: %w", err)
}
csp := config.GetProvider() csp := config.GetProvider()
values := map[string]any{ values := map[string]any{
"global": map[string]any{ "global": map[string]any{
@ -374,9 +367,8 @@ func (i *ChartLoader) loadConstellationServicesHelper(config *config.Config, mas
"measurementsFilename": constants.MeasurementsFilename, "measurementsFilename": constants.MeasurementsFilename,
}, },
"join-service": map[string]any{ "join-service": map[string]any{
"csp": csp.String(), "csp": csp.String(),
"enforcedPCRs": string(enforcedPCRsJSON), "image": i.joinServiceImage,
"image": i.joinServiceImage,
}, },
"ccm": map[string]any{ "ccm": map[string]any{
"csp": csp.String(), "csp": csp.String(),

View file

@ -15,6 +15,7 @@ import (
"path" "path"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/deploy/helm" "github.com/edgelesssys/constellation/v2/internal/deploy/helm"
@ -56,8 +57,7 @@ func TestConstellationServices(t *testing.T) {
}{ }{
"GCP": { "GCP": {
config: &config.Config{Provider: config.ProviderConfig{GCP: &config.GCPConfig{ config: &config.Config{Provider: config.ProviderConfig{GCP: &config.GCPConfig{
DeployCSIDriver: func() *bool { b := true; return &b }(), DeployCSIDriver: func() *bool { b := true; return &b }(),
EnforcedMeasurements: []uint32{1, 11},
}}}, }}},
enforceIDKeyDigest: false, enforceIDKeyDigest: false,
valuesModifier: prepareGCPValues, valuesModifier: prepareGCPValues,
@ -65,9 +65,8 @@ func TestConstellationServices(t *testing.T) {
}, },
"Azure": { "Azure": {
config: &config.Config{Provider: config.ProviderConfig{Azure: &config.AzureConfig{ config: &config.Config{Provider: config.ProviderConfig{Azure: &config.AzureConfig{
DeployCSIDriver: func() *bool { b := true; return &b }(), DeployCSIDriver: func() *bool { b := true; return &b }(),
EnforcedMeasurements: []uint32{1, 11}, EnforceIDKeyDigest: func() *bool { b := true; return &b }(),
EnforceIDKeyDigest: func() *bool { b := true; return &b }(),
}}}, }}},
enforceIDKeyDigest: true, enforceIDKeyDigest: true,
valuesModifier: prepareAzureValues, valuesModifier: prepareAzureValues,
@ -75,9 +74,7 @@ func TestConstellationServices(t *testing.T) {
cnmImage: "cnmImageForAzure", cnmImage: "cnmImageForAzure",
}, },
"QEMU": { "QEMU": {
config: &config.Config{Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{ config: &config.Config{Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}}},
EnforcedMeasurements: []uint32{1, 11},
}}},
enforceIDKeyDigest: false, enforceIDKeyDigest: false,
valuesModifier: prepareQEMUValues, valuesModifier: prepareQEMUValues,
}, },
@ -88,7 +85,14 @@ func TestConstellationServices(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
chartLoader := ChartLoader{joinServiceImage: "joinServiceImage", kmsImage: "kmsImage", ccmImage: tc.ccmImage, cnmImage: tc.cnmImage, autoscalerImage: "autoscalerImage", verificationServiceImage: "verificationImage"} chartLoader := ChartLoader{
joinServiceImage: "joinServiceImage",
kmsImage: "kmsImage",
ccmImage: tc.ccmImage,
cnmImage: tc.cnmImage,
autoscalerImage: "autoscalerImage",
verificationServiceImage: "verificationImage",
}
chart, values, err := chartLoader.loadConstellationServicesHelper(tc.config, []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) chart, values, err := chartLoader.loadConstellationServicesHelper(tc.config, []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"))
require.NoError(err) require.NoError(err)
@ -197,7 +201,15 @@ func prepareGCPValues(values map[string]any) error {
if !ok { if !ok {
return errors.New("missing 'join-service' key") return errors.New("missing 'join-service' key")
} }
joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
m := measurements.M{
1: measurements.WithAllBytes(0xAA, false),
}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
joinVals["measurements"] = string(mJSON)
joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
ccmVals, ok := values["ccm"].(map[string]any) ccmVals, ok := values["ccm"].(map[string]any)
@ -269,7 +281,12 @@ func prepareAzureValues(values map[string]any) error {
return errors.New("missing 'join-service' key") return errors.New("missing 'join-service' key")
} }
joinVals["idkeydigest"] = "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad" joinVals["idkeydigest"] = "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad"
joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" m := measurements.M{1: measurements.WithAllBytes(0xAA, false)}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
joinVals["measurements"] = string(mJSON)
joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
ccmVals, ok := values["ccm"].(map[string]any) ccmVals, ok := values["ccm"].(map[string]any)
@ -311,7 +328,12 @@ func prepareQEMUValues(values map[string]any) error {
if !ok { if !ok {
return errors.New("missing 'join-service' key") return errors.New("missing 'join-service' key")
} }
joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" m := measurements.M{1: measurements.WithAllBytes(0xAA, false)}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
joinVals["measurements"] = string(mJSON)
joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
verificationVals, ok := values["verification-service"].(map[string]any) verificationVals, ok := values["verification-service"].(map[string]any)

View file

@ -4,8 +4,7 @@ metadata:
name: join-config name: join-config
namespace: testNamespace namespace: testNamespace
data: data:
enforcedPCRs: "[1,11]" measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}"
measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
enforceIdKeyDigest: "true" enforceIdKeyDigest: "true"
idkeydigest: "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad" idkeydigest: "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad"
binaryData: binaryData:

View file

@ -4,7 +4,6 @@ metadata:
name: join-config name: join-config
namespace: testNamespace namespace: testNamespace
data: data:
enforcedPCRs: "[1,11]" measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}"
measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
binaryData: binaryData:
measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA

View file

@ -4,7 +4,6 @@ metadata:
name: join-config name: join-config
namespace: testNamespace namespace: testNamespace
data: data:
enforcedPCRs: "[1,11]" measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}"
measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
binaryData: binaryData:
measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA

View file

@ -171,7 +171,7 @@ func main() {
// We can use this to calculate the PCRs of the image locally. // We can use this to calculate the PCRs of the image locally.
func exportPCRs() error { func exportPCRs() error {
// get TPM state // get TPM state
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, tpmClient.FullPcrSel(tpm2.AlgSHA256)) pcrs, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, tpmClient.FullPcrSel(tpm2.AlgSHA256))
if err != nil { if err != nil {
return err return err
} }

2
go.mod
View file

@ -92,7 +92,6 @@ require (
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6
google.golang.org/grpc v1.51.0 google.golang.org/grpc v1.51.0
google.golang.org/protobuf v1.28.1 google.golang.org/protobuf v1.28.1
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
helm.sh/helm v2.17.0+incompatible helm.sh/helm v2.17.0+incompatible
helm.sh/helm/v3 v3.10.2 helm.sh/helm/v3 v3.10.2
@ -119,6 +118,7 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect
golang.org/x/text v0.4.0 // indirect golang.org/x/text v0.4.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
) )
require ( require (

View file

@ -41,7 +41,6 @@ require (
github.com/go-git/go-git/v5 v5.4.2 github.com/go-git/go-git/v5 v5.4.2
github.com/google/go-tpm-tools v0.3.9 github.com/google/go-tpm-tools v0.3.9
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/spf13/afero v1.9.3
github.com/spf13/cobra v1.6.1 github.com/spf13/cobra v1.6.1
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
go.uber.org/goleak v1.2.0 go.uber.org/goleak v1.2.0
@ -189,6 +188,7 @@ require (
github.com/sigstore/rekor v1.0.1 // indirect github.com/sigstore/rekor v1.0.1 // indirect
github.com/sigstore/sigstore v1.4.5 // indirect github.com/sigstore/sigstore v1.4.5 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect github.com/sirupsen/logrus v1.9.0 // indirect
github.com/spf13/afero v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/talos-systems/talos/pkg/machinery v1.2.7 // indirect github.com/talos-systems/talos/pkg/machinery v1.2.7 // indirect
github.com/tent/canonical-json-go v0.0.0-20130607151641-96e4ba3a7613 // indirect github.com/tent/canonical-json-go v0.0.0-20130607151641-96e4ba3a7613 // indirect

View file

@ -24,24 +24,24 @@ import (
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/verify/verifyproto" "github.com/edgelesssys/constellation/v2/verify/verifyproto"
"github.com/spf13/afero"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
var (
coordIP = flag.String("constell-ip", "", "Public IP of the Constellation")
port = flag.String("constell-port", strconv.Itoa(constants.VerifyServiceNodePortGRPC), "NodePort of the Constellation's verification service")
export = flag.String("o", "", "Write PCRs, formatted as Go code, to file")
format = flag.String("format", "json", "Output format: json, yaml (default json)")
quiet = flag.Bool("q", false, "Set to disable output")
timeout = flag.Duration("timeout", 2*time.Minute, "Wait this duration for the verification service to become available")
)
func main() { func main() {
coordIP := flag.String("constell-ip", "", "Public IP of the Constellation")
port := flag.String("constell-port", strconv.Itoa(constants.VerifyServiceNodePortGRPC), "NodePort of the Constellation's verification service")
format := flag.String("format", "json", "Output format: json, yaml (default json)")
quiet := flag.Bool("q", false, "Set to disable output")
timeout := flag.Duration("timeout", 2*time.Minute, "Wait this duration for the verification service to become available")
flag.Parse() flag.Parse()
if *coordIP == "" || *port == "" {
flag.Usage()
os.Exit(1)
}
addr := net.JoinHostPort(*coordIP, *port) addr := net.JoinHostPort(*coordIP, *port)
ctx, cancel := context.WithTimeout(context.Background(), *timeout) ctx, cancel := context.WithTimeout(context.Background(), *timeout)
defer cancel() defer cancel()
@ -51,18 +51,13 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
pcrs, err := validatePCRAttDoc(attDocRaw) measurements, err := validatePCRAttDoc(attDocRaw)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if !*quiet { if !*quiet {
if err := printPCRs(os.Stdout, pcrs, *format); err != nil { if err := printPCRs(os.Stdout, measurements, *format); err != nil {
log.Fatal(err)
}
}
if *export != "" {
if err := exportToFile(*export, pcrs, &afero.Afero{Fs: afero.NewOsFs()}); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
@ -104,16 +99,23 @@ func validatePCRAttDoc(attDocRaw []byte) (measurements.M, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := measurements.M{}
for idx, pcr := range attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs { for idx, pcr := range attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs {
if len(pcr) != 32 { if len(pcr) != 32 {
return nil, fmt.Errorf("incomplete PCR at index: %d", idx) return nil, fmt.Errorf("incomplete PCR at index: %d", idx)
} }
m[idx] = measurements.Measurement{
Expected: *(*[32]byte)(pcr),
WarnOnly: true,
}
} }
return attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, nil return m, nil
} }
// printPCRs formates and prints PCRs to the given writer. // printPCRs formats and prints PCRs to the given writer.
// format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'. // format can be one of 'json' or 'yaml'. If it doesn't match defaults to 'json'.
func printPCRs(w io.Writer, pcrs measurements.M, format string) error { func printPCRs(w io.Writer, pcrs measurements.M, format string) error {
switch format { switch format {
case "json": case "json":
@ -142,24 +144,3 @@ func printPCRsJSON(w io.Writer, pcrs measurements.M) error {
fmt.Fprintf(w, "%s", string(pcrJSON)) fmt.Fprintf(w, "%s", string(pcrJSON))
return nil return nil
} }
// exportToFile writes pcrs to a file, formatted to be valid Go code.
// Validity of the PCR map is not checked, and should be handled by the caller.
func exportToFile(path string, pcrs measurements.M, fs *afero.Afero) error {
goCode := `package pcrs
var pcrs = map[uint32][]byte{%s
}
`
pcrsFormatted := ""
for i := 0; i < len(pcrs); i++ {
pcrHex := fmt.Sprintf("%#02X", pcrs[uint32(i)][0])
for j := 1; j < len(pcrs[uint32(i)]); j++ {
pcrHex = fmt.Sprintf("%s, %#02X", pcrHex, pcrs[uint32(i)][j])
}
pcrsFormatted = pcrsFormatted + fmt.Sprintf("\n\t%d: {%s},", i, pcrHex)
}
return fs.WriteFile(path, []byte(fmt.Sprintf(goCode, pcrsFormatted)), 0o644)
}

View file

@ -8,7 +8,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"testing" "testing"
@ -17,67 +17,13 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm" "github.com/google/go-tpm-tools/proto/tpm"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.VerifyTestMain(m)
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestExportToFile(t *testing.T) {
testCases := map[string]struct {
pcrs measurements.M
fs *afero.Afero
wantErr bool
}{
"file not writeable": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
fs: &afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())},
wantErr: true,
},
"file writeable": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
wantErr: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
path := "test-file"
err := exportToFile(path, tc.pcrs, tc.fs)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
content, err := tc.fs.ReadFile(path)
require.NoError(err)
for _, pcr := range tc.pcrs {
for _, register := range pcr {
assert.Contains(string(content), fmt.Sprintf("%#02X", register))
}
}
}
})
}
} }
func TestValidatePCRAttDoc(t *testing.T) { func TestValidatePCRAttDoc(t *testing.T) {
@ -106,7 +52,7 @@ func TestValidatePCRAttDoc(t *testing.T) {
{ {
Pcrs: &tpm.PCRs{ Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256, Hash: tpm.HashAlgo_SHA256,
Pcrs: measurements.M{ Pcrs: map[uint32][]byte{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
}, },
}, },
@ -123,8 +69,8 @@ func TestValidatePCRAttDoc(t *testing.T) {
{ {
Pcrs: &tpm.PCRs{ Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256, Hash: tpm.HashAlgo_SHA256,
Pcrs: measurements.M{ Pcrs: map[uint32][]byte{
0: measurements.PCRWithAllBytes(0xAA), 0: bytes.Repeat([]byte{0xAA}, 32),
}, },
}, },
}, },
@ -150,7 +96,10 @@ func TestValidatePCRAttDoc(t *testing.T) {
require.NoError(json.Unmarshal(tc.attDocRaw, &attDoc)) require.NoError(json.Unmarshal(tc.attDocRaw, &attDoc))
qIdx, err := vtpm.GetSHA256QuoteIndex(attDoc.Attestation.Quotes) qIdx, err := vtpm.GetSHA256QuoteIndex(attDoc.Attestation.Quotes)
require.NoError(err) require.NoError(err)
assert.EqualValues(attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, pcrs)
for pcrIdx, pcrVal := range pcrs {
assert.Equal(pcrVal.Expected[:], attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs[pcrIdx])
}
} }
}) })
} }
@ -164,31 +113,15 @@ func mustMarshalAttDoc(t *testing.T, attDoc vtpm.AttestationDocument) []byte {
func TestPrintPCRs(t *testing.T) { func TestPrintPCRs(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
pcrs measurements.M
format string format string
}{ }{
"json": { "json": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
format: "json", format: "json",
}, },
"empty format": { "empty format": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
format: "", format: "",
}, },
"yaml": { "yaml": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
format: "yaml", format: "yaml",
}, },
} }
@ -197,13 +130,19 @@ func TestPrintPCRs(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
pcrs := measurements.M{
0: measurements.WithAllBytes(0xAA, true),
1: measurements.WithAllBytes(0xBB, true),
2: measurements.WithAllBytes(0xCC, true),
}
var out bytes.Buffer var out bytes.Buffer
err := printPCRs(&out, tc.pcrs, tc.format) err := printPCRs(&out, pcrs, tc.format)
assert.NoError(err) assert.NoError(err)
for idx, pcr := range tc.pcrs { for idx, pcr := range pcrs {
assert.Contains(out.String(), fmt.Sprintf("%d", idx)) assert.Contains(out.String(), fmt.Sprintf("%d", idx))
assert.Contains(out.String(), base64.StdEncoding.EncodeToString(pcr)) assert.Contains(out.String(), hex.EncodeToString(pcr.Expected[:]))
} }
}) })
} }

View file

@ -307,12 +307,12 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "192.0.100.1:1234", remoteAddr: "192.0.100.1:1234",
connect: defaultConnect, connect: defaultConnect,
method: http.MethodPost, method: http.MethodPost,
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
}, },
"incorrect method": { "incorrect method": {
remoteAddr: "192.0.100.1:1234", remoteAddr: "192.0.100.1:1234",
connect: defaultConnect, connect: defaultConnect,
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
method: http.MethodGet, method: http.MethodGet,
wantErr: true, wantErr: true,
}, },
@ -321,7 +321,7 @@ func TestExportPCRs(t *testing.T) {
connect: &stubConnect{ connect: &stubConnect{
getNetworkErr: errors.New("error"), getNetworkErr: errors.New("error"),
}, },
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
method: http.MethodPost, method: http.MethodPost,
wantErr: true, wantErr: true,
}, },
@ -336,7 +336,7 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "localhost", remoteAddr: "localhost",
connect: defaultConnect, connect: defaultConnect,
method: http.MethodPost, method: http.MethodPost,
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
wantErr: true, wantErr: true,
}, },
} }

View file

@ -29,11 +29,10 @@ type Validator struct {
} }
// NewValidator create a new Validator structure and returns it. // NewValidator create a new Validator structure and returns it.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
v := &Validator{} v := &Validator{}
v.Validator = vtpm.NewValidator( v.Validator = vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
getTrustedKey, getTrustedKey,
v.tpmEnabled, v.tpmEnabled,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -42,11 +42,10 @@ type Validator struct {
} }
// NewValidator initializes a new Azure validator with the provided PCR values. // NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
getTrustedKey(&azureInstanceInfo{}, idKeyDigest, enforceIDKeyDigest, log), getTrustedKey(&azureInstanceInfo{}, idKeyDigest, enforceIDKeyDigest, log),
validateCVM, validateCVM,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -189,7 +189,7 @@ func TestGetAttestationCert(t *testing.T) {
} }
require.NoError(err) require.NoError(err)
validator := NewValidator(measurements.M{}, []uint32{}, nil) validator := NewValidator(measurements.M{}, nil)
cert, err := x509.ParseCertificate(rootCert.Raw) cert, err := x509.ParseCertificate(rootCert.Raw)
require.NoError(err) require.NoError(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()

View file

@ -33,13 +33,12 @@ type Validator struct {
} }
// NewValidator initializes a new Azure validator with the provided PCR values. // NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
rootPool := x509.NewCertPool() rootPool := x509.NewCertPool()
rootPool.AddCert(ameRoot) rootPool.AddCert(ameRoot)
v := &Validator{roots: rootPool} v := &Validator{roots: rootPool}
v.Validator = vtpm.NewValidator( v.Validator = vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
v.verifyAttestationKey, v.verifyAttestationKey,
validateVM, validateVM,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -35,11 +35,10 @@ type Validator struct {
} }
// NewValidator initializes a new GCP validator with the provided PCR values. // NewValidator initializes a new GCP validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
trustedKeyFromGCEAPI(newInstanceClient), trustedKeyFromGCEAPI(newInstanceClient),
gceNonHostInfoEvent, gceNonHostInfoEvent,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -12,61 +12,31 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/sigstore" "github.com/edgelesssys/constellation/v2/internal/sigstore"
"gopkg.in/yaml.v2" "github.com/google/go-tpm/tpmutil"
"go.uber.org/multierr"
"gopkg.in/yaml.v3"
)
const (
// PCRIndexClusterID is a PCR we extend to mark the node as initialized.
// The value used to extend is a random generated 32 Byte value.
PCRIndexClusterID = tpmutil.Handle(15)
// PCRIndexOwnerID is a PCR we extend to mark the node as initialized.
// The value used to extend is derived from Constellation's master key.
// TODO: move to stable, non-debug PCR before use.
PCRIndexOwnerID = tpmutil.Handle(16)
) )
// M are Platform Configuration Register (PCR) values that make up the Measurements. // M are Platform Configuration Register (PCR) values that make up the Measurements.
type M map[uint32][]byte type M map[uint32]Measurement
// PCRWithAllBytes returns a PCR value where all 32 bytes are set to b.
func PCRWithAllBytes(b byte) []byte {
return bytes.Repeat([]byte{b}, 32)
}
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider) M {
switch provider {
case cloudprovider.AWS:
return M{
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.Azure:
return M{
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.GCP:
return M{
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.QEMU:
return M{
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
default:
return nil
}
}
// FetchAndVerify fetches measurement and signature files via provided URLs, // FetchAndVerify fetches measurement and signature files via provided URLs,
// using client for download. The publicKey is used to verify the measurements. // using client for download. The publicKey is used to verify the measurements.
@ -83,8 +53,14 @@ func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurement
if err := sigstore.VerifySignature(measurements, signature, publicKey); err != nil { if err := sigstore.VerifySignature(measurements, signature, publicKey); err != nil {
return "", err return "", err
} }
if err := yaml.NewDecoder(bytes.NewReader(measurements)).Decode(&m); err != nil {
return "", err if err := json.Unmarshal(measurements, m); err != nil {
if yamlErr := yaml.Unmarshal(measurements, m); yamlErr != nil {
return "", multierr.Append(
err,
fmt.Errorf("trying yaml format: %w", yamlErr),
)
}
} }
shaHash := sha256.Sum256(measurements) shaHash := sha256.Sum256(measurements)
@ -94,58 +70,231 @@ func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurement
// CopyFrom copies over all values from other. Overwriting existing values, // CopyFrom copies over all values from other. Overwriting existing values,
// but keeping not specified values untouched. // but keeping not specified values untouched.
func (m M) CopyFrom(other M) { func (m *M) CopyFrom(other M) {
for idx := range other { for idx := range other {
m[idx] = other[idx] (*m)[idx] = other[idx]
} }
} }
// EqualTo tests whether the provided other Measurements are equal to these // EqualTo tests whether the provided other Measurements are equal to these
// measurements. // measurements.
func (m M) EqualTo(other M) bool { func (m *M) EqualTo(other M) bool {
if len(m) != len(other) { if len(*m) != len(other) {
return false return false
} }
for k, v := range m { for k, v := range *m {
if !bytes.Equal(v, other[k]) { otherExpected := other[k].Expected
if !bytes.Equal(v.Expected[:], otherExpected[:]) {
return false
}
if v.WarnOnly != other[k].WarnOnly {
return false return false
} }
} }
return true return true
} }
// MarshalYAML overwrites the default behaviour of writing out []byte not as // GetEnforced returns a list of all enforced Measurements,
// single bytes, but as a single base64 encoded string. // i.e. all Measurements that are not marked as WarnOnly.
func (m M) MarshalYAML() (any, error) { func (m *M) GetEnforced() []uint32 {
base64Map := make(map[uint32]string) var enforced []uint32
for idx, measurement := range *m {
for key, value := range m { if !measurement.WarnOnly {
base64Map[key] = base64.StdEncoding.EncodeToString(value[:]) enforced = append(enforced, idx)
}
} }
return enforced
return base64Map, nil
} }
// UnmarshalYAML overwrites the default behaviour of reading []byte not as // SetEnforced sets the WarnOnly flag to true for all Measurements
// single bytes, but as a single base64 encoded string. // that are NOT included in the provided list of enforced measurements.
func (m *M) UnmarshalYAML(unmarshal func(any) error) error { func (m *M) SetEnforced(enforced []uint32) error {
base64Map := make(map[uint32]string) newM := make(M)
err := unmarshal(base64Map)
if err != nil { // set all measurements to warn only
return err for idx, measurement := range *m {
newM[idx] = Measurement{
Expected: measurement.Expected,
WarnOnly: true,
}
} }
*m = make(M) // set enforced measurements from list
for key, value := range base64Map { for _, idx := range enforced {
measurement, err := base64.StdEncoding.DecodeString(value) measurement, ok := newM[idx]
if err != nil { if !ok {
return err return fmt.Errorf("measurement %d not in list, but set to enforced", idx)
} }
(*m)[key] = measurement measurement.WarnOnly = false
newM[idx] = measurement
}
*m = newM
return nil
}
// Measurement wraps expected PCR value and whether it is enforced.
type Measurement struct {
// Expected measurement value.
Expected [32]byte `json:"expected" yaml:"expected"`
// WarnOnly if set to true, a mismatching measurement will only result in a warning.
WarnOnly bool `json:"warnOnly" yaml:"warnOnly"`
}
// UnmarshalJSON reads a Measurement either as json object,
// or as a simple hex or base64 encoded string.
func (m *Measurement) UnmarshalJSON(b []byte) error {
var eM encodedMeasurement
if err := json.Unmarshal(b, &eM); err != nil {
// Unmarshalling failed, Measurement might be in legacy format,
// meaning a simple string instead of Measurement struct.
// TODO: remove with v2.4.0
if legacyErr := json.Unmarshal(b, &eM.Expected); legacyErr != nil {
return multierr.Append(
err,
fmt.Errorf("trying legacy format: %w", legacyErr),
)
}
}
if err := m.unmarshal(eM); err != nil {
return fmt.Errorf("unmarshalling json: %w", err)
} }
return nil return nil
} }
// MarshalJSON writes out a Measurement with Expected encoded as a hex string.
func (m Measurement) MarshalJSON() ([]byte, error) {
return json.Marshal(encodedMeasurement{
Expected: hex.EncodeToString(m.Expected[:]),
WarnOnly: m.WarnOnly,
})
}
// UnmarshalYAML reads a Measurement either as yaml object,
// or as a simple hex or base64 encoded string.
func (m *Measurement) UnmarshalYAML(unmarshal func(any) error) error {
var eM encodedMeasurement
if err := unmarshal(&eM); err != nil {
// Unmarshalling failed, Measurement might be in legacy format,
// meaning a simple string instead of Measurement struct.
// TODO: remove with v2.4.0
if legacyErr := unmarshal(&eM.Expected); legacyErr != nil {
return multierr.Append(
err,
fmt.Errorf("trying legacy format: %w", legacyErr),
)
}
}
if err := m.unmarshal(eM); err != nil {
return fmt.Errorf("unmarshalling yaml: %w", err)
}
return nil
}
// MarshalYAML writes out a Measurement with Expected encoded as a hex string.
func (m Measurement) MarshalYAML() (any, error) {
return encodedMeasurement{
Expected: hex.EncodeToString(m.Expected[:]),
WarnOnly: m.WarnOnly,
}, nil
}
// unmarshal parses a hex or base64 encoded Measurement.
func (m *Measurement) unmarshal(eM encodedMeasurement) error {
expected, err := hex.DecodeString(eM.Expected)
if err != nil {
// expected value might be in base64 legacy format
// TODO: Remove with v2.4.0
hexErr := err
expected, err = base64.StdEncoding.DecodeString(eM.Expected)
if err != nil {
return multierr.Append(
fmt.Errorf("invalid measurement: not a hex string %w", hexErr),
fmt.Errorf("not a base64 string: %w", err),
)
}
}
if len(expected) != 32 {
return fmt.Errorf("invalid measurement: invalid length: %d", len(expected))
}
m.Expected = *(*[32]byte)(expected)
m.WarnOnly = eM.WarnOnly
return nil
}
// WithAllBytes returns a measurement value where all 32 bytes are set to b.
func WithAllBytes(b byte, warnOnly bool) Measurement {
return Measurement{
Expected: *(*[32]byte)(bytes.Repeat([]byte{b}, 32)),
WarnOnly: warnOnly,
}
}
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider) M {
switch provider {
case cloudprovider.AWS:
return M{
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
case cloudprovider.Azure:
return M{
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
case cloudprovider.GCP:
return M{
0: {
Expected: [32]byte{0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
WarnOnly: false,
},
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
case cloudprovider.QEMU:
return M{
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
default:
return nil
}
}
// PlaceHolderMeasurement returns a measurement with placeholder values for Expected.
func PlaceHolderMeasurement() Measurement {
return Measurement{
Expected: *(*[32]byte)(bytes.Repeat([]byte{0x12, 0x34}, 16)),
WarnOnly: false,
}
}
func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([]byte, error) { func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), http.NoBody)
if err != nil { if err != nil {
@ -166,3 +315,8 @@ func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([
} }
return content, nil return content, nil
} }
type encodedMeasurement struct {
Expected string `json:"expected" yaml:"expected"`
WarnOnly bool `json:"warnOnly" yaml:"warnOnly"`
}

View file

@ -8,7 +8,7 @@ package measurements
import ( import (
"context" "context"
"errors" "encoding/json"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -17,32 +17,29 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
) )
func TestMarshalYAML(t *testing.T) { func TestMarshal(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
measurements M m Measurement
wantBase64Map map[uint32]string wantYAML string
wantJSON string
}{ }{
"valid measurements": { "measurement": {
measurements: M{ m: Measurement{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
wantBase64Map: map[uint32]string{
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=",
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
}, },
wantYAML: "expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35\"\nwarnOnly: false",
wantJSON: `{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35","warnOnly":false}`,
}, },
"omit bytes": { "warn only": {
measurements: M{ m: Measurement{
2: []byte{}, Expected: [32]byte{1, 2, 3, 4}, // implicitly padded with 0s
3: []byte{1, 2, 3, 4}, WarnOnly: true,
},
wantBase64Map: map[uint32]string{
2: "",
3: "AQIDBA==",
}, },
wantYAML: "expected: \"0102030400000000000000000000000000000000000000000000000000000000\"\nwarnOnly: true",
wantJSON: `{"expected":"0102030400000000000000000000000000000000000000000000000000000000","warnOnly":true}`,
}, },
} }
@ -51,63 +48,99 @@ func TestMarshalYAML(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
base64Map, err := tc.measurements.MarshalYAML() {
require.NoError(err) // YAML
yaml, err := yaml.Marshal(tc.m)
require.NoError(err)
assert.Equal(tc.wantBase64Map, base64Map) assert.YAMLEq(tc.wantYAML, string(yaml))
}
{
// JSON
json, err := json.Marshal(tc.m)
require.NoError(err)
assert.JSONEq(tc.wantJSON, string(json))
}
}) })
} }
} }
func TestUnmarshalYAML(t *testing.T) { func TestUnmarshal(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
inputBase64Map map[uint32]string inputYAML string
forceUnmarshalError bool inputJSON string
wantMeasurements M wantMeasurements M
wantErr bool wantErr bool
}{ }{
"valid measurements": { "valid measurements base64": {
inputBase64Map: map[uint32]string{ inputYAML: "2:\n expected: \"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=\"\n3:\n expected: \"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=\"",
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", inputJSON: `{"2":{"expected":"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU="},"3":{"expected":"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754="}}`,
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
},
wantMeasurements: M{ wantMeasurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, 2: {
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
},
3: {
Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
},
},
"valid measurements hex": {
inputYAML: "2:\n expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35\"\n3:\n expected: \"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463fe607f2488ddf4f1006ef9e\"",
inputJSON: `{"2":{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35"},"3":{"expected":"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463fe607f2488ddf4f1006ef9e"}}`,
wantMeasurements: M{
2: {
Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
},
3: {
Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
}, },
}, },
"empty bytes": { "empty bytes": {
inputBase64Map: map[uint32]string{ inputYAML: "2:\n expected: \"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"",
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", inputJSON: `{"2":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}}`,
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: M{ wantMeasurements: M{
2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 2: {
3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Expected: [32]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
3: {
Expected: [32]byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}, },
}, },
"invalid base64": { "invalid base64": {
inputBase64Map: map[uint32]string{ inputYAML: "2:\n expected: \"This is not base64\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"",
2: "This is not base64", inputJSON: `{"2":{"expected":"This is not base64"},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}}`,
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", wantErr: true,
},
wantMeasurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
wantErr: true,
}, },
"simulated unmarshal error": { "legacy format": {
inputBase64Map: map[uint32]string{ inputYAML: "2: \"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=\"\n3: \"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=\"",
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", inputJSON: `{"2":"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=","3":"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754="}`,
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
forceUnmarshalError: true,
wantMeasurements: M{ wantMeasurements: M{
2: []byte{}, 2: {
3: []byte{1, 2, 3, 4}, Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
},
3: {
Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
}, },
wantErr: true, },
"invalid length hex": {
inputYAML: "2:\n expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef\"\n3:\n expected: \"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463f\"",
inputJSON: `{"2":{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef"},"3":{"expected":"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463f"}}`,
wantErr: true,
},
"invalid length base64": {
inputYAML: "2:\n expected: \"AA==\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\"",
inputJSON: `{"2":{"expected":"AA=="},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}}`,
wantErr: true,
},
"invalid format": {
inputYAML: "1:\n expected:\n someKey: 12\n anotherKey: 34",
inputJSON: `{"1":{"expected":{"someKey":12,"anotherKey":34}}}`,
wantErr: true,
}, },
} }
@ -116,24 +149,30 @@ func TestUnmarshalYAML(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
var m M {
err := m.UnmarshalYAML(func(i any) error { // YAML
if base64Map, ok := i.(map[uint32]string); ok { var m M
for key, value := range tc.inputBase64Map { err := yaml.Unmarshal([]byte(tc.inputYAML), &m)
base64Map[key] = value
}
}
if tc.forceUnmarshalError {
return errors.New("unmarshal error")
}
return nil
})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err, "yaml.Unmarshal should have failed")
} else { } else {
require.NoError(err) require.NoError(err, "yaml.Unmarshal failed")
assert.Equal(tc.wantMeasurements, m) assert.Equal(tc.wantMeasurements, m)
}
}
{
// JSON
var m M
err := json.Unmarshal([]byte(tc.inputJSON), &m)
if tc.wantErr {
assert.Error(err, "json.Unmarshal should have failed")
} else {
require.NoError(err, "json.Unmarshal failed")
assert.Equal(tc.wantMeasurements, m)
}
} }
}) })
} }
@ -148,48 +187,48 @@ func TestMeasurementsCopyFrom(t *testing.T) {
"add to empty": { "add to empty": {
current: M{}, current: M{},
newMeasurements: M{ newMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
wantMeasurements: M{ wantMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
}, },
"keep existing": { "keep existing": {
current: M{ current: M{
4: PCRWithAllBytes(0x01), 4: WithAllBytes(0x01, false),
5: PCRWithAllBytes(0x02), 5: WithAllBytes(0x02, true),
}, },
newMeasurements: M{ newMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
wantMeasurements: M{ wantMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
4: PCRWithAllBytes(0x01), 4: WithAllBytes(0x01, false),
5: PCRWithAllBytes(0x02), 5: WithAllBytes(0x02, true),
}, },
}, },
"overwrite existing": { "overwrite existing": {
current: M{ current: M{
2: PCRWithAllBytes(0x04), 2: WithAllBytes(0x04, false),
3: PCRWithAllBytes(0x05), 3: WithAllBytes(0x05, false),
}, },
newMeasurements: M{ newMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
wantMeasurements: M{ wantMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
}, },
} }
@ -224,6 +263,22 @@ func urlMustParse(raw string) *url.URL {
} }
func TestMeasurementsFetchAndVerify(t *testing.T) { func TestMeasurementsFetchAndVerify(t *testing.T) {
// Cosign private key used to sign the measurements.
// Generated with: cosign generate-key-pair
// Password left empty.
//
// -----BEGIN ENCRYPTED COSIGN PRIVATE KEY-----
// eyJrZGYiOnsibmFtZSI6InNjcnlwdCIsInBhcmFtcyI6eyJOIjozMjc2OCwiciI6
// OCwicCI6MX0sInNhbHQiOiJlRHVYMWRQMGtIWVRnK0xkbjcxM0tjbFVJaU92eFVX
// VXgvNi9BbitFVk5BPSJ9LCJjaXBoZXIiOnsibmFtZSI6Im5hY2wvc2VjcmV0Ym94
// Iiwibm9uY2UiOiJwaWhLL2txNmFXa2hqSVVHR3RVUzhTVkdHTDNIWWp4TCJ9LCJj
// aXBoZXJ0ZXh0Ijoidm81SHVWRVFWcUZ2WFlQTTVPaTVaWHM5a255bndZU2dvcyth
// VklIeHcrOGFPamNZNEtvVjVmL3lHRHR0K3BHV2toanJPR1FLOWdBbmtsazFpQ0c5
// a2czUXpPQTZsU2JRaHgvZlowRVRZQ0hLeElncEdPRVRyTDlDenZDemhPZXVSOXJ6
// TDcvRjBBVy9vUDVqZXR3dmJMNmQxOEhjck9kWE8yVmYxY2w0YzNLZjVRcnFSZzlN
// dlRxQWFsNXJCNHNpY1JaMVhpUUJjb0YwNHc9PSJ9
// -----END ENCRYPTED COSIGN PRIVATE KEY-----
testCases := map[string]struct { testCases := map[string]struct {
measurements string measurements string
measurementsStatus int measurementsStatus int
@ -237,44 +292,66 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
"simple": { "simple": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n",
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIQDcHS2bLls7OrLHpQKuiFGXhPrTcehPDwgVyERHl4V02wIgeIxK4J9oJpXWRBjokbog2lgifRXuJK8ljlAID26MbHk=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{ wantMeasurements: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
}, },
wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87", wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87",
}, },
"404 measurements": { "json measurements": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusNotFound, measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{
0: WithAllBytes(0x00, false),
},
wantSHA: "1da09758c89537946496358f80b892e508563fcbbc695c90b6c16bf158e69c11",
},
"yaml measurements": {
measurements: "0:\n expected: \"0000000000000000000000000000000000000000000000000000000000000000\"\n warnOnly: false\n",
measurementsStatus: http.StatusOK,
signature: "MEUCIFzQdwBS92aJjY0bcIag1uQRl42lUSBmmjEvO0tM/N0ZAiEAvuWaP744qYMw5uEmc7BY4mm4Ij3TEqAWFgxNhFkckp4=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{
0: WithAllBytes(0x00, false),
},
wantSHA: "c651cd419fd536c63cfc5349ad44da140a09987465e31192660059d383413807",
},
"404 measurements": {
measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusNotFound,
signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
"404 signature": { "404 signature": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusNotFound, signatureStatus: http.StatusNotFound,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
"broken signature": { "broken signature": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "AAAAAAs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "AAAAAAAA3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
"not yaml": { "not yaml or json": {
measurements: "This is some content to be signed!\n", measurements: "This is some content to be signed!\n",
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "MEUCIQDzMN3yaiO9sxLGAaSA9YD8rLwzvOaZKWa/bzkcjImUFAIgXLLGzClYUd1dGbuEiY3O/g/eiwQYlyxqLQalxjFmz+8=", signature: "MEUCIQCGA/lSu5qCJgNNvgMaTKJ9rj6vQMecUDaQo3ukaiAfUgIgWoxXRoDKLY9naN7YgxokM7r2fwnyYk3M2WKJJO1g6yo=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAElWUhon39eAqzEC+/GP03oY4/MQg+\ngCDlEzkuOCybCHf+q766bve799L7Y5y5oRsHY1MrUCUwYF/tL7Sg7EYMsA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
} }
@ -322,30 +399,196 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
} }
} }
func TestPCRWithAllBytes(t *testing.T) { func TestGetEnforced(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
b byte input M
wantPCR []byte want map[uint32]struct{}
}{ }{
"0x00": { "only warnings": {
b: 0x00, input: M{
wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
want: map[uint32]struct{}{},
}, },
"0x01": { "all enforced": {
b: 0x01, input: M{
wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, 0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
want: map[uint32]struct{}{
0: {},
1: {},
},
}, },
"0xFF": { "mixed": {
b: 0xFF, input: M{
wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, 0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, true),
2: WithAllBytes(0x02, false),
},
want: map[uint32]struct{}{
0: {},
2: {},
},
}, },
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
pcr := PCRWithAllBytes(tc.b)
assert.Equal(tc.wantPCR, pcr) got := tc.input.GetEnforced()
enforced := map[uint32]struct{}{}
for _, id := range got {
enforced[id] = struct{}{}
}
assert.Equal(tc.want, enforced)
})
}
}
func TestSetEnforced(t *testing.T) {
testCases := map[string]struct {
input M
enforced []uint32
wantM M
wantErr bool
}{
"no enforced measurements": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
enforced: []uint32{},
wantM: M{
0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
},
"all enforced measurements": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
enforced: []uint32{0, 1},
wantM: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
},
"mixed": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
2: WithAllBytes(0x02, false),
3: WithAllBytes(0x03, false),
},
enforced: []uint32{0, 2},
wantM: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, true),
2: WithAllBytes(0x02, false),
3: WithAllBytes(0x03, true),
},
},
"warn only to enforced": {
input: M{
0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
enforced: []uint32{0, 1},
wantM: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
},
"more enforced than measurements": {
input: M{
0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
enforced: []uint32{0, 1, 2},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := tc.input.SetEnforced(tc.enforced)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.True(tc.input.EqualTo(tc.wantM))
})
}
}
func TestWithAllBytes(t *testing.T) {
testCases := map[string]struct {
b byte
warnOnly bool
wantMeasurement Measurement
}{
"0x00 warnOnly": {
b: 0x00,
warnOnly: true,
wantMeasurement: Measurement{
Expected: [32]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
WarnOnly: true,
},
},
"0x00": {
b: 0x00,
warnOnly: false,
wantMeasurement: Measurement{
Expected: [32]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
WarnOnly: false,
},
},
"0x01 warnOnly": {
b: 0x01,
warnOnly: true,
wantMeasurement: Measurement{
Expected: [32]byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
WarnOnly: true,
},
},
"0x01": {
b: 0x01,
warnOnly: false,
wantMeasurement: Measurement{
Expected: [32]byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
WarnOnly: false,
},
},
"0xFF warnOnly": {
b: 0xFF,
warnOnly: true,
wantMeasurement: Measurement{
Expected: [32]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
WarnOnly: true,
},
},
"0xFF": {
b: 0xFF,
warnOnly: false,
wantMeasurement: Measurement{
Expected: [32]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
WarnOnly: false,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
measurement := WithAllBytes(tc.b, tc.warnOnly)
assert.Equal(tc.wantMeasurement, measurement)
}) })
} }
} }
@ -358,33 +601,44 @@ func TestEqualTo(t *testing.T) {
}{ }{
"same values": { "same values": {
given: M{ given: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
other: M{ other: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
wantEqual: true, wantEqual: true,
}, },
"different number of elements": { "different number of elements": {
given: M{ given: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
other: M{ other: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
}, },
wantEqual: false, wantEqual: false,
}, },
"different values": { "different values": {
given: M{ given: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
other: M{ other: M{
0: PCRWithAllBytes(0xFF), 0: WithAllBytes(0xFF, false),
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, false),
},
wantEqual: false,
},
"different warn settings": {
given: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0xFF, false),
},
other: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0xFF, true),
}, },
wantEqual: false, wantEqual: false,
}, },

View file

@ -22,11 +22,10 @@ type Validator struct {
} }
// NewValidator initializes a new QEMU validator with the provided PCR values. // NewValidator initializes a new QEMU validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
unconditionalTrust, unconditionalTrust,
func(attestation vtpm.AttestationDocument) error { return nil }, func(attestation vtpm.AttestationDocument) error { return nil },
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -16,6 +16,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
tpmClient "github.com/google/go-tpm-tools/client" tpmClient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
tpmProto "github.com/google/go-tpm-tools/proto/tpm" tpmProto "github.com/google/go-tpm-tools/proto/tpm"
@ -144,8 +145,7 @@ func (i *Issuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
// Validator handles validation of TPM based attestation. // Validator handles validation of TPM based attestation.
type Validator struct { type Validator struct {
expectedPCRs map[uint32][]byte expected measurements.M
enforcedPCRs map[uint32]struct{}
getTrustedKey GetTPMTrustedAttestationPublicKey getTrustedKey GetTPMTrustedAttestationPublicKey
validateCVM ValidateCVM validateCVM ValidateCVM
verifyUserData VerifyUserData verifyUserData VerifyUserData
@ -154,18 +154,11 @@ type Validator struct {
} }
// NewValidator returns a new Validator. // NewValidator returns a new Validator.
func NewValidator(expectedPCRs map[uint32][]byte, enforcedPCRs []uint32, getTrustedKey GetTPMTrustedAttestationPublicKey, func NewValidator(expected measurements.M, getTrustedKey GetTPMTrustedAttestationPublicKey,
validateCVM ValidateCVM, verifyUserData VerifyUserData, log AttestationLogger, validateCVM ValidateCVM, verifyUserData VerifyUserData, log AttestationLogger,
) *Validator { ) *Validator {
// Convert the enforced PCR list to a map for convenient and fast lookup
enforcedMap := make(map[uint32]struct{})
for _, pcr := range enforcedPCRs {
enforcedMap[pcr] = struct{}{}
}
return &Validator{ return &Validator{
expectedPCRs: expectedPCRs, expected: expected,
enforcedPCRs: enforcedMap,
getTrustedKey: getTrustedKey, getTrustedKey: getTrustedKey,
validateCVM: validateCVM, validateCVM: validateCVM,
verifyUserData: verifyUserData, verifyUserData: verifyUserData,
@ -212,9 +205,9 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
for idx, pcr := range v.expectedPCRs { for idx, pcr := range v.expected {
if !bytes.Equal(pcr, attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) { if !bytes.Equal(pcr.Expected[:], attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) {
if _, ok := v.enforcedPCRs[idx]; ok { if !pcr.WarnOnly {
return nil, fmt.Errorf("untrusted PCR value at PCR index %d", idx) return nil, fmt.Errorf("untrusted PCR value at PCR index %d", idx)
} }
if v.log != nil { if v.log != nil {
@ -263,8 +256,8 @@ func VerifyPKCS1v15(pub crypto.PublicKey, hash crypto.Hash, hashed, sig []byte)
return rsa.VerifyPKCS1v15(key, hash, hashed, sig) return rsa.VerifyPKCS1v15(key, hash, hashed, sig)
} }
// GetSelectedPCRs returns a map of the selected PCR hashes. // GetSelectedMeasurements returns a map of Measurments for the PCRs in selection.
func GetSelectedPCRs(open TPMOpenFunc, selection tpm2.PCRSelection) (map[uint32][]byte, error) { func GetSelectedMeasurements(open TPMOpenFunc, selection tpm2.PCRSelection) (measurements.M, error) {
tpm, err := open() tpm, err := open()
if err != nil { if err != nil {
return nil, err return nil, err
@ -276,5 +269,15 @@ func GetSelectedPCRs(open TPMOpenFunc, selection tpm2.PCRSelection) (map[uint32]
return nil, err return nil, err
} }
return pcrList.Pcrs, nil m := make(measurements.M)
for i, pcr := range pcrList.Pcrs {
if len(pcr) != 32 {
return nil, fmt.Errorf("invalid measurement: invalid length: %d", len(pcr))
}
m[i] = measurements.Measurement{
Expected: *(*[32]byte)(pcr),
}
}
return m, nil
} }

View file

@ -14,6 +14,7 @@ import (
"io" "io"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
tpmsim "github.com/edgelesssys/constellation/v2/internal/attestation/simulator" tpmsim "github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
tpmclient "github.com/google/go-tpm-tools/client" tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
@ -64,14 +65,14 @@ func TestValidate(t *testing.T) {
return pubArea.Key() return pubArea.Key()
} }
testExpectedPCRs := map[uint32][]byte{ testExpectedPCRs := measurements.M{
0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0: measurements.WithAllBytes(0x00, true),
1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 1: measurements.WithAllBytes(0x00, true),
} }
warnLog := &testAttestationLogger{} warnLog := &testAttestationLogger{}
issuer := NewIssuer(newSimTPMWithEventLog, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo) issuer := NewIssuer(newSimTPMWithEventLog, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo)
validator := NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog) validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog)
nonce := []byte{1, 2, 3, 4} nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation") challenge := []byte("Constellation")
@ -89,18 +90,28 @@ func TestValidate(t *testing.T) {
require.NoError(err) require.NoError(err)
require.Equal(challenge, out) require.Equal(challenge, out)
enforcedPCRs := []uint32{0, 1} expectedPCRs := measurements.M{
expectedPCRs := map[uint32][]byte{ 0: measurements.WithAllBytes(0x00, true),
0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 1: measurements.WithAllBytes(0x00, true),
1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 2: measurements.Measurement{
2: {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}, Expected: [32]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20},
3: {0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40}, WarnOnly: true,
4: {0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60}, },
5: {0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80}, 3: measurements.Measurement{
Expected: [32]byte{0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40},
WarnOnly: true,
},
4: measurements.Measurement{
Expected: [32]byte{0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60},
WarnOnly: true,
},
5: measurements.Measurement{
Expected: [32]byte{0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80},
WarnOnly: true,
},
} }
warningValidator := NewValidator( warningValidator := NewValidator(
expectedPCRs, expectedPCRs,
enforcedPCRs,
fakeGetTrustedKey, fakeGetTrustedKey,
fakeValidateCVM, fakeValidateCVM,
VerifyPKCS1v15, VerifyPKCS1v15,
@ -109,7 +120,7 @@ func TestValidate(t *testing.T) {
out, err = warningValidator.Validate(attDocRaw, nonce) out, err = warningValidator.Validate(attDocRaw, nonce)
require.NoError(err) require.NoError(err)
assert.Equal(t, challenge, out) assert.Equal(t, challenge, out)
assert.Len(t, warnLog.warnings, len(expectedPCRs)-len(enforcedPCRs)) assert.Len(t, warnLog.warnings, 4)
testCases := map[string]struct { testCases := map[string]struct {
validator *Validator validator *Validator
@ -118,13 +129,13 @@ func TestValidate(t *testing.T) {
wantErr bool wantErr bool
}{ }{
"invalid nonce": { "invalid nonce": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: mustMarshalAttestation(attDoc, require), attDoc: mustMarshalAttestation(attDoc, require),
nonce: []byte{4, 3, 2, 1}, nonce: []byte{4, 3, 2, 1},
wantErr: true, wantErr: true,
}, },
"invalid signature": { "invalid signature": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{ attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: attDoc.Attestation, Attestation: attDoc.Attestation,
InstanceInfo: attDoc.InstanceInfo, InstanceInfo: attDoc.InstanceInfo,
@ -137,7 +148,6 @@ func TestValidate(t *testing.T) {
"untrusted attestation public key": { "untrusted attestation public key": {
validator: NewValidator( validator: NewValidator(
testExpectedPCRs, testExpectedPCRs,
[]uint32{0, 1},
func(akPub, instanceInfo []byte) (crypto.PublicKey, error) { func(akPub, instanceInfo []byte) (crypto.PublicKey, error) {
return nil, errors.New("untrusted") return nil, errors.New("untrusted")
}, },
@ -149,7 +159,6 @@ func TestValidate(t *testing.T) {
"not a CVM": { "not a CVM": {
validator: NewValidator( validator: NewValidator(
testExpectedPCRs, testExpectedPCRs,
[]uint32{0, 1},
fakeGetTrustedKey, fakeGetTrustedKey,
func(attestation AttestationDocument) error { func(attestation AttestationDocument) error {
return errors.New("untrusted") return errors.New("untrusted")
@ -161,10 +170,12 @@ func TestValidate(t *testing.T) {
}, },
"untrusted PCRs": { "untrusted PCRs": {
validator: NewValidator( validator: NewValidator(
map[uint32][]byte{ measurements.M{
0: {0xFF}, 0: measurements.Measurement{
Expected: [32]byte{0xFF},
WarnOnly: false,
},
}, },
[]uint32{0},
fakeGetTrustedKey, fakeGetTrustedKey,
fakeValidateCVM, fakeValidateCVM,
VerifyPKCS1v15, warnLog), VerifyPKCS1v15, warnLog),
@ -173,7 +184,7 @@ func TestValidate(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"no sha256 quote": { "no sha256 quote": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{ attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: &attest.Attestation{ Attestation: &attest.Attestation{
AkPub: attDoc.Attestation.AkPub, AkPub: attDoc.Attestation.AkPub,
@ -191,7 +202,7 @@ func TestValidate(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"invalid attestation document": { "invalid attestation document": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: []byte("invalid attestation"), attDoc: []byte("invalid attestation"),
nonce: nonce, nonce: nonce,
wantErr: true, wantErr: true,
@ -350,7 +361,7 @@ func TestGetSHA256QuoteIndex(t *testing.T) {
} }
} }
func TestGetSelectedPCRs(t *testing.T) { func TestGetSelectedMeasurements(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
openFunc TPMOpenFunc openFunc TPMOpenFunc
pcrSelection tpm2.PCRSelection pcrSelection tpm2.PCRSelection
@ -386,17 +397,13 @@ func TestGetSelectedPCRs(t *testing.T) {
require := require.New(t) require := require.New(t)
assert := assert.New(t) assert := assert.New(t)
pcrs, err := GetSelectedPCRs(tc.openFunc, tc.pcrSelection) pcrs, err := GetSelectedMeasurements(tc.openFunc, tc.pcrSelection)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { return
require.NoError(err)
assert.Equal(len(pcrs), len(tc.pcrSelection.PCRs))
for _, pcr := range pcrs {
assert.Len(pcr, 32)
}
} }
require.NoError(err)
assert.Len(pcrs, len(tc.pcrSelection.PCRs))
}) })
} }
} }

View file

@ -9,18 +9,8 @@ package vtpm
import ( import (
"errors" "errors"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
"github.com/google/go-tpm/tpmutil"
)
const (
// PCRIndexClusterID is a PCR we extend to mark the node as initialized.
// The value used to extend is a random generated 32 Byte value.
PCRIndexClusterID = tpmutil.Handle(15)
// PCRIndexOwnerID is a PCR we extend to mark the node as initialized.
// The value used to extend is derived from Constellation's master key.
// TODO: move to stable, non-debug PCR before use.
PCRIndexOwnerID = tpmutil.Handle(16)
) )
// MarkNodeAsBootstrapped marks a node as initialized by extending PCRs. // MarkNodeAsBootstrapped marks a node as initialized by extending PCRs.
@ -32,7 +22,7 @@ func MarkNodeAsBootstrapped(openTPM TPMOpenFunc, clusterID []byte) error {
defer tpm.Close() defer tpm.Close()
// clusterID is used to uniquely identify this running instance of Constellation // clusterID is used to uniquely identify this running instance of Constellation
return tpm2.PCREvent(tpm, PCRIndexClusterID, clusterID) return tpm2.PCREvent(tpm, measurements.PCRIndexClusterID, clusterID)
} }
// IsNodeBootstrapped checks if a node is already bootstrapped by reading PCRs. // IsNodeBootstrapped checks if a node is already bootstrapped by reading PCRs.
@ -43,7 +33,7 @@ func IsNodeBootstrapped(openTPM TPMOpenFunc) (bool, error) {
} }
defer tpm.Close() defer tpm.Close()
idxClusterID := int(PCRIndexClusterID) idxClusterID := int(measurements.PCRIndexClusterID)
pcrs, err := tpm2.ReadPCRs(tpm, tpm2.PCRSelection{ pcrs, err := tpm2.ReadPCRs(tpm, tpm2.PCRSelection{
Hash: tpm2.AlgSHA256, Hash: tpm2.AlgSHA256,
PCRs: []int{idxClusterID}, PCRs: []int{idxClusterID},

View file

@ -11,6 +11,7 @@ import (
"io" "io"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator" "github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/google/go-tpm-tools/client" "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
@ -45,7 +46,7 @@ func TestMarkNodeAsBootstrapped(t *testing.T) {
require.NoError(err) require.NoError(err)
for i := range pcrs { for i := range pcrs {
assert.NotEqual(pcrs[i].Pcrs[uint32(PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(PCRIndexClusterID)]) assert.NotEqual(pcrs[i].Pcrs[uint32(measurements.PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(measurements.PCRIndexClusterID)])
} }
} }
@ -76,7 +77,7 @@ func TestIsNodeInitialized(t *testing.T) {
require.NoError(err) require.NoError(err)
defer tpm.Close() defer tpm.Close()
if tc.pcrValueClusterID != nil { if tc.pcrValueClusterID != nil {
require.NoError(tpm2.PCREvent(tpm, PCRIndexClusterID, tc.pcrValueClusterID)) require.NoError(tpm2.PCREvent(tpm, measurements.PCRIndexClusterID, tc.pcrValueClusterID))
} }
initialized, err := IsNodeBootstrapped(func() (io.ReadWriteCloser, error) { initialized, err := IsNodeBootstrapped(func() (io.ReadWriteCloser, error) {
return &simTPMNOPCloser{tpm}, nil return &simTPMNOPCloser{tpm}, nil

View file

@ -138,10 +138,7 @@ type AWSConfig struct {
IAMProfileWorkerNodes string `yaml:"iamProfileWorkerNodes" validate:"required"` IAMProfileWorkerNodes string `yaml:"iamProfileWorkerNodes" validate:"required"`
// description: | // description: |
// Expected VM measurements. // Expected VM measurements.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// AzureConfig are Azure specific configuration values used by the CLI. // AzureConfig are Azure specific configuration values used by the CLI.
@ -190,10 +187,7 @@ type AzureConfig struct {
EnforceIDKeyDigest *bool `yaml:"enforceIdKeyDigest" validate:"required"` EnforceIDKeyDigest *bool `yaml:"enforceIdKeyDigest" validate:"required"`
// description: | // description: |
// Expected confidential VM measurements. // Expected confidential VM measurements.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// GCPConfig are GCP specific configuration values used by the CLI. // GCPConfig are GCP specific configuration values used by the CLI.
@ -221,10 +215,7 @@ type GCPConfig struct {
DeployCSIDriver *bool `yaml:"deployCSIDriver" validate:"required"` DeployCSIDriver *bool `yaml:"deployCSIDriver" validate:"required"`
// description: | // description: |
// Expected confidential VM measurements. // Expected confidential VM measurements.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// QEMUConfig holds config information for QEMU based Constellation deployments. // QEMUConfig holds config information for QEMU based Constellation deployments.
@ -255,10 +246,7 @@ type QEMUConfig struct {
Firmware string `yaml:"firmware"` Firmware string `yaml:"firmware"`
// description: | // description: |
// Measurement used to enable measured boot. // Measurement used to enable measured boot.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// Default returns a struct with the default config. // Default returns a struct with the default config.
@ -276,7 +264,6 @@ func Default() *Config {
IAMProfileControlPlane: "", IAMProfileControlPlane: "",
IAMProfileWorkerNodes: "", IAMProfileWorkerNodes: "",
Measurements: measurements.DefaultsFor(cloudprovider.AWS), Measurements: measurements.DefaultsFor(cloudprovider.AWS),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
Azure: &AzureConfig{ Azure: &AzureConfig{
SubscriptionID: "", SubscriptionID: "",
@ -292,7 +279,6 @@ func Default() *Config {
ConfidentialVM: func() *bool { b := true; return &b }(), ConfidentialVM: func() *bool { b := true; return &b }(),
SecureBoot: func() *bool { b := false; return &b }(), SecureBoot: func() *bool { b := false; return &b }(),
Measurements: measurements.DefaultsFor(cloudprovider.Azure), Measurements: measurements.DefaultsFor(cloudprovider.Azure),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
GCP: &GCPConfig{ GCP: &GCPConfig{
Project: "", Project: "",
@ -303,7 +289,6 @@ func Default() *Config {
StateDiskType: "pd-ssd", StateDiskType: "pd-ssd",
DeployCSIDriver: func() *bool { b := true; return &b }(), DeployCSIDriver: func() *bool { b := true; return &b }(),
Measurements: measurements.DefaultsFor(cloudprovider.GCP), Measurements: measurements.DefaultsFor(cloudprovider.GCP),
EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15},
}, },
QEMU: &QEMUConfig{ QEMU: &QEMUConfig{
ImageFormat: "raw", ImageFormat: "raw",
@ -314,7 +299,6 @@ func Default() *Config {
LibvirtContainerImage: versions.LibvirtImage, LibvirtContainerImage: versions.LibvirtImage,
NVRAM: "production", NVRAM: "production",
Measurements: measurements.DefaultsFor(cloudprovider.QEMU), Measurements: measurements.DefaultsFor(cloudprovider.QEMU),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
}, },
KubernetesVersion: string(versions.Default), KubernetesVersion: string(versions.Default),
@ -446,18 +430,18 @@ func (c *Config) EnforcesIDKeyDigest() bool {
return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest
} }
// GetEnforcedPCRs returns the list of enforced PCRs for the configured cloud provider. // EnforcedPCRs returns the list of enforced PCRs for the configured cloud provider.
func (c *Config) GetEnforcedPCRs() []uint32 { func (c *Config) EnforcedPCRs() []uint32 {
provider := c.GetProvider() provider := c.GetProvider()
switch provider { switch provider {
case cloudprovider.AWS: case cloudprovider.AWS:
return c.Provider.AWS.EnforcedMeasurements return c.Provider.AWS.Measurements.GetEnforced()
case cloudprovider.Azure: case cloudprovider.Azure:
return c.Provider.Azure.EnforcedMeasurements return c.Provider.Azure.Measurements.GetEnforced()
case cloudprovider.GCP: case cloudprovider.GCP:
return c.Provider.GCP.EnforcedMeasurements return c.Provider.GCP.Measurements.GetEnforced()
case cloudprovider.QEMU: case cloudprovider.QEMU:
return c.Provider.QEMU.EnforcedMeasurements return c.Provider.QEMU.Measurements.GetEnforced()
default: default:
return nil return nil
} }
@ -499,6 +483,14 @@ func (c *Config) Validate() error {
return err return err
} }
if err := validate.RegisterTranslation("no_placeholders", trans, registerContainsPlaceholderError, translateContainsPlaceholderError); err != nil {
return err
}
if err := validate.RegisterValidation("no_placeholders", validateNoPlaceholder); err != nil {
return err
}
if err := validate.RegisterValidation("safe_image", validateImage); err != nil { if err := validate.RegisterValidation("safe_image", validateImage); err != nil {
return err return err
} }

View file

@ -157,7 +157,7 @@ func init() {
FieldName: "aws", FieldName: "aws",
}, },
} }
AWSConfigDoc.Fields = make([]encoder.Doc, 8) AWSConfigDoc.Fields = make([]encoder.Doc, 7)
AWSConfigDoc.Fields[0].Name = "region" AWSConfigDoc.Fields[0].Name = "region"
AWSConfigDoc.Fields[0].Type = "string" AWSConfigDoc.Fields[0].Type = "string"
AWSConfigDoc.Fields[0].Note = "" AWSConfigDoc.Fields[0].Note = ""
@ -193,11 +193,6 @@ func init() {
AWSConfigDoc.Fields[6].Note = "" AWSConfigDoc.Fields[6].Note = ""
AWSConfigDoc.Fields[6].Description = "Expected VM measurements." AWSConfigDoc.Fields[6].Description = "Expected VM measurements."
AWSConfigDoc.Fields[6].Comments[encoder.LineComment] = "Expected VM measurements." AWSConfigDoc.Fields[6].Comments[encoder.LineComment] = "Expected VM measurements."
AWSConfigDoc.Fields[7].Name = "enforcedMeasurements"
AWSConfigDoc.Fields[7].Type = "[]uint32"
AWSConfigDoc.Fields[7].Note = ""
AWSConfigDoc.Fields[7].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
AWSConfigDoc.Fields[7].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
AzureConfigDoc.Type = "AzureConfig" AzureConfigDoc.Type = "AzureConfig"
AzureConfigDoc.Comments[encoder.LineComment] = "AzureConfig are Azure specific configuration values used by the CLI." AzureConfigDoc.Comments[encoder.LineComment] = "AzureConfig are Azure specific configuration values used by the CLI."
@ -208,7 +203,7 @@ func init() {
FieldName: "azure", FieldName: "azure",
}, },
} }
AzureConfigDoc.Fields = make([]encoder.Doc, 16) AzureConfigDoc.Fields = make([]encoder.Doc, 15)
AzureConfigDoc.Fields[0].Name = "subscription" AzureConfigDoc.Fields[0].Name = "subscription"
AzureConfigDoc.Fields[0].Type = "string" AzureConfigDoc.Fields[0].Type = "string"
AzureConfigDoc.Fields[0].Note = "" AzureConfigDoc.Fields[0].Note = ""
@ -284,11 +279,6 @@ func init() {
AzureConfigDoc.Fields[14].Note = "" AzureConfigDoc.Fields[14].Note = ""
AzureConfigDoc.Fields[14].Description = "Expected confidential VM measurements." AzureConfigDoc.Fields[14].Description = "Expected confidential VM measurements."
AzureConfigDoc.Fields[14].Comments[encoder.LineComment] = "Expected confidential VM measurements." AzureConfigDoc.Fields[14].Comments[encoder.LineComment] = "Expected confidential VM measurements."
AzureConfigDoc.Fields[15].Name = "enforcedMeasurements"
AzureConfigDoc.Fields[15].Type = "[]uint32"
AzureConfigDoc.Fields[15].Note = ""
AzureConfigDoc.Fields[15].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
AzureConfigDoc.Fields[15].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
GCPConfigDoc.Type = "GCPConfig" GCPConfigDoc.Type = "GCPConfig"
GCPConfigDoc.Comments[encoder.LineComment] = "GCPConfig are GCP specific configuration values used by the CLI." GCPConfigDoc.Comments[encoder.LineComment] = "GCPConfig are GCP specific configuration values used by the CLI."
@ -299,7 +289,7 @@ func init() {
FieldName: "gcp", FieldName: "gcp",
}, },
} }
GCPConfigDoc.Fields = make([]encoder.Doc, 9) GCPConfigDoc.Fields = make([]encoder.Doc, 8)
GCPConfigDoc.Fields[0].Name = "project" GCPConfigDoc.Fields[0].Name = "project"
GCPConfigDoc.Fields[0].Type = "string" GCPConfigDoc.Fields[0].Type = "string"
GCPConfigDoc.Fields[0].Note = "" GCPConfigDoc.Fields[0].Note = ""
@ -340,11 +330,6 @@ func init() {
GCPConfigDoc.Fields[7].Note = "" GCPConfigDoc.Fields[7].Note = ""
GCPConfigDoc.Fields[7].Description = "Expected confidential VM measurements." GCPConfigDoc.Fields[7].Description = "Expected confidential VM measurements."
GCPConfigDoc.Fields[7].Comments[encoder.LineComment] = "Expected confidential VM measurements." GCPConfigDoc.Fields[7].Comments[encoder.LineComment] = "Expected confidential VM measurements."
GCPConfigDoc.Fields[8].Name = "enforcedMeasurements"
GCPConfigDoc.Fields[8].Type = "[]uint32"
GCPConfigDoc.Fields[8].Note = ""
GCPConfigDoc.Fields[8].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
GCPConfigDoc.Fields[8].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
QEMUConfigDoc.Type = "QEMUConfig" QEMUConfigDoc.Type = "QEMUConfig"
QEMUConfigDoc.Comments[encoder.LineComment] = "QEMUConfig holds config information for QEMU based Constellation deployments." QEMUConfigDoc.Comments[encoder.LineComment] = "QEMUConfig holds config information for QEMU based Constellation deployments."
@ -355,7 +340,7 @@ func init() {
FieldName: "qemu", FieldName: "qemu",
}, },
} }
QEMUConfigDoc.Fields = make([]encoder.Doc, 10) QEMUConfigDoc.Fields = make([]encoder.Doc, 9)
QEMUConfigDoc.Fields[0].Name = "imageFormat" QEMUConfigDoc.Fields[0].Name = "imageFormat"
QEMUConfigDoc.Fields[0].Type = "string" QEMUConfigDoc.Fields[0].Type = "string"
QEMUConfigDoc.Fields[0].Note = "" QEMUConfigDoc.Fields[0].Note = ""
@ -401,11 +386,6 @@ func init() {
QEMUConfigDoc.Fields[8].Note = "" QEMUConfigDoc.Fields[8].Note = ""
QEMUConfigDoc.Fields[8].Description = "Measurement used to enable measured boot." QEMUConfigDoc.Fields[8].Description = "Measurement used to enable measured boot."
QEMUConfigDoc.Fields[8].Comments[encoder.LineComment] = "Measurement used to enable measured boot." QEMUConfigDoc.Fields[8].Comments[encoder.LineComment] = "Measurement used to enable measured boot."
QEMUConfigDoc.Fields[9].Name = "enforcedMeasurements"
QEMUConfigDoc.Fields[9].Type = "[]uint32"
QEMUConfigDoc.Fields[9].Note = ""
QEMUConfigDoc.Fields[9].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
QEMUConfigDoc.Fields[9].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
} }
func (_ Config) Doc() *encoder.Doc { func (_ Config) Doc() *encoder.Doc {

View file

@ -128,6 +128,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
c.Provider.Azure.ResourceGroup = "test" c.Provider.Azure.ResourceGroup = "test"
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity" c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e" c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return c return c
}(), }(),
envToSet: map[string]string{ envToSet: map[string]string{
@ -147,6 +148,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well. c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well.
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity" c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e" c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return c return c
}(), }(),
envToSet: map[string]string{ envToSet: map[string]string{
@ -182,9 +184,9 @@ func TestNewWithDefaultOptions(t *testing.T) {
} }
func TestValidate(t *testing.T) { func TestValidate(t *testing.T) {
const defaultErrCount = 17 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default const defaultErrCount = 21 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
const azErrCount = 8 const azErrCount = 9
const gcpErrCount = 5 const gcpErrCount = 6
testCases := map[string]struct { testCases := map[string]struct {
cnf *Config cnf *Config
@ -240,6 +242,7 @@ func TestValidate(t *testing.T) {
az.ClientSecretValue = "test-client-secret" az.ClientSecretValue = "test-client-secret"
cnf.Provider = ProviderConfig{} cnf.Provider = ProviderConfig{}
cnf.Provider.Azure = az cnf.Provider.Azure = az
cnf.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return cnf return cnf
}(), }(),
}, },
@ -265,6 +268,7 @@ func TestValidate(t *testing.T) {
gcp.ServiceAccountKeyPath = "test-key-path" gcp.ServiceAccountKeyPath = "test-key-path"
cnf.Provider = ProviderConfig{} cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp cnf.Provider.GCP = gcp
cnf.Provider.GCP.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return cnf return cnf
}(), }(),
}, },
@ -364,9 +368,9 @@ func TestConfigGeneratedDocsFresh(t *testing.T) {
func TestConfig_UpdateMeasurements(t *testing.T) { func TestConfig_UpdateMeasurements(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
newMeasurements := measurements.M{ newMeasurements := measurements.M{
1: measurements.PCRWithAllBytes(0x00), 1: measurements.WithAllBytes(0x00, false),
2: measurements.PCRWithAllBytes(0x01), 2: measurements.WithAllBytes(0x01, false),
3: measurements.PCRWithAllBytes(0x02), 3: measurements.WithAllBytes(0x02, false),
} }
{ // AWS { // AWS

View file

@ -7,9 +7,11 @@ SPDX-License-Identifier: AGPL-3.0-only
package config package config
import ( import (
"bytes"
"fmt" "fmt"
"strings" "strings"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes" "github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
@ -223,3 +225,35 @@ func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.
return t return t
} }
func validateNoPlaceholder(fl validator.FieldLevel) bool {
return len(getPlaceholderEntries(fl.Field().Interface().(Measurements))) == 0
}
func registerContainsPlaceholderError(ut ut.Translator) error {
return ut.Add("no_placeholders", "{0} placeholder values (repeated 1234...)", true)
}
func translateContainsPlaceholderError(ut ut.Translator, fe validator.FieldError) string {
placeholders := getPlaceholderEntries(fe.Value().(Measurements))
msg := fmt.Sprintf("Measurements %v contain", placeholders)
if len(placeholders) == 1 {
msg = fmt.Sprintf("Measurement %v contains", placeholders)
}
t, _ := ut.T("no_placeholders", msg)
return t
}
func getPlaceholderEntries(m Measurements) []uint32 {
var placeholders []uint32
placeholder := measurements.PlaceHolderMeasurement()
for idx, measurement := range m {
if bytes.Equal(measurement.Expected[:], placeholder.Expected[:]) {
placeholders = append(placeholders, idx)
}
}
return placeholders
}

View file

@ -9,7 +9,9 @@ package watcher
import ( import (
"encoding/asn1" "encoding/asn1"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync" "sync"
@ -19,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
@ -42,26 +45,26 @@ func NewValidator(log *logger.Logger, csp string, fileHandler file.Handler, azur
var newValidator newValidatorFunc var newValidator newValidatorFunc
switch cloudprovider.FromString(csp) { switch cloudprovider.FromString(csp) {
case cloudprovider.AWS: case cloudprovider.AWS:
newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator {
return aws.NewValidator(m, e, log) return aws.NewValidator(m, log)
} }
case cloudprovider.Azure: case cloudprovider.Azure:
if azureCVM { if azureCVM {
newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator {
return snp.NewValidator(m, e, idkeydigest, enforceIdKeyDigest, log) return snp.NewValidator(m, idkeydigest, enforceIdKeyDigest, log)
} }
} else { } else {
newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator {
return trustedlaunch.NewValidator(m, e, log) return trustedlaunch.NewValidator(m, log)
} }
} }
case cloudprovider.GCP: case cloudprovider.GCP:
newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator {
return gcp.NewValidator(m, e, log) return gcp.NewValidator(m, log)
} }
case cloudprovider.QEMU: case cloudprovider.QEMU:
newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator {
return qemu.NewValidator(m, e, log) return qemu.NewValidator(m, log)
} }
default: default:
return nil, fmt.Errorf("unknown cloud service provider: %q", csp) return nil, fmt.Errorf("unknown cloud service provider: %q", csp)
@ -100,17 +103,24 @@ func (u *Updatable) Update() error {
u.log.Infof("Updating expected measurements") u.log.Infof("Updating expected measurements")
var measurements map[uint32][]byte var measurements measurements.M
if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), &measurements); err != nil { if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), &measurements); err != nil {
return err return err
} }
u.log.Debugf("New measurements: %v", measurements) u.log.Debugf("New measurements: %+v", measurements)
// handle legacy measurement format, where expected measurements and enforced measurements were stored in separate data structures
// TODO: remove with v2.4.0
var enforced []uint32 var enforced []uint32
if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err != nil { if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err == nil {
u.log.Debugf("Detected legacy format. Loading enforced PCRs...")
if err := measurements.SetEnforced(enforced); err != nil {
return err
}
u.log.Debugf("Merged measurements with enforced values: %+v", measurements)
} else if !errors.Is(err, os.ErrNotExist) {
return err return err
} }
u.log.Debugf("Enforced PCRs: %v", enforced)
var idkeydigest []byte var idkeydigest []byte
var enforceIDKeyDigest bool var enforceIDKeyDigest bool
@ -138,9 +148,9 @@ func (u *Updatable) Update() error {
u.log.Debugf("New idkeydigest: %x", idkeydigest) u.log.Debugf("New idkeydigest: %x", idkeydigest)
} }
u.Validator = u.newValidator(measurements, enforced, idkeydigest, enforceIDKeyDigest, u.log) u.Validator = u.newValidator(measurements, idkeydigest, enforceIDKeyDigest, u.log)
return nil return nil
} }
type newValidatorFunc func(measurements map[uint32][]byte, enforcedPCRs []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator type newValidatorFunc func(measurements measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator

View file

@ -20,6 +20,7 @@ import (
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
@ -117,7 +118,7 @@ func TestUpdate(t *testing.T) {
require := require.New(t) require := require.New(t)
oid := fakeOID{1, 3, 9900, 1} oid := fakeOID{1, 3, 9900, 1}
newValidator := func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: oid} return fakeValidator{fakeOID: oid}
} }
handler := file.NewHandler(afero.NewMemMapFs()) handler := file.NewHandler(afero.NewMemMapFs())
@ -135,14 +136,7 @@ func TestUpdate(t *testing.T) {
// write measurement config // write measurement config
require.NoError(handler.WriteJSON( require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{ measurements.M{11: measurements.WithAllBytes(0x00, false)},
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
file.OptNone,
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
)) ))
require.NoError(handler.Write( require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
@ -189,6 +183,23 @@ func TestUpdate(t *testing.T) {
defer resp.Body.Close() defer resp.Body.Close()
} }
assert.Error(err) assert.Error(err)
// update should work for legacy measurement format
// TODO: remove with v2.4.0
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: bytes.Repeat([]byte{0x0}, 32),
12: bytes.Repeat([]byte{0x1}, 32),
},
file.OptOverwrite,
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
))
assert.NoError(validator.Update())
} }
func TestUpdateConcurrency(t *testing.T) { func TestUpdateConcurrency(t *testing.T) {
@ -199,7 +210,7 @@ func TestUpdateConcurrency(t *testing.T) {
validator := &Updatable{ validator := &Updatable{
log: logger.NewTest(t), log: logger.NewTest(t),
fileHandler: handler, fileHandler: handler,
newValidator: func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { newValidator: func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}} return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
}, },
} }