diff --git a/bootstrapper/cmd/bootstrapper/main.go b/bootstrapper/cmd/bootstrapper/main.go index ebae1ec8e..4fe908e15 100644 --- a/bootstrapper/cmd/bootstrapper/main.go +++ b/bootstrapper/cmd/bootstrapper/main.go @@ -18,12 +18,7 @@ import ( "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/k8sapi" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/kubernetes/kubewaiter" "github.com/edgelesssys/constellation/v2/bootstrapper/internal/logging" - "github.com/edgelesssys/constellation/v2/internal/atls" - "github.com/edgelesssys/constellation/v2/internal/attestation/aws" - "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" - "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" - "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" - "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" + "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/attestation/simulator" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" awscloud "github.com/edgelesssys/constellation/v2/internal/cloud/aws" @@ -41,7 +36,7 @@ import ( ) const ( - // ConstellationCSP is the environment variable stating which Cloud Service Provider Constellation is running on. + // constellationCSP is the environment variable stating which Cloud Service Provider Constellation is running on. constellationCSP = "CONSTEL_CSP" ) @@ -66,7 +61,6 @@ func main() { var clusterInitJoiner clusterInitJoiner var metadataAPI metadataAPI var cloudLogger logging.CloudLogger - var issuer atls.Issuer var openTPM vtpm.TPMOpenFunc var fs afero.Fs @@ -75,6 +69,15 @@ func main() { log.With(zap.Error(err)).Fatalf("Helm client could not be initialized") } + attestVariant, err := oid.FromString(os.Getenv(constants.AttestationVariant)) + if err != nil { + log.With(zap.Error(err)).Fatalf("Failed to parse attestation variant") + } + issuer, err := choose.Issuer(attestVariant, log) + if err != nil { + log.With(zap.Error(err)).Fatalf("Failed to select issuer") + } + switch cloudprovider.FromString(os.Getenv(constellationCSP)) { case cloudprovider.AWS: measurements, err := vtpm.GetSelectedMeasurements(vtpm.OpenVTPM, vtpm.AWSPCRSelection) @@ -82,8 +85,6 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } - issuer = aws.NewIssuer(log) - metadata, err := awscloud.New(ctx) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to set up AWS metadata API") @@ -108,8 +109,6 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } - issuer = gcp.NewIssuer(log) - metadata, err := gcpcloud.New(ctx) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to create GCP metadata client") @@ -136,13 +135,6 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } - if _, err := snp.GetIDKeyDigest(vtpm.OpenVTPM); err == nil { - issuer = snp.NewIssuer(log) - } else { - // assume we are running in a trusted-launch VM - issuer = trustedlaunch.NewIssuer(log) - } - metadata, err := azurecloud.New(ctx) if err != nil { log.With(zap.Error(err)).Fatalf("Failed to create Azure metadata client") @@ -166,8 +158,6 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to get selected PCRs") } - issuer = qemu.NewIssuer(log) - cloudLogger = qemucloud.NewLogger() metadata := qemucloud.New() clusterInitJoiner = kubernetes.New( @@ -179,7 +169,6 @@ func main() { openTPM = vtpm.OpenVTPM fs = afero.NewOsFs() default: - issuer = atls.NewFakeIssuer(oid.Dummy{}) clusterInitJoiner = &clusterFake{} metadataAPI = &providerMetadataFake{} cloudLogger = &logging.NopLogger{} diff --git a/disk-mapper/cmd/main.go b/disk-mapper/cmd/main.go index 9b923a0c1..a610be2ff 100644 --- a/disk-mapper/cmd/main.go +++ b/disk-mapper/cmd/main.go @@ -14,17 +14,14 @@ import ( "net" "net/http" "net/url" + "os" "path/filepath" "github.com/edgelesssys/constellation/v2/disk-mapper/internal/mapper" "github.com/edgelesssys/constellation/v2/disk-mapper/internal/recoveryserver" "github.com/edgelesssys/constellation/v2/disk-mapper/internal/rejoinclient" "github.com/edgelesssys/constellation/v2/disk-mapper/internal/setup" - "github.com/edgelesssys/constellation/v2/internal/atls" - "github.com/edgelesssys/constellation/v2/internal/attestation/aws" - "github.com/edgelesssys/constellation/v2/internal/attestation/azure" - "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" - "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" + "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" awscloud "github.com/edgelesssys/constellation/v2/internal/cloud/aws" azurecloud "github.com/edgelesssys/constellation/v2/internal/cloud/azure" @@ -37,6 +34,7 @@ import ( "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/role" tpmClient "github.com/google/go-tpm-tools/client" "github.com/google/go-tpm/tpm2" @@ -61,10 +59,18 @@ func main() { log.With(zap.String("version", constants.VersionInfo()), zap.String("cloudProvider", *csp)). Infof("Starting disk-mapper") - // set up metadata API and quote issuer for aTLS connections - var err error + // set up quote issuer for aTLS connections + attestVariant, err := oid.FromString(os.Getenv(constants.AttestationVariant)) + if err != nil { + log.With(zap.Error(err)).Fatalf("Failed to parse attestation variant") + } + issuer, err := choose.Issuer(attestVariant, log) + if err != nil { + log.With(zap.Error(err)).Fatalf("Failed to select issuer") + } + + // set up metadata API var diskPath string - var issuer atls.Issuer var metadataClient setup.MetadataAPI switch cloudprovider.FromString(*csp) { case cloudprovider.AWS: @@ -80,8 +86,6 @@ func main() { log.With(zap.Error(err)).Fatalf("Failed to set up AWS metadata client") } - issuer = aws.NewIssuer(log) - case cloudprovider.Azure: diskPath, err = filepath.EvalSymlinks(azureStateDiskPath) if err != nil { @@ -93,15 +97,12 @@ func main() { log.With(zap.Error).Fatalf("Failed to set up Azure metadata client") } - issuer = azure.NewIssuer(log) - case cloudprovider.GCP: diskPath, err = filepath.EvalSymlinks(gcpStateDiskPath) if err != nil { _ = exportPCRs() log.With(zap.Error(err)).Fatalf("Unable to resolve GCP state disk path") } - issuer = gcp.NewIssuer(log) gcpMeta, err := gcpcloud.New(context.Background()) if err != nil { log.With(zap.Error).Fatalf("Failed to create GCP metadata client") @@ -115,13 +116,10 @@ func main() { if err != nil { log.With(zap.Error).Fatalf("Failed to create OpenStack metadata client") } - // TODO(malt3): implement OpenStack quote issuer - issuer = qemu.NewIssuer(log) _ = exportPCRs() case cloudprovider.QEMU: diskPath = qemuStateDiskPath - issuer = qemu.NewIssuer(log) metadataClient = qemucloud.New() _ = exportPCRs() diff --git a/internal/attestation/azure/azure.go b/internal/attestation/azure/azure.go index abac62368..53e6a9466 100644 --- a/internal/attestation/azure/azure.go +++ b/internal/attestation/azure/azure.go @@ -18,19 +18,3 @@ Constellation supports multiple attestation technologies on Azure. Basic TPM attestation. */ package azure - -import ( - "github.com/edgelesssys/constellation/v2/internal/atls" - "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" - "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" - "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" -) - -// NewIssuer returns an SNP issuer if it can successfully read the idkeydigest from the TPM. -// Otherwise returns a Trusted Launch issuer. -func NewIssuer(log vtpm.AttestationLogger) atls.Issuer { - if _, err := snp.GetIDKeyDigest(vtpm.OpenVTPM); err == nil { - return snp.NewIssuer(log) - } - return trustedlaunch.NewIssuer(log) -} diff --git a/internal/attestation/choose/choose.go b/internal/attestation/choose/choose.go new file mode 100644 index 000000000..3541b0814 --- /dev/null +++ b/internal/attestation/choose/choose.go @@ -0,0 +1,66 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package choose + +import ( + "fmt" + + "github.com/edgelesssys/constellation/v2/internal/atls" + "github.com/edgelesssys/constellation/v2/internal/attestation/aws" + "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" + "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" + "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" + "github.com/edgelesssys/constellation/v2/internal/attestation/idkeydigest" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" + "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" + "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" + "github.com/edgelesssys/constellation/v2/internal/oid" +) + +// Issuer returns the issuer for the given variant. +func Issuer(variant oid.Getter, log vtpm.AttestationLogger) (atls.Issuer, error) { + switch variant { + case oid.AWSNitroTPM{}: + return aws.NewIssuer(log), nil + case oid.AzureTrustedLaunch{}: + return trustedlaunch.NewIssuer(log), nil + case oid.AzureSEVSNP{}: + return snp.NewIssuer(log), nil + case oid.GCPSEVES{}: + return gcp.NewIssuer(log), nil + case oid.QEMUVTPM{}: + return qemu.NewIssuer(log), nil + case oid.Dummy{}: + return atls.NewFakeIssuer(oid.Dummy{}), nil + default: + return nil, fmt.Errorf("unknown attestation variant: %s", variant) + } +} + +// Validator returns the validator for the given variant. +func Validator( + variant oid.Getter, measurements measurements.M, + idKeyDigest idkeydigest.IDKeyDigests, enfoceIDKeyDigest bool, + log vtpm.AttestationLogger, +) (atls.Validator, error) { + switch variant { + case oid.AWSNitroTPM{}: + return aws.NewValidator(measurements, log), nil + case oid.AzureTrustedLaunch{}: + return trustedlaunch.NewValidator(measurements, log), nil + case oid.AzureSEVSNP{}: + return snp.NewValidator(measurements, idKeyDigest, enfoceIDKeyDigest, log), nil + case oid.GCPSEVES{}: + return gcp.NewValidator(measurements, log), nil + case oid.QEMUVTPM{}: + return qemu.NewValidator(measurements, log), nil + case oid.Dummy{}: + return atls.NewFakeValidator(oid.Dummy{}), nil + default: + return nil, fmt.Errorf("unknown attestation variant: %s", variant) + } +} diff --git a/internal/attestation/choose/choose_test.go b/internal/attestation/choose/choose_test.go new file mode 100644 index 000000000..9d23e6740 --- /dev/null +++ b/internal/attestation/choose/choose_test.go @@ -0,0 +1,114 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package choose + +import ( + "encoding/asn1" + "testing" + + "github.com/edgelesssys/constellation/v2/internal/oid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIssuer(t *testing.T) { + testCases := map[string]struct { + variant oid.Getter + wantErr bool + }{ + "aws-nitro-tpm": { + variant: oid.AWSNitroTPM{}, + }, + "azure-sev-snp": { + variant: oid.AzureSEVSNP{}, + }, + "azure-trusted-launch": { + variant: oid.AzureTrustedLaunch{}, + }, + "gcp-sev-es": { + variant: oid.GCPSEVES{}, + }, + "qemu-vtpm": { + variant: oid.QEMUVTPM{}, + }, + "dummy": { + variant: oid.Dummy{}, + }, + "unknown": { + variant: unknownVariant{}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + issuer, err := Issuer(tc.variant, nil) + + if tc.wantErr { + assert.Error(err) + return + } + require.NoError(err) + assert.True(issuer.OID().Equal(tc.variant.OID())) + }) + } +} + +func TestValidator(t *testing.T) { + testCases := map[string]struct { + variant oid.Getter + wantErr bool + }{ + "aws-nitro-tpm": { + variant: oid.AWSNitroTPM{}, + }, + "azure-sev-snp": { + variant: oid.AzureSEVSNP{}, + }, + "azure-trusted-launch": { + variant: oid.AzureTrustedLaunch{}, + }, + "gcp-sev-es": { + variant: oid.GCPSEVES{}, + }, + "qemu-vtpm": { + variant: oid.QEMUVTPM{}, + }, + "dummy": { + variant: oid.Dummy{}, + }, + "unknown": { + variant: unknownVariant{}, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + validator, err := Validator(tc.variant, nil, nil, false, nil) + + if tc.wantErr { + assert.Error(err) + return + } + require.NoError(err) + assert.True(validator.OID().Equal(tc.variant.OID())) + }) + } +} + +type unknownVariant struct{} + +func (unknownVariant) OID() asn1.ObjectIdentifier { + return asn1.ObjectIdentifier{1, 3, 9900, 9999, 9999} +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go index a24e8403b..154aaa1ce 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -31,6 +31,8 @@ const ( ConstellationSaltKey = "salt" // ConstellationVerifyServiceUserData is the user data that the verification service includes in the attestation. ConstellationVerifyServiceUserData = "VerifyService" + // AttestationVariant is the name of the environment variable that contains the attestation variant. + AttestationVariant = "CONSTEL_ATTESTATION_VARIANT" // // Ports. diff --git a/verify/cmd/main.go b/verify/cmd/main.go index 7e1965163..265e81898 100644 --- a/verify/cmd/main.go +++ b/verify/cmd/main.go @@ -12,7 +12,7 @@ import ( "strconv" "github.com/edgelesssys/constellation/v2/internal/attestation/aws" - "github.com/edgelesssys/constellation/v2/internal/attestation/azure" + "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" @@ -39,7 +39,7 @@ func main() { case cloudprovider.GCP: issuer = gcp.NewIssuer(log) case cloudprovider.Azure: - issuer = azure.NewIssuer(log) + issuer = snp.NewIssuer(log) // TODO: dynamic selection case cloudprovider.QEMU: issuer = qemu.NewIssuer(log) default: