mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-07-29 01:58:34 -04:00
cli: new flag to set the attestation type for config generate
(#1769)
* add attestation flag to specify type in config
This commit is contained in:
parent
e7b7a544f0
commit
f99e06b63b
11 changed files with 336 additions and 42 deletions
|
@ -447,6 +447,12 @@ func (c *Config) UpdateMeasurements(newMeasurements measurements.M) {
|
|||
}
|
||||
}
|
||||
|
||||
// RemoveProviderAndAttestationExcept calls RemoveProviderExcept and sets the default attestations for the provider (only used for convenience in tests).
|
||||
func (c *Config) RemoveProviderAndAttestationExcept(provider cloudprovider.Provider) {
|
||||
c.RemoveProviderExcept(provider)
|
||||
c.SetAttestation(variant.GetDefaultAttestation(provider))
|
||||
}
|
||||
|
||||
// RemoveProviderExcept removes all provider specific configurations, i.e.,
|
||||
// sets them to nil, except the one specified.
|
||||
// If an unknown provider is passed, the same configuration is returned.
|
||||
|
@ -454,29 +460,37 @@ func (c *Config) RemoveProviderExcept(provider cloudprovider.Provider) {
|
|||
currentProviderConfigs := c.Provider
|
||||
c.Provider = ProviderConfig{}
|
||||
|
||||
// TODO(AB#2976): Replace attestation replacement
|
||||
// with custom function for attestation selection
|
||||
currentAttetationConfigs := c.Attestation
|
||||
c.Attestation = AttestationConfig{}
|
||||
switch provider {
|
||||
case cloudprovider.AWS:
|
||||
c.Provider.AWS = currentProviderConfigs.AWS
|
||||
c.Attestation.AWSNitroTPM = currentAttetationConfigs.AWSNitroTPM
|
||||
case cloudprovider.Azure:
|
||||
c.Provider.Azure = currentProviderConfigs.Azure
|
||||
c.Attestation.AzureSEVSNP = currentAttetationConfigs.AzureSEVSNP
|
||||
case cloudprovider.GCP:
|
||||
c.Provider.GCP = currentProviderConfigs.GCP
|
||||
c.Attestation.GCPSEVES = currentAttetationConfigs.GCPSEVES
|
||||
case cloudprovider.OpenStack:
|
||||
c.Provider.OpenStack = currentProviderConfigs.OpenStack
|
||||
c.Attestation.QEMUVTPM = currentAttetationConfigs.QEMUVTPM
|
||||
case cloudprovider.QEMU:
|
||||
c.Provider.QEMU = currentProviderConfigs.QEMU
|
||||
c.Attestation.QEMUVTPM = currentAttetationConfigs.QEMUVTPM
|
||||
default:
|
||||
c.Provider = currentProviderConfigs
|
||||
c.Attestation = currentAttetationConfigs
|
||||
}
|
||||
}
|
||||
|
||||
// SetAttestation sets the attestation config for the given attestation variant and removes all other attestation configs.
|
||||
func (c *Config) SetAttestation(attestation variant.Variant) {
|
||||
currentAttetationConfigs := c.Attestation
|
||||
c.Attestation = AttestationConfig{}
|
||||
switch attestation.(type) {
|
||||
case variant.AzureSEVSNP:
|
||||
c.Attestation = AttestationConfig{AzureSEVSNP: currentAttetationConfigs.AzureSEVSNP}
|
||||
case variant.AWSNitroTPM:
|
||||
c.Attestation = AttestationConfig{AWSNitroTPM: currentAttetationConfigs.AWSNitroTPM}
|
||||
case variant.AzureTrustedLaunch:
|
||||
c.Attestation = AttestationConfig{AzureTrustedLaunch: currentAttetationConfigs.AzureTrustedLaunch}
|
||||
case variant.GCPSEVES:
|
||||
c.Attestation = AttestationConfig{GCPSEVES: currentAttetationConfigs.GCPSEVES}
|
||||
case variant.QEMUVTPM:
|
||||
c.Attestation = AttestationConfig{QEMUVTPM: currentAttetationConfigs.QEMUVTPM}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
|
|||
"set env works": {
|
||||
confToWrite: func() *Config { // valid config with all, but clientSecretValue
|
||||
c := Default()
|
||||
c.RemoveProviderExcept(cloudprovider.Azure)
|
||||
c.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
|
||||
c.Image = "v" + constants.VersionInfo()
|
||||
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
|
||||
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
|
||||
|
@ -142,7 +142,7 @@ func TestNewWithDefaultOptions(t *testing.T) {
|
|||
"set env overwrites": {
|
||||
confToWrite: func() *Config {
|
||||
c := Default()
|
||||
c.RemoveProviderExcept(cloudprovider.Azure)
|
||||
c.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
|
||||
c.Image = "v" + constants.VersionInfo()
|
||||
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
|
||||
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
|
||||
|
@ -231,7 +231,7 @@ func TestValidate(t *testing.T) {
|
|||
"default Azure config is not valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.RemoveProviderExcept(cloudprovider.Azure)
|
||||
cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
|
@ -240,7 +240,7 @@ func TestValidate(t *testing.T) {
|
|||
"Azure config with all required fields is valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.RemoveProviderExcept(cloudprovider.Azure)
|
||||
cnf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
|
||||
cnf.Image = "v" + constants.VersionInfo()
|
||||
az := cnf.Provider.Azure
|
||||
az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab"
|
||||
|
@ -261,7 +261,7 @@ func TestValidate(t *testing.T) {
|
|||
"default GCP config is not valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.RemoveProviderExcept(cloudprovider.GCP)
|
||||
cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
|
@ -270,7 +270,7 @@ func TestValidate(t *testing.T) {
|
|||
"GCP config with all required fields is valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.RemoveProviderExcept(cloudprovider.GCP)
|
||||
cnf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
|
||||
cnf.Image = "v" + constants.VersionInfo()
|
||||
gcp := cnf.Provider.GCP
|
||||
gcp.Region = "test-region"
|
||||
|
@ -379,7 +379,7 @@ func TestConfigRemoveProviderExcept(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
|
||||
conf := Default()
|
||||
conf.RemoveProviderExcept(tc.removeExcept)
|
||||
conf.RemoveProviderAndAttestationExcept(tc.removeExcept)
|
||||
|
||||
assert.Equal(tc.wantAWS, conf.Provider.AWS)
|
||||
assert.Equal(tc.wantAzure, conf.Provider.Azure)
|
||||
|
@ -411,7 +411,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
|
|||
|
||||
{ // AWS
|
||||
conf := Default()
|
||||
conf.RemoveProviderExcept(cloudprovider.AWS)
|
||||
conf.RemoveProviderAndAttestationExcept(cloudprovider.AWS)
|
||||
for k := range conf.Attestation.AWSNitroTPM.Measurements {
|
||||
delete(conf.Attestation.AWSNitroTPM.Measurements, k)
|
||||
}
|
||||
|
@ -420,7 +420,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
|
|||
}
|
||||
{ // Azure
|
||||
conf := Default()
|
||||
conf.RemoveProviderExcept(cloudprovider.Azure)
|
||||
conf.RemoveProviderAndAttestationExcept(cloudprovider.Azure)
|
||||
for k := range conf.Attestation.AzureSEVSNP.Measurements {
|
||||
delete(conf.Attestation.AzureSEVSNP.Measurements, k)
|
||||
}
|
||||
|
@ -429,7 +429,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
|
|||
}
|
||||
{ // GCP
|
||||
conf := Default()
|
||||
conf.RemoveProviderExcept(cloudprovider.GCP)
|
||||
conf.RemoveProviderAndAttestationExcept(cloudprovider.GCP)
|
||||
for k := range conf.Attestation.GCPSEVES.Measurements {
|
||||
delete(conf.Attestation.GCPSEVES.Measurements, k)
|
||||
}
|
||||
|
@ -438,7 +438,7 @@ func TestConfig_UpdateMeasurements(t *testing.T) {
|
|||
}
|
||||
{ // QEMU
|
||||
conf := Default()
|
||||
conf.RemoveProviderExcept(cloudprovider.QEMU)
|
||||
conf.RemoveProviderAndAttestationExcept(cloudprovider.QEMU)
|
||||
for k := range conf.Attestation.QEMUVTPM.Measurements {
|
||||
delete(conf.Attestation.QEMUVTPM.Measurements, k)
|
||||
}
|
||||
|
|
|
@ -5,4 +5,5 @@ go_library(
|
|||
srcs = ["variant.go"],
|
||||
importpath = "github.com/edgelesssys/constellation/v2/internal/variant",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = ["//internal/cloud/cloudprovider"],
|
||||
)
|
||||
|
|
|
@ -34,6 +34,9 @@ package variant
|
|||
import (
|
||||
"encoding/asn1"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -46,6 +49,42 @@ const (
|
|||
qemuTDX = "qemu-tdx"
|
||||
)
|
||||
|
||||
var providerAttestationMapping = map[cloudprovider.Provider][]Variant{
|
||||
cloudprovider.AWS: {AWSNitroTPM{}},
|
||||
cloudprovider.Azure: {AzureSEVSNP{}, AzureTrustedLaunch{}},
|
||||
cloudprovider.GCP: {GCPSEVES{}},
|
||||
cloudprovider.QEMU: {QEMUVTPM{}},
|
||||
cloudprovider.OpenStack: {QEMUVTPM{}},
|
||||
}
|
||||
|
||||
// GetDefaultAttestation returns the default attestation type for the given provider. If not found, it returns the default variant.
|
||||
func GetDefaultAttestation(provider cloudprovider.Provider) Variant {
|
||||
res, ok := providerAttestationMapping[provider]
|
||||
if ok {
|
||||
return res[0]
|
||||
}
|
||||
return Dummy{}
|
||||
}
|
||||
|
||||
// GetAvailableAttestationTypes returns the available attestation types.
|
||||
func GetAvailableAttestationTypes() []Variant {
|
||||
var res []Variant
|
||||
|
||||
// assumes that cloudprovider.Provider is a uint32 to sort the providers and get a consistent order
|
||||
var keys []cloudprovider.Provider
|
||||
for k := range providerAttestationMapping {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return uint(keys[i]) < uint(keys[j])
|
||||
})
|
||||
|
||||
for _, k := range keys {
|
||||
res = append(res, providerAttestationMapping[k]...)
|
||||
}
|
||||
return removeDuplicate(res)
|
||||
}
|
||||
|
||||
// Getter returns an ASN.1 Object Identifier.
|
||||
type Getter interface {
|
||||
OID() asn1.ObjectIdentifier
|
||||
|
@ -79,7 +118,20 @@ func FromString(oid string) (Variant, error) {
|
|||
return nil, fmt.Errorf("unknown OID: %q", oid)
|
||||
}
|
||||
|
||||
// Dummy OID for testing.
|
||||
// ValidProvider returns true if the attestation type is valid for the given provider.
|
||||
func ValidProvider(provider cloudprovider.Provider, variant Variant) bool {
|
||||
validTypes, ok := providerAttestationMapping[provider]
|
||||
if ok {
|
||||
for _, aType := range validTypes {
|
||||
if variant.Equal(aType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Dummy OID for testfing.
|
||||
type Dummy struct{}
|
||||
|
||||
// OID returns the struct's object identifier.
|
||||
|
@ -92,7 +144,7 @@ func (Dummy) String() string {
|
|||
return dummy
|
||||
}
|
||||
|
||||
// Equal returns true if the other variant is also a Dummy.
|
||||
// Equal returns true if the other variant is also a Default.
|
||||
func (Dummy) Equal(other Getter) bool {
|
||||
return other.OID().Equal(Dummy{}.OID())
|
||||
}
|
||||
|
@ -206,3 +258,15 @@ func (QEMUTDX) String() string {
|
|||
func (QEMUTDX) Equal(other Getter) bool {
|
||||
return other.OID().Equal(QEMUTDX{}.OID())
|
||||
}
|
||||
|
||||
func removeDuplicate(sliceList []Variant) []Variant {
|
||||
allKeys := make(map[Variant]bool)
|
||||
list := []Variant{}
|
||||
for _, item := range sliceList {
|
||||
if _, value := allKeys[item]; !value {
|
||||
allKeys[item] = true
|
||||
list = append(list, item)
|
||||
}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue