cli: new flag to set the attestation type for config generate (#1769)

* add attestation flag to specify type in config
This commit is contained in:
Adrian Stobbe 2023-05-17 16:53:56 +02:00 committed by GitHub
parent e7b7a544f0
commit f99e06b63b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 336 additions and 42 deletions

View File

@ -62,7 +62,7 @@ func TestCreator(t *testing.T) {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
config: func() *config.Config { config: func() *config.Config {
cfg := config.Default() cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.Azure) cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
return cfg return cfg
}(), }(),
policyPatcher: &stubPolicyPatcher{}, policyPatcher: &stubPolicyPatcher{},
@ -84,7 +84,7 @@ func TestCreator(t *testing.T) {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
config: func() *config.Config { config: func() *config.Config {
cfg := config.Default() cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.Azure) cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
return cfg return cfg
}(), }(),
policyPatcher: &stubPolicyPatcher{someErr}, policyPatcher: &stubPolicyPatcher{someErr},
@ -95,7 +95,7 @@ func TestCreator(t *testing.T) {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
config: func() *config.Config { config: func() *config.Config {
cfg := config.Default() cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.Azure) cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
return cfg return cfg
}(), }(),
policyPatcher: &stubPolicyPatcher{}, policyPatcher: &stubPolicyPatcher{},
@ -106,7 +106,7 @@ func TestCreator(t *testing.T) {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
config: func() *config.Config { config: func() *config.Config {
cfg := config.Default() cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.Azure) cfg.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
return cfg return cfg
}(), }(),
policyPatcher: &stubPolicyPatcher{}, policyPatcher: &stubPolicyPatcher{},

View File

@ -15,6 +15,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/variant"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/siderolabs/talos/pkg/machinery/config/encoder" "github.com/siderolabs/talos/pkg/machinery/config/encoder"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -36,6 +37,7 @@ func newConfigGenerateCmd() *cobra.Command {
} }
cmd.Flags().StringP("file", "f", constants.ConfigFilename, "path to output file, or '-' for stdout") cmd.Flags().StringP("file", "f", constants.ConfigFilename, "path to output file, or '-' for stdout")
cmd.Flags().StringP("kubernetes", "k", semver.MajorMinor(config.Default().KubernetesVersion), "Kubernetes version to use in format MAJOR.MINOR") cmd.Flags().StringP("kubernetes", "k", semver.MajorMinor(config.Default().KubernetesVersion), "Kubernetes version to use in format MAJOR.MINOR")
cmd.Flags().StringP("attestation", "a", "", fmt.Sprintf("attestation variant to use %s. If not specified, the default for the cloud provider is used", printFormattedSlice(variant.GetAvailableAttestationTypes())))
return cmd return cmd
} }
@ -43,6 +45,7 @@ func newConfigGenerateCmd() *cobra.Command {
type generateFlags struct { type generateFlags struct {
file string file string
k8sVersion string k8sVersion string
attestationVariant variant.Variant
} }
type configGenerateCmd struct { type configGenerateCmd struct {
@ -69,7 +72,10 @@ func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file
cg.log.Debugf("Parsed flags as %v", flags) cg.log.Debugf("Parsed flags as %v", flags)
cg.log.Debugf("Using cloud provider %s", provider.String()) cg.log.Debugf("Using cloud provider %s", provider.String())
conf := createConfig(provider) conf, err := createConfigWithAttestationType(provider, flags.attestationVariant)
if err != nil {
return fmt.Errorf("creating config: %w", err)
}
conf.KubernetesVersion = flags.k8sVersion conf.KubernetesVersion = flags.k8sVersion
if flags.file == "-" { if flags.file == "-" {
content, err := encoder.NewEncoder(conf).Encode() content, err := encoder.NewEncoder(conf).Encode()
@ -96,7 +102,7 @@ func (cg *configGenerateCmd) configGenerate(cmd *cobra.Command, fileHandler file
} }
// createConfig creates a config file for the given provider. // createConfig creates a config file for the given provider.
func createConfig(provider cloudprovider.Provider) *config.Config { func createConfigWithAttestationType(provider cloudprovider.Provider, attestationVariant variant.Variant) (*config.Config, error) {
conf := config.Default() conf := config.Default()
conf.RemoveProviderExcept(provider) conf.RemoveProviderExcept(provider)
@ -105,7 +111,25 @@ func createConfig(provider cloudprovider.Provider) *config.Config {
conf.StateDiskSizeGB = 10 conf.StateDiskSizeGB = 10
} }
return conf if provider == cloudprovider.Unknown {
return conf, nil
}
if attestationVariant.Equal(variant.Dummy{}) {
attestationVariant = variant.GetDefaultAttestation(provider)
if attestationVariant.Equal(variant.Dummy{}) {
return nil, fmt.Errorf("provider %s does not have a default attestation variant", provider)
}
} else if !variant.ValidProvider(provider, attestationVariant) {
return nil, fmt.Errorf("provider %s does not support attestation type %s", provider, attestationVariant)
}
conf.SetAttestation(attestationVariant)
return conf, nil
}
// createConfig creates a config file for the given provider.
func createConfig(provider cloudprovider.Provider) *config.Config {
res, _ := createConfigWithAttestationType(provider, variant.Dummy{})
return res
} }
// supportedVersions prints the supported version without v prefix and without patch version. // supportedVersions prints the supported version without v prefix and without patch version.
@ -135,13 +159,29 @@ func parseGenerateFlags(cmd *cobra.Command) (generateFlags, error) {
return generateFlags{}, fmt.Errorf("resolving kuberentes version from flag: %w", err) return generateFlags{}, fmt.Errorf("resolving kuberentes version from flag: %w", err)
} }
attestationString, err := cmd.Flags().GetString("attestation")
if err != nil {
return generateFlags{}, fmt.Errorf("parsing attestation flag: %w", err)
}
var attestationType variant.Variant
// if no attestation type is specified, use the default for the cloud provider
if attestationString == "" {
attestationType = variant.Dummy{}
} else {
attestationType, err = variant.FromString(attestationString)
if err != nil {
return generateFlags{}, fmt.Errorf("invalid attestation variant: %s", attestationString)
}
}
return generateFlags{ return generateFlags{
file: file, file: file,
k8sVersion: resolvedVersion, k8sVersion: resolvedVersion,
attestationVariant: attestationType,
}, nil }, nil
} }
// createCompletion handles the completion of the create command. It is frequently called // generateCompletion handles the completion of the create command. It is frequently called
// while the user types arguments of the command to suggest completion. // while the user types arguments of the command to suggest completion.
func generateCompletion(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) { func generateCompletion(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) {
switch len(args) { switch len(args) {
@ -167,3 +207,15 @@ func resolveK8sVersion(k8sVersion string) (string, error) {
return extendedVersion, nil return extendedVersion, nil
} }
func printFormattedSlice[T any](input []T) string {
return fmt.Sprintf("{%s}", strings.Join(toString(input), "|"))
}
func toString[T any](t []T) []string {
var res []string
for _, v := range t {
res = append(res, fmt.Sprintf("%v", v))
}
return res
}

View File

@ -8,6 +8,7 @@ package cmd
import ( import (
"bytes" "bytes"
"fmt"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -15,8 +16,10 @@ import (
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/variant"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
@ -87,7 +90,7 @@ func TestConfigGenerateDefaultGCPSpecific(t *testing.T) {
cmd := newConfigGenerateCmd() cmd := newConfigGenerateCmd()
wantConf := config.Default() wantConf := config.Default()
wantConf.RemoveProviderExcept(cloudprovider.GCP) wantConf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
cg := &configGenerateCmd{log: logger.NewTest(t)} cg := &configGenerateCmd{log: logger.NewTest(t)}
require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.GCP)) require.NoError(cg.configGenerate(cmd, fileHandler, cloudprovider.GCP))
@ -139,3 +142,133 @@ func TestConfigGenerateStdOut(t *testing.T) {
assert.Equal(*config.Default(), readConfig) assert.Equal(*config.Default(), readConfig)
} }
func TestNoValidProviderAttestationCombination(t *testing.T) {
assert := assert.New(t)
tests := []struct {
provider cloudprovider.Provider
attestation variant.Variant
}{
{cloudprovider.Azure, variant.AWSNitroTPM{}},
{cloudprovider.AWS, variant.AzureTrustedLaunch{}},
{cloudprovider.GCP, variant.AWSNitroTPM{}},
{cloudprovider.QEMU, variant.GCPSEVES{}},
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
_, err := createConfigWithAttestationType(test.provider, test.attestation)
assert.Error(err)
})
}
}
func TestValidProviderAttestationCombination(t *testing.T) {
defaultAttestation := config.Default().Attestation
tests := []struct {
provider cloudprovider.Provider
attestation variant.Variant
expected config.AttestationConfig
}{
{
cloudprovider.Azure,
variant.AzureTrustedLaunch{},
config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch},
},
{
cloudprovider.Azure,
variant.AzureSEVSNP{},
config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP},
},
{
cloudprovider.AWS,
variant.AWSNitroTPM{},
config.AttestationConfig{AWSNitroTPM: defaultAttestation.AWSNitroTPM},
},
{
cloudprovider.GCP,
variant.GCPSEVES{},
config.AttestationConfig{GCPSEVES: defaultAttestation.GCPSEVES},
},
{
cloudprovider.QEMU,
variant.QEMUVTPM{},
config.AttestationConfig{QEMUVTPM: defaultAttestation.QEMUVTPM},
},
{
cloudprovider.OpenStack,
variant.QEMUVTPM{},
config.AttestationConfig{QEMUVTPM: defaultAttestation.QEMUVTPM},
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("Provider:%s,Attestation:%s", test.provider, test.attestation), func(t *testing.T) {
sut, err := createConfigWithAttestationType(test.provider, test.attestation)
assert := assert.New(t)
assert.NoError(err)
assert.Equal(test.expected, sut.Attestation)
})
}
}
func TestAttestationArgument(t *testing.T) {
defaultAttestation := config.Default().Attestation
tests := []struct {
name string
provider cloudprovider.Provider
expectErr bool
expectedCfg config.AttestationConfig
setFlag func(*cobra.Command) error
}{
{
name: "InvalidAttestationArgument",
provider: cloudprovider.Unknown,
expectErr: true,
setFlag: func(cmd *cobra.Command) error {
return cmd.Flags().Set("attestation", "unknown")
},
},
{
name: "ValidAttestationArgument",
provider: cloudprovider.Azure,
expectErr: false,
setFlag: func(cmd *cobra.Command) error {
return cmd.Flags().Set("attestation", "azure-trustedlaunch")
},
expectedCfg: config.AttestationConfig{AzureTrustedLaunch: defaultAttestation.AzureTrustedLaunch},
},
{
name: "WithoutAttestationArgument",
provider: cloudprovider.Azure,
expectErr: false,
setFlag: func(cmd *cobra.Command) error {
return nil
},
expectedCfg: config.AttestationConfig{AzureSEVSNP: defaultAttestation.AzureSEVSNP},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require := assert.New(t)
assert := assert.New(t)
cmd := newConfigGenerateCmd()
require.NoError(test.setFlag(cmd))
fileHandler := file.NewHandler(afero.NewMemMapFs())
cg := &configGenerateCmd{log: logger.NewTest(t)}
err := cg.configGenerate(cmd, fileHandler, test.provider)
if test.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
var readConfig config.Config
require.NoError(fileHandler.ReadYAML(constants.ConfigFilename, &readConfig))
assert.Equal(test.expectedCfg, readConfig.Attestation)
}
})
}
}

View File

@ -429,7 +429,7 @@ func TestAttestation(t *testing.T) {
cfg := config.Default() cfg := config.Default()
cfg.Image = "image" cfg.Image = "image"
cfg.RemoveProviderExcept(cloudprovider.QEMU) cfg.RemoveProviderAndAttestationExcept(cloudprovider.QEMU)
cfg.Attestation.QEMUVTPM.Measurements[0] = measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength) cfg.Attestation.QEMUVTPM.Measurements[0] = measurements.WithAllBytes(0x00, measurements.Enforce, measurements.PCRMeasurementLength)
cfg.Attestation.QEMUVTPM.Measurements[1] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength) cfg.Attestation.QEMUVTPM.Measurements[1] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength)
cfg.Attestation.QEMUVTPM.Measurements[2] = measurements.WithAllBytes(0x22, measurements.Enforce, measurements.PCRMeasurementLength) cfg.Attestation.QEMUVTPM.Measurements[2] = measurements.WithAllBytes(0x22, measurements.Enforce, measurements.PCRMeasurementLength)
@ -554,7 +554,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Attestation.QEMUVTPM.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.QEMUVTPM.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength)
} }
conf.RemoveProviderExcept(csp) conf.RemoveProviderAndAttestationExcept(csp)
return conf return conf
} }

View File

@ -216,7 +216,7 @@ func (m *miniUpCmd) prepareConfig(cmd *cobra.Command, fileHandler file.Handler,
config := config.Default() config := config.Default()
config.Name = constants.MiniConstellationUID config.Name = constants.MiniConstellationUID
config.RemoveProviderExcept(cloudprovider.QEMU) config.RemoveProviderAndAttestationExcept(cloudprovider.QEMU)
config.StateDiskSizeGB = 8 config.StateDiskSizeGB = 8
// only release images (e.g. v2.7.0) use the production NVRAM // only release images (e.g. v2.7.0) use the production NVRAM

View File

@ -9,6 +9,8 @@ Prerequisites:
* [Bazelisk installed as `bazel` in your path](https://github.com/bazelbuild/bazelisk/releases). * [Bazelisk installed as `bazel` in your path](https://github.com/bazelbuild/bazelisk/releases).
* [Docker](https://docs.docker.com/engine/install/). Can be installed with these commands on Ubuntu 22.04: `sudo apt update && sudo apt install docker.io`. As the build spawns docker containers your user account either needs to be in the `docker` group (Add with `sudo usermod -a -G docker $USER`) or you have to run builds with `sudo`. When using `sudo` remember that your root user might (depending on your distro and local config) not have the go binary in it's PATH. The current PATH can be forwarded to the root env with `sudo env PATH=$PATH <cmd>`. * [Docker](https://docs.docker.com/engine/install/). Can be installed with these commands on Ubuntu 22.04: `sudo apt update && sudo apt install docker.io`. As the build spawns docker containers your user account either needs to be in the `docker` group (Add with `sudo usermod -a -G docker $USER`) or you have to run builds with `sudo`. When using `sudo` remember that your root user might (depending on your distro and local config) not have the go binary in it's PATH. The current PATH can be forwarded to the root env with `sudo env PATH=$PATH <cmd>`.
---
### On Linux
* Packages on Ubuntu: * Packages on Ubuntu:
```sh ```sh
@ -21,6 +23,20 @@ Prerequisites:
sudo dnf install @development-tools pkg-config cmake openssl-devel cryptsetup-libs cryptsetup-devel sudo dnf install @development-tools pkg-config cmake openssl-devel cryptsetup-libs cryptsetup-devel
``` ```
### On Mac
```
brew install bash
```
to fix unsupported shell options used in some build script.
To troubleshoot potential problems with bazel on ARM architecture when running it for the first time, it might help to purge and retry:
```
bazel clean --expunge
```
---
Developer workspace: Developer workspace:
```sh ```sh
@ -83,6 +99,19 @@ Running unit tests with Bazel:
bazel test //... bazel test //...
``` ```
# Opening a PR
Before opening a PR, please run the tests and
```
bazel run //:generate && bazel run //:tidy
bazel run //:check
```
The linter check doesn't work on Mac at the moment, but you can run the linter directly:
```
golangci-lint run
```
Furthermore, the PR titles are used for the changelog, so please stick to our [conventions](https://github.com/edgelesssys/constellation/blob/main/dev-docs/conventions.md#pr-conventions).
# Deploy # Deploy
> :warning: Debug images are not safe to use in production environments. :warning: > :warning: Debug images are not safe to use in production environments. :warning:

View File

@ -75,6 +75,7 @@ constellation config generate {aws|azure|gcp|openstack|qemu} [flags]
### Options ### Options
``` ```
-a, --attestation string attestation variant to use {aws-nitro-tpm|azure-sev-snp|azure-trustedlaunch|gcp-sev-es|qemu-vtpm}. If not specified, the default for the cloud provider is used
-f, --file string path to output file, or '-' for stdout (default "constellation-conf.yaml") -f, --file string path to output file, or '-' for stdout (default "constellation-conf.yaml")
-h, --help help for generate -h, --help help for generate
-k, --kubernetes string Kubernetes version to use in format MAJOR.MINOR (default "v1.26") -k, --kubernetes string Kubernetes version to use in format MAJOR.MINOR (default "v1.26")

View File

@ -447,6 +447,12 @@ func (c *Config) UpdateMeasurements(newMeasurements measurements.M) {
} }
} }
// RemoveProviderAndAttestationExcept calls RemoveProviderExcept and sets the default attestations for the provider (only used for convenience in tests).
func (c *Config) RemoveProviderAndAttestationExcept(provider cloudprovider.Provider) {
c.RemoveProviderExcept(provider)
c.SetAttestation(variant.GetDefaultAttestation(provider))
}
// RemoveProviderExcept removes all provider specific configurations, i.e., // RemoveProviderExcept removes all provider specific configurations, i.e.,
// sets them to nil, except the one specified. // sets them to nil, except the one specified.
// If an unknown provider is passed, the same configuration is returned. // If an unknown provider is passed, the same configuration is returned.
@ -454,29 +460,37 @@ func (c *Config) RemoveProviderExcept(provider cloudprovider.Provider) {
currentProviderConfigs := c.Provider currentProviderConfigs := c.Provider
c.Provider = ProviderConfig{} c.Provider = ProviderConfig{}
// TODO(AB#2976): Replace attestation replacement
// with custom function for attestation selection
currentAttetationConfigs := c.Attestation
c.Attestation = AttestationConfig{}
switch provider { switch provider {
case cloudprovider.AWS: case cloudprovider.AWS:
c.Provider.AWS = currentProviderConfigs.AWS c.Provider.AWS = currentProviderConfigs.AWS
c.Attestation.AWSNitroTPM = currentAttetationConfigs.AWSNitroTPM
case cloudprovider.Azure: case cloudprovider.Azure:
c.Provider.Azure = currentProviderConfigs.Azure c.Provider.Azure = currentProviderConfigs.Azure
c.Attestation.AzureSEVSNP = currentAttetationConfigs.AzureSEVSNP
case cloudprovider.GCP: case cloudprovider.GCP:
c.Provider.GCP = currentProviderConfigs.GCP c.Provider.GCP = currentProviderConfigs.GCP
c.Attestation.GCPSEVES = currentAttetationConfigs.GCPSEVES
case cloudprovider.OpenStack: case cloudprovider.OpenStack:
c.Provider.OpenStack = currentProviderConfigs.OpenStack c.Provider.OpenStack = currentProviderConfigs.OpenStack
c.Attestation.QEMUVTPM = currentAttetationConfigs.QEMUVTPM
case cloudprovider.QEMU: case cloudprovider.QEMU:
c.Provider.QEMU = currentProviderConfigs.QEMU c.Provider.QEMU = currentProviderConfigs.QEMU
c.Attestation.QEMUVTPM = currentAttetationConfigs.QEMUVTPM
default: default:
c.Provider = currentProviderConfigs c.Provider = currentProviderConfigs
c.Attestation = currentAttetationConfigs }
}
// SetAttestation sets the attestation config for the given attestation variant and removes all other attestation configs.
func (c *Config) SetAttestation(attestation variant.Variant) {
currentAttetationConfigs := c.Attestation
c.Attestation = AttestationConfig{}
switch attestation.(type) {
case variant.AzureSEVSNP:
c.Attestation = AttestationConfig{AzureSEVSNP: currentAttetationConfigs.AzureSEVSNP}
case variant.AWSNitroTPM:
c.Attestation = AttestationConfig{AWSNitroTPM: currentAttetationConfigs.AWSNitroTPM}
case variant.AzureTrustedLaunch:
c.Attestation = AttestationConfig{AzureTrustedLaunch: currentAttetationConfigs.AzureTrustedLaunch}
case variant.GCPSEVES:
c.Attestation = AttestationConfig{GCPSEVES: currentAttetationConfigs.GCPSEVES}
case variant.QEMUVTPM:
c.Attestation = AttestationConfig{QEMUVTPM: currentAttetationConfigs.QEMUVTPM}
} }
} }

View File

@ -121,7 +121,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
"set env works": { "set env works": {
confToWrite: func() *Config { // valid config with all, but clientSecretValue confToWrite: func() *Config { // valid config with all, but clientSecretValue
c := Default() c := Default()
c.RemoveProviderExcept(cloudprovider.Azure) c.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
c.Image = "v" + constants.VersionInfo() c.Image = "v" + constants.VersionInfo()
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5" c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa" c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
@ -142,7 +142,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
"set env overwrites": { "set env overwrites": {
confToWrite: func() *Config { confToWrite: func() *Config {
c := Default() c := Default()
c.RemoveProviderExcept(cloudprovider.Azure) c.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
c.Image = "v" + constants.VersionInfo() c.Image = "v" + constants.VersionInfo()
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5" c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa" c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
@ -231,7 +231,7 @@ func TestValidate(t *testing.T) {
"default Azure config is not valid": { "default Azure config is not valid": {
cnf: func() *Config { cnf: func() *Config {
cnf := Default() cnf := Default()
cnf.RemoveProviderExcept(cloudprovider.Azure) cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
return cnf return cnf
}(), }(),
wantErr: true, wantErr: true,
@ -240,7 +240,7 @@ func TestValidate(t *testing.T) {
"Azure config with all required fields is valid": { "Azure config with all required fields is valid": {
cnf: func() *Config { cnf: func() *Config {
cnf := Default() cnf := Default()
cnf.RemoveProviderExcept(cloudprovider.Azure) cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
cnf.Image = "v" + constants.VersionInfo() cnf.Image = "v" + constants.VersionInfo()
az := cnf.Provider.Azure az := cnf.Provider.Azure
az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab" az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab"
@ -261,7 +261,7 @@ func TestValidate(t *testing.T) {
"default GCP config is not valid": { "default GCP config is not valid": {
cnf: func() *Config { cnf: func() *Config {
cnf := Default() cnf := Default()
cnf.RemoveProviderExcept(cloudprovider.GCP) cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
return cnf return cnf
}(), }(),
wantErr: true, wantErr: true,
@ -270,7 +270,7 @@ func TestValidate(t *testing.T) {
"GCP config with all required fields is valid": { "GCP config with all required fields is valid": {
cnf: func() *Config { cnf: func() *Config {
cnf := Default() cnf := Default()
cnf.RemoveProviderExcept(cloudprovider.GCP) cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
cnf.Image = "v" + constants.VersionInfo() cnf.Image = "v" + constants.VersionInfo()
gcp := cnf.Provider.GCP gcp := cnf.Provider.GCP
gcp.Region = "test-region" gcp.Region = "test-region"
@ -379,7 +379,7 @@ func TestConfigRemoveProviderExcept(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
conf := Default() conf := Default()
conf.RemoveProviderExcept(tc.removeExcept) conf.RemoveProviderAndAttestationExcept(tc.removeExcept)
assert.Equal(tc.wantAWS, conf.Provider.AWS) assert.Equal(tc.wantAWS, conf.Provider.AWS)
assert.Equal(tc.wantAzure, conf.Provider.Azure) assert.Equal(tc.wantAzure, conf.Provider.Azure)
@ -411,7 +411,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
{ // AWS { // AWS
conf := Default() conf := Default()
conf.RemoveProviderExcept(cloudprovider.AWS) conf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
for k := range conf.Attestation.AWSNitroTPM.Measurements { for k := range conf.Attestation.AWSNitroTPM.Measurements {
delete(conf.Attestation.AWSNitroTPM.Measurements, k) delete(conf.Attestation.AWSNitroTPM.Measurements, k)
} }
@ -420,7 +420,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
} }
{ // Azure { // Azure
conf := Default() conf := Default()
conf.RemoveProviderExcept(cloudprovider.Azure) conf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
for k := range conf.Attestation.AzureSEVSNP.Measurements { for k := range conf.Attestation.AzureSEVSNP.Measurements {
delete(conf.Attestation.AzureSEVSNP.Measurements, k) delete(conf.Attestation.AzureSEVSNP.Measurements, k)
} }
@ -429,7 +429,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
} }
{ // GCP { // GCP
conf := Default() conf := Default()
conf.RemoveProviderExcept(cloudprovider.GCP) conf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
for k := range conf.Attestation.GCPSEVES.Measurements { for k := range conf.Attestation.GCPSEVES.Measurements {
delete(conf.Attestation.GCPSEVES.Measurements, k) delete(conf.Attestation.GCPSEVES.Measurements, k)
} }
@ -438,7 +438,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
} }
{ // QEMU { // QEMU
conf := Default() conf := Default()
conf.RemoveProviderExcept(cloudprovider.QEMU) conf.RemoveProviderAndAttestationExcept(cloudprovider.QEMU)
for k := range conf.Attestation.QEMUVTPM.Measurements { for k := range conf.Attestation.QEMUVTPM.Measurements {
delete(conf.Attestation.QEMUVTPM.Measurements, k) delete(conf.Attestation.QEMUVTPM.Measurements, k)
} }

View File

@ -5,4 +5,5 @@ go_library(
srcs = ["variant.go"], srcs = ["variant.go"],
importpath = "github.com/edgelesssys/constellation/v2/internal/variant", importpath = "github.com/edgelesssys/constellation/v2/internal/variant",
visibility = ["//:__subpackages__"], visibility = ["//:__subpackages__"],
deps = ["//internal/cloud/cloudprovider"],
) )

View File

@ -34,6 +34,9 @@ package variant
import ( import (
"encoding/asn1" "encoding/asn1"
"fmt" "fmt"
"sort"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
) )
const ( const (
@ -46,6 +49,42 @@ const (
qemuTDX = "qemu-tdx" qemuTDX = "qemu-tdx"
) )
var providerAttestationMapping = map[cloudprovider.Provider][]Variant{
cloudprovider.AWS: {AWSNitroTPM{}},
cloudprovider.Azure: {AzureSEVSNP{}, AzureTrustedLaunch{}},
cloudprovider.GCP: {GCPSEVES{}},
cloudprovider.QEMU: {QEMUVTPM{}},
cloudprovider.OpenStack: {QEMUVTPM{}},
}
// GetDefaultAttestation returns the default attestation type for the given provider. If not found, it returns the default variant.
func GetDefaultAttestation(provider cloudprovider.Provider) Variant {
res, ok := providerAttestationMapping[provider]
if ok {
return res[0]
}
return Dummy{}
}
// GetAvailableAttestationTypes returns the available attestation types.
func GetAvailableAttestationTypes() []Variant {
var res []Variant
// assumes that cloudprovider.Provider is a uint32 to sort the providers and get a consistent order
var keys []cloudprovider.Provider
for k := range providerAttestationMapping {
keys = append(keys, k)
}
sort.Slice(keys, func(i, j int) bool {
return uint(keys[i]) < uint(keys[j])
})
for _, k := range keys {
res = append(res, providerAttestationMapping[k]...)
}
return removeDuplicate(res)
}
// Getter returns an ASN.1 Object Identifier. // Getter returns an ASN.1 Object Identifier.
type Getter interface { type Getter interface {
OID() asn1.ObjectIdentifier OID() asn1.ObjectIdentifier
@ -79,7 +118,20 @@ func FromString(oid string) (Variant, error) {
return nil, fmt.Errorf("unknown OID: %q", oid) return nil, fmt.Errorf("unknown OID: %q", oid)
} }
// Dummy OID for testing. // ValidProvider returns true if the attestation type is valid for the given provider.
func ValidProvider(provider cloudprovider.Provider, variant Variant) bool {
validTypes, ok := providerAttestationMapping[provider]
if ok {
for _, aType := range validTypes {
if variant.Equal(aType) {
return true
}
}
}
return false
}
// Dummy OID for testfing.
type Dummy struct{} type Dummy struct{}
// OID returns the struct's object identifier. // OID returns the struct's object identifier.
@ -92,7 +144,7 @@ func (Dummy) String() string {
return dummy return dummy
} }
// Equal returns true if the other variant is also a Dummy. // Equal returns true if the other variant is also a Default.
func (Dummy) Equal(other Getter) bool { func (Dummy) Equal(other Getter) bool {
return other.OID().Equal(Dummy{}.OID()) return other.OID().Equal(Dummy{}.OID())
} }
@ -206,3 +258,15 @@ func (QEMUTDX) String() string {
func (QEMUTDX) Equal(other Getter) bool { func (QEMUTDX) Equal(other Getter) bool {
return other.OID().Equal(QEMUTDX{}.OID()) return other.OID().Equal(QEMUTDX{}.OID())
} }
func removeDuplicate(sliceList []Variant) []Variant {
allKeys := make(map[Variant]bool)
list := []Variant{}
for _, item := range sliceList {
if _, value := allKeys[item]; !value {
allKeys[item] = true
list = append(list, item)
}
}
return list
}