diff --git a/cli/internal/cloudcmd/create_test.go b/cli/internal/cloudcmd/create_test.go index ffaba934b..eded26230 100644 --- a/cli/internal/cloudcmd/create_test.go +++ b/cli/internal/cloudcmd/create_test.go @@ -62,7 +62,7 @@ func TestCreator(t *testing.T) { provider: cloudprovider.Azure, config: func() *config.Config { cfg := config.Default() - cfg.RemoveProviderExcept(cloudprovider.Azure) + cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure) return cfg }(), policyPatcher: &stubPolicyPatcher{}, @@ -84,7 +84,7 @@ func TestCreator(t *testing.T) { provider: cloudprovider.Azure, config: func() *config.Config { cfg := config.Default() - cfg.RemoveProviderExcept(cloudprovider.Azure) + cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure) return cfg }(), policyPatcher: &stubPolicyPatcher{someErr}, @@ -95,7 +95,7 @@ func TestCreator(t *testing.T) { provider: cloudprovider.Azure, config: func() *config.Config { cfg := config.Default() - cfg.RemoveProviderExcept(cloudprovider.Azure) + cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure) return cfg }(), policyPatcher: &stubPolicyPatcher{}, @@ -106,7 +106,7 @@ func TestCreator(t *testing.T) { provider: cloudprovider.Azure, config: func() *config.Config { cfg := config.Default() - cfg.RemoveProviderExcept(cloudprovider.Azure) + cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure) return cfg }(), policyPatcher: &stubPolicyPatcher{}, diff --git a/cli/internal/cmd/configgenerate.go b/cli/internal/cmd/configgenerate.go index 27fc38e91..6fad71ba8 100644 --- a/cli/internal/cmd/configgenerate.go +++ b/cli/internal/cmd/configgenerate.go @@ -15,6 +15,7 @@ import ( "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" + "github.com/edgelesssys/constellation/v2/internal/variant" "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/siderolabs/talos/pkg/machinery/config/encoder" "github.com/spf13/afero" @@ -36,13 +37,15 @@ func newConfigGenerateCmd() *cobra.Command { } cmd.Flags().StringP("file", "f", constants.ConfigFilename, "path to output file, or '-' for stdout") cmd.Flags().StringP("kubernetes", "k", semver.MajorMinor(config.Default().KubernetesVersion), "Kubernetes version to use in format MAJOR.MINOR") + cmd.Flags().StringP("attestation", "a", "", fmt.Sprintf("attestation variant to use %s. If not specified, the default for the cloud provider is used", printFormattedSlice(variant.GetAvailableAttestationTypes()))) return cmd } type generateFlags struct { - file string - k8sVersion string + file string + k8sVersion string + attestationVariant variant.Variant } type configGenerateCmd struct { @@ -69,7 +72,10 @@ func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file cg.log.Debugf("Parsed flags as %v", flags) cg.log.Debugf("Using cloud provider %s", provider.String()) - conf := createConfig(provider) + conf, err := createConfigWithAttestationType(provider, flags.attestationVariant) + if err != nil { + return fmt.Errorf("creating config: %w", err) + } conf.KubernetesVersion = flags.k8sVersion if flags.file == "-" { content, err := encoder.NewEncoder(conf).Encode() @@ -96,7 +102,7 @@ func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file } // createConfig creates a config file for the given provider. -func createConfig(provider cloudprovider.Provider) *config.Config { +func createConfigWithAttestationType(provider cloudprovider.Provider, attestationVariant variant.Variant) (*config.Config, error) { conf := config.Default() conf.RemoveProviderExcept(provider) @@ -105,7 +111,25 @@ func createConfig(provider cloudprovider.Provider) *config.Config { conf.StateDiskSizeGB = 10 } - return conf + if provider == cloudprovider.Unknown { + return conf, nil + } + if attestationVariant.Equal(variant.Dummy{}) { + attestationVariant = variant.GetDefaultAttestation(provider) + if attestationVariant.Equal(variant.Dummy{}) { + return nil, fmt.Errorf("provider %s does not have a default attestation variant", provider) + } + } else if !variant.ValidProvider(provider, attestationVariant) { + return nil, fmt.Errorf("provider %s does not support attestation type %s", provider, attestationVariant) + } + conf.SetAttestation(attestationVariant) + return conf, nil +} + +// createConfig creates a config file for the given provider. +func createConfig(provider cloudprovider.Provider) *config.Config { + res, _ := createConfigWithAttestationType(provider, variant.Dummy{}) + return res } // supportedVersions prints the supported version without v prefix and without patch version. @@ -135,13 +159,29 @@ func parseGenerateFlags(cmd *cobra.Command) (generateFlags, error) { return generateFlags{}, fmt.Errorf("resolving kuberentes version from flag: %w", err) } + attestationString, err := cmd.Flags().GetString("attestation") + if err != nil { + return generateFlags{}, fmt.Errorf("parsing attestation flag: %w", err) + } + + var attestationType variant.Variant + // if no attestation type is specified, use the default for the cloud provider + if attestationString == "" { + attestationType = variant.Dummy{} + } else { + attestationType, err = variant.FromString(attestationString) + if err != nil { + return generateFlags{}, fmt.Errorf("invalid attestation variant: %s", attestationString) + } + } return generateFlags{ - file: file, - k8sVersion: resolvedVersion, + file: file, + k8sVersion: resolvedVersion, + attestationVariant: attestationType, }, nil } -// createCompletion handles the completion of the create command. It is frequently called +// generateCompletion handles the completion of the create command. It is frequently called // while the user types arguments of the command to suggest completion. func generateCompletion(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) { switch len(args) { @@ -167,3 +207,15 @@ func resolveK8sVersion(k8sVersion string) (string, error) { return extendedVersion, nil } + +func printFormattedSlice[T any](input []T) string { + return fmt.Sprintf("{%s}", strings.Join(toString(input), "|")) +} + +func toString[T any](t []T) []string { + var res []string + for _, v := range t { + res = append(res, fmt.Sprintf("%v", v)) + } + return res +} diff --git a/cli/internal/cmd/configgenerate_test.go b/cli/internal/cmd/configgenerate_test.go index 71c593a38..343fa12da 100644 --- a/cli/internal/cmd/configgenerate_test.go +++ b/cli/internal/cmd/configgenerate_test.go @@ -8,6 +8,7 @@ package cmd import ( "bytes" + "fmt" "testing" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" @@ -15,8 +16,10 @@ import ( "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/edgelesssys/constellation/v2/internal/variant" "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/spf13/afero" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/mod/semver" @@ -87,7 +90,7 @@ func TestConfigGenerateDefaultGCPSpecific(t *testing.T) { cmd := newConfigGenerateCmd() wantConf := config.Default() - wantConf.RemoveProviderExcept(cloudprovider.GCP) + wantConf.RemoveProviderAndAttestationExcept(cloudprovider.GCP) cg := &configGenerateCmd{log: logger.NewTest(t)} require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.GCP)) @@ -139,3 +142,133 @@ func TestConfigGenerateStdOut(t *testing.T) { assert.Equal(*config.Default(), readConfig) } + +func TestNoValidProviderAttestationCombination(t *testing.T) { + assert := assert.New(t) + tests := []struct { + provider cloudprovider.Provider + attestation variant.Variant + }{ + {cloudprovider.Azure, variant.AWSNitroTPM{}}, + {cloudprovider.AWS, variant.AzureTrustedLaunch{}}, + {cloudprovider.GCP, variant.AWSNitroTPM{}}, + {cloudprovider.QEMU, variant.GCPSEVES{}}, + } + for _, test := range tests { + t.Run("", func(t *testing.T) { + _, err := createConfigWithAttestationType(test.provider, test.attestation) + assert.Error(err) + }) + } +} + +func TestValidProviderAttestationCombination(t *testing.T) { + defaultAttestation := config.Default().Attestation + tests := []struct { + provider cloudprovider.Provider + attestation variant.Variant + expected config.AttestationConfig + }{ + { + cloudprovider.Azure, + variant.AzureTrustedLaunch{}, + config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch}, + }, + { + cloudprovider.Azure, + variant.AzureSEVSNP{}, + config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP}, + }, + + { + cloudprovider.AWS, + variant.AWSNitroTPM{}, + config.AttestationConfig{AWSNitroTPM: defaultAttestation.AWSNitroTPM}, + }, + { + cloudprovider.GCP, + variant.GCPSEVES{}, + config.AttestationConfig{GCPSEVES: defaultAttestation.GCPSEVES}, + }, + + { + cloudprovider.QEMU, + variant.QEMUVTPM{}, + config.AttestationConfig{QEMUVTPM: defaultAttestation.QEMUVTPM}, + }, + { + cloudprovider.OpenStack, + variant.QEMUVTPM{}, + config.AttestationConfig{QEMUVTPM: defaultAttestation.QEMUVTPM}, + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("Provider:%s,Attestation:%s", test.provider, test.attestation), func(t *testing.T) { + sut, err := createConfigWithAttestationType(test.provider, test.attestation) + assert := assert.New(t) + assert.NoError(err) + assert.Equal(test.expected, sut.Attestation) + }) + } +} + +func TestAttestationArgument(t *testing.T) { + defaultAttestation := config.Default().Attestation + tests := []struct { + name string + provider cloudprovider.Provider + expectErr bool + expectedCfg config.AttestationConfig + setFlag func(*cobra.Command) error + }{ + { + name: "InvalidAttestationArgument", + provider: cloudprovider.Unknown, + expectErr: true, + setFlag: func(cmd *cobra.Command) error { + return cmd.Flags().Set("attestation", "unknown") + }, + }, + { + name: "ValidAttestationArgument", + provider: cloudprovider.Azure, + expectErr: false, + setFlag: func(cmd *cobra.Command) error { + return cmd.Flags().Set("attestation", "azure-trustedlaunch") + }, + expectedCfg: config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch}, + }, + { + name: "WithoutAttestationArgument", + provider: cloudprovider.Azure, + expectErr: false, + setFlag: func(cmd *cobra.Command) error { + return nil + }, + expectedCfg: config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP}, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := assert.New(t) + assert := assert.New(t) + + cmd := newConfigGenerateCmd() + require.NoError(test.setFlag(cmd)) + + fileHandler := file.NewHandler(afero.NewMemMapFs()) + + cg := &configGenerateCmd{log: logger.NewTest(t)} + err := cg.configGenerate(cmd, fileHandler, test.provider) + if test.expectErr { + assert.Error(err) + } else { + assert.NoError(err) + var readConfig config.Config + require.NoError(fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)) + + assert.Equal(test.expectedCfg, readConfig.Attestation) + } + }) + } +} diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 7c341d82e..35b276839 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -429,7 +429,7 @@ func TestAttestation(t *testing.T) { cfg := config.Default() cfg.Image = "image" - cfg.RemoveProviderExcept(cloudprovider.QEMU) + cfg.RemoveProviderAndAttestationExcept(cloudprovider.QEMU) cfg.Attestation.QEMUVTPM.Measurements[0] = measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength) cfg.Attestation.QEMUVTPM.Measurements[1] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength) cfg.Attestation.QEMUVTPM.Measurements[2] = measurements.WithAllBytes(0x22, measurements.Enforce, measurements.PCRMeasurementLength) @@ -554,7 +554,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs conf.Attestation.QEMUVTPM.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength) } - conf.RemoveProviderExcept(csp) + conf.RemoveProviderAndAttestationExcept(csp) return conf } diff --git a/cli/internal/cmd/miniup.go b/cli/internal/cmd/miniup.go index 33ea61c38..ae287d0bd 100644 --- a/cli/internal/cmd/miniup.go +++ b/cli/internal/cmd/miniup.go @@ -216,7 +216,7 @@ func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler, config := config.Default() config.Name = constants.MiniConstellationUID - config.RemoveProviderExcept(cloudprovider.QEMU) + config.RemoveProviderAndAttestationExcept(cloudprovider.QEMU) config.StateDiskSizeGB = 8 // only release images (e.g. v2.7.0) use the production NVRAM diff --git a/dev-docs/workflows/build-test-run.md b/dev-docs/workflows/build-test-run.md index 5613c7b25..5d7355d50 100644 --- a/dev-docs/workflows/build-test-run.md +++ b/dev-docs/workflows/build-test-run.md @@ -9,6 +9,8 @@ Prerequisites: * [Bazelisk installed as `bazel` in your path](https://github.com/bazelbuild/bazelisk/releases). * [Docker](https://docs.docker.com/engine/install/). Can be installed with these commands on Ubuntu 22.04: `sudo apt update && sudo apt install docker.io`. As the build spawns docker containers your user account either needs to be in the `docker` group (Add with `sudo usermod -a -G docker $USER`) or you have to run builds with `sudo`. When using `sudo` remember that your root user might (depending on your distro and local config) not have the go binary in it's PATH. The current PATH can be forwarded to the root env with `sudo env PATH=$PATH `. +--- +### On Linux * Packages on Ubuntu: ```sh @@ -21,6 +23,20 @@ Prerequisites: sudo dnf install @development-tools pkg-config cmake openssl-devel cryptsetup-libs cryptsetup-devel ``` +### On Mac + +``` +brew install bash +``` +to fix unsupported shell options used in some build script. + +To troubleshoot potential problems with bazel on ARM architecture when running it for the first time, it might help to purge and retry: +``` +bazel clean --expunge +``` + +--- + Developer workspace: ```sh @@ -83,6 +99,19 @@ Running unit tests with Bazel: bazel test //... ``` +# Opening a PR +Before opening a PR, please run the tests and +``` +bazel run //:generate && bazel run //:tidy +bazel run //:check +``` + +The linter check doesn't work on Mac at the moment, but you can run the linter directly: +``` +golangci-lint run +``` +Furthermore, the PR titles are used for the changelog, so please stick to our [conventions](https://github.com/edgelesssys/constellation/blob/main/dev-docs/conventions.md#pr-conventions). + # Deploy > :warning: Debug images are not safe to use in production environments. :warning: diff --git a/docs/docs/reference/cli.md b/docs/docs/reference/cli.md index 6fd2d21fd..6bd6f11f5 100644 --- a/docs/docs/reference/cli.md +++ b/docs/docs/reference/cli.md @@ -75,9 +75,10 @@ constellation config generate {aws|azure|gcp|openstack|qemu} [flags] ### Options ``` - -f, --file string path to output file, or '-' for stdout (default "constellation-conf.yaml") - -h, --help help for generate - -k, --kubernetes string Kubernetes version to use in format MAJOR.MINOR (default "v1.26") + -a, --attestation string attestation variant to use {aws-nitro-tpm|azure-sev-snp|azure-trustedlaunch|gcp-sev-es|qemu-vtpm}. If not specified, the default for the cloud provider is used + -f, --file string path to output file, or '-' for stdout (default "constellation-conf.yaml") + -h, --help help for generate + -k, --kubernetes string Kubernetes version to use in format MAJOR.MINOR (default "v1.26") ``` ### Options inherited from parent commands diff --git a/internal/config/config.go b/internal/config/config.go index f79506633..03db68449 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -447,6 +447,12 @@ func (c *Config) UpdateMeasurements(newMeasurements measurements.M) { } } +// RemoveProviderAndAttestationExcept calls RemoveProviderExcept and sets the default attestations for the provider (only used for convenience in tests). +func (c *Config) RemoveProviderAndAttestationExcept(provider cloudprovider.Provider) { + c.RemoveProviderExcept(provider) + c.SetAttestation(variant.GetDefaultAttestation(provider)) +} + // RemoveProviderExcept removes all provider specific configurations, i.e., // sets them to nil, except the one specified. // If an unknown provider is passed, the same configuration is returned. @@ -454,29 +460,37 @@ func (c *Config) RemoveProviderExcept(provider cloudprovider.Provider) { currentProviderConfigs := c.Provider c.Provider = ProviderConfig{} - // TODO(AB#2976): Replace attestation replacement - // with custom function for attestation selection - currentAttetationConfigs := c.Attestation - c.Attestation = AttestationConfig{} switch provider { case cloudprovider.AWS: c.Provider.AWS = currentProviderConfigs.AWS - c.Attestation.AWSNitroTPM = currentAttetationConfigs.AWSNitroTPM case cloudprovider.Azure: c.Provider.Azure = currentProviderConfigs.Azure - c.Attestation.AzureSEVSNP = currentAttetationConfigs.AzureSEVSNP case cloudprovider.GCP: c.Provider.GCP = currentProviderConfigs.GCP - c.Attestation.GCPSEVES = currentAttetationConfigs.GCPSEVES case cloudprovider.OpenStack: c.Provider.OpenStack = currentProviderConfigs.OpenStack - c.Attestation.QEMUVTPM = currentAttetationConfigs.QEMUVTPM case cloudprovider.QEMU: c.Provider.QEMU = currentProviderConfigs.QEMU - c.Attestation.QEMUVTPM = currentAttetationConfigs.QEMUVTPM default: c.Provider = currentProviderConfigs - c.Attestation = currentAttetationConfigs + } +} + +// SetAttestation sets the attestation config for the given attestation variant and removes all other attestation configs. +func (c *Config) SetAttestation(attestation variant.Variant) { + currentAttetationConfigs := c.Attestation + c.Attestation = AttestationConfig{} + switch attestation.(type) { + case variant.AzureSEVSNP: + c.Attestation = AttestationConfig{AzureSEVSNP: currentAttetationConfigs.AzureSEVSNP} + case variant.AWSNitroTPM: + c.Attestation = AttestationConfig{AWSNitroTPM: currentAttetationConfigs.AWSNitroTPM} + case variant.AzureTrustedLaunch: + c.Attestation = AttestationConfig{AzureTrustedLaunch: currentAttetationConfigs.AzureTrustedLaunch} + case variant.GCPSEVES: + c.Attestation = AttestationConfig{GCPSEVES: currentAttetationConfigs.GCPSEVES} + case variant.QEMUVTPM: + c.Attestation = AttestationConfig{QEMUVTPM: currentAttetationConfigs.QEMUVTPM} } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c4cdd7d1c..c2e0bc27f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -121,7 +121,7 @@ func TestNewWithDefaultOptions(t *testing.T) { "set env works": { confToWrite: func() *Config { // valid config with all, but clientSecretValue c := Default() - c.RemoveProviderExcept(cloudprovider.Azure) + c.RemoveProviderAndAttestationExcept(cloudprovider.Azure) c.Image = "v" + constants.VersionInfo() c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5" c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa" @@ -142,7 +142,7 @@ func TestNewWithDefaultOptions(t *testing.T) { "set env overwrites": { confToWrite: func() *Config { c := Default() - c.RemoveProviderExcept(cloudprovider.Azure) + c.RemoveProviderAndAttestationExcept(cloudprovider.Azure) c.Image = "v" + constants.VersionInfo() c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5" c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa" @@ -231,7 +231,7 @@ func TestValidate(t *testing.T) { "default Azure config is not valid": { cnf: func() *Config { cnf := Default() - cnf.RemoveProviderExcept(cloudprovider.Azure) + cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure) return cnf }(), wantErr: true, @@ -240,7 +240,7 @@ func TestValidate(t *testing.T) { "Azure config with all required fields is valid": { cnf: func() *Config { cnf := Default() - cnf.RemoveProviderExcept(cloudprovider.Azure) + cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure) cnf.Image = "v" + constants.VersionInfo() az := cnf.Provider.Azure az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab" @@ -261,7 +261,7 @@ func TestValidate(t *testing.T) { "default GCP config is not valid": { cnf: func() *Config { cnf := Default() - cnf.RemoveProviderExcept(cloudprovider.GCP) + cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP) return cnf }(), wantErr: true, @@ -270,7 +270,7 @@ func TestValidate(t *testing.T) { "GCP config with all required fields is valid": { cnf: func() *Config { cnf := Default() - cnf.RemoveProviderExcept(cloudprovider.GCP) + cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP) cnf.Image = "v" + constants.VersionInfo() gcp := cnf.Provider.GCP gcp.Region = "test-region" @@ -379,7 +379,7 @@ func TestConfigRemoveProviderExcept(t *testing.T) { assert := assert.New(t) conf := Default() - conf.RemoveProviderExcept(tc.removeExcept) + conf.RemoveProviderAndAttestationExcept(tc.removeExcept) assert.Equal(tc.wantAWS, conf.Provider.AWS) assert.Equal(tc.wantAzure, conf.Provider.Azure) @@ -411,7 +411,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) { { // AWS conf := Default() - conf.RemoveProviderExcept(cloudprovider.AWS) + conf.RemoveProviderAndAttestationExcept(cloudprovider.AWS) for k := range conf.Attestation.AWSNitroTPM.Measurements { delete(conf.Attestation.AWSNitroTPM.Measurements, k) } @@ -420,7 +420,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) { } { // Azure conf := Default() - conf.RemoveProviderExcept(cloudprovider.Azure) + conf.RemoveProviderAndAttestationExcept(cloudprovider.Azure) for k := range conf.Attestation.AzureSEVSNP.Measurements { delete(conf.Attestation.AzureSEVSNP.Measurements, k) } @@ -429,7 +429,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) { } { // GCP conf := Default() - conf.RemoveProviderExcept(cloudprovider.GCP) + conf.RemoveProviderAndAttestationExcept(cloudprovider.GCP) for k := range conf.Attestation.GCPSEVES.Measurements { delete(conf.Attestation.GCPSEVES.Measurements, k) } @@ -438,7 +438,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) { } { // QEMU conf := Default() - conf.RemoveProviderExcept(cloudprovider.QEMU) + conf.RemoveProviderAndAttestationExcept(cloudprovider.QEMU) for k := range conf.Attestation.QEMUVTPM.Measurements { delete(conf.Attestation.QEMUVTPM.Measurements, k) } diff --git a/internal/variant/BUILD.bazel b/internal/variant/BUILD.bazel index 916444f3c..1251c7192 100644 --- a/internal/variant/BUILD.bazel +++ b/internal/variant/BUILD.bazel @@ -5,4 +5,5 @@ go_library( srcs = ["variant.go"], importpath = "github.com/edgelesssys/constellation/v2/internal/variant", visibility = ["//:__subpackages__"], + deps = ["//internal/cloud/cloudprovider"], ) diff --git a/internal/variant/variant.go b/internal/variant/variant.go index 3db996b0e..3c70e764b 100644 --- a/internal/variant/variant.go +++ b/internal/variant/variant.go @@ -34,6 +34,9 @@ package variant import ( "encoding/asn1" "fmt" + "sort" + + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" ) const ( @@ -46,6 +49,42 @@ const ( qemuTDX = "qemu-tdx" ) +var providerAttestationMapping = map[cloudprovider.Provider][]Variant{ + cloudprovider.AWS: {AWSNitroTPM{}}, + cloudprovider.Azure: {AzureSEVSNP{}, AzureTrustedLaunch{}}, + cloudprovider.GCP: {GCPSEVES{}}, + cloudprovider.QEMU: {QEMUVTPM{}}, + cloudprovider.OpenStack: {QEMUVTPM{}}, +} + +// GetDefaultAttestation returns the default attestation type for the given provider. If not found, it returns the default variant. +func GetDefaultAttestation(provider cloudprovider.Provider) Variant { + res, ok := providerAttestationMapping[provider] + if ok { + return res[0] + } + return Dummy{} +} + +// GetAvailableAttestationTypes returns the available attestation types. +func GetAvailableAttestationTypes() []Variant { + var res []Variant + + // assumes that cloudprovider.Provider is a uint32 to sort the providers and get a consistent order + var keys []cloudprovider.Provider + for k := range providerAttestationMapping { + keys = append(keys, k) + } + sort.Slice(keys, func(i, j int) bool { + return uint(keys[i]) < uint(keys[j]) + }) + + for _, k := range keys { + res = append(res, providerAttestationMapping[k]...) + } + return removeDuplicate(res) +} + // Getter returns an ASN.1 Object Identifier. type Getter interface { OID() asn1.ObjectIdentifier @@ -79,7 +118,20 @@ func FromString(oid string) (Variant, error) { return nil, fmt.Errorf("unknown OID: %q", oid) } -// Dummy OID for testing. +// ValidProvider returns true if the attestation type is valid for the given provider. +func ValidProvider(provider cloudprovider.Provider, variant Variant) bool { + validTypes, ok := providerAttestationMapping[provider] + if ok { + for _, aType := range validTypes { + if variant.Equal(aType) { + return true + } + } + } + return false +} + +// Dummy OID for testfing. type Dummy struct{} // OID returns the struct's object identifier. @@ -92,7 +144,7 @@ func (Dummy) String() string { return dummy } -// Equal returns true if the other variant is also a Dummy. +// Equal returns true if the other variant is also a Default. func (Dummy) Equal(other Getter) bool { return other.OID().Equal(Dummy{}.OID()) } @@ -206,3 +258,15 @@ func (QEMUTDX) String() string { func (QEMUTDX) Equal(other Getter) bool { return other.OID().Equal(QEMUTDX{}.OID()) } + +func removeDuplicate(sliceList []Variant) []Variant { + allKeys := make(map[Variant]bool) + list := []Variant{} + for _, item := range sliceList { + if _, value := allKeys[item]; !value { + allKeys[item] = true + list = append(list, item) + } + } + return list +}