Refactor enforced/expected PCRs (#553)

* Merge enforced and expected measurements

* Update measurement generation to new format

* Write expected measurements hex encoded by default

* Allow hex or base64 encoded expected measurements

* Allow hex or base64 encoded clusterID

* Allow security upgrades to warnOnly flag

* Upload signed measurements in JSON format

* Fetch measurements either from JSON or YAML

* Use yaml.v3 instead of yaml.v2

* Error on invalid enforced selection

* Add placeholder measurements to config

* Update e2e test to new measurement format

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-11-24 10:57:58 +01:00 committed by GitHub
parent 8ce954e012
commit f8001efbc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1180 additions and 801 deletions

View file

@ -75,14 +75,14 @@ runs:
(.provider | select(. | has(\"azure\")).azure.resourceGroup) = \"${{ inputs.azureResourceGroup }}\" | (.provider | select(. | has(\"azure\")).azure.resourceGroup) = \"${{ inputs.azureResourceGroup }}\" |
(.provider | select(. | has(\"azure\")).azure.appClientID) = \"${{ inputs.azureClientID }}\" | (.provider | select(. | has(\"azure\")).azure.appClientID) = \"${{ inputs.azureClientID }}\" |
(.provider | select(. | has(\"azure\")).azure.clientSecretValue) = \"${{ inputs.azureClientSecret }}\" | (.provider | select(. | has(\"azure\")).azure.clientSecretValue) = \"${{ inputs.azureClientSecret }}\" |
(.provider | select(. | has(\"azure\")).azure.enforcedMeasurements) = [15]" \ (.provider | select(. | has(\"azure\")).azure.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}}" \
constellation-conf.yaml constellation-conf.yaml
yq eval -i \ yq eval -i \
"(.provider | select(. | has(\"gcp\")).gcp.project) = \"${{ inputs.gcpProject }}\" | "(.provider | select(. | has(\"gcp\")).gcp.project) = \"${{ inputs.gcpProject }}\" |
(.provider | select(. | has(\"gcp\")).gcp.region) = \"europe-west3\" | (.provider | select(. | has(\"gcp\")).gcp.region) = \"europe-west3\" |
(.provider | select(. | has(\"gcp\")).gcp.zone) = \"europe-west3-b\" | (.provider | select(. | has(\"gcp\")).gcp.zone) = \"europe-west3-b\" |
(.provider | select(. | has(\"gcp\")).gcp.enforcedMeasurements) = [15] | (.provider | select(. | has(\"gcp\")).gcp.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}} |
(.provider | select(. | has(\"gcp\")).gcp.serviceAccountKeyPath) = \"serviceAccountKey.json\"" \ (.provider | select(. | has(\"gcp\")).gcp.serviceAccountKeyPath) = \"serviceAccountKey.json\"" \
constellation-conf.yaml constellation-conf.yaml
@ -91,7 +91,7 @@ runs:
(.provider | select(. | has(\"aws\")).aws.zone) = \"eu-central-1a\" | (.provider | select(. | has(\"aws\")).aws.zone) = \"eu-central-1a\" |
(.provider | select(. | has(\"aws\")).aws.iamProfileControlPlane) = \"e2e_test_control_plane_instance_profile\" | (.provider | select(. | has(\"aws\")).aws.iamProfileControlPlane) = \"e2e_test_control_plane_instance_profile\" |
(.provider | select(. | has(\"aws\")).aws.iamProfileWorkerNodes) = \"e2e_test_worker_node_instance_profile\" | (.provider | select(. | has(\"aws\")).aws.iamProfileWorkerNodes) = \"e2e_test_worker_node_instance_profile\" |
(.provider | select(. | has(\"aws\")).aws.enforcedMeasurements) = [15]" \ (.provider | select(. | has(\"aws\")).aws.measurements) = {15:{\"expected\":\"0000000000000000000000000000000000000000000000000000000000000000\",\"warnOnly\":false}}" \
constellation-conf.yaml constellation-conf.yaml
if [ ${{ inputs.kubernetesVersion != '' }} = true ]; then if [ ${{ inputs.kubernetesVersion != '' }} = true ]; then

View file

@ -51,16 +51,35 @@ runs:
run: | run: |
KUBECONFIG="$PWD/constellation-admin.conf" kubectl rollout status ds/verification-service -n kube-system --timeout=3m KUBECONFIG="$PWD/constellation-admin.conf" kubectl rollout status ds/verification-service -n kube-system --timeout=3m
CONSTELL_IP=$(jq -r ".ip" constellation-id.json) CONSTELL_IP=$(jq -r ".ip" constellation-id.json)
pcr-reader --constell-ip ${CONSTELL_IP} -format yaml > measurements.yaml pcr-reader --constell-ip ${CONSTELL_IP} -format json > measurements.json
case $CSP in case $CSP in
azure) azure)
yq e 'del(.[0,6,10,11,12,13,14,15,16,17,18,19,20,21,22,23])' -i measurements.yaml yq e 'del(.[0,6,10,16,17,18,19,20,21,22,23])' -I 0 -o json -i measurements.json
yq '.4.warnOnly = false |
.8.warnOnly = false |
.9.warnOnly = false |
.11.warnOnly = false |
.12.warnOnly = false |
.13.warnOnly = false |
.15.warnOnly = false |
.15.expected = "0000000000000000000000000000000000000000000000000000000000000000"' \
-I 0 -o json -i measurements.json
;; ;;
gcp) gcp)
yq e 'del(.[11,12,13,14,15,16,17,18,19,20,21,22,23])' -i measurements.yaml yq e 'del(.[16,17,18,19,20,21,22,23])' -I 0 -o json -i measurements.json
yq '.0.warnOnly = false |
.4.warnOnly = false |
.8.warnOnly = false |
.9.warnOnly = false |
.11.warnOnly = false |
.12.warnOnly = false |
.13.warnOnly = false |
.15.warnOnly = false |
.15.expected = "0000000000000000000000000000000000000000000000000000000000000000"' \
-I 0 -o json -i measurements.json
;; ;;
esac esac
cat measurements.yaml cat measurements.json
shell: bash shell: bash
env: env:
CSP: ${{ inputs.cloudProvider }} CSP: ${{ inputs.cloudProvider }}
@ -81,14 +100,14 @@ runs:
run: | run: |
echo "$COSIGN_PUBLIC_KEY" > cosign.pub echo "$COSIGN_PUBLIC_KEY" > cosign.pub
# Enabling experimental mode also publishes signature to Rekor # Enabling experimental mode also publishes signature to Rekor
COSIGN_EXPERIMENTAL=1 cosign sign-blob --key env://COSIGN_PRIVATE_KEY measurements.yaml > measurements.yaml.sig COSIGN_EXPERIMENTAL=1 cosign sign-blob --key env://COSIGN_PRIVATE_KEY measurements.json > measurements.json.sig
# Verify - As documentation & check # Verify - As documentation & check
# Local Signature (input: artifact, key, signature) # Local Signature (input: artifact, key, signature)
cosign verify-blob --key cosign.pub --signature measurements.yaml.sig measurements.yaml cosign verify-blob --key cosign.pub --signature measurements.json.sig measurements.json
# Transparency Log Signature (input: artifact, key) # Transparency Log Signature (input: artifact, key)
uuid=$(rekor-cli search --artifact measurements.yaml | tail -n 1) uuid=$(rekor-cli search --artifact measurements.json | tail -n 1)
sig=$(rekor-cli get --uuid=$uuid --format=json | jq -r .Body.HashedRekordObj.signature.content) sig=$(rekor-cli get --uuid=$uuid --format=json | jq -r .Body.HashedRekordObj.signature.content)
cosign verify-blob --key cosign.pub --signature <(echo $sig) measurements.yaml cosign verify-blob --key cosign.pub --signature <(echo $sig) measurements.json
shell: bash shell: bash
env: env:
COSIGN_PUBLIC_KEY: ${{ inputs.cosignPublicKey }} COSIGN_PUBLIC_KEY: ${{ inputs.cosignPublicKey }}
@ -100,9 +119,9 @@ runs:
run: | run: |
IMAGE=$(yq e ".provider.${CSP}.image" constellation-conf.yaml) IMAGE=$(yq e ".provider.${CSP}.image" constellation-conf.yaml)
S3_PATH=s3://${PUBLIC_BUCKET_NAME}/${IMAGE,,} S3_PATH=s3://${PUBLIC_BUCKET_NAME}/${IMAGE,,}
aws s3 cp measurements.yaml ${S3_PATH}/measurements.yaml aws s3 cp measurements.json ${S3_PATH}/measurements.json
if test -f measurements.yaml.sig; then if test -f measurements.json.sig; then
aws s3 cp measurements.yaml.sig ${S3_PATH}/measurements.yaml.sig aws s3 cp measurements.json.sig ${S3_PATH}/measurements.json.sig
fi fi
shell: bash shell: bash
env: env:

View file

@ -8,7 +8,6 @@ package main
import ( import (
"context" "context"
"encoding/json"
"flag" "flag"
"io" "io"
"os" "os"
@ -80,14 +79,10 @@ func main() {
switch cloudprovider.FromString(os.Getenv(constellationCSP)) { switch cloudprovider.FromString(os.Getenv(constellationCSP)) {
case cloudprovider.AWS: case cloudprovider.AWS:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.AWSPCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AWSPCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
issuer = initserver.NewIssuerWrapper(aws.NewIssuer(), vmtype.Unknown, nil) issuer = initserver.NewIssuerWrapper(aws.NewIssuer(), vmtype.Unknown, nil)
@ -104,13 +99,13 @@ func main() {
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"aws", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "aws", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
openTPM = vtpm.OpenVTPM openTPM = vtpm.OpenVTPM
fs = afero.NewOsFs() fs = afero.NewOsFs()
case cloudprovider.GCP: case cloudprovider.GCP:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.GCPPCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.GCPPCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
@ -129,20 +124,16 @@ func main() {
} }
metadataAPI = metadata metadataAPI = metadata
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"gcp", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "gcp", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
openTPM = vtpm.OpenVTPM openTPM = vtpm.OpenVTPM
fs = afero.NewOsFs() fs = afero.NewOsFs()
log.Infof("Added load balancer IP to routing table") log.Infof("Added load balancer IP to routing table")
case cloudprovider.Azure: case cloudprovider.Azure:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.AzurePCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AzurePCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
@ -163,20 +154,16 @@ func main() {
log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger") log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger")
} }
metadataAPI = metadata metadataAPI = metadata
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"azure", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "azure", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
openTPM = vtpm.OpenVTPM openTPM = vtpm.OpenVTPM
fs = afero.NewOsFs() fs = afero.NewOsFs()
case cloudprovider.QEMU: case cloudprovider.QEMU:
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, vtpm.QEMUPCRSelection) measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.QEMUPCRSelection)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs")
} }
@ -185,13 +172,9 @@ func main() {
cloudLogger = qemucloud.NewLogger() cloudLogger = qemucloud.NewLogger()
metadata := qemucloud.New() metadata := qemucloud.New()
pcrsJSON, err := json.Marshal(pcrs)
if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to marshal PCRs")
}
clusterInitJoiner = kubernetes.New( clusterInitJoiner = kubernetes.New(
"qemu", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(), "qemu", k8sapi.NewKubernetesUtil(), &k8sapi.KubdeadmConfiguration{}, kubectl.New(),
metadata, pcrsJSON, helmClient, &kubewaiter.CloudKubeAPIWaiter{}, metadata, measurements, helmClient, &kubewaiter.CloudKubeAPIWaiter{},
) )
metadataAPI = metadata metadataAPI = metadata

View file

@ -20,6 +20,7 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/k8sapi" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/k8sapi"
kubewaiter "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/kubeWaiter" kubewaiter "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/kubeWaiter"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared" "github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
@ -60,13 +61,13 @@ type KubeWrapper struct {
client k8sapi.Client client k8sapi.Client
kubeconfigReader configReader kubeconfigReader configReader
providerMetadata ProviderMetadata providerMetadata ProviderMetadata
initialMeasurementsJSON []byte initialMeasurements measurements.M
getIPAddr func() (string, error) getIPAddr func() (string, error)
} }
// New creates a new KubeWrapper with real values. // New creates a new KubeWrapper with real values.
func New(cloudProvider string, clusterUtil clusterUtil, configProvider configurationProvider, client k8sapi.Client, func New(cloudProvider string, clusterUtil clusterUtil, configProvider configurationProvider, client k8sapi.Client,
providerMetadata ProviderMetadata, initialMeasurementsJSON []byte, helmClient helmClient, kubeAPIWaiter kubeAPIWaiter, providerMetadata ProviderMetadata, measurements measurements.M, helmClient helmClient, kubeAPIWaiter kubeAPIWaiter,
) *KubeWrapper { ) *KubeWrapper {
return &KubeWrapper{ return &KubeWrapper{
cloudProvider: cloudProvider, cloudProvider: cloudProvider,
@ -77,7 +78,7 @@ func New(cloudProvider string, clusterUtil clusterUtil, configProvider configura
client: client, client: client,
kubeconfigReader: &KubeconfigReader{fs: afero.Afero{Fs: afero.NewOsFs()}}, kubeconfigReader: &KubeconfigReader{fs: afero.Afero{Fs: afero.NewOsFs()}},
providerMetadata: providerMetadata, providerMetadata: providerMetadata,
initialMeasurementsJSON: initialMeasurementsJSON, initialMeasurements: measurements,
getIPAddr: getIPAddr, getIPAddr: getIPAddr,
} }
} }
@ -187,7 +188,21 @@ func (k *KubeWrapper) InitCluster(
} else { } else {
controlPlaneIP = controlPlaneEndpoint controlPlaneIP = controlPlaneEndpoint
} }
serviceConfig := constellationServicesConfig{k.initialMeasurementsJSON, idKeyDigest, measurementSalt, subnetworkPodCIDR, cloudServiceAccountURI, controlPlaneIP} if err := k.initialMeasurements.SetEnforced(enforcedPCRs); err != nil {
return nil, err
}
measurementsJSON, err := json.Marshal(k.initialMeasurements)
if err != nil {
return nil, fmt.Errorf("marshaling initial measurements: %w", err)
}
serviceConfig := constellationServicesConfig{
initialMeasurementsJSON: measurementsJSON,
idkeydigest: idKeyDigest,
measurementSalt: measurementSalt,
subnetworkPodCIDR: subnetworkPodCIDR,
cloudServiceAccountURI: cloudServiceAccountURI,
loadBalancerIP: controlPlaneIP,
}
extraVals, err := k.setupExtraVals(ctx, serviceConfig) extraVals, err := k.setupExtraVals(ctx, serviceConfig)
if err != nil { if err != nil {
return nil, fmt.Errorf("setting up extraVals: %w", err) return nil, fmt.Errorf("setting up extraVals: %w", err)

View file

@ -112,6 +112,13 @@ func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measu
return nil return nil
} }
// don't allow potential security downgrades by setting the warnOnly flag to true
for k, newM := range newMeasurements {
if currentM, ok := currentMeasurements[k]; ok && !currentM.WarnOnly && newM.WarnOnly {
return fmt.Errorf("setting enforced measurement %d to warn only: not allowed", k)
}
}
// backup of previous measurements // backup of previous measurements
existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename] existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename]

View file

@ -33,12 +33,12 @@ func TestUpdateMeasurements(t *testing.T) {
updater: &stubMeasurementsUpdater{ updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{ oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{ Data: map[string]string{
constants.MeasurementsFilename: `{"0":"AAAAAA=="}`, constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
}, },
}, },
}, },
newMeasurements: measurements.M{ newMeasurements: measurements.M{
0: []byte("1"), 0: measurements.WithAllBytes(0xBB, false),
}, },
wantUpdate: true, wantUpdate: true,
}, },
@ -46,14 +46,40 @@ func TestUpdateMeasurements(t *testing.T) {
updater: &stubMeasurementsUpdater{ updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{ oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{ Data: map[string]string{
constants.MeasurementsFilename: `{"0":"MQ=="}`, constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
}, },
}, },
}, },
newMeasurements: measurements.M{ newMeasurements: measurements.M{
0: []byte("1"), 0: measurements.WithAllBytes(0xAA, false),
}, },
}, },
"trying to set warnOnly to true results in error": {
updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{
constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
},
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, true),
},
wantErr: true,
},
"setting warnOnly to false is allowed": {
updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{
constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":true}}`,
},
},
},
newMeasurements: measurements.M{
0: measurements.WithAllBytes(0xAA, false),
},
wantUpdate: true,
},
"getCurrent error": { "getCurrent error": {
updater: &stubMeasurementsUpdater{getErr: someErr}, updater: &stubMeasurementsUpdater{getErr: someErr},
wantErr: true, wantErr: true,
@ -62,7 +88,7 @@ func TestUpdateMeasurements(t *testing.T) {
updater: &stubMeasurementsUpdater{ updater: &stubMeasurementsUpdater{
oldMeasurements: &corev1.ConfigMap{ oldMeasurements: &corev1.ConfigMap{
Data: map[string]string{ Data: map[string]string{
constants.MeasurementsFilename: `{"0":"AAAAAA=="}`, constants.MeasurementsFilename: `{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`,
}, },
}, },
updateErr: someErr, updateErr: someErr,
@ -82,7 +108,7 @@ func TestUpdateMeasurements(t *testing.T) {
err := upgrader.updateMeasurements(context.Background(), tc.newMeasurements) err := upgrader.updateMeasurements(context.Background(), tc.newMeasurements)
if tc.wantErr { if tc.wantErr {
assert.ErrorIs(err, someErr) assert.Error(err)
return return
} }

View file

@ -20,17 +20,16 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"go.uber.org/multierr"
) )
// Validator validates Platform Configuration Registers (PCRs). // Validator validates Platform Configuration Registers (PCRs).
type Validator struct { type Validator struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs measurements.M pcrs measurements.M
enforcedPCRs []uint32
idkeydigest []byte idkeydigest []byte
enforceIDKeyDigest bool enforceIDKeyDigest bool
azureCVM bool azureCVM bool
@ -65,41 +64,44 @@ func NewValidator(provider cloudprovider.Provider, conf *config.Config) (*Valida
// UpdateInitPCRs sets the owner and cluster PCR values. // UpdateInitPCRs sets the owner and cluster PCR values.
func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error { func (v *Validator) UpdateInitPCRs(ownerID, clusterID string) error {
if err := v.updatePCR(uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil { if err := v.updatePCR(uint32(measurements.PCRIndexOwnerID), ownerID); err != nil {
return err return err
} }
return v.updatePCR(uint32(vtpm.PCRIndexClusterID), clusterID) return v.updatePCR(uint32(measurements.PCRIndexClusterID), clusterID)
} }
// updatePCR adds a new entry to the pcr map of v, or removes the key if the input is an empty string. // updatePCR adds a new entry to the measurements of v, or removes the key if the input is an empty string.
// //
// When adding, the input is first decoded from base64. // When adding, the input is first decoded from hex or base64.
// We then calculate the expected PCR by hashing the input using SHA256, // We then calculate the expected PCR by hashing the input using SHA256,
// appending expected PCR for initialization, and then hashing once more. // appending expected PCR for initialization, and then hashing once more.
func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error { func (v *Validator) updatePCR(pcrIndex uint32, encoded string) error {
if encoded == "" { if encoded == "" {
delete(v.pcrs, pcrIndex) delete(v.pcrs, pcrIndex)
// remove enforced PCR if it exists
for i, enforcedIdx := range v.enforcedPCRs {
if enforcedIdx == pcrIndex {
v.enforcedPCRs[i] = v.enforcedPCRs[len(v.enforcedPCRs)-1]
v.enforcedPCRs = v.enforcedPCRs[:len(v.enforcedPCRs)-1]
break
}
}
return nil return nil
} }
decoded, err := base64.StdEncoding.DecodeString(encoded)
// decode from hex or base64
decoded, err := hex.DecodeString(encoded)
if err != nil { if err != nil {
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err) hexErr := err
decoded, err = base64.StdEncoding.DecodeString(encoded)
if err != nil {
return multierr.Append(
fmt.Errorf("input [%s] is not hex encoded: %w", encoded, hexErr),
fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err),
)
}
} }
// new_pcr_value := hash(old_pcr_value || data_to_extend) // new_pcr_value := hash(old_pcr_value || data_to_extend)
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
hashedInput := sha256.Sum256(decoded) hashedInput := sha256.Sum256(decoded)
expectedPcr := sha256.Sum256(append(v.pcrs[pcrIndex], hashedInput[:]...)) oldExpected := v.pcrs[pcrIndex].Expected
v.pcrs[pcrIndex] = expectedPcr[:] expectedPcr := sha256.Sum256(append(oldExpected[:], hashedInput[:]...))
v.pcrs[pcrIndex] = measurements.Measurement{
Expected: expectedPcr,
WarnOnly: v.pcrs[pcrIndex].WarnOnly,
}
return nil return nil
} }
@ -107,35 +109,27 @@ func (v *Validator) setPCRs(config *config.Config) error {
switch v.provider { switch v.provider {
case cloudprovider.AWS: case cloudprovider.AWS:
awsPCRs := config.Provider.AWS.Measurements awsPCRs := config.Provider.AWS.Measurements
enforcedPCRs := config.Provider.AWS.EnforcedMeasurements if len(awsPCRs) == 0 {
if err := v.checkPCRs(awsPCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = awsPCRs v.pcrs = awsPCRs
case cloudprovider.Azure: case cloudprovider.Azure:
azurePCRs := config.Provider.Azure.Measurements azurePCRs := config.Provider.Azure.Measurements
enforcedPCRs := config.Provider.Azure.EnforcedMeasurements if len(azurePCRs) == 0 {
if err := v.checkPCRs(azurePCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = azurePCRs v.pcrs = azurePCRs
case cloudprovider.GCP: case cloudprovider.GCP:
gcpPCRs := config.Provider.GCP.Measurements gcpPCRs := config.Provider.GCP.Measurements
enforcedPCRs := config.Provider.GCP.EnforcedMeasurements if len(gcpPCRs) == 0 {
if err := v.checkPCRs(gcpPCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = gcpPCRs v.pcrs = gcpPCRs
case cloudprovider.QEMU: case cloudprovider.QEMU:
qemuPCRs := config.Provider.QEMU.Measurements qemuPCRs := config.Provider.QEMU.Measurements
enforcedPCRs := config.Provider.QEMU.EnforcedMeasurements if len(qemuPCRs) == 0 {
if err := v.checkPCRs(qemuPCRs, enforcedPCRs); err != nil { return errors.New("no expected measurement provided")
return err
} }
v.enforcedPCRs = enforcedPCRs
v.pcrs = qemuPCRs v.pcrs = qemuPCRs
} }
return nil return nil
@ -156,37 +150,20 @@ func (v *Validator) updateValidator(cmd *cobra.Command) {
log := warnLogger{cmd: cmd} log := warnLogger{cmd: cmd}
switch v.provider { switch v.provider {
case cloudprovider.GCP: case cloudprovider.GCP:
v.validator = gcp.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = gcp.NewValidator(v.pcrs, log)
case cloudprovider.Azure: case cloudprovider.Azure:
if v.azureCVM { if v.azureCVM {
v.validator = snp.NewValidator(v.pcrs, v.enforcedPCRs, v.idkeydigest, v.enforceIDKeyDigest, log) v.validator = snp.NewValidator(v.pcrs, v.idkeydigest, v.enforceIDKeyDigest, log)
} else { } else {
v.validator = trustedlaunch.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = trustedlaunch.NewValidator(v.pcrs, log)
} }
case cloudprovider.AWS: case cloudprovider.AWS:
v.validator = aws.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = aws.NewValidator(v.pcrs, log)
case cloudprovider.QEMU: case cloudprovider.QEMU:
v.validator = qemu.NewValidator(v.pcrs, v.enforcedPCRs, log) v.validator = qemu.NewValidator(v.pcrs, log)
} }
} }
func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error {
if len(pcrs) == 0 {
return errors.New("no PCR values provided")
}
for k, v := range pcrs {
if len(v) != 32 {
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
}
}
for _, v := range enforcedPCRs {
if _, ok := pcrs[v]; !ok {
return fmt.Errorf("bad config: PCR[%d] is enforced, but no expected measurement is provided", v)
}
}
return nil
}
// warnLogger implements logging of warnings for validators. // warnLogger implements logging of warnings for validators.
type warnLogger struct { type warnLogger struct {
cmd *cobra.Command cmd *cobra.Command

View file

@ -7,9 +7,9 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd package cloudcmd
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
@ -18,7 +18,6 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -27,12 +26,12 @@ import (
func TestNewValidator(t *testing.T) { func TestNewValidator(t *testing.T) {
testPCRs := measurements.M{ testPCRs := measurements.M{
0: measurements.PCRWithAllBytes(0x00), 0: measurements.WithAllBytes(0x00, false),
1: measurements.PCRWithAllBytes(0xFF), 1: measurements.WithAllBytes(0xFF, false),
2: measurements.PCRWithAllBytes(0x00), 2: measurements.WithAllBytes(0x00, false),
3: measurements.PCRWithAllBytes(0xFF), 3: measurements.WithAllBytes(0xFF, false),
4: measurements.PCRWithAllBytes(0x00), 4: measurements.WithAllBytes(0x00, false),
5: measurements.PCRWithAllBytes(0x00), 5: measurements.WithAllBytes(0x00, false),
} }
testCases := map[string]struct { testCases := map[string]struct {
@ -67,13 +66,6 @@ func TestNewValidator(t *testing.T) {
pcrs: measurements.M{}, pcrs: measurements.M{},
wantErr: true, wantErr: true,
}, },
"invalid pcr length": {
provider: cloudprovider.GCP,
pcrs: measurements.M{
0: bytes.Repeat([]byte{0x00}, 31),
},
wantErr: true,
},
"unknown provider": { "unknown provider": {
provider: cloudprovider.Unknown, provider: cloudprovider.Unknown,
pcrs: testPCRs, pcrs: testPCRs,
@ -126,19 +118,19 @@ func TestNewValidator(t *testing.T) {
func TestValidatorV(t *testing.T) { func TestValidatorV(t *testing.T) {
newTestPCRs := func() measurements.M { newTestPCRs := func() measurements.M {
return measurements.M{ return measurements.M{
0: measurements.PCRWithAllBytes(0x00), 0: measurements.WithAllBytes(0x00, true),
1: measurements.PCRWithAllBytes(0x00), 1: measurements.WithAllBytes(0x00, true),
2: measurements.PCRWithAllBytes(0x00), 2: measurements.WithAllBytes(0x00, true),
3: measurements.PCRWithAllBytes(0x00), 3: measurements.WithAllBytes(0x00, true),
4: measurements.PCRWithAllBytes(0x00), 4: measurements.WithAllBytes(0x00, true),
5: measurements.PCRWithAllBytes(0x00), 5: measurements.WithAllBytes(0x00, true),
6: measurements.PCRWithAllBytes(0x00), 6: measurements.WithAllBytes(0x00, true),
7: measurements.PCRWithAllBytes(0x00), 7: measurements.WithAllBytes(0x00, true),
8: measurements.PCRWithAllBytes(0x00), 8: measurements.WithAllBytes(0x00, true),
9: measurements.PCRWithAllBytes(0x00), 9: measurements.WithAllBytes(0x00, true),
10: measurements.PCRWithAllBytes(0x00), 10: measurements.WithAllBytes(0x00, true),
11: measurements.PCRWithAllBytes(0x00), 11: measurements.WithAllBytes(0x00, true),
12: measurements.PCRWithAllBytes(0x00), 12: measurements.WithAllBytes(0x00, true),
} }
} }
@ -151,23 +143,23 @@ func TestValidatorV(t *testing.T) {
"gcp": { "gcp": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: gcp.NewValidator(newTestPCRs(), nil, nil), wantVs: gcp.NewValidator(newTestPCRs(), nil),
}, },
"azure cvm": { "azure cvm": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: snp.NewValidator(newTestPCRs(), nil, nil, false, nil), wantVs: snp.NewValidator(newTestPCRs(), nil, false, nil),
azureCVM: true, azureCVM: true,
}, },
"azure trusted launch": { "azure trusted launch": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: trustedlaunch.NewValidator(newTestPCRs(), nil, nil), wantVs: trustedlaunch.NewValidator(newTestPCRs(), nil),
}, },
"qemu": { "qemu": {
provider: cloudprovider.QEMU, provider: cloudprovider.QEMU,
pcrs: newTestPCRs(), pcrs: newTestPCRs(),
wantVs: qemu.NewValidator(newTestPCRs(), nil, nil), wantVs: qemu.NewValidator(newTestPCRs(), nil),
}, },
} }
@ -185,37 +177,37 @@ func TestValidatorV(t *testing.T) {
} }
func TestValidatorUpdateInitPCRs(t *testing.T) { func TestValidatorUpdateInitPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000") zero := measurements.WithAllBytes(0x00, true)
one := []byte("11111111111111111111111111111111") one := measurements.WithAllBytes(0x11, true)
one64 := base64.StdEncoding.EncodeToString(one) one64 := base64.StdEncoding.EncodeToString(one.Expected[:])
oneHash := sha256.Sum256(one) oneHash := sha256.Sum256(one.Expected[:])
pcrZeroUpdatedOne := sha256.Sum256(append(zero, oneHash[:]...)) pcrZeroUpdatedOne := sha256.Sum256(append(zero.Expected[:], oneHash[:]...))
newTestPCRs := func() map[uint32][]byte { newTestPCRs := func() measurements.M {
return map[uint32][]byte{ return measurements.M{
0: zero, 0: measurements.WithAllBytes(0x00, true),
1: zero, 1: measurements.WithAllBytes(0x00, true),
2: zero, 2: measurements.WithAllBytes(0x00, true),
3: zero, 3: measurements.WithAllBytes(0x00, true),
4: zero, 4: measurements.WithAllBytes(0x00, true),
5: zero, 5: measurements.WithAllBytes(0x00, true),
6: zero, 6: measurements.WithAllBytes(0x00, true),
7: zero, 7: measurements.WithAllBytes(0x00, true),
8: zero, 8: measurements.WithAllBytes(0x00, true),
9: zero, 9: measurements.WithAllBytes(0x00, true),
10: zero, 10: measurements.WithAllBytes(0x00, true),
11: zero, 11: measurements.WithAllBytes(0x00, true),
12: zero, 12: measurements.WithAllBytes(0x00, true),
13: zero, 13: measurements.WithAllBytes(0x00, true),
14: zero, 14: measurements.WithAllBytes(0x00, true),
15: zero, 15: measurements.WithAllBytes(0x00, true),
16: zero, 16: measurements.WithAllBytes(0x00, true),
17: one, 17: measurements.WithAllBytes(0x11, true),
18: one, 18: measurements.WithAllBytes(0x11, true),
19: one, 19: measurements.WithAllBytes(0x11, true),
20: one, 20: measurements.WithAllBytes(0x11, true),
21: one, 21: measurements.WithAllBytes(0x11, true),
22: one, 22: measurements.WithAllBytes(0x11, true),
23: zero, 23: measurements.WithAllBytes(0x00, true),
} }
} }
@ -285,25 +277,25 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
assert.NoError(err) assert.NoError(err)
for i := 0; i < len(tc.pcrs); i++ { for i := 0; i < len(tc.pcrs); i++ {
switch { switch {
case i == int(vtpm.PCRIndexClusterID) && tc.clusterID == "": case i == int(measurements.PCRIndexClusterID) && tc.clusterID == "":
// should be deleted // should be deleted
_, ok := validators.pcrs[uint32(i)] _, ok := validators.pcrs[uint32(i)]
assert.False(ok) assert.False(ok)
case i == int(vtpm.PCRIndexClusterID): case i == int(measurements.PCRIndexClusterID):
pcr, ok := validators.pcrs[uint32(i)] pcr, ok := validators.pcrs[uint32(i)]
assert.True(ok) assert.True(ok)
assert.Equal(pcrZeroUpdatedOne[:], pcr) assert.Equal(pcrZeroUpdatedOne, pcr.Expected)
case i == int(vtpm.PCRIndexOwnerID) && tc.ownerID == "": case i == int(measurements.PCRIndexOwnerID) && tc.ownerID == "":
// should be deleted // should be deleted
_, ok := validators.pcrs[uint32(i)] _, ok := validators.pcrs[uint32(i)]
assert.False(ok) assert.False(ok)
case i == int(vtpm.PCRIndexOwnerID): case i == int(measurements.PCRIndexOwnerID):
pcr, ok := validators.pcrs[uint32(i)] pcr, ok := validators.pcrs[uint32(i)]
assert.True(ok) assert.True(ok)
assert.Equal(pcrZeroUpdatedOne[:], pcr) assert.Equal(pcrZeroUpdatedOne, pcr.Expected)
default: default:
if i >= 17 && i <= 22 { if i >= 17 && i <= 22 {
@ -320,8 +312,8 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
func TestUpdatePCR(t *testing.T) { func TestUpdatePCR(t *testing.T) {
emptyMap := measurements.M{} emptyMap := measurements.M{}
defaultMap := measurements.M{ defaultMap := measurements.M{
0: measurements.PCRWithAllBytes(0xAA), 0: measurements.WithAllBytes(0xAA, false),
1: measurements.PCRWithAllBytes(0xBB), 1: measurements.WithAllBytes(0xBB, false),
} }
testCases := map[string]struct { testCases := map[string]struct {
@ -359,6 +351,20 @@ func TestUpdatePCR(t *testing.T) {
wantEntries: len(defaultMap) + 1, wantEntries: len(defaultMap) + 1,
wantErr: false, wantErr: false,
}, },
"hex input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: hex.EncodeToString([]byte("Constellation")),
wantEntries: 1,
wantErr: false,
},
"hex input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: hex.EncodeToString([]byte("Constellation")),
wantEntries: len(defaultMap) + 1,
wantErr: false,
},
"unencoded input, empty map": { "unencoded input, empty map": {
pcrMap: emptyMap, pcrMap: emptyMap,
pcrIndex: 10, pcrIndex: 10,
@ -403,9 +409,6 @@ func TestUpdatePCR(t *testing.T) {
assert.NoError(err) assert.NoError(err)
} }
assert.Len(pcrs, tc.wantEntries) assert.Len(pcrs, tc.wantEntries)
for _, v := range pcrs {
assert.Len(v, 32)
}
}) })
} }
} }

View file

@ -137,7 +137,7 @@ func (f *fetchMeasurementsFlags) updateURLs(ctx context.Context, conf *config.Co
if f.measurementsURL == nil { if f.measurementsURL == nil {
// TODO(AB#2644): resolve image version to reference // TODO(AB#2644): resolve image version to reference
parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.yaml") parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.json")
if err != nil { if err != nil {
return err return err
} }
@ -145,7 +145,7 @@ func (f *fetchMeasurementsFlags) updateURLs(ctx context.Context, conf *config.Co
} }
if f.signatureURL == nil { if f.signatureURL == nil {
parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.yaml.sig") parsedURL, err := url.Parse(constants.S3PublicBucket + imageRef + "/measurements.json.sig")
if err != nil { if err != nil {
return err return err
} }

View file

@ -109,17 +109,17 @@ func TestUpdateURLs(t *testing.T) {
}, },
}, },
flags: &fetchMeasurementsFlags{}, flags: &fetchMeasurementsFlags{},
wantMeasurementsURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.yaml", wantMeasurementsURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.json",
wantMeasurementsSigURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.yaml.sig", wantMeasurementsSigURL: constants.S3PublicBucket + "some/image/path/image-123456/measurements.json.sig",
}, },
"both set by user": { "both set by user": {
conf: &config.Config{}, conf: &config.Config{},
flags: &fetchMeasurementsFlags{ flags: &fetchMeasurementsFlags{
measurementsURL: urlMustParse("get.my/measurements.yaml"), measurementsURL: urlMustParse("get.my/measurements.json"),
signatureURL: urlMustParse("get.my/measurements.yaml.sig"), signatureURL: urlMustParse("get.my/measurements.json.sig"),
}, },
wantMeasurementsURL: "get.my/measurements.yaml", wantMeasurementsURL: "get.my/measurements.json",
wantMeasurementsSigURL: "get.my/measurements.yaml.sig", wantMeasurementsSigURL: "get.my/measurements.json.sig",
}, },
} }
@ -164,14 +164,14 @@ func TestConfigFetchMeasurements(t *testing.T) {
signature := "MEUCIFdJ5dH6HDywxQWTUh9Bw77wMrq0mNCUjMQGYP+6QsVmAiEAmazj/L7rFGA4/Gz8y+kI5h5E5cDgc3brihvXBKF6qZA=" signature := "MEUCIFdJ5dH6HDywxQWTUh9Bw77wMrq0mNCUjMQGYP+6QsVmAiEAmazj/L7rFGA4/Gz8y+kI5h5E5cDgc3brihvXBKF6qZA="
client := newTestClient(func(req *http.Request) *http.Response { client := newTestClient(func(req *http.Request) *http.Response {
if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.yaml" { if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.json" {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(measurements)), Body: io.NopCloser(bytes.NewBufferString(measurements)),
Header: make(http.Header), Header: make(http.Header),
} }
} }
if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.yaml.sig" { if req.URL.String() == "https://public-edgeless-constellation.s3.us-east-2.amazonaws.com/someImage/measurements.json.sig" {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewBufferString(signature)), Body: io.NopCloser(bytes.NewBufferString(signature)),

View file

@ -8,7 +8,7 @@ package cmd
import ( import (
"context" "context"
"encoding/base64" "encoding/hex"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -134,7 +134,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
KubernetesVersion: conf.KubernetesVersion, KubernetesVersion: conf.KubernetesVersion,
KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToProto(), KubernetesComponents: versions.VersionConfigs[k8sVersion].KubernetesComponents.ToProto(),
HelmDeployments: helmDeployments, HelmDeployments: helmDeployments,
EnforcedPcrs: conf.GetEnforcedPCRs(), EnforcedPcrs: conf.EnforcedPCRs(),
EnforceIdkeydigest: conf.EnforcesIDKeyDigest(), EnforceIdkeydigest: conf.EnforcesIDKeyDigest(),
ConformanceMode: flags.conformance, ConformanceMode: flags.conformance,
} }
@ -190,8 +190,8 @@ func (d *initDoer) Do(ctx context.Context) error {
func writeOutput(idFile clusterid.File, resp *initproto.InitResponse, wr io.Writer, fileHandler file.Handler) error { func writeOutput(idFile clusterid.File, resp *initproto.InitResponse, wr io.Writer, fileHandler file.Handler) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n") fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
ownerID := base64.StdEncoding.EncodeToString(resp.OwnerId) ownerID := hex.EncodeToString(resp.OwnerId)
clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) clusterID := hex.EncodeToString(resp.ClusterId)
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0) tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
// writeRow(tw, "Constellation cluster's owner identifier", ownerID) // writeRow(tw, "Constellation cluster's owner identifier", ownerID)

View file

@ -9,7 +9,7 @@ package cmd
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/base64" "encoding/hex"
"encoding/json" "encoding/json"
"errors" "errors"
"net" "net"
@ -101,16 +101,6 @@ func TestInitialize(t *testing.T) {
initServerAPI: &stubInitServer{initErr: someErr}, initServerAPI: &stubInitServer{initErr: someErr},
wantErr: true, wantErr: true,
}, },
"fail missing enforced PCR": {
provider: cloudprovider.GCP,
idFile: &clusterid.File{IP: "192.0.2.1"},
configMutator: func(c *config.Config) {
c.Provider.GCP.EnforcedMeasurements = append(c.Provider.GCP.EnforcedMeasurements, 10)
},
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp},
wantErr: true,
},
} }
for name, tc := range testCases { for name, tc := range testCases {
@ -174,7 +164,7 @@ func TestInitialize(t *testing.T) {
} }
require.NoError(err) require.NoError(err)
// assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) // assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID")))
assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("clusterID"))) assert.Contains(out.String(), hex.EncodeToString([]byte("clusterID")))
var secret masterSecret var secret masterSecret
assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret)) assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret))
assert.NotEmpty(secret.Key) assert.NotEmpty(secret.Key)
@ -192,8 +182,8 @@ func TestWriteOutput(t *testing.T) {
Kubeconfig: []byte("kubeconfig"), Kubeconfig: []byte("kubeconfig"),
} }
ownerID := base64.StdEncoding.EncodeToString(resp.OwnerId) ownerID := hex.EncodeToString(resp.OwnerId)
clusterID := base64.StdEncoding.EncodeToString(resp.ClusterId) clusterID := hex.EncodeToString(resp.ClusterId)
expectedIDFile := clusterid.File{ expectedIDFile := clusterid.File{
ClusterID: clusterID, ClusterID: clusterID,
@ -361,11 +351,11 @@ func TestAttestation(t *testing.T) {
issuer := &testIssuer{ issuer := &testIssuer{
Getter: oid.QEMU{}, Getter: oid.QEMU{},
pcrs: measurements.M{ pcrs: map[uint32][]byte{
0: measurements.PCRWithAllBytes(0xFF), 0: bytes.Repeat([]byte{0xFF}, 32),
1: measurements.PCRWithAllBytes(0xFF), 1: bytes.Repeat([]byte{0xFF}, 32),
2: measurements.PCRWithAllBytes(0xFF), 2: bytes.Repeat([]byte{0xFF}, 32),
3: measurements.PCRWithAllBytes(0xFF), 3: bytes.Repeat([]byte{0xFF}, 32),
}, },
} }
serverCreds := atlscredentials.New(issuer, nil) serverCreds := atlscredentials.New(issuer, nil)
@ -390,13 +380,13 @@ func TestAttestation(t *testing.T) {
cfg := config.Default() cfg := config.Default()
cfg.Image = "image" cfg.Image = "image"
cfg.RemoveProviderExcept(cloudprovider.QEMU) cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00) cfg.Provider.QEMU.Measurements[0] = measurements.WithAllBytes(0x00, false)
cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11) cfg.Provider.QEMU.Measurements[1] = measurements.WithAllBytes(0x11, false)
cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22) cfg.Provider.QEMU.Measurements[2] = measurements.WithAllBytes(0x22, false)
cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33) cfg.Provider.QEMU.Measurements[3] = measurements.WithAllBytes(0x33, false)
cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44) cfg.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false)
cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99) cfg.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x99, false)
cfg.Provider.QEMU.Measurements[12] = measurements.PCRWithAllBytes(0xcc) cfg.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone)) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background() ctx := context.Background()
@ -418,14 +408,14 @@ type testValidator struct {
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var attestation struct { var attestation struct {
UserData []byte UserData []byte
PCRs measurements.M PCRs map[uint32][]byte
} }
if err := json.Unmarshal(attDoc, &attestation); err != nil { if err := json.Unmarshal(attDoc, &attestation); err != nil {
return nil, err return nil, err
} }
for k, pcr := range v.pcrs { for k, pcr := range v.pcrs {
if !bytes.Equal(attestation.PCRs[k], pcr) { if !bytes.Equal(attestation.PCRs[k], pcr.Expected[:]) {
return nil, errors.New("invalid PCR value") return nil, errors.New("invalid PCR value")
} }
} }
@ -434,14 +424,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
type testIssuer struct { type testIssuer struct {
oid.Getter oid.Getter
pcrs measurements.M pcrs map[uint32][]byte
} }
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal( return json.Marshal(
struct { struct {
UserData []byte UserData []byte
PCRs measurements.M PCRs map[uint32][]byte
}{ }{
UserData: userData, UserData: userData,
PCRs: i.pcrs, PCRs: i.pcrs,
@ -474,21 +464,21 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.ResourceGroup = "test-resource-group" conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab" conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
conf.Provider.Azure.ClientSecretValue = "test-client-secret" conf.Provider.Azure.ClientSecretValue = "test-client-secret"
conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44) conf.Provider.Azure.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11) conf.Provider.Azure.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.Azure.Measurements[12] = measurements.PCRWithAllBytes(0xcc) conf.Provider.Azure.Measurements[12] = measurements.WithAllBytes(0xcc, false)
case cloudprovider.GCP: case cloudprovider.GCP:
conf.Provider.GCP.Region = "test-region" conf.Provider.GCP.Region = "test-region"
conf.Provider.GCP.Project = "test-project" conf.Provider.GCP.Project = "test-project"
conf.Provider.GCP.Zone = "test-zone" conf.Provider.GCP.Zone = "test-zone"
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path" conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44) conf.Provider.GCP.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11) conf.Provider.GCP.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.GCP.Measurements[12] = measurements.PCRWithAllBytes(0xcc) conf.Provider.GCP.Measurements[12] = measurements.WithAllBytes(0xcc, false)
case cloudprovider.QEMU: case cloudprovider.QEMU:
conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44) conf.Provider.QEMU.Measurements[4] = measurements.WithAllBytes(0x44, false)
conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11) conf.Provider.QEMU.Measurements[9] = measurements.WithAllBytes(0x11, false)
conf.Provider.QEMU.Measurements[12] = measurements.PCRWithAllBytes(0xcc) conf.Provider.QEMU.Measurements[12] = measurements.WithAllBytes(0xcc, false)
} }
conf.RemoveProviderExcept(csp) conf.RemoveProviderExcept(csp)

View file

@ -181,12 +181,12 @@ func getCompatibleImages(csp cloudprovider.Provider, currentVersion string, imag
// getCompatibleImageMeasurements retrieves the expected measurements for each image. // getCompatibleImageMeasurements retrieves the expected measurements for each image.
func getCompatibleImageMeasurements(ctx context.Context, cmd *cobra.Command, client *http.Client, rekor rekorVerifier, pubK []byte, images map[string]config.UpgradeConfig) error { func getCompatibleImageMeasurements(ctx context.Context, cmd *cobra.Command, client *http.Client, rekor rekorVerifier, pubK []byte, images map[string]config.UpgradeConfig) error {
for idx, img := range images { for idx, img := range images {
measurementsURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.yaml") measurementsURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.json")
if err != nil { if err != nil {
return err return err
} }
signatureURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.yaml.sig") signatureURL, err := url.Parse(constants.S3PublicBucket + strings.ToLower(img.Image) + "/measurements.json.sig")
if err != nil { if err != nil {
return err return err
} }

View file

@ -25,7 +25,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v3"
"k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured"
) )
@ -248,14 +248,14 @@ func TestGetCompatibleImageMeasurements(t *testing.T) {
} }
client := newTestClient(func(req *http.Request) *http.Response { client := newTestClient(func(req *http.Request) *http.Response {
if strings.HasSuffix(req.URL.String(), "/measurements.yaml") { if strings.HasSuffix(req.URL.String(), "/measurements.json") {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")), Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")),
Header: make(http.Header), Header: make(http.Header),
} }
} }
if strings.HasSuffix(req.URL.String(), "/measurements.yaml.sig") { if strings.HasSuffix(req.URL.String(), "/measurements.json.sig") {
return &http.Response{ return &http.Response{
StatusCode: http.StatusOK, StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")), Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")),
@ -470,14 +470,14 @@ func TestUpgradePlan(t *testing.T) {
Header: make(http.Header), Header: make(http.Header),
} }
} }
if strings.HasSuffix(req.URL.String(), "/measurements.yaml") { if strings.HasSuffix(req.URL.String(), "/measurements.json") {
return &http.Response{ return &http.Response{
StatusCode: tc.measurementsFetchStatus, StatusCode: tc.measurementsFetchStatus,
Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")), Body: io.NopCloser(strings.NewReader("0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n")),
Header: make(http.Header), Header: make(http.Header),
} }
} }
if strings.HasSuffix(req.URL.String(), "/measurements.yaml.sig") { if strings.HasSuffix(req.URL.String(), "/measurements.json.sig") {
return &http.Response{ return &http.Response{
StatusCode: tc.measurementsFetchStatus, StatusCode: tc.measurementsFetchStatus,
Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")), Body: io.NopCloser(strings.NewReader("MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=")),

View file

@ -5,7 +5,6 @@ metadata:
namespace: {{ .Release.Namespace }} namespace: {{ .Release.Namespace }}
data: data:
# mustToJson is required so the json-strings passed from go are of type string in the rendered yaml. # mustToJson is required so the json-strings passed from go are of type string in the rendered yaml.
enforcedPCRs: {{ .Values.enforcedPCRs | mustToJson }}
measurements: {{ .Values.measurements | mustToJson }} measurements: {{ .Values.measurements | mustToJson }}
{{- if eq .Values.csp "Azure" }} {{- if eq .Values.csp "Azure" }}
# ConfigMap.data is of type map[string]string. quote will not quote a quoted string. # ConfigMap.data is of type map[string]string. quote will not quote a quoted string.

View file

@ -5,15 +5,10 @@
"description": "CSP to which the chart is deployed.", "description": "CSP to which the chart is deployed.",
"enum": ["Azure", "GCP", "AWS", "QEMU"] "enum": ["Azure", "GCP", "AWS", "QEMU"]
}, },
"enforcedPCRs": {
"description": "JSON-string to describe the enforced PCRs.",
"type": "string",
"examples": ["[1, 15]"]
},
"measurements": { "measurements": {
"description": "JSON-string to describe the expected measurements.", "description": "JSON-string to describe the expected measurements.",
"type": "string", "type": "string",
"examples": ["{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"] "examples": ["{'1':{'expected':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','warnOnly':true},'15':{'expected':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=','warnOnly':true}}"]
}, },
"enforceIdKeyDigest": { "enforceIdKeyDigest": {
"description": "Whether or not idkeydigest should be enforced during attestation on azure.", "description": "Whether or not idkeydigest should be enforced during attestation on azure.",
@ -37,7 +32,6 @@
}, },
"required": [ "required": [
"csp", "csp",
"enforcedPCRs",
"measurements", "measurements",
"measurementSalt", "measurementSalt",
"image" "image"

View file

@ -76,9 +76,7 @@ func New(csp cloudprovider.Provider, k8sVersion versions.ValidK8sVersion) *Chart
// Load the embedded helm charts. // Load the embedded helm charts.
func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSecret, salt []byte) ([]byte, error) { func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSecret, salt []byte) ([]byte, error) {
csp := config.GetProvider() ciliumRelease, err := i.loadCilium(config.GetProvider(), conformanceMode)
ciliumRelease, err := i.loadCilium(csp, conformanceMode)
if err != nil { if err != nil {
return nil, fmt.Errorf("loading cilium: %w", err) return nil, fmt.Errorf("loading cilium: %w", err)
} }
@ -88,7 +86,7 @@ func (i *ChartLoader) Load(config *config.Config, conformanceMode bool, masterSe
return nil, fmt.Errorf("loading cilium: %w", err) return nil, fmt.Errorf("loading cilium: %w", err)
} }
operatorRelease, err := i.loadOperators(csp) operatorRelease, err := i.loadOperators(config.GetProvider())
if err != nil { if err != nil {
return nil, fmt.Errorf("loading operators: %w", err) return nil, fmt.Errorf("loading operators: %w", err)
} }
@ -350,11 +348,6 @@ func (i *ChartLoader) loadConstellationServicesHelper(config *config.Config, mas
return nil, nil, fmt.Errorf("loading constellation-services chart: %w", err) return nil, nil, fmt.Errorf("loading constellation-services chart: %w", err)
} }
enforcedPCRsJSON, err := json.Marshal(config.GetEnforcedPCRs())
if err != nil {
return nil, nil, fmt.Errorf("marshaling enforcedPCRs: %w", err)
}
csp := config.GetProvider() csp := config.GetProvider()
values := map[string]any{ values := map[string]any{
"global": map[string]any{ "global": map[string]any{
@ -375,7 +368,6 @@ func (i *ChartLoader) loadConstellationServicesHelper(config *config.Config, mas
}, },
"join-service": map[string]any{ "join-service": map[string]any{
"csp": csp.String(), "csp": csp.String(),
"enforcedPCRs": string(enforcedPCRsJSON),
"image": i.joinServiceImage, "image": i.joinServiceImage,
}, },
"ccm": map[string]any{ "ccm": map[string]any{

View file

@ -15,6 +15,7 @@ import (
"path" "path"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/deploy/helm" "github.com/edgelesssys/constellation/v2/internal/deploy/helm"
@ -57,7 +58,6 @@ func TestConstellationServices(t *testing.T) {
"GCP": { "GCP": {
config: &config.Config{Provider: config.ProviderConfig{GCP: &config.GCPConfig{ config: &config.Config{Provider: config.ProviderConfig{GCP: &config.GCPConfig{
DeployCSIDriver: func() *bool { b := true; return &b }(), DeployCSIDriver: func() *bool { b := true; return &b }(),
EnforcedMeasurements: []uint32{1, 11},
}}}, }}},
enforceIDKeyDigest: false, enforceIDKeyDigest: false,
valuesModifier: prepareGCPValues, valuesModifier: prepareGCPValues,
@ -66,7 +66,6 @@ func TestConstellationServices(t *testing.T) {
"Azure": { "Azure": {
config: &config.Config{Provider: config.ProviderConfig{Azure: &config.AzureConfig{ config: &config.Config{Provider: config.ProviderConfig{Azure: &config.AzureConfig{
DeployCSIDriver: func() *bool { b := true; return &b }(), DeployCSIDriver: func() *bool { b := true; return &b }(),
EnforcedMeasurements: []uint32{1, 11},
EnforceIDKeyDigest: func() *bool { b := true; return &b }(), EnforceIDKeyDigest: func() *bool { b := true; return &b }(),
}}}, }}},
enforceIDKeyDigest: true, enforceIDKeyDigest: true,
@ -75,9 +74,7 @@ func TestConstellationServices(t *testing.T) {
cnmImage: "cnmImageForAzure", cnmImage: "cnmImageForAzure",
}, },
"QEMU": { "QEMU": {
config: &config.Config{Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{ config: &config.Config{Provider: config.ProviderConfig{QEMU: &config.QEMUConfig{}}},
EnforcedMeasurements: []uint32{1, 11},
}}},
enforceIDKeyDigest: false, enforceIDKeyDigest: false,
valuesModifier: prepareQEMUValues, valuesModifier: prepareQEMUValues,
}, },
@ -88,7 +85,14 @@ func TestConstellationServices(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
chartLoader := ChartLoader{joinServiceImage: "joinServiceImage", kmsImage: "kmsImage", ccmImage: tc.ccmImage, cnmImage: tc.cnmImage, autoscalerImage: "autoscalerImage", verificationServiceImage: "verificationImage"} chartLoader := ChartLoader{
joinServiceImage: "joinServiceImage",
kmsImage: "kmsImage",
ccmImage: tc.ccmImage,
cnmImage: tc.cnmImage,
autoscalerImage: "autoscalerImage",
verificationServiceImage: "verificationImage",
}
chart, values, err := chartLoader.loadConstellationServicesHelper(tc.config, []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")) chart, values, err := chartLoader.loadConstellationServicesHelper(tc.config, []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"), []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"))
require.NoError(err) require.NoError(err)
@ -197,7 +201,15 @@ func prepareGCPValues(values map[string]any) error {
if !ok { if !ok {
return errors.New("missing 'join-service' key") return errors.New("missing 'join-service' key")
} }
joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
m := measurements.M{
1: measurements.WithAllBytes(0xAA, false),
}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
joinVals["measurements"] = string(mJSON)
joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
ccmVals, ok := values["ccm"].(map[string]any) ccmVals, ok := values["ccm"].(map[string]any)
@ -269,7 +281,12 @@ func prepareAzureValues(values map[string]any) error {
return errors.New("missing 'join-service' key") return errors.New("missing 'join-service' key")
} }
joinVals["idkeydigest"] = "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad" joinVals["idkeydigest"] = "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad"
joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" m := measurements.M{1: measurements.WithAllBytes(0xAA, false)}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
joinVals["measurements"] = string(mJSON)
joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
ccmVals, ok := values["ccm"].(map[string]any) ccmVals, ok := values["ccm"].(map[string]any)
@ -311,7 +328,12 @@ func prepareQEMUValues(values map[string]any) error {
if !ok { if !ok {
return errors.New("missing 'join-service' key") return errors.New("missing 'join-service' key")
} }
joinVals["measurements"] = "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}" m := measurements.M{1: measurements.WithAllBytes(0xAA, false)}
mJSON, err := json.Marshal(m)
if err != nil {
return err
}
joinVals["measurements"] = string(mJSON)
joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" joinVals["measurementSalt"] = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
verificationVals, ok := values["verification-service"].(map[string]any) verificationVals, ok := values["verification-service"].(map[string]any)

View file

@ -4,8 +4,7 @@ metadata:
name: join-config name: join-config
namespace: testNamespace namespace: testNamespace
data: data:
enforcedPCRs: "[1,11]" measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}"
measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
enforceIdKeyDigest: "true" enforceIdKeyDigest: "true"
idkeydigest: "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad" idkeydigest: "baaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaadbaaaaaad"
binaryData: binaryData:

View file

@ -4,7 +4,6 @@ metadata:
name: join-config name: join-config
namespace: testNamespace namespace: testNamespace
data: data:
enforcedPCRs: "[1,11]" measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}"
measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
binaryData: binaryData:
measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA

View file

@ -4,7 +4,6 @@ metadata:
name: join-config name: join-config
namespace: testNamespace namespace: testNamespace
data: data:
enforcedPCRs: "[1,11]" measurements: "{\"1\":{\"expected\":\"aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\",\"warnOnly\":false}}"
measurements: "{'1':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA','15':'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA='}"
binaryData: binaryData:
measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA measurementSalt: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA

View file

@ -171,7 +171,7 @@ func main() {
// We can use this to calculate the PCRs of the image locally. // We can use this to calculate the PCRs of the image locally.
func exportPCRs() error { func exportPCRs() error {
// get TPM state // get TPM state
pcrs, err := vtpm.GetSelectedPCRs(vtpm.OpenVTPM, tpmClient.FullPcrSel(tpm2.AlgSHA256)) pcrs, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, tpmClient.FullPcrSel(tpm2.AlgSHA256))
if err != nil { if err != nil {
return err return err
} }

2
go.mod
View file

@ -92,7 +92,6 @@ require (
google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6 google.golang.org/genproto v0.0.0-20221118155620-16455021b5e6
google.golang.org/grpc v1.51.0 google.golang.org/grpc v1.51.0
google.golang.org/protobuf v1.28.1 google.golang.org/protobuf v1.28.1
gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
helm.sh/helm v2.17.0+incompatible helm.sh/helm v2.17.0+incompatible
helm.sh/helm/v3 v3.10.2 helm.sh/helm/v3 v3.10.2
@ -119,6 +118,7 @@ require (
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect
golang.org/x/text v0.4.0 // indirect golang.org/x/text v0.4.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
) )
require ( require (

View file

@ -41,7 +41,6 @@ require (
github.com/go-git/go-git/v5 v5.4.2 github.com/go-git/go-git/v5 v5.4.2
github.com/google/go-tpm-tools v0.3.9 github.com/google/go-tpm-tools v0.3.9
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/spf13/afero v1.9.3
github.com/spf13/cobra v1.6.1 github.com/spf13/cobra v1.6.1
github.com/stretchr/testify v1.8.1 github.com/stretchr/testify v1.8.1
go.uber.org/goleak v1.2.0 go.uber.org/goleak v1.2.0
@ -189,6 +188,7 @@ require (
github.com/sigstore/rekor v1.0.1 // indirect github.com/sigstore/rekor v1.0.1 // indirect
github.com/sigstore/sigstore v1.4.5 // indirect github.com/sigstore/sigstore v1.4.5 // indirect
github.com/sirupsen/logrus v1.9.0 // indirect github.com/sirupsen/logrus v1.9.0 // indirect
github.com/spf13/afero v1.9.3 // indirect
github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/pflag v1.0.5 // indirect
github.com/talos-systems/talos/pkg/machinery v1.2.7 // indirect github.com/talos-systems/talos/pkg/machinery v1.2.7 // indirect
github.com/tent/canonical-json-go v0.0.0-20130607151641-96e4ba3a7613 // indirect github.com/tent/canonical-json-go v0.0.0-20130607151641-96e4ba3a7613 // indirect

View file

@ -24,24 +24,24 @@ import (
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/verify/verifyproto" "github.com/edgelesssys/constellation/v2/verify/verifyproto"
"github.com/spf13/afero"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
var (
coordIP = flag.String("constell-ip", "", "Public IP of the Constellation")
port = flag.String("constell-port", strconv.Itoa(constants.VerifyServiceNodePortGRPC), "NodePort of the Constellation's verification service")
export = flag.String("o", "", "Write PCRs, formatted as Go code, to file")
format = flag.String("format", "json", "Output format: json, yaml (default json)")
quiet = flag.Bool("q", false, "Set to disable output")
timeout = flag.Duration("timeout", 2*time.Minute, "Wait this duration for the verification service to become available")
)
func main() { func main() {
coordIP := flag.String("constell-ip", "", "Public IP of the Constellation")
port := flag.String("constell-port", strconv.Itoa(constants.VerifyServiceNodePortGRPC), "NodePort of the Constellation's verification service")
format := flag.String("format", "json", "Output format: json, yaml (default json)")
quiet := flag.Bool("q", false, "Set to disable output")
timeout := flag.Duration("timeout", 2*time.Minute, "Wait this duration for the verification service to become available")
flag.Parse() flag.Parse()
if *coordIP == "" || *port == "" {
flag.Usage()
os.Exit(1)
}
addr := net.JoinHostPort(*coordIP, *port) addr := net.JoinHostPort(*coordIP, *port)
ctx, cancel := context.WithTimeout(context.Background(), *timeout) ctx, cancel := context.WithTimeout(context.Background(), *timeout)
defer cancel() defer cancel()
@ -51,18 +51,13 @@ func main() {
log.Fatal(err) log.Fatal(err)
} }
pcrs, err := validatePCRAttDoc(attDocRaw) measurements, err := validatePCRAttDoc(attDocRaw)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
if !*quiet { if !*quiet {
if err := printPCRs(os.Stdout, pcrs, *format); err != nil { if err := printPCRs(os.Stdout, measurements, *format); err != nil {
log.Fatal(err)
}
}
if *export != "" {
if err := exportToFile(*export, pcrs, &afero.Afero{Fs: afero.NewOsFs()}); err != nil {
log.Fatal(err) log.Fatal(err)
} }
} }
@ -104,16 +99,23 @@ func validatePCRAttDoc(attDocRaw []byte) (measurements.M, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
m := measurements.M{}
for idx, pcr := range attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs { for idx, pcr := range attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs {
if len(pcr) != 32 { if len(pcr) != 32 {
return nil, fmt.Errorf("incomplete PCR at index: %d", idx) return nil, fmt.Errorf("incomplete PCR at index: %d", idx)
} }
m[idx] = measurements.Measurement{
Expected: *(*[32]byte)(pcr),
WarnOnly: true,
} }
return attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, nil }
return m, nil
} }
// printPCRs formates and prints PCRs to the given writer. // printPCRs formats and prints PCRs to the given writer.
// format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'. // format can be one of 'json' or 'yaml'. If it doesn't match defaults to 'json'.
func printPCRs(w io.Writer, pcrs measurements.M, format string) error { func printPCRs(w io.Writer, pcrs measurements.M, format string) error {
switch format { switch format {
case "json": case "json":
@ -142,24 +144,3 @@ func printPCRsJSON(w io.Writer, pcrs measurements.M) error {
fmt.Fprintf(w, "%s", string(pcrJSON)) fmt.Fprintf(w, "%s", string(pcrJSON))
return nil return nil
} }
// exportToFile writes pcrs to a file, formatted to be valid Go code.
// Validity of the PCR map is not checked, and should be handled by the caller.
func exportToFile(path string, pcrs measurements.M, fs *afero.Afero) error {
goCode := `package pcrs
var pcrs = map[uint32][]byte{%s
}
`
pcrsFormatted := ""
for i := 0; i < len(pcrs); i++ {
pcrHex := fmt.Sprintf("%#02X", pcrs[uint32(i)][0])
for j := 1; j < len(pcrs[uint32(i)]); j++ {
pcrHex = fmt.Sprintf("%s, %#02X", pcrHex, pcrs[uint32(i)][j])
}
pcrsFormatted = pcrsFormatted + fmt.Sprintf("\n\t%d: {%s},", i, pcrHex)
}
return fs.WriteFile(path, []byte(fmt.Sprintf(goCode, pcrsFormatted)), 0o644)
}

View file

@ -8,7 +8,7 @@ package main
import ( import (
"bytes" "bytes"
"encoding/base64" "encoding/hex"
"encoding/json" "encoding/json"
"fmt" "fmt"
"testing" "testing"
@ -17,67 +17,13 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm" "github.com/google/go-tpm-tools/proto/tpm"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, goleak.VerifyTestMain(m)
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestExportToFile(t *testing.T) {
testCases := map[string]struct {
pcrs measurements.M
fs *afero.Afero
wantErr bool
}{
"file not writeable": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
fs: &afero.Afero{Fs: afero.NewReadOnlyFs(afero.NewMemMapFs())},
wantErr: true,
},
"file writeable": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
fs: &afero.Afero{Fs: afero.NewMemMapFs()},
wantErr: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
path := "test-file"
err := exportToFile(path, tc.pcrs, tc.fs)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
content, err := tc.fs.ReadFile(path)
require.NoError(err)
for _, pcr := range tc.pcrs {
for _, register := range pcr {
assert.Contains(string(content), fmt.Sprintf("%#02X", register))
}
}
}
})
}
} }
func TestValidatePCRAttDoc(t *testing.T) { func TestValidatePCRAttDoc(t *testing.T) {
@ -106,7 +52,7 @@ func TestValidatePCRAttDoc(t *testing.T) {
{ {
Pcrs: &tpm.PCRs{ Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256, Hash: tpm.HashAlgo_SHA256,
Pcrs: measurements.M{ Pcrs: map[uint32][]byte{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
}, },
}, },
@ -123,8 +69,8 @@ func TestValidatePCRAttDoc(t *testing.T) {
{ {
Pcrs: &tpm.PCRs{ Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256, Hash: tpm.HashAlgo_SHA256,
Pcrs: measurements.M{ Pcrs: map[uint32][]byte{
0: measurements.PCRWithAllBytes(0xAA), 0: bytes.Repeat([]byte{0xAA}, 32),
}, },
}, },
}, },
@ -150,7 +96,10 @@ func TestValidatePCRAttDoc(t *testing.T) {
require.NoError(json.Unmarshal(tc.attDocRaw, &attDoc)) require.NoError(json.Unmarshal(tc.attDocRaw, &attDoc))
qIdx, err := vtpm.GetSHA256QuoteIndex(attDoc.Attestation.Quotes) qIdx, err := vtpm.GetSHA256QuoteIndex(attDoc.Attestation.Quotes)
require.NoError(err) require.NoError(err)
assert.EqualValues(attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs, pcrs)
for pcrIdx, pcrVal := range pcrs {
assert.Equal(pcrVal.Expected[:], attDoc.Attestation.Quotes[qIdx].Pcrs.Pcrs[pcrIdx])
}
} }
}) })
} }
@ -164,31 +113,15 @@ func mustMarshalAttDoc(t *testing.T, attDoc vtpm.AttestationDocument) []byte {
func TestPrintPCRs(t *testing.T) { func TestPrintPCRs(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
pcrs measurements.M
format string format string
}{ }{
"json": { "json": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
format: "json", format: "json",
}, },
"empty format": { "empty format": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
format: "", format: "",
}, },
"yaml": { "yaml": {
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
},
format: "yaml", format: "yaml",
}, },
} }
@ -197,13 +130,19 @@ func TestPrintPCRs(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
pcrs := measurements.M{
0: measurements.WithAllBytes(0xAA, true),
1: measurements.WithAllBytes(0xBB, true),
2: measurements.WithAllBytes(0xCC, true),
}
var out bytes.Buffer var out bytes.Buffer
err := printPCRs(&out, tc.pcrs, tc.format) err := printPCRs(&out, pcrs, tc.format)
assert.NoError(err) assert.NoError(err)
for idx, pcr := range tc.pcrs { for idx, pcr := range pcrs {
assert.Contains(out.String(), fmt.Sprintf("%d", idx)) assert.Contains(out.String(), fmt.Sprintf("%d", idx))
assert.Contains(out.String(), base64.StdEncoding.EncodeToString(pcr)) assert.Contains(out.String(), hex.EncodeToString(pcr.Expected[:]))
} }
}) })
} }

View file

@ -307,12 +307,12 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "192.0.100.1:1234", remoteAddr: "192.0.100.1:1234",
connect: defaultConnect, connect: defaultConnect,
method: http.MethodPost, method: http.MethodPost,
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
}, },
"incorrect method": { "incorrect method": {
remoteAddr: "192.0.100.1:1234", remoteAddr: "192.0.100.1:1234",
connect: defaultConnect, connect: defaultConnect,
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
method: http.MethodGet, method: http.MethodGet,
wantErr: true, wantErr: true,
}, },
@ -321,7 +321,7 @@ func TestExportPCRs(t *testing.T) {
connect: &stubConnect{ connect: &stubConnect{
getNetworkErr: errors.New("error"), getNetworkErr: errors.New("error"),
}, },
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
method: http.MethodPost, method: http.MethodPost,
wantErr: true, wantErr: true,
}, },
@ -336,7 +336,7 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "localhost", remoteAddr: "localhost",
connect: defaultConnect, connect: defaultConnect,
method: http.MethodPost, method: http.MethodPost,
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: measurements.WithAllBytes(0xAA, false)}),
wantErr: true, wantErr: true,
}, },
} }

View file

@ -29,11 +29,10 @@ type Validator struct {
} }
// NewValidator create a new Validator structure and returns it. // NewValidator create a new Validator structure and returns it.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
v := &Validator{} v := &Validator{}
v.Validator = vtpm.NewValidator( v.Validator = vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
getTrustedKey, getTrustedKey,
v.tpmEnabled, v.tpmEnabled,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -42,11 +42,10 @@ type Validator struct {
} }
// NewValidator initializes a new Azure validator with the provided PCR values. // NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
getTrustedKey(&azureInstanceInfo{}, idKeyDigest, enforceIDKeyDigest, log), getTrustedKey(&azureInstanceInfo{}, idKeyDigest, enforceIDKeyDigest, log),
validateCVM, validateCVM,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -189,7 +189,7 @@ func TestGetAttestationCert(t *testing.T) {
} }
require.NoError(err) require.NoError(err)
validator := NewValidator(measurements.M{}, []uint32{}, nil) validator := NewValidator(measurements.M{}, nil)
cert, err := x509.ParseCertificate(rootCert.Raw) cert, err := x509.ParseCertificate(rootCert.Raw)
require.NoError(err) require.NoError(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()

View file

@ -33,13 +33,12 @@ type Validator struct {
} }
// NewValidator initializes a new Azure validator with the provided PCR values. // NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
rootPool := x509.NewCertPool() rootPool := x509.NewCertPool()
rootPool.AddCert(ameRoot) rootPool.AddCert(ameRoot)
v := &Validator{roots: rootPool} v := &Validator{roots: rootPool}
v.Validator = vtpm.NewValidator( v.Validator = vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
v.verifyAttestationKey, v.verifyAttestationKey,
validateVM, validateVM,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -35,11 +35,10 @@ type Validator struct {
} }
// NewValidator initializes a new GCP validator with the provided PCR values. // NewValidator initializes a new GCP validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
trustedKeyFromGCEAPI(newInstanceClient), trustedKeyFromGCEAPI(newInstanceClient),
gceNonHostInfoEvent, gceNonHostInfoEvent,
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -12,61 +12,31 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/sigstore" "github.com/edgelesssys/constellation/v2/internal/sigstore"
"gopkg.in/yaml.v2" "github.com/google/go-tpm/tpmutil"
"go.uber.org/multierr"
"gopkg.in/yaml.v3"
)
const (
// PCRIndexClusterID is a PCR we extend to mark the node as initialized.
// The value used to extend is a random generated 32 Byte value.
PCRIndexClusterID = tpmutil.Handle(15)
// PCRIndexOwnerID is a PCR we extend to mark the node as initialized.
// The value used to extend is derived from Constellation's master key.
// TODO: move to stable, non-debug PCR before use.
PCRIndexOwnerID = tpmutil.Handle(16)
) )
// M are Platform Configuration Register (PCR) values that make up the Measurements. // M are Platform Configuration Register (PCR) values that make up the Measurements.
type M map[uint32][]byte type M map[uint32]Measurement
// PCRWithAllBytes returns a PCR value where all 32 bytes are set to b.
func PCRWithAllBytes(b byte) []byte {
return bytes.Repeat([]byte{b}, 32)
}
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider) M {
switch provider {
case cloudprovider.AWS:
return M{
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.Azure:
return M{
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.GCP:
return M{
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.QEMU:
return M{
8: PCRWithAllBytes(0x00),
11: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
default:
return nil
}
}
// FetchAndVerify fetches measurement and signature files via provided URLs, // FetchAndVerify fetches measurement and signature files via provided URLs,
// using client for download. The publicKey is used to verify the measurements. // using client for download. The publicKey is used to verify the measurements.
@ -83,8 +53,14 @@ func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurement
if err := sigstore.VerifySignature(measurements, signature, publicKey); err != nil { if err := sigstore.VerifySignature(measurements, signature, publicKey); err != nil {
return "", err return "", err
} }
if err := yaml.NewDecoder(bytes.NewReader(measurements)).Decode(&m); err != nil {
return "", err if err := json.Unmarshal(measurements, m); err != nil {
if yamlErr := yaml.Unmarshal(measurements, m); yamlErr != nil {
return "", multierr.Append(
err,
fmt.Errorf("trying yaml format: %w", yamlErr),
)
}
} }
shaHash := sha256.Sum256(measurements) shaHash := sha256.Sum256(measurements)
@ -94,58 +70,231 @@ func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurement
// CopyFrom copies over all values from other. Overwriting existing values, // CopyFrom copies over all values from other. Overwriting existing values,
// but keeping not specified values untouched. // but keeping not specified values untouched.
func (m M) CopyFrom(other M) { func (m *M) CopyFrom(other M) {
for idx := range other { for idx := range other {
m[idx] = other[idx] (*m)[idx] = other[idx]
} }
} }
// EqualTo tests whether the provided other Measurements are equal to these // EqualTo tests whether the provided other Measurements are equal to these
// measurements. // measurements.
func (m M) EqualTo(other M) bool { func (m *M) EqualTo(other M) bool {
if len(m) != len(other) { if len(*m) != len(other) {
return false return false
} }
for k, v := range m { for k, v := range *m {
if !bytes.Equal(v, other[k]) { otherExpected := other[k].Expected
if !bytes.Equal(v.Expected[:], otherExpected[:]) {
return false
}
if v.WarnOnly != other[k].WarnOnly {
return false return false
} }
} }
return true return true
} }
// MarshalYAML overwrites the default behaviour of writing out []byte not as // GetEnforced returns a list of all enforced Measurements,
// single bytes, but as a single base64 encoded string. // i.e. all Measurements that are not marked as WarnOnly.
func (m M) MarshalYAML() (any, error) { func (m *M) GetEnforced() []uint32 {
base64Map := make(map[uint32]string) var enforced []uint32
for idx, measurement := range *m {
for key, value := range m { if !measurement.WarnOnly {
base64Map[key] = base64.StdEncoding.EncodeToString(value[:]) enforced = append(enforced, idx)
}
}
return enforced
} }
return base64Map, nil // SetEnforced sets the WarnOnly flag to true for all Measurements
// that are NOT included in the provided list of enforced measurements.
func (m *M) SetEnforced(enforced []uint32) error {
newM := make(M)
// set all measurements to warn only
for idx, measurement := range *m {
newM[idx] = Measurement{
Expected: measurement.Expected,
WarnOnly: true,
}
} }
// UnmarshalYAML overwrites the default behaviour of reading []byte not as // set enforced measurements from list
// single bytes, but as a single base64 encoded string. for _, idx := range enforced {
func (m *M) UnmarshalYAML(unmarshal func(any) error) error { measurement, ok := newM[idx]
base64Map := make(map[uint32]string) if !ok {
err := unmarshal(base64Map) return fmt.Errorf("measurement %d not in list, but set to enforced", idx)
if err != nil { }
return err measurement.WarnOnly = false
newM[idx] = measurement
} }
*m = make(M) *m = newM
for key, value := range base64Map { return nil
measurement, err := base64.StdEncoding.DecodeString(value)
if err != nil {
return err
} }
(*m)[key] = measurement
// Measurement wraps expected PCR value and whether it is enforced.
type Measurement struct {
// Expected measurement value.
Expected [32]byte `json:"expected" yaml:"expected"`
// WarnOnly if set to true, a mismatching measurement will only result in a warning.
WarnOnly bool `json:"warnOnly" yaml:"warnOnly"`
}
// UnmarshalJSON reads a Measurement either as json object,
// or as a simple hex or base64 encoded string.
func (m *Measurement) UnmarshalJSON(b []byte) error {
var eM encodedMeasurement
if err := json.Unmarshal(b, &eM); err != nil {
// Unmarshalling failed, Measurement might be in legacy format,
// meaning a simple string instead of Measurement struct.
// TODO: remove with v2.4.0
if legacyErr := json.Unmarshal(b, &eM.Expected); legacyErr != nil {
return multierr.Append(
err,
fmt.Errorf("trying legacy format: %w", legacyErr),
)
}
}
if err := m.unmarshal(eM); err != nil {
return fmt.Errorf("unmarshalling json: %w", err)
} }
return nil return nil
} }
// MarshalJSON writes out a Measurement with Expected encoded as a hex string.
func (m Measurement) MarshalJSON() ([]byte, error) {
return json.Marshal(encodedMeasurement{
Expected: hex.EncodeToString(m.Expected[:]),
WarnOnly: m.WarnOnly,
})
}
// UnmarshalYAML reads a Measurement either as yaml object,
// or as a simple hex or base64 encoded string.
func (m *Measurement) UnmarshalYAML(unmarshal func(any) error) error {
var eM encodedMeasurement
if err := unmarshal(&eM); err != nil {
// Unmarshalling failed, Measurement might be in legacy format,
// meaning a simple string instead of Measurement struct.
// TODO: remove with v2.4.0
if legacyErr := unmarshal(&eM.Expected); legacyErr != nil {
return multierr.Append(
err,
fmt.Errorf("trying legacy format: %w", legacyErr),
)
}
}
if err := m.unmarshal(eM); err != nil {
return fmt.Errorf("unmarshalling yaml: %w", err)
}
return nil
}
// MarshalYAML writes out a Measurement with Expected encoded as a hex string.
func (m Measurement) MarshalYAML() (any, error) {
return encodedMeasurement{
Expected: hex.EncodeToString(m.Expected[:]),
WarnOnly: m.WarnOnly,
}, nil
}
// unmarshal parses a hex or base64 encoded Measurement.
func (m *Measurement) unmarshal(eM encodedMeasurement) error {
expected, err := hex.DecodeString(eM.Expected)
if err != nil {
// expected value might be in base64 legacy format
// TODO: Remove with v2.4.0
hexErr := err
expected, err = base64.StdEncoding.DecodeString(eM.Expected)
if err != nil {
return multierr.Append(
fmt.Errorf("invalid measurement: not a hex string %w", hexErr),
fmt.Errorf("not a base64 string: %w", err),
)
}
}
if len(expected) != 32 {
return fmt.Errorf("invalid measurement: invalid length: %d", len(expected))
}
m.Expected = *(*[32]byte)(expected)
m.WarnOnly = eM.WarnOnly
return nil
}
// WithAllBytes returns a measurement value where all 32 bytes are set to b.
func WithAllBytes(b byte, warnOnly bool) Measurement {
return Measurement{
Expected: *(*[32]byte)(bytes.Repeat([]byte{b}, 32)),
WarnOnly: warnOnly,
}
}
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider) M {
switch provider {
case cloudprovider.AWS:
return M{
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
case cloudprovider.Azure:
return M{
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
case cloudprovider.GCP:
return M{
0: {
Expected: [32]byte{0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
WarnOnly: false,
},
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
case cloudprovider.QEMU:
return M{
4: PlaceHolderMeasurement(),
8: WithAllBytes(0x00, false),
9: PlaceHolderMeasurement(),
11: WithAllBytes(0x00, false),
12: PlaceHolderMeasurement(),
13: WithAllBytes(0x00, false),
uint32(PCRIndexClusterID): WithAllBytes(0x00, false),
}
default:
return nil
}
}
// PlaceHolderMeasurement returns a measurement with placeholder values for Expected.
func PlaceHolderMeasurement() Measurement {
return Measurement{
Expected: *(*[32]byte)(bytes.Repeat([]byte{0x12, 0x34}, 16)),
WarnOnly: false,
}
}
func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([]byte, error) { func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), http.NoBody)
if err != nil { if err != nil {
@ -166,3 +315,8 @@ func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([
} }
return content, nil return content, nil
} }
type encodedMeasurement struct {
Expected string `json:"expected" yaml:"expected"`
WarnOnly bool `json:"warnOnly" yaml:"warnOnly"`
}

View file

@ -8,7 +8,7 @@ package measurements
import ( import (
"context" "context"
"errors" "encoding/json"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
@ -17,32 +17,29 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
) )
func TestMarshalYAML(t *testing.T) { func TestMarshal(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
measurements M m Measurement
wantBase64Map map[uint32]string wantYAML string
wantJSON string
}{ }{
"valid measurements": { "measurement": {
measurements: M{ m: Measurement{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
}, },
wantBase64Map: map[uint32]string{ wantYAML: "expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35\"\nwarnOnly: false",
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", wantJSON: `{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35","warnOnly":false}`,
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
}, },
"warn only": {
m: Measurement{
Expected: [32]byte{1, 2, 3, 4}, // implicitly padded with 0s
WarnOnly: true,
}, },
"omit bytes": { wantYAML: "expected: \"0102030400000000000000000000000000000000000000000000000000000000\"\nwarnOnly: true",
measurements: M{ wantJSON: `{"expected":"0102030400000000000000000000000000000000000000000000000000000000","warnOnly":true}`,
2: []byte{},
3: []byte{1, 2, 3, 4},
},
wantBase64Map: map[uint32]string{
2: "",
3: "AQIDBA==",
},
}, },
} }
@ -51,62 +48,98 @@ func TestMarshalYAML(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
base64Map, err := tc.measurements.MarshalYAML() {
// YAML
yaml, err := yaml.Marshal(tc.m)
require.NoError(err) require.NoError(err)
assert.Equal(tc.wantBase64Map, base64Map) assert.YAMLEq(tc.wantYAML, string(yaml))
}
{
// JSON
json, err := json.Marshal(tc.m)
require.NoError(err)
assert.JSONEq(tc.wantJSON, string(json))
}
}) })
} }
} }
func TestUnmarshalYAML(t *testing.T) { func TestUnmarshal(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
inputBase64Map map[uint32]string inputYAML string
forceUnmarshalError bool inputJSON string
wantMeasurements M wantMeasurements M
wantErr bool wantErr bool
}{ }{
"valid measurements": { "valid measurements base64": {
inputBase64Map: map[uint32]string{ inputYAML: "2:\n expected: \"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=\"\n3:\n expected: \"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=\"",
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", inputJSON: `{"2":{"expected":"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU="},"3":{"expected":"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754="}}`,
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
},
wantMeasurements: M{ wantMeasurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, 2: {
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
},
3: {
Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
},
},
"valid measurements hex": {
inputYAML: "2:\n expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35\"\n3:\n expected: \"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463fe607f2488ddf4f1006ef9e\"",
inputJSON: `{"2":{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef51239f752ce69dbc600f35"},"3":{"expected":"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463fe607f2488ddf4f1006ef9e"}}`,
wantMeasurements: M{
2: {
Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
},
3: {
Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
}, },
}, },
"empty bytes": { "empty bytes": {
inputBase64Map: map[uint32]string{ inputYAML: "2:\n expected: \"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"",
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", inputJSON: `{"2":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}}`,
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: M{ wantMeasurements: M{
2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 2: {
3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, Expected: [32]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
3: {
Expected: [32]byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
}, },
}, },
"invalid base64": { "invalid base64": {
inputBase64Map: map[uint32]string{ inputYAML: "2:\n expected: \"This is not base64\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\"",
2: "This is not base64", inputJSON: `{"2":{"expected":"This is not base64"},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA="}}`,
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
wantErr: true, wantErr: true,
}, },
"simulated unmarshal error": { "legacy format": {
inputBase64Map: map[uint32]string{ inputYAML: "2: \"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=\"\n3: \"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=\"",
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", inputJSON: `{"2":"/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=","3":"1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754="}`,
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
forceUnmarshalError: true,
wantMeasurements: M{ wantMeasurements: M{
2: []byte{}, 2: {
3: []byte{1, 2, 3, 4}, Expected: [32]byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
}, },
3: {
Expected: [32]byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
},
},
"invalid length hex": {
inputYAML: "2:\n expected: \"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef\"\n3:\n expected: \"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463f\"",
inputJSON: `{"2":{"expected":"fd5de9df350e3bc4410ac06bbfe5ccdeb93f53b9ef"},"3":{"expected":"d5a4496d21dec9a5258ddb19c6feb53bb4d3c0463f"}}`,
wantErr: true,
},
"invalid length base64": {
inputYAML: "2:\n expected: \"AA==\"\n3:\n expected: \"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA==\"",
inputJSON: `{"2":{"expected":"AA=="},"3":{"expected":"AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=="}}`,
wantErr: true,
},
"invalid format": {
inputYAML: "1:\n expected:\n someKey: 12\n anotherKey: 34",
inputJSON: `{"1":{"expected":{"someKey":12,"anotherKey":34}}}`,
wantErr: true, wantErr: true,
}, },
} }
@ -116,25 +149,31 @@ func TestUnmarshalYAML(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
{
// YAML
var m M var m M
err := m.UnmarshalYAML(func(i any) error { err := yaml.Unmarshal([]byte(tc.inputYAML), &m)
if base64Map, ok := i.(map[uint32]string); ok {
for key, value := range tc.inputBase64Map {
base64Map[key] = value
}
}
if tc.forceUnmarshalError {
return errors.New("unmarshal error")
}
return nil
})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err, "yaml.Unmarshal should have failed")
} else { } else {
require.NoError(err) require.NoError(err, "yaml.Unmarshal failed")
assert.Equal(tc.wantMeasurements, m) assert.Equal(tc.wantMeasurements, m)
} }
}
{
// JSON
var m M
err := json.Unmarshal([]byte(tc.inputJSON), &m)
if tc.wantErr {
assert.Error(err, "json.Unmarshal should have failed")
} else {
require.NoError(err, "json.Unmarshal failed")
assert.Equal(tc.wantMeasurements, m)
}
}
}) })
} }
} }
@ -148,48 +187,48 @@ func TestMeasurementsCopyFrom(t *testing.T) {
"add to empty": { "add to empty": {
current: M{}, current: M{},
newMeasurements: M{ newMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
wantMeasurements: M{ wantMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
}, },
"keep existing": { "keep existing": {
current: M{ current: M{
4: PCRWithAllBytes(0x01), 4: WithAllBytes(0x01, false),
5: PCRWithAllBytes(0x02), 5: WithAllBytes(0x02, true),
}, },
newMeasurements: M{ newMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
wantMeasurements: M{ wantMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
4: PCRWithAllBytes(0x01), 4: WithAllBytes(0x01, false),
5: PCRWithAllBytes(0x02), 5: WithAllBytes(0x02, true),
}, },
}, },
"overwrite existing": { "overwrite existing": {
current: M{ current: M{
2: PCRWithAllBytes(0x04), 2: WithAllBytes(0x04, false),
3: PCRWithAllBytes(0x05), 3: WithAllBytes(0x05, false),
}, },
newMeasurements: M{ newMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
wantMeasurements: M{ wantMeasurements: M{
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, true),
2: PCRWithAllBytes(0x01), 2: WithAllBytes(0x01, true),
3: PCRWithAllBytes(0x02), 3: WithAllBytes(0x02, true),
}, },
}, },
} }
@ -224,6 +263,22 @@ func urlMustParse(raw string) *url.URL {
} }
func TestMeasurementsFetchAndVerify(t *testing.T) { func TestMeasurementsFetchAndVerify(t *testing.T) {
// Cosign private key used to sign the measurements.
// Generated with: cosign generate-key-pair
// Password left empty.
//
// -----BEGIN ENCRYPTED COSIGN PRIVATE KEY-----
// eyJrZGYiOnsibmFtZSI6InNjcnlwdCIsInBhcmFtcyI6eyJOIjozMjc2OCwiciI6
// OCwicCI6MX0sInNhbHQiOiJlRHVYMWRQMGtIWVRnK0xkbjcxM0tjbFVJaU92eFVX
// VXgvNi9BbitFVk5BPSJ9LCJjaXBoZXIiOnsibmFtZSI6Im5hY2wvc2VjcmV0Ym94
// Iiwibm9uY2UiOiJwaWhLL2txNmFXa2hqSVVHR3RVUzhTVkdHTDNIWWp4TCJ9LCJj
// aXBoZXJ0ZXh0Ijoidm81SHVWRVFWcUZ2WFlQTTVPaTVaWHM5a255bndZU2dvcyth
// VklIeHcrOGFPamNZNEtvVjVmL3lHRHR0K3BHV2toanJPR1FLOWdBbmtsazFpQ0c5
// a2czUXpPQTZsU2JRaHgvZlowRVRZQ0hLeElncEdPRVRyTDlDenZDemhPZXVSOXJ6
// TDcvRjBBVy9vUDVqZXR3dmJMNmQxOEhjck9kWE8yVmYxY2w0YzNLZjVRcnFSZzlN
// dlRxQWFsNXJCNHNpY1JaMVhpUUJjb0YwNHc9PSJ9
// -----END ENCRYPTED COSIGN PRIVATE KEY-----
testCases := map[string]struct { testCases := map[string]struct {
measurements string measurements string
measurementsStatus int measurementsStatus int
@ -237,44 +292,66 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
"simple": { "simple": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n",
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIQDcHS2bLls7OrLHpQKuiFGXhPrTcehPDwgVyERHl4V02wIgeIxK4J9oJpXWRBjokbog2lgifRXuJK8ljlAID26MbHk=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{ wantMeasurements: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
}, },
wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87", wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87",
}, },
"404 measurements": { "json measurements": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusNotFound, measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{
0: WithAllBytes(0x00, false),
},
wantSHA: "1da09758c89537946496358f80b892e508563fcbbc695c90b6c16bf158e69c11",
},
"yaml measurements": {
measurements: "0:\n expected: \"0000000000000000000000000000000000000000000000000000000000000000\"\n warnOnly: false\n",
measurementsStatus: http.StatusOK,
signature: "MEUCIFzQdwBS92aJjY0bcIag1uQRl42lUSBmmjEvO0tM/N0ZAiEAvuWaP744qYMw5uEmc7BY4mm4Ij3TEqAWFgxNhFkckp4=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantMeasurements: M{
0: WithAllBytes(0x00, false),
},
wantSHA: "c651cd419fd536c63cfc5349ad44da140a09987465e31192660059d383413807",
},
"404 measurements": {
measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusNotFound,
signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
"404 signature": { "404 signature": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIQDh3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusNotFound, signatureStatus: http.StatusNotFound,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
"broken signature": { "broken signature": {
measurements: "0: AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=\n", measurements: `{"0":{"expected":"0000000000000000000000000000000000000000000000000000000000000000","warnOnly":false}}`,
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "AAAAAAs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "AAAAAAAA3nCgrdTiYWiV4NkiaZ6vxovj79Pk8V90mdWAnmCEOwIgMAVWAx5dW0saut+8X15SgtBEiKqEixYiSICSqqhxUMg=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
"not yaml": { "not yaml or json": {
measurements: "This is some content to be signed!\n", measurements: "This is some content to be signed!\n",
measurementsStatus: http.StatusOK, measurementsStatus: http.StatusOK,
signature: "MEUCIQDzMN3yaiO9sxLGAaSA9YD8rLwzvOaZKWa/bzkcjImUFAIgXLLGzClYUd1dGbuEiY3O/g/eiwQYlyxqLQalxjFmz+8=", signature: "MEUCIQCGA/lSu5qCJgNNvgMaTKJ9rj6vQMecUDaQo3ukaiAfUgIgWoxXRoDKLY9naN7YgxokM7r2fwnyYk3M2WKJJO1g6yo=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAElWUhon39eAqzEC+/GP03oY4/MQg+\ngCDlEzkuOCybCHf+q766bve799L7Y5y5oRsHY1MrUCUwYF/tL7Sg7EYMsA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEu78QgxOOcao6U91CSzEXxrKhvFTt\nJHNy+eX6EMePtDm8CnDF9HSwnTlD0itGJ/XHPQA5YX10fJAqI1y+ehlFMw==\n-----END PUBLIC KEY-----"),
wantError: true, wantError: true,
}, },
} }
@ -322,30 +399,196 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
} }
} }
func TestPCRWithAllBytes(t *testing.T) { func TestGetEnforced(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
b byte input M
wantPCR []byte want map[uint32]struct{}
}{ }{
"0x00": { "only warnings": {
b: 0x00, input: M{
wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
}, },
"0x01": { want: map[uint32]struct{}{},
b: 0x01, },
wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}, "all enforced": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
want: map[uint32]struct{}{
0: {},
1: {},
},
},
"mixed": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, true),
2: WithAllBytes(0x02, false),
},
want: map[uint32]struct{}{
0: {},
2: {},
}, },
"0xFF": {
b: 0xFF,
wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
}, },
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
pcr := PCRWithAllBytes(tc.b)
assert.Equal(tc.wantPCR, pcr) got := tc.input.GetEnforced()
enforced := map[uint32]struct{}{}
for _, id := range got {
enforced[id] = struct{}{}
}
assert.Equal(tc.want, enforced)
})
}
}
func TestSetEnforced(t *testing.T) {
testCases := map[string]struct {
input M
enforced []uint32
wantM M
wantErr bool
}{
"no enforced measurements": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
enforced: []uint32{},
wantM: M{
0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
},
"all enforced measurements": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
enforced: []uint32{0, 1},
wantM: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
},
"mixed": {
input: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
2: WithAllBytes(0x02, false),
3: WithAllBytes(0x03, false),
},
enforced: []uint32{0, 2},
wantM: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, true),
2: WithAllBytes(0x02, false),
3: WithAllBytes(0x03, true),
},
},
"warn only to enforced": {
input: M{
0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
enforced: []uint32{0, 1},
wantM: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0x01, false),
},
},
"more enforced than measurements": {
input: M{
0: WithAllBytes(0x00, true),
1: WithAllBytes(0x01, true),
},
enforced: []uint32{0, 1, 2},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := tc.input.SetEnforced(tc.enforced)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.True(tc.input.EqualTo(tc.wantM))
})
}
}
func TestWithAllBytes(t *testing.T) {
testCases := map[string]struct {
b byte
warnOnly bool
wantMeasurement Measurement
}{
"0x00 warnOnly": {
b: 0x00,
warnOnly: true,
wantMeasurement: Measurement{
Expected: [32]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
WarnOnly: true,
},
},
"0x00": {
b: 0x00,
warnOnly: false,
wantMeasurement: Measurement{
Expected: [32]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
WarnOnly: false,
},
},
"0x01 warnOnly": {
b: 0x01,
warnOnly: true,
wantMeasurement: Measurement{
Expected: [32]byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
WarnOnly: true,
},
},
"0x01": {
b: 0x01,
warnOnly: false,
wantMeasurement: Measurement{
Expected: [32]byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
WarnOnly: false,
},
},
"0xFF warnOnly": {
b: 0xFF,
warnOnly: true,
wantMeasurement: Measurement{
Expected: [32]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
WarnOnly: true,
},
},
"0xFF": {
b: 0xFF,
warnOnly: false,
wantMeasurement: Measurement{
Expected: [32]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
WarnOnly: false,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
measurement := WithAllBytes(tc.b, tc.warnOnly)
assert.Equal(tc.wantMeasurement, measurement)
}) })
} }
} }
@ -358,33 +601,44 @@ func TestEqualTo(t *testing.T) {
}{ }{
"same values": { "same values": {
given: M{ given: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
other: M{ other: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
wantEqual: true, wantEqual: true,
}, },
"different number of elements": { "different number of elements": {
given: M{ given: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
other: M{ other: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
}, },
wantEqual: false, wantEqual: false,
}, },
"different values": { "different values": {
given: M{ given: M{
0: PCRWithAllBytes(0x00), 0: WithAllBytes(0x00, false),
1: PCRWithAllBytes(0xFF), 1: WithAllBytes(0xFF, false),
}, },
other: M{ other: M{
0: PCRWithAllBytes(0xFF), 0: WithAllBytes(0xFF, false),
1: PCRWithAllBytes(0x00), 1: WithAllBytes(0x00, false),
},
wantEqual: false,
},
"different warn settings": {
given: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0xFF, false),
},
other: M{
0: WithAllBytes(0x00, false),
1: WithAllBytes(0xFF, true),
}, },
wantEqual: false, wantEqual: false,
}, },

View file

@ -22,11 +22,10 @@ type Validator struct {
} }
// NewValidator initializes a new QEMU validator with the provided PCR values. // NewValidator initializes a new QEMU validator with the provided PCR values.
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,
enforcedPCRs,
unconditionalTrust, unconditionalTrust,
func(attestation vtpm.AttestationDocument) error { return nil }, func(attestation vtpm.AttestationDocument) error { return nil },
vtpm.VerifyPKCS1v15, vtpm.VerifyPKCS1v15,

View file

@ -16,6 +16,7 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
tpmClient "github.com/google/go-tpm-tools/client" tpmClient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
tpmProto "github.com/google/go-tpm-tools/proto/tpm" tpmProto "github.com/google/go-tpm-tools/proto/tpm"
@ -144,8 +145,7 @@ func (i *Issuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
// Validator handles validation of TPM based attestation. // Validator handles validation of TPM based attestation.
type Validator struct { type Validator struct {
expectedPCRs map[uint32][]byte expected measurements.M
enforcedPCRs map[uint32]struct{}
getTrustedKey GetTPMTrustedAttestationPublicKey getTrustedKey GetTPMTrustedAttestationPublicKey
validateCVM ValidateCVM validateCVM ValidateCVM
verifyUserData VerifyUserData verifyUserData VerifyUserData
@ -154,18 +154,11 @@ type Validator struct {
} }
// NewValidator returns a new Validator. // NewValidator returns a new Validator.
func NewValidator(expectedPCRs map[uint32][]byte, enforcedPCRs []uint32, getTrustedKey GetTPMTrustedAttestationPublicKey, func NewValidator(expected measurements.M, getTrustedKey GetTPMTrustedAttestationPublicKey,
validateCVM ValidateCVM, verifyUserData VerifyUserData, log AttestationLogger, validateCVM ValidateCVM, verifyUserData VerifyUserData, log AttestationLogger,
) *Validator { ) *Validator {
// Convert the enforced PCR list to a map for convenient and fast lookup
enforcedMap := make(map[uint32]struct{})
for _, pcr := range enforcedPCRs {
enforcedMap[pcr] = struct{}{}
}
return &Validator{ return &Validator{
expectedPCRs: expectedPCRs, expected: expected,
enforcedPCRs: enforcedMap,
getTrustedKey: getTrustedKey, getTrustedKey: getTrustedKey,
validateCVM: validateCVM, validateCVM: validateCVM,
verifyUserData: verifyUserData, verifyUserData: verifyUserData,
@ -212,9 +205,9 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte) ([]byte, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
for idx, pcr := range v.expectedPCRs { for idx, pcr := range v.expected {
if !bytes.Equal(pcr, attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) { if !bytes.Equal(pcr.Expected[:], attDoc.Attestation.Quotes[quoteIdx].Pcrs.Pcrs[idx]) {
if _, ok := v.enforcedPCRs[idx]; ok { if !pcr.WarnOnly {
return nil, fmt.Errorf("untrusted PCR value at PCR index %d", idx) return nil, fmt.Errorf("untrusted PCR value at PCR index %d", idx)
} }
if v.log != nil { if v.log != nil {
@ -263,8 +256,8 @@ func VerifyPKCS1v15(pub crypto.PublicKey, hash crypto.Hash, hashed, sig []byte)
return rsa.VerifyPKCS1v15(key, hash, hashed, sig) return rsa.VerifyPKCS1v15(key, hash, hashed, sig)
} }
// GetSelectedPCRs returns a map of the selected PCR hashes. // GetSelectedMeasurements returns a map of Measurments for the PCRs in selection.
func GetSelectedPCRs(open TPMOpenFunc, selection tpm2.PCRSelection) (map[uint32][]byte, error) { func GetSelectedMeasurements(open TPMOpenFunc, selection tpm2.PCRSelection) (measurements.M, error) {
tpm, err := open() tpm, err := open()
if err != nil { if err != nil {
return nil, err return nil, err
@ -276,5 +269,15 @@ func GetSelectedPCRs(open TPMOpenFunc, selection tpm2.PCRSelection) (map[uint32]
return nil, err return nil, err
} }
return pcrList.Pcrs, nil m := make(measurements.M)
for i, pcr := range pcrList.Pcrs {
if len(pcr) != 32 {
return nil, fmt.Errorf("invalid measurement: invalid length: %d", len(pcr))
}
m[i] = measurements.Measurement{
Expected: *(*[32]byte)(pcr),
}
}
return m, nil
} }

View file

@ -14,6 +14,7 @@ import (
"io" "io"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
tpmsim "github.com/edgelesssys/constellation/v2/internal/attestation/simulator" tpmsim "github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
tpmclient "github.com/google/go-tpm-tools/client" tpmclient "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
@ -64,14 +65,14 @@ func TestValidate(t *testing.T) {
return pubArea.Key() return pubArea.Key()
} }
testExpectedPCRs := map[uint32][]byte{ testExpectedPCRs := measurements.M{
0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 0: measurements.WithAllBytes(0x00, true),
1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 1: measurements.WithAllBytes(0x00, true),
} }
warnLog := &testAttestationLogger{} warnLog := &testAttestationLogger{}
issuer := NewIssuer(newSimTPMWithEventLog, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo) issuer := NewIssuer(newSimTPMWithEventLog, tpmclient.AttestationKeyRSA, fakeGetInstanceInfo)
validator := NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog) validator := NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog)
nonce := []byte{1, 2, 3, 4} nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation") challenge := []byte("Constellation")
@ -89,18 +90,28 @@ func TestValidate(t *testing.T) {
require.NoError(err) require.NoError(err)
require.Equal(challenge, out) require.Equal(challenge, out)
enforcedPCRs := []uint32{0, 1} expectedPCRs := measurements.M{
expectedPCRs := map[uint32][]byte{ 0: measurements.WithAllBytes(0x00, true),
0: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 1: measurements.WithAllBytes(0x00, true),
1: {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, 2: measurements.Measurement{
2: {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20}, Expected: [32]byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20},
3: {0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40}, WarnOnly: true,
4: {0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60}, },
5: {0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80}, 3: measurements.Measurement{
Expected: [32]byte{0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40},
WarnOnly: true,
},
4: measurements.Measurement{
Expected: [32]byte{0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60},
WarnOnly: true,
},
5: measurements.Measurement{
Expected: [32]byte{0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80},
WarnOnly: true,
},
} }
warningValidator := NewValidator( warningValidator := NewValidator(
expectedPCRs, expectedPCRs,
enforcedPCRs,
fakeGetTrustedKey, fakeGetTrustedKey,
fakeValidateCVM, fakeValidateCVM,
VerifyPKCS1v15, VerifyPKCS1v15,
@ -109,7 +120,7 @@ func TestValidate(t *testing.T) {
out, err = warningValidator.Validate(attDocRaw, nonce) out, err = warningValidator.Validate(attDocRaw, nonce)
require.NoError(err) require.NoError(err)
assert.Equal(t, challenge, out) assert.Equal(t, challenge, out)
assert.Len(t, warnLog.warnings, len(expectedPCRs)-len(enforcedPCRs)) assert.Len(t, warnLog.warnings, 4)
testCases := map[string]struct { testCases := map[string]struct {
validator *Validator validator *Validator
@ -118,13 +129,13 @@ func TestValidate(t *testing.T) {
wantErr bool wantErr bool
}{ }{
"invalid nonce": { "invalid nonce": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: mustMarshalAttestation(attDoc, require), attDoc: mustMarshalAttestation(attDoc, require),
nonce: []byte{4, 3, 2, 1}, nonce: []byte{4, 3, 2, 1},
wantErr: true, wantErr: true,
}, },
"invalid signature": { "invalid signature": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{ attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: attDoc.Attestation, Attestation: attDoc.Attestation,
InstanceInfo: attDoc.InstanceInfo, InstanceInfo: attDoc.InstanceInfo,
@ -137,7 +148,6 @@ func TestValidate(t *testing.T) {
"untrusted attestation public key": { "untrusted attestation public key": {
validator: NewValidator( validator: NewValidator(
testExpectedPCRs, testExpectedPCRs,
[]uint32{0, 1},
func(akPub, instanceInfo []byte) (crypto.PublicKey, error) { func(akPub, instanceInfo []byte) (crypto.PublicKey, error) {
return nil, errors.New("untrusted") return nil, errors.New("untrusted")
}, },
@ -149,7 +159,6 @@ func TestValidate(t *testing.T) {
"not a CVM": { "not a CVM": {
validator: NewValidator( validator: NewValidator(
testExpectedPCRs, testExpectedPCRs,
[]uint32{0, 1},
fakeGetTrustedKey, fakeGetTrustedKey,
func(attestation AttestationDocument) error { func(attestation AttestationDocument) error {
return errors.New("untrusted") return errors.New("untrusted")
@ -161,10 +170,12 @@ func TestValidate(t *testing.T) {
}, },
"untrusted PCRs": { "untrusted PCRs": {
validator: NewValidator( validator: NewValidator(
map[uint32][]byte{ measurements.M{
0: {0xFF}, 0: measurements.Measurement{
Expected: [32]byte{0xFF},
WarnOnly: false,
},
}, },
[]uint32{0},
fakeGetTrustedKey, fakeGetTrustedKey,
fakeValidateCVM, fakeValidateCVM,
VerifyPKCS1v15, warnLog), VerifyPKCS1v15, warnLog),
@ -173,7 +184,7 @@ func TestValidate(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"no sha256 quote": { "no sha256 quote": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: mustMarshalAttestation(AttestationDocument{ attDoc: mustMarshalAttestation(AttestationDocument{
Attestation: &attest.Attestation{ Attestation: &attest.Attestation{
AkPub: attDoc.Attestation.AkPub, AkPub: attDoc.Attestation.AkPub,
@ -191,7 +202,7 @@ func TestValidate(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"invalid attestation document": { "invalid attestation document": {
validator: NewValidator(testExpectedPCRs, []uint32{0, 1}, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog), validator: NewValidator(testExpectedPCRs, fakeGetTrustedKey, fakeValidateCVM, VerifyPKCS1v15, warnLog),
attDoc: []byte("invalid attestation"), attDoc: []byte("invalid attestation"),
nonce: nonce, nonce: nonce,
wantErr: true, wantErr: true,
@ -350,7 +361,7 @@ func TestGetSHA256QuoteIndex(t *testing.T) {
} }
} }
func TestGetSelectedPCRs(t *testing.T) { func TestGetSelectedMeasurements(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
openFunc TPMOpenFunc openFunc TPMOpenFunc
pcrSelection tpm2.PCRSelection pcrSelection tpm2.PCRSelection
@ -386,17 +397,13 @@ func TestGetSelectedPCRs(t *testing.T) {
require := require.New(t) require := require.New(t)
assert := assert.New(t) assert := assert.New(t)
pcrs, err := GetSelectedPCRs(tc.openFunc, tc.pcrSelection) pcrs, err := GetSelectedMeasurements(tc.openFunc, tc.pcrSelection)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { return
}
require.NoError(err) require.NoError(err)
assert.Len(pcrs, len(tc.pcrSelection.PCRs))
assert.Equal(len(pcrs), len(tc.pcrSelection.PCRs))
for _, pcr := range pcrs {
assert.Len(pcr, 32)
}
}
}) })
} }
} }

View file

@ -9,18 +9,8 @@ package vtpm
import ( import (
"errors" "errors"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
"github.com/google/go-tpm/tpmutil"
)
const (
// PCRIndexClusterID is a PCR we extend to mark the node as initialized.
// The value used to extend is a random generated 32 Byte value.
PCRIndexClusterID = tpmutil.Handle(15)
// PCRIndexOwnerID is a PCR we extend to mark the node as initialized.
// The value used to extend is derived from Constellation's master key.
// TODO: move to stable, non-debug PCR before use.
PCRIndexOwnerID = tpmutil.Handle(16)
) )
// MarkNodeAsBootstrapped marks a node as initialized by extending PCRs. // MarkNodeAsBootstrapped marks a node as initialized by extending PCRs.
@ -32,7 +22,7 @@ func MarkNodeAsBootstrapped(openTPM TPMOpenFunc, clusterID []byte) error {
defer tpm.Close() defer tpm.Close()
// clusterID is used to uniquely identify this running instance of Constellation // clusterID is used to uniquely identify this running instance of Constellation
return tpm2.PCREvent(tpm, PCRIndexClusterID, clusterID) return tpm2.PCREvent(tpm, measurements.PCRIndexClusterID, clusterID)
} }
// IsNodeBootstrapped checks if a node is already bootstrapped by reading PCRs. // IsNodeBootstrapped checks if a node is already bootstrapped by reading PCRs.
@ -43,7 +33,7 @@ func IsNodeBootstrapped(openTPM TPMOpenFunc) (bool, error) {
} }
defer tpm.Close() defer tpm.Close()
idxClusterID := int(PCRIndexClusterID) idxClusterID := int(measurements.PCRIndexClusterID)
pcrs, err := tpm2.ReadPCRs(tpm, tpm2.PCRSelection{ pcrs, err := tpm2.ReadPCRs(tpm, tpm2.PCRSelection{
Hash: tpm2.AlgSHA256, Hash: tpm2.AlgSHA256,
PCRs: []int{idxClusterID}, PCRs: []int{idxClusterID},

View file

@ -11,6 +11,7 @@ import (
"io" "io"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator" "github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/google/go-tpm-tools/client" "github.com/google/go-tpm-tools/client"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
@ -45,7 +46,7 @@ func TestMarkNodeAsBootstrapped(t *testing.T) {
require.NoError(err) require.NoError(err)
for i := range pcrs { for i := range pcrs {
assert.NotEqual(pcrs[i].Pcrs[uint32(PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(PCRIndexClusterID)]) assert.NotEqual(pcrs[i].Pcrs[uint32(measurements.PCRIndexClusterID)], pcrsInitialized[i].Pcrs[uint32(measurements.PCRIndexClusterID)])
} }
} }
@ -76,7 +77,7 @@ func TestIsNodeInitialized(t *testing.T) {
require.NoError(err) require.NoError(err)
defer tpm.Close() defer tpm.Close()
if tc.pcrValueClusterID != nil { if tc.pcrValueClusterID != nil {
require.NoError(tpm2.PCREvent(tpm, PCRIndexClusterID, tc.pcrValueClusterID)) require.NoError(tpm2.PCREvent(tpm, measurements.PCRIndexClusterID, tc.pcrValueClusterID))
} }
initialized, err := IsNodeBootstrapped(func() (io.ReadWriteCloser, error) { initialized, err := IsNodeBootstrapped(func() (io.ReadWriteCloser, error) {
return &simTPMNOPCloser{tpm}, nil return &simTPMNOPCloser{tpm}, nil

View file

@ -138,10 +138,7 @@ type AWSConfig struct {
IAMProfileWorkerNodes string `yaml:"iamProfileWorkerNodes" validate:"required"` IAMProfileWorkerNodes string `yaml:"iamProfileWorkerNodes" validate:"required"`
// description: | // description: |
// Expected VM measurements. // Expected VM measurements.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// AzureConfig are Azure specific configuration values used by the CLI. // AzureConfig are Azure specific configuration values used by the CLI.
@ -190,10 +187,7 @@ type AzureConfig struct {
EnforceIDKeyDigest *bool `yaml:"enforceIdKeyDigest" validate:"required"` EnforceIDKeyDigest *bool `yaml:"enforceIdKeyDigest" validate:"required"`
// description: | // description: |
// Expected confidential VM measurements. // Expected confidential VM measurements.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// GCPConfig are GCP specific configuration values used by the CLI. // GCPConfig are GCP specific configuration values used by the CLI.
@ -221,10 +215,7 @@ type GCPConfig struct {
DeployCSIDriver *bool `yaml:"deployCSIDriver" validate:"required"` DeployCSIDriver *bool `yaml:"deployCSIDriver" validate:"required"`
// description: | // description: |
// Expected confidential VM measurements. // Expected confidential VM measurements.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// QEMUConfig holds config information for QEMU based Constellation deployments. // QEMUConfig holds config information for QEMU based Constellation deployments.
@ -255,10 +246,7 @@ type QEMUConfig struct {
Firmware string `yaml:"firmware"` Firmware string `yaml:"firmware"`
// description: | // description: |
// Measurement used to enable measured boot. // Measurement used to enable measured boot.
Measurements Measurements `yaml:"measurements"` Measurements Measurements `yaml:"measurements" validate:"required,no_placeholders"`
// description: |
// List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning.
EnforcedMeasurements []uint32 `yaml:"enforcedMeasurements"`
} }
// Default returns a struct with the default config. // Default returns a struct with the default config.
@ -276,7 +264,6 @@ func Default() *Config {
IAMProfileControlPlane: "", IAMProfileControlPlane: "",
IAMProfileWorkerNodes: "", IAMProfileWorkerNodes: "",
Measurements: measurements.DefaultsFor(cloudprovider.AWS), Measurements: measurements.DefaultsFor(cloudprovider.AWS),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
Azure: &AzureConfig{ Azure: &AzureConfig{
SubscriptionID: "", SubscriptionID: "",
@ -292,7 +279,6 @@ func Default() *Config {
ConfidentialVM: func() *bool { b := true; return &b }(), ConfidentialVM: func() *bool { b := true; return &b }(),
SecureBoot: func() *bool { b := false; return &b }(), SecureBoot: func() *bool { b := false; return &b }(),
Measurements: measurements.DefaultsFor(cloudprovider.Azure), Measurements: measurements.DefaultsFor(cloudprovider.Azure),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
GCP: &GCPConfig{ GCP: &GCPConfig{
Project: "", Project: "",
@ -303,7 +289,6 @@ func Default() *Config {
StateDiskType: "pd-ssd", StateDiskType: "pd-ssd",
DeployCSIDriver: func() *bool { b := true; return &b }(), DeployCSIDriver: func() *bool { b := true; return &b }(),
Measurements: measurements.DefaultsFor(cloudprovider.GCP), Measurements: measurements.DefaultsFor(cloudprovider.GCP),
EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15},
}, },
QEMU: &QEMUConfig{ QEMU: &QEMUConfig{
ImageFormat: "raw", ImageFormat: "raw",
@ -314,7 +299,6 @@ func Default() *Config {
LibvirtContainerImage: versions.LibvirtImage, LibvirtContainerImage: versions.LibvirtImage,
NVRAM: "production", NVRAM: "production",
Measurements: measurements.DefaultsFor(cloudprovider.QEMU), Measurements: measurements.DefaultsFor(cloudprovider.QEMU),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
}, },
KubernetesVersion: string(versions.Default), KubernetesVersion: string(versions.Default),
@ -446,18 +430,18 @@ func (c *Config) EnforcesIDKeyDigest() bool {
return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest
} }
// GetEnforcedPCRs returns the list of enforced PCRs for the configured cloud provider. // EnforcedPCRs returns the list of enforced PCRs for the configured cloud provider.
func (c *Config) GetEnforcedPCRs() []uint32 { func (c *Config) EnforcedPCRs() []uint32 {
provider := c.GetProvider() provider := c.GetProvider()
switch provider { switch provider {
case cloudprovider.AWS: case cloudprovider.AWS:
return c.Provider.AWS.EnforcedMeasurements return c.Provider.AWS.Measurements.GetEnforced()
case cloudprovider.Azure: case cloudprovider.Azure:
return c.Provider.Azure.EnforcedMeasurements return c.Provider.Azure.Measurements.GetEnforced()
case cloudprovider.GCP: case cloudprovider.GCP:
return c.Provider.GCP.EnforcedMeasurements return c.Provider.GCP.Measurements.GetEnforced()
case cloudprovider.QEMU: case cloudprovider.QEMU:
return c.Provider.QEMU.EnforcedMeasurements return c.Provider.QEMU.Measurements.GetEnforced()
default: default:
return nil return nil
} }
@ -499,6 +483,14 @@ func (c *Config) Validate() error {
return err return err
} }
if err := validate.RegisterTranslation("no_placeholders", trans, registerContainsPlaceholderError, translateContainsPlaceholderError); err != nil {
return err
}
if err := validate.RegisterValidation("no_placeholders", validateNoPlaceholder); err != nil {
return err
}
if err := validate.RegisterValidation("safe_image", validateImage); err != nil { if err := validate.RegisterValidation("safe_image", validateImage); err != nil {
return err return err
} }

View file

@ -157,7 +157,7 @@ func init() {
FieldName: "aws", FieldName: "aws",
}, },
} }
AWSConfigDoc.Fields = make([]encoder.Doc, 8) AWSConfigDoc.Fields = make([]encoder.Doc, 7)
AWSConfigDoc.Fields[0].Name = "region" AWSConfigDoc.Fields[0].Name = "region"
AWSConfigDoc.Fields[0].Type = "string" AWSConfigDoc.Fields[0].Type = "string"
AWSConfigDoc.Fields[0].Note = "" AWSConfigDoc.Fields[0].Note = ""
@ -193,11 +193,6 @@ func init() {
AWSConfigDoc.Fields[6].Note = "" AWSConfigDoc.Fields[6].Note = ""
AWSConfigDoc.Fields[6].Description = "Expected VM measurements." AWSConfigDoc.Fields[6].Description = "Expected VM measurements."
AWSConfigDoc.Fields[6].Comments[encoder.LineComment] = "Expected VM measurements." AWSConfigDoc.Fields[6].Comments[encoder.LineComment] = "Expected VM measurements."
AWSConfigDoc.Fields[7].Name = "enforcedMeasurements"
AWSConfigDoc.Fields[7].Type = "[]uint32"
AWSConfigDoc.Fields[7].Note = ""
AWSConfigDoc.Fields[7].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
AWSConfigDoc.Fields[7].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
AzureConfigDoc.Type = "AzureConfig" AzureConfigDoc.Type = "AzureConfig"
AzureConfigDoc.Comments[encoder.LineComment] = "AzureConfig are Azure specific configuration values used by the CLI." AzureConfigDoc.Comments[encoder.LineComment] = "AzureConfig are Azure specific configuration values used by the CLI."
@ -208,7 +203,7 @@ func init() {
FieldName: "azure", FieldName: "azure",
}, },
} }
AzureConfigDoc.Fields = make([]encoder.Doc, 16) AzureConfigDoc.Fields = make([]encoder.Doc, 15)
AzureConfigDoc.Fields[0].Name = "subscription" AzureConfigDoc.Fields[0].Name = "subscription"
AzureConfigDoc.Fields[0].Type = "string" AzureConfigDoc.Fields[0].Type = "string"
AzureConfigDoc.Fields[0].Note = "" AzureConfigDoc.Fields[0].Note = ""
@ -284,11 +279,6 @@ func init() {
AzureConfigDoc.Fields[14].Note = "" AzureConfigDoc.Fields[14].Note = ""
AzureConfigDoc.Fields[14].Description = "Expected confidential VM measurements." AzureConfigDoc.Fields[14].Description = "Expected confidential VM measurements."
AzureConfigDoc.Fields[14].Comments[encoder.LineComment] = "Expected confidential VM measurements." AzureConfigDoc.Fields[14].Comments[encoder.LineComment] = "Expected confidential VM measurements."
AzureConfigDoc.Fields[15].Name = "enforcedMeasurements"
AzureConfigDoc.Fields[15].Type = "[]uint32"
AzureConfigDoc.Fields[15].Note = ""
AzureConfigDoc.Fields[15].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
AzureConfigDoc.Fields[15].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
GCPConfigDoc.Type = "GCPConfig" GCPConfigDoc.Type = "GCPConfig"
GCPConfigDoc.Comments[encoder.LineComment] = "GCPConfig are GCP specific configuration values used by the CLI." GCPConfigDoc.Comments[encoder.LineComment] = "GCPConfig are GCP specific configuration values used by the CLI."
@ -299,7 +289,7 @@ func init() {
FieldName: "gcp", FieldName: "gcp",
}, },
} }
GCPConfigDoc.Fields = make([]encoder.Doc, 9) GCPConfigDoc.Fields = make([]encoder.Doc, 8)
GCPConfigDoc.Fields[0].Name = "project" GCPConfigDoc.Fields[0].Name = "project"
GCPConfigDoc.Fields[0].Type = "string" GCPConfigDoc.Fields[0].Type = "string"
GCPConfigDoc.Fields[0].Note = "" GCPConfigDoc.Fields[0].Note = ""
@ -340,11 +330,6 @@ func init() {
GCPConfigDoc.Fields[7].Note = "" GCPConfigDoc.Fields[7].Note = ""
GCPConfigDoc.Fields[7].Description = "Expected confidential VM measurements." GCPConfigDoc.Fields[7].Description = "Expected confidential VM measurements."
GCPConfigDoc.Fields[7].Comments[encoder.LineComment] = "Expected confidential VM measurements." GCPConfigDoc.Fields[7].Comments[encoder.LineComment] = "Expected confidential VM measurements."
GCPConfigDoc.Fields[8].Name = "enforcedMeasurements"
GCPConfigDoc.Fields[8].Type = "[]uint32"
GCPConfigDoc.Fields[8].Note = ""
GCPConfigDoc.Fields[8].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
GCPConfigDoc.Fields[8].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
QEMUConfigDoc.Type = "QEMUConfig" QEMUConfigDoc.Type = "QEMUConfig"
QEMUConfigDoc.Comments[encoder.LineComment] = "QEMUConfig holds config information for QEMU based Constellation deployments." QEMUConfigDoc.Comments[encoder.LineComment] = "QEMUConfig holds config information for QEMU based Constellation deployments."
@ -355,7 +340,7 @@ func init() {
FieldName: "qemu", FieldName: "qemu",
}, },
} }
QEMUConfigDoc.Fields = make([]encoder.Doc, 10) QEMUConfigDoc.Fields = make([]encoder.Doc, 9)
QEMUConfigDoc.Fields[0].Name = "imageFormat" QEMUConfigDoc.Fields[0].Name = "imageFormat"
QEMUConfigDoc.Fields[0].Type = "string" QEMUConfigDoc.Fields[0].Type = "string"
QEMUConfigDoc.Fields[0].Note = "" QEMUConfigDoc.Fields[0].Note = ""
@ -401,11 +386,6 @@ func init() {
QEMUConfigDoc.Fields[8].Note = "" QEMUConfigDoc.Fields[8].Note = ""
QEMUConfigDoc.Fields[8].Description = "Measurement used to enable measured boot." QEMUConfigDoc.Fields[8].Description = "Measurement used to enable measured boot."
QEMUConfigDoc.Fields[8].Comments[encoder.LineComment] = "Measurement used to enable measured boot." QEMUConfigDoc.Fields[8].Comments[encoder.LineComment] = "Measurement used to enable measured boot."
QEMUConfigDoc.Fields[9].Name = "enforcedMeasurements"
QEMUConfigDoc.Fields[9].Type = "[]uint32"
QEMUConfigDoc.Fields[9].Note = ""
QEMUConfigDoc.Fields[9].Description = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
QEMUConfigDoc.Fields[9].Comments[encoder.LineComment] = "List of values that should be enforced to be equal to the ones from the measurement list. Any non-equal values not in this list will only result in a warning."
} }
func (_ Config) Doc() *encoder.Doc { func (_ Config) Doc() *encoder.Doc {

View file

@ -128,6 +128,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
c.Provider.Azure.ResourceGroup = "test" c.Provider.Azure.ResourceGroup = "test"
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity" c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e" c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return c return c
}(), }(),
envToSet: map[string]string{ envToSet: map[string]string{
@ -147,6 +148,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well. c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well.
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity" c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e" c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return c return c
}(), }(),
envToSet: map[string]string{ envToSet: map[string]string{
@ -182,9 +184,9 @@ func TestNewWithDefaultOptions(t *testing.T) {
} }
func TestValidate(t *testing.T) { func TestValidate(t *testing.T) {
const defaultErrCount = 17 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default const defaultErrCount = 21 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
const azErrCount = 8 const azErrCount = 9
const gcpErrCount = 5 const gcpErrCount = 6
testCases := map[string]struct { testCases := map[string]struct {
cnf *Config cnf *Config
@ -240,6 +242,7 @@ func TestValidate(t *testing.T) {
az.ClientSecretValue = "test-client-secret" az.ClientSecretValue = "test-client-secret"
cnf.Provider = ProviderConfig{} cnf.Provider = ProviderConfig{}
cnf.Provider.Azure = az cnf.Provider.Azure = az
cnf.Provider.Azure.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return cnf return cnf
}(), }(),
}, },
@ -265,6 +268,7 @@ func TestValidate(t *testing.T) {
gcp.ServiceAccountKeyPath = "test-key-path" gcp.ServiceAccountKeyPath = "test-key-path"
cnf.Provider = ProviderConfig{} cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp cnf.Provider.GCP = gcp
cnf.Provider.GCP.Measurements = measurements.M{15: measurements.WithAllBytes(0x00, false)}
return cnf return cnf
}(), }(),
}, },
@ -364,9 +368,9 @@ func TestConfigGeneratedDocsFresh(t *testing.T) {
func TestConfig_UpdateMeasurements(t *testing.T) { func TestConfig_UpdateMeasurements(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
newMeasurements := measurements.M{ newMeasurements := measurements.M{
1: measurements.PCRWithAllBytes(0x00), 1: measurements.WithAllBytes(0x00, false),
2: measurements.PCRWithAllBytes(0x01), 2: measurements.WithAllBytes(0x01, false),
3: measurements.PCRWithAllBytes(0x02), 3: measurements.WithAllBytes(0x02, false),
} }
{ // AWS { // AWS

View file

@ -7,9 +7,11 @@ SPDX-License-Identifier: AGPL-3.0-only
package config package config
import ( import (
"bytes"
"fmt" "fmt"
"strings" "strings"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes" "github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
@ -223,3 +225,35 @@ func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.
return t return t
} }
func validateNoPlaceholder(fl validator.FieldLevel) bool {
return len(getPlaceholderEntries(fl.Field().Interface().(Measurements))) == 0
}
func registerContainsPlaceholderError(ut ut.Translator) error {
return ut.Add("no_placeholders", "{0} placeholder values (repeated 1234...)", true)
}
func translateContainsPlaceholderError(ut ut.Translator, fe validator.FieldError) string {
placeholders := getPlaceholderEntries(fe.Value().(Measurements))
msg := fmt.Sprintf("Measurements %v contain", placeholders)
if len(placeholders) == 1 {
msg = fmt.Sprintf("Measurement %v contains", placeholders)
}
t, _ := ut.T("no_placeholders", msg)
return t
}
func getPlaceholderEntries(m Measurements) []uint32 {
var placeholders []uint32
placeholder := measurements.PlaceHolderMeasurement()
for idx, measurement := range m {
if bytes.Equal(measurement.Expected[:], placeholder.Expected[:]) {
placeholders = append(placeholders, idx)
}
}
return placeholders
}

View file

@ -9,7 +9,9 @@ package watcher
import ( import (
"encoding/asn1" "encoding/asn1"
"encoding/hex" "encoding/hex"
"errors"
"fmt" "fmt"
"os"
"path/filepath" "path/filepath"
"strconv" "strconv"
"sync" "sync"
@ -19,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
@ -42,26 +45,26 @@ func NewValidator(log *logger.Logger, csp string, fileHandler file.Handler, azur
var newValidator newValidatorFunc var newValidator newValidatorFunc
switch cloudprovider.FromString(csp) { switch cloudprovider.FromString(csp) {
case cloudprovider.AWS: case cloudprovider.AWS:
newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator {
return aws.NewValidator(m, e, log) return aws.NewValidator(m, log)
} }
case cloudprovider.Azure: case cloudprovider.Azure:
if azureCVM { if azureCVM {
newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator {
return snp.NewValidator(m, e, idkeydigest, enforceIdKeyDigest, log) return snp.NewValidator(m, idkeydigest, enforceIdKeyDigest, log)
} }
} else { } else {
newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator {
return trustedlaunch.NewValidator(m, e, log) return trustedlaunch.NewValidator(m, log)
} }
} }
case cloudprovider.GCP: case cloudprovider.GCP:
newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator {
return gcp.NewValidator(m, e, log) return gcp.NewValidator(m, log)
} }
case cloudprovider.QEMU: case cloudprovider.QEMU:
newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator { newValidator = func(m measurements.M, _ []byte, _ bool, log *logger.Logger) atls.Validator {
return qemu.NewValidator(m, e, log) return qemu.NewValidator(m, log)
} }
default: default:
return nil, fmt.Errorf("unknown cloud service provider: %q", csp) return nil, fmt.Errorf("unknown cloud service provider: %q", csp)
@ -100,17 +103,24 @@ func (u *Updatable) Update() error {
u.log.Infof("Updating expected measurements") u.log.Infof("Updating expected measurements")
var measurements map[uint32][]byte var measurements measurements.M
if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), &measurements); err != nil { if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), &measurements); err != nil {
return err return err
} }
u.log.Debugf("New measurements: %v", measurements) u.log.Debugf("New measurements: %+v", measurements)
// handle legacy measurement format, where expected measurements and enforced measurements were stored in separate data structures
// TODO: remove with v2.4.0
var enforced []uint32 var enforced []uint32
if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err != nil { if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err == nil {
u.log.Debugf("Detected legacy format. Loading enforced PCRs...")
if err := measurements.SetEnforced(enforced); err != nil {
return err
}
u.log.Debugf("Merged measurements with enforced values: %+v", measurements)
} else if !errors.Is(err, os.ErrNotExist) {
return err return err
} }
u.log.Debugf("Enforced PCRs: %v", enforced)
var idkeydigest []byte var idkeydigest []byte
var enforceIDKeyDigest bool var enforceIDKeyDigest bool
@ -138,9 +148,9 @@ func (u *Updatable) Update() error {
u.log.Debugf("New idkeydigest: %x", idkeydigest) u.log.Debugf("New idkeydigest: %x", idkeydigest)
} }
u.Validator = u.newValidator(measurements, enforced, idkeydigest, enforceIDKeyDigest, u.log) u.Validator = u.newValidator(measurements, idkeydigest, enforceIDKeyDigest, u.log)
return nil return nil
} }
type newValidatorFunc func(measurements map[uint32][]byte, enforcedPCRs []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator type newValidatorFunc func(measurements measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator

View file

@ -20,6 +20,7 @@ import (
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
@ -117,7 +118,7 @@ func TestUpdate(t *testing.T) {
require := require.New(t) require := require.New(t)
oid := fakeOID{1, 3, 9900, 1} oid := fakeOID{1, 3, 9900, 1}
newValidator := func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { newValidator := func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: oid} return fakeValidator{fakeOID: oid}
} }
handler := file.NewHandler(afero.NewMemMapFs()) handler := file.NewHandler(afero.NewMemMapFs())
@ -135,14 +136,7 @@ func TestUpdate(t *testing.T) {
// write measurement config // write measurement config
require.NoError(handler.WriteJSON( require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{ measurements.M{11: measurements.WithAllBytes(0x00, false)},
11: {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0},
},
file.OptNone,
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
)) ))
require.NoError(handler.Write( require.NoError(handler.Write(
filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename), filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename),
@ -189,6 +183,23 @@ func TestUpdate(t *testing.T) {
defer resp.Body.Close() defer resp.Body.Close()
} }
assert.Error(err) assert.Error(err)
// update should work for legacy measurement format
// TODO: remove with v2.4.0
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename),
map[uint32][]byte{
11: bytes.Repeat([]byte{0x0}, 32),
12: bytes.Repeat([]byte{0x1}, 32),
},
file.OptOverwrite,
))
require.NoError(handler.WriteJSON(
filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename),
[]uint32{11},
))
assert.NoError(validator.Update())
} }
func TestUpdateConcurrency(t *testing.T) { func TestUpdateConcurrency(t *testing.T) {
@ -199,7 +210,7 @@ func TestUpdateConcurrency(t *testing.T) {
validator := &Updatable{ validator := &Updatable{
log: logger.NewTest(t), log: logger.NewTest(t),
fileHandler: handler, fileHandler: handler,
newValidator: func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator { newValidator: func(m measurements.M, idkeydigest []byte, enforceIdKeyDigest bool, _ *logger.Logger) atls.Validator {
return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}} return fakeValidator{fakeOID: fakeOID{1, 3, 9900, 1}}
}, },
} }