AB#2512 Config secrets via env var & config refactoring (#544)

* refactor measurements to use consistent types and less byte pushing
* refactor: only rely on a single multierr dependency
* extend config creation with envar support
* document changes
Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
Fabian Kammel 2022-11-15 15:40:49 +01:00 committed by GitHub
parent 80a801629e
commit bb76a4e4c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 932 additions and 791 deletions

View File

@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Environment variable `CONSTELL_AZURE_CLIENT_SECRET_VALUE` as an alternative way to provide the configuration value `provider.azure.clientSecretValue`.
### Changed
<!-- For changes in existing functionality. -->

View File

@ -22,10 +22,10 @@ import (
"testing"
"time"
"github.com/hashicorp/go-multierror"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/multierr"
"google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing"
)
@ -713,10 +713,10 @@ func (w *tarGzWriter) Bytes() []byte {
func (w *tarGzWriter) Close() (result error) {
if err := w.tarWriter.Close(); err != nil {
result = multierror.Append(result, err)
result = multierr.Append(result, err)
}
if err := w.gzWriter.Close(); err != nil {
result = multierror.Append(result, err)
result = multierr.Append(result, err)
}
return result
}

View File

@ -7,13 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -58,7 +58,7 @@ func NewUpgrader(outWriter io.Writer) (*Upgrader, error) {
}
// Upgrade upgrades the cluster to the given measurements and image.
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error {
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements measurements.M) error {
if err := u.updateMeasurements(ctx, measurements); err != nil {
return fmt.Errorf("updating measurements: %w", err)
}
@ -97,36 +97,25 @@ func (u *Upgrader) GetCurrentImage(ctx context.Context) (*unstructured.Unstructu
return imageStruct, imageDefinition, nil
}
func (u *Upgrader) updateMeasurements(ctx context.Context, measurements map[uint32][]byte) error {
func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measurements.M) error {
existingConf, err := u.measurementsUpdater.getCurrent(ctx, constants.JoinConfigMap)
if err != nil {
return fmt.Errorf("retrieving current measurements: %w", err)
}
var currentMeasurements map[uint32][]byte
var currentMeasurements measurements.M
if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), &currentMeasurements); err != nil {
return fmt.Errorf("retrieving current measurements: %w", err)
}
if len(currentMeasurements) == len(measurements) {
changed := false
for k, v := range currentMeasurements {
if !bytes.Equal(v, measurements[k]) {
// measurements have changed
changed = true
break
}
}
if !changed {
// measurements are the same, nothing to be done
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
return nil
}
if currentMeasurements.EqualTo(newMeasurements) {
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
return nil
}
// backup of previous measurements
existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename]
measurementsJSON, err := json.Marshal(measurements)
measurementsJSON, err := json.Marshal(newMeasurements)
if err != nil {
return fmt.Errorf("marshaling measurements: %w", err)
}

View File

@ -13,6 +13,7 @@ import (
"errors"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func TestUpdateMeasurements(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct {
updater *stubMeasurementsUpdater
newMeasurements map[uint32][]byte
newMeasurements measurements.M
wantUpdate bool
wantErr bool
}{
@ -36,7 +37,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
},
newMeasurements: map[uint32][]byte{
newMeasurements: measurements.M{
0: []byte("1"),
},
wantUpdate: true,
@ -49,7 +50,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
},
newMeasurements: map[uint32][]byte{
newMeasurements: measurements.M{
0: []byte("1"),
},
},

View File

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -28,7 +29,7 @@ import (
// Validator validates Platform Configuration Registers (PCRs).
type Validator struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
pcrs measurements.M
enforcedPCRs []uint32
idkeydigest []byte
enforceIDKeyDigest bool
@ -147,7 +148,7 @@ func (v *Validator) V(cmd *cobra.Command) atls.Validator {
}
// PCRS returns the validator's PCR map.
func (v *Validator) PCRS() map[uint32][]byte {
func (v *Validator) PCRS() measurements.M {
return v.pcrs
}
@ -169,7 +170,7 @@ func (v *Validator) updateValidator(cmd *cobra.Command) {
}
}
func (v *Validator) checkPCRs(pcrs map[uint32][]byte, enforcedPCRs []uint32) error {
func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error {
if len(pcrs) == 0 {
return errors.New("no PCR values provided")
}

View File

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"testing"
@ -15,6 +16,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -24,21 +26,19 @@ import (
)
func TestNewValidator(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
one := []byte("11111111111111111111111111111111")
testPCRs := map[uint32][]byte{
0: zero,
1: one,
2: zero,
3: one,
4: zero,
5: zero,
testPCRs := measurements.M{
0: measurements.PCRWithAllBytes(0x00),
1: measurements.PCRWithAllBytes(0xFF),
2: measurements.PCRWithAllBytes(0x00),
3: measurements.PCRWithAllBytes(0xFF),
4: measurements.PCRWithAllBytes(0x00),
5: measurements.PCRWithAllBytes(0x00),
}
testCases := map[string]struct {
provider cloudprovider.Provider
config *config.Config
pcrs map[uint32][]byte
pcrs measurements.M
enforceIDKeyDigest bool
idKeyDigest string
azureCVM bool
@ -64,13 +64,15 @@ func TestNewValidator(t *testing.T) {
},
"no pcrs provided": {
provider: cloudprovider.Azure,
pcrs: map[uint32][]byte{},
pcrs: measurements.M{},
wantErr: true,
},
"invalid pcr length": {
provider: cloudprovider.GCP,
pcrs: map[uint32][]byte{0: []byte("0000000000000000000000000000000")},
wantErr: true,
pcrs: measurements.M{
0: bytes.Repeat([]byte{0x00}, 31),
},
wantErr: true,
},
"unknown provider": {
provider: cloudprovider.Unknown,
@ -99,16 +101,13 @@ func TestNewValidator(t *testing.T) {
conf := &config.Config{Provider: config.ProviderConfig{}}
if tc.provider == cloudprovider.GCP {
measurements := config.Measurements(tc.pcrs)
conf.Provider.GCP = &config.GCPConfig{Measurements: measurements}
conf.Provider.GCP = &config.GCPConfig{Measurements: tc.pcrs}
}
if tc.provider == cloudprovider.Azure {
measurements := config.Measurements(tc.pcrs)
conf.Provider.Azure = &config.AzureConfig{Measurements: measurements, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
conf.Provider.Azure = &config.AzureConfig{Measurements: tc.pcrs, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
}
if tc.provider == cloudprovider.QEMU {
measurements := config.Measurements(tc.pcrs)
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
conf.Provider.QEMU = &config.QEMUConfig{Measurements: tc.pcrs}
}
validators, err := NewValidator(tc.provider, conf)
@ -125,29 +124,27 @@ func TestNewValidator(t *testing.T) {
}
func TestValidatorV(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
newTestPCRs := func() map[uint32][]byte {
return map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
newTestPCRs := func() measurements.M {
return measurements.M{
0: measurements.PCRWithAllBytes(0x00),
1: measurements.PCRWithAllBytes(0x00),
2: measurements.PCRWithAllBytes(0x00),
3: measurements.PCRWithAllBytes(0x00),
4: measurements.PCRWithAllBytes(0x00),
5: measurements.PCRWithAllBytes(0x00),
6: measurements.PCRWithAllBytes(0x00),
7: measurements.PCRWithAllBytes(0x00),
8: measurements.PCRWithAllBytes(0x00),
9: measurements.PCRWithAllBytes(0x00),
10: measurements.PCRWithAllBytes(0x00),
11: measurements.PCRWithAllBytes(0x00),
12: measurements.PCRWithAllBytes(0x00),
}
}
testCases := map[string]struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
pcrs measurements.M
wantVs atls.Validator
azureCVM bool
}{
@ -224,7 +221,7 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
testCases := map[string]struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
pcrs measurements.M
ownerID string
clusterID string
wantErr bool
@ -321,14 +318,14 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
}
func TestUpdatePCR(t *testing.T) {
emptyMap := map[uint32][]byte{}
defaultMap := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
emptyMap := measurements.M{}
defaultMap := measurements.M{
0: measurements.PCRWithAllBytes(0xAA),
1: measurements.PCRWithAllBytes(0xBB),
}
testCases := map[string]struct {
pcrMap map[uint32][]byte
pcrMap measurements.M
pcrIndex uint32
encoded string
wantEntries int
@ -389,7 +386,7 @@ func TestUpdatePCR(t *testing.T) {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
pcrs := make(map[uint32][]byte)
pcrs := make(measurements.M)
for k, v := range tc.pcrMap {
pcrs[k] = v
}

View File

@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -39,7 +40,7 @@ func newConfigFetchMeasurementsCmd() *cobra.Command {
type fetchMeasurementsFlags struct {
measurementsURL *url.URL
signatureURL *url.URL
config string
configPath string
}
func runConfigFetchMeasurements(cmd *cobra.Command, args []string) error {
@ -57,9 +58,9 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
return err
}
conf, err := config.FromFile(fileHandler, flags.config)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return err
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
if conf.IsDebugImage() {
@ -72,7 +73,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var fetchedMeasurements config.Measurements
var fetchedMeasurements measurements.M
hash, err := fetchedMeasurements.FetchAndVerify(ctx, client, flags.measurementsURL, flags.signatureURL, []byte(constants.CosignPublicKey))
if err != nil {
return err
@ -84,7 +85,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
}
conf.UpdateMeasurements(fetchedMeasurements)
if err := fileHandler.WriteYAML(flags.config, conf, file.OptOverwrite); err != nil {
if err := fileHandler.WriteYAML(flags.configPath, conf, file.OptOverwrite); err != nil {
return err
}
@ -123,7 +124,7 @@ func parseFetchMeasurementsFlags(cmd *cobra.Command) (*fetchMeasurementsFlags, e
return &fetchMeasurementsFlags{
measurementsURL: measurementsURL,
signatureURL: measurementsSignatureURL,
config: config,
configPath: config,
}, nil
}

View File

@ -40,7 +40,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
wantFlags: &fetchMeasurementsFlags{
measurementsURL: nil,
signatureURL: nil,
config: constants.ConfigFilename,
configPath: constants.ConfigFilename,
},
},
"url": {
@ -49,7 +49,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
wantFlags: &fetchMeasurementsFlags{
measurementsURL: urlMustParse("https://some.other.url/with/path"),
signatureURL: urlMustParse("https://some.other.url/with/path.sig"),
config: constants.ConfigFilename,
configPath: constants.ConfigFilename,
},
},
"broken url": {
@ -59,7 +59,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
"config": {
configFlag: "someOtherConfig.yaml",
wantFlags: &fetchMeasurementsFlags{
config: "someOtherConfig.yaml",
configPath: "someOtherConfig.yaml",
},
},
}
@ -212,8 +212,7 @@ func TestConfigFetchMeasurements(t *testing.T) {
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
fileHandler := file.NewHandler(afero.NewMemMapFs())
gcpConfig := config.Default()
gcpConfig.RemoveProviderExcept(cloudprovider.GCP)
gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
gcpConfig.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-coreos-1658216163"
err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll)

View File

@ -0,0 +1,28 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"io"
"go.uber.org/multierr"
)
func displayConfigValidationErrors(errWriter io.Writer, configError error) error {
errs := multierr.Errors(configError)
if errs != nil {
fmt.Fprintln(errWriter, "Problems validating config file:")
for _, err := range errs {
fmt.Fprintln(errWriter, "\t"+err.Error())
}
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
return errors.New("invalid configuration")
}
return nil
}

View File

@ -13,6 +13,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
@ -59,27 +60,27 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
var printedAWarning bool
if config.IsDebugImage() {
if conf.IsDebugImage() {
cmd.PrintErrln("Configured image doesn't look like a released production image. Double check image before deploying to production.")
printedAWarning = true
}
if config.IsDebugCluster() {
if conf.IsDebugCluster() {
cmd.PrintErrln("WARNING: Creating a debug cluster. This cluster is not secure and should only be used for debugging purposes.")
cmd.PrintErrln("DO NOT USE THIS CLUSTER IN PRODUCTION.")
printedAWarning = true
}
if config.IsAzureNonCVM() {
if conf.IsAzureNonCVM() {
cmd.PrintErrln("Disabling Confidential VMs is insecure. Use only for evaluation purposes.")
printedAWarning = true
if config.EnforcesIDKeyDigest() {
if conf.EnforcesIDKeyDigest() {
cmd.PrintErrln("Your config asks for enforcing the idkeydigest. This is only available on Confidential VMs. It will not be enforced.")
}
}
@ -89,20 +90,20 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
cmd.PrintErrln("")
}
provider := config.GetProvider()
provider := conf.GetProvider()
var instanceType string
switch provider {
case cloudprovider.AWS:
instanceType = config.Provider.AWS.InstanceType
instanceType = conf.Provider.AWS.InstanceType
if len(flags.name) > 10 {
return fmt.Errorf("cluster name on AWS must not be longer than 10 characters")
}
case cloudprovider.Azure:
instanceType = config.Provider.Azure.InstanceType
instanceType = conf.Provider.Azure.InstanceType
case cloudprovider.GCP:
instanceType = config.Provider.GCP.InstanceType
instanceType = conf.Provider.GCP.InstanceType
case cloudprovider.QEMU:
cpus := config.Provider.QEMU.VCPUs
cpus := conf.Provider.QEMU.VCPUs
instanceType = fmt.Sprintf("%d-vCPU", cpus)
}
@ -122,7 +123,7 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
}
spinner.Start("Creating", false)
idFile, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount)
idFile, err := creator.Create(cmd.Context(), provider, conf, flags.name, instanceType, flags.controllerCount, flags.workerCount)
spinner.Stop()
if err != nil {
return err

View File

@ -78,9 +78,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
var idFile clusterid.File
@ -88,7 +88,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("reading cluster ID file: %w", err)
}
k8sVersion, err := versions.NewValidK8sVersion(config.KubernetesVersion)
k8sVersion, err := versions.NewValidK8sVersion(conf.KubernetesVersion)
if err != nil {
return fmt.Errorf("validating kubernetes version: %w", err)
}
@ -96,18 +96,18 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion)
}
provider := config.GetProvider()
provider := conf.GetProvider()
checker := license.NewChecker(quotaChecker, fileHandler)
if err := checker.CheckLicense(cmd.Context(), provider, config.Provider, cmd.Printf); err != nil {
if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err)
}
validator, err := cloudcmd.NewValidator(provider, config)
validator, err := cloudcmd.NewValidator(provider, conf)
if err != nil {
return err
}
serviceAccURI, err := getMarshaledServiceAccountURI(provider, config, fileHandler)
serviceAccURI, err := getMarshaledServiceAccountURI(provider, conf, fileHandler)
if err != nil {
return err
}
@ -117,7 +117,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err)
}
helmLoader := helm.New(provider, k8sVersion)
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, config), getEnforceIDKeyDigest(provider, config))
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, conf), getEnforceIDKeyDigest(provider, conf))
if err != nil {
return fmt.Errorf("loading Helm charts: %w", err)
}
@ -131,10 +131,10 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
KeyEncryptionKeyId: "",
UseExistingKek: false,
CloudServiceAccountUri: serviceAccURI,
KubernetesVersion: config.KubernetesVersion,
KubernetesVersion: conf.KubernetesVersion,
HelmDeployments: helmDeployments,
EnforcedPcrs: getEnforcedPCRs(provider, config),
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, config),
EnforcedPcrs: getEnforcedPCRs(provider, conf),
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, conf),
ConformanceMode: flags.conformance,
}
resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req)

View File

@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
"github.com/edgelesssys/constellation/v2/internal/config"
@ -360,11 +361,11 @@ func TestAttestation(t *testing.T) {
issuer := &testIssuer{
Getter: oid.QEMU{},
pcrs: map[uint32][]byte{
0: []byte("ffffffffffffffffffffffffffffffff"),
1: []byte("ffffffffffffffffffffffffffffffff"),
2: []byte("ffffffffffffffffffffffffffffffff"),
3: []byte("ffffffffffffffffffffffffffffffff"),
pcrs: measurements.M{
0: measurements.PCRWithAllBytes(0xFF),
1: measurements.PCRWithAllBytes(0xFF),
2: measurements.PCRWithAllBytes(0xFF),
3: measurements.PCRWithAllBytes(0xFF),
},
}
serverCreds := atlscredentials.New(issuer, nil)
@ -389,13 +390,13 @@ func TestAttestation(t *testing.T) {
cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Image = "some/image/location"
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000")
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111")
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222")
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333")
cfg.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444")
cfg.Provider.QEMU.Measurements[8] = []byte("88888888888888888888888888888888")
cfg.Provider.QEMU.Measurements[9] = []byte("99999999999999999999999999999999")
cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00)
cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11)
cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22)
cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33)
cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
cfg.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x88)
cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background()
@ -411,13 +412,13 @@ func TestAttestation(t *testing.T) {
type testValidator struct {
oid.Getter
pcrs map[uint32][]byte
pcrs measurements.M
}
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var attestation struct {
UserData []byte
PCRs map[uint32][]byte
PCRs measurements.M
}
if err := json.Unmarshal(attDoc, &attestation); err != nil {
return nil, err
@ -433,14 +434,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
type testIssuer struct {
oid.Getter
pcrs map[uint32][]byte
pcrs measurements.M
}
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(
struct {
UserData []byte
PCRs map[uint32][]byte
PCRs measurements.M
}{
UserData: userData,
PCRs: i.pcrs,
@ -472,23 +473,23 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
conf.Provider.Azure.ClientSecretValue = "test-client-secret"
conf.Provider.Azure.Measurements[4] = []byte("44444444444444444444444444444444")
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.Azure.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11)
case cloudprovider.GCP:
conf.Provider.GCP.Region = "test-region"
conf.Provider.GCP.Project = "test-project"
conf.Provider.GCP.Image = "some/image/location"
conf.Provider.GCP.Zone = "test-zone"
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
conf.Provider.GCP.Measurements[4] = []byte("44444444444444444444444444444444")
conf.Provider.GCP.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.GCP.Measurements[9] = []byte("11111111111111111111111111111111")
conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.GCP.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11)
case cloudprovider.QEMU:
conf.Provider.QEMU.Image = "some/image/location"
conf.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444")
conf.Provider.QEMU.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.QEMU.Measurements[9] = []byte("11111111111111111111111111111111")
conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11)
}
conf.RemoveProviderExcept(csp)

View File

@ -163,14 +163,14 @@ func prepareConfig(cmd *cobra.Command, fileHandler file.Handler) (*config.Config
// check for existing config
if configPath != "" {
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, configPath)
conf, err := config.New(fileHandler, configPath)
if err != nil {
return nil, err
return nil, displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
if config.GetProvider() != cloudprovider.QEMU {
if conf.GetProvider() != cloudprovider.QEMU {
return nil, errors.New("invalid provider for MiniConstellation cluster")
}
return config, nil
return conf, nil
}
if err := cmd.Flags().Set("config", constants.ConfigFilename); err != nil {
return nil, err

View File

@ -1,46 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
)
func readConfig(errWriter io.Writer, fileHandler file.Handler, name string) (*config.Config, error) {
cnf, err := config.FromFile(fileHandler, name)
if err != nil {
return nil, err
}
if err := validateConfig(errWriter, cnf); err != nil {
return nil, err
}
return cnf, nil
}
func validateConfig(errWriter io.Writer, cnf *config.Config) error {
msgs, err := cnf.Validate()
if err != nil {
return fmt.Errorf("performing config validation: %w", err)
}
if len(msgs) > 0 {
fmt.Fprintln(errWriter, "Invalid fields in config file:")
for _, m := range msgs {
fmt.Fprintln(errWriter, "\t"+m)
}
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
return errors.New("invalid configuration")
}
return nil
}

View File

@ -1,126 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateConfig(t *testing.T) {
testCases := map[string]struct {
cnf *config.Config
provider cloudprovider.Provider
wantOutput bool
wantErr bool
}{
"default config is not valid": {
cnf: config.Default(),
wantOutput: true,
wantErr: true,
},
"default Azure config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
az := cnf.Provider.Azure
cnf.Provider = config.ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
provider: cloudprovider.Azure,
wantOutput: true,
wantErr: true,
},
"Azure config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure),
provider: cloudprovider.Azure,
},
"default GCP config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
gcp := cnf.Provider.GCP
cnf.Provider = config.ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
provider: cloudprovider.GCP,
wantOutput: true,
wantErr: true,
},
"GCP config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP),
provider: cloudprovider.GCP,
},
"default QEMU config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
qemu := cnf.Provider.QEMU
cnf.Provider = config.ProviderConfig{}
cnf.Provider.QEMU = qemu
return cnf
}(),
provider: cloudprovider.QEMU,
wantOutput: true,
wantErr: true,
},
"QEMU config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.QEMU),
provider: cloudprovider.QEMU,
},
"config with an error": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Version = "v0"
return cnf
}(),
wantOutput: true,
wantErr: true,
},
"config without provider is not ok": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider = config.ProviderConfig{}
return cnf
}(),
wantErr: true,
},
"config without required provider": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider.Azure = nil
return cnf
}(),
provider: cloudprovider.Azure,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
out := &bytes.Buffer{}
err := validateConfig(out, tc.cnf)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantOutput, out.Len() > 0)
})
}
}

View File

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -66,16 +67,16 @@ func recover(
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
provider := config.GetProvider()
provider := conf.GetProvider()
if provider == cloudprovider.Azure {
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
}
validator, err := cloudcmd.NewValidator(provider, config)
validator, err := cloudcmd.NewValidator(provider, conf)
if err != nil {
return err
}

View File

@ -10,6 +10,7 @@ import (
"context"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
@ -43,17 +44,17 @@ func upgradeExecute(cmd *cobra.Command, upgrader cloudUpgrader, fileHandler file
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, configPath)
conf, err := config.New(fileHandler, configPath)
if err != nil {
return err
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
// TODO: validate upgrade config? Should be basic things like checking image is not an empty string
// More sophisticated validation, like making sure we don't downgrade the cluster, should be done by `constellation upgrade plan`
return upgrader.Upgrade(cmd.Context(), config.Upgrade.Image, config.Upgrade.Measurements)
return upgrader.Upgrade(cmd.Context(), conf.Upgrade.Image, conf.Upgrade.Measurements)
}
type cloudUpgrader interface {
Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error
Upgrade(ctx context.Context, image string, measurements measurements.M) error
}

View File

@ -11,6 +11,8 @@ import (
"errors"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -41,7 +43,8 @@ func TestUpgradeExecute(t *testing.T) {
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
handler := file.NewHandler(afero.NewMemMapFs())
require.NoError(handler.WriteYAML(constants.ConfigFilename, config.Default()))
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg))
err := upgradeExecute(cmd, tc.upgrader, handler)
if tc.wantErr {
@ -57,6 +60,6 @@ type stubUpgrader struct {
err error
}
func (u stubUpgrader) Upgrade(context.Context, string, map[uint32][]byte) error {
func (u stubUpgrader) Upgrade(context.Context, string, measurements.M) error {
return u.err
}

View File

@ -73,13 +73,13 @@ func runUpgradePlan(cmd *cobra.Command, args []string) error {
func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
fileHandler file.Handler, client *http.Client, rekor rekorVerifier, flags upgradePlanFlags,
) error {
config, err := config.FromFile(fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return err
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
// get current image version of the cluster
csp := config.GetProvider()
csp := conf.GetProvider()
version, err := getCurrentImageVersion(cmd.Context(), planner, csp)
if err != nil {
@ -108,7 +108,7 @@ func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
return upgradePlanInteractive(
&nopWriteCloser{cmd.OutOrStdout()},
io.NopCloser(cmd.InOrStdin()),
flags.configPath, config, fileHandler,
flags.configPath, conf, fileHandler,
compatibleImages,
)
}

View File

@ -451,8 +451,8 @@ func TestUpgradePlan(t *testing.T) {
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cfg := config.Default()
cfg.RemoveProviderExcept(tc.csp)
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.csp)
require.NoError(fileHandler.WriteYAML(tc.flags.configPath, cfg))
cmd := newUpgradePlanCmd()

View File

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -59,13 +60,13 @@ func verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyCli
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
provider := config.GetProvider()
validators, err := cloudcmd.NewValidator(provider, config)
provider := conf.GetProvider()
validators, err := cloudcmd.NewValidator(provider, conf)
if err != nil {
return err
}

View File

@ -67,6 +67,12 @@ If you don't have a cloud subscription, check out [MiniConstellation](first-step
Fill the values produced by the script into your configuration file.
:::tip
Alternatively, you can leave `clientSecretValue` empty and provide the secret via the `CONSTELL_AZURE_CLIENT_SECRET_VALUE` environment variable.
:::
By default, Constellation uses `Standard_DC4as_v5` CVMs (4 vCPUs, 16 GB RAM) to create your cluster. Optionally, you can switch to a different VM type by modifying **instanceType** in the configuration file. For CVMs, any VM type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported.
Run `constellation config instance-types` to get the list of all supported options.
@ -112,6 +118,12 @@ If you don't have a cloud subscription, check out [MiniConstellation](first-step
Set the configuration value to the secret value.
:::tip
Alternatively, you can leave `clientSecretValue` empty and provide the secret via the `CONSTELL_AZURE_CLIENT_SECRET_VALUE` environment variable.
:::
* **instanceType**: The VM type you want to use for your Constellation nodes.
For CVMs, any type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported. It defaults to `Standard_DC4as_v5` (4 vCPUs, 16 GB RAM).

2
go.mod
View File

@ -66,7 +66,6 @@ require (
github.com/google/tink/go v1.7.0
github.com/googleapis/gax-go/v2 v2.7.0
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-version v1.6.0
github.com/hashicorp/hc-install v0.4.0
github.com/hashicorp/terraform-exec v0.17.3
@ -117,6 +116,7 @@ require (
github.com/golang-jwt/jwt/v4 v4.4.2 // indirect
github.com/google/logger v1.1.1 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect
golang.org/x/text v0.4.0 // indirect

View File

@ -16,7 +16,7 @@ import (
"syscall"
"github.com/edgelesssys/constellation/v2/hack/image-measurement/server"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/multierr"
"go.uber.org/zap"
@ -282,7 +282,7 @@ func (l *libvirtInstance) deleteLibvirtInstance() error {
return err
}
func (l *libvirtInstance) obtainMeasurements() (measurements config.Measurements, err error) {
func (l *libvirtInstance) obtainMeasurements() (measurements measurements.M, err error) {
// sanity check
if err := l.deleteLibvirtInstance(); err != nil {
return nil, err

View File

@ -12,6 +12,7 @@ import (
"net"
"net/http"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/zap"
)
@ -20,7 +21,7 @@ import (
type Server struct {
log *logger.Logger
server http.Server
measurements map[uint32][]byte
measurements measurements.M
done chan<- struct{}
}
@ -72,7 +73,7 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) {
}
// unmarshal the request body into a map of PCRs
var pcrs map[uint32][]byte
var pcrs measurements.M
if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil {
log.With(zap.Error(err)).Errorf("Failed to read request body")
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -89,6 +90,6 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) {
}
// GetMeasurements returns the static measurements for QEMU environment.
func (s *Server) GetMeasurements() map[uint32][]byte {
func (s *Server) GetMeasurements() measurements.M {
return s.measurements
}

View File

@ -8,7 +8,6 @@ package main
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"flag"
@ -20,6 +19,7 @@ import (
"strconv"
"time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
@ -68,23 +68,6 @@ func main() {
}
}
// Measurements contains all PCR values.
type Measurements map[uint32][]byte
var _ yaml.Marshaler = Measurements{}
// MarshalYAML forces that measurements are written as base64. Default would
// be to print list of bytes.
func (m Measurements) MarshalYAML() (any, error) {
base64Map := make(map[uint32]string)
for key, value := range m {
base64Map[key] = base64.StdEncoding.EncodeToString(value[:])
}
return base64Map, nil
}
// getAttestation connects to the Constellation verification service and returns its attestation document.
func getAttestation(ctx context.Context, addr string) ([]byte, error) {
conn, err := grpc.DialContext(
@ -109,7 +92,7 @@ func getAttestation(ctx context.Context, addr string) ([]byte, error) {
}
// validatePCRAttDoc parses and validates PCRs of an attestation document.
func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) {
func validatePCRAttDoc(attDocRaw []byte) (measurements.M, error) {
attDoc := vtpm.AttestationDocument{}
if err := json.Unmarshal(attDocRaw, &attDoc); err != nil {
return nil, err
@ -131,7 +114,7 @@ func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) {
// printPCRs formates and prints PCRs to the given writer.
// format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'.
func printPCRs(w io.Writer, pcrs map[uint32][]byte, format string) error {
func printPCRs(w io.Writer, pcrs measurements.M, format string) error {
switch format {
case "json":
return printPCRsJSON(w, pcrs)
@ -142,7 +125,7 @@ func printPCRs(w io.Writer, pcrs map[uint32][]byte, format string) error {
}
}
func printPCRsYAML(w io.Writer, pcrs Measurements) error {
func printPCRsYAML(w io.Writer, pcrs measurements.M) error {
pcrYAML, err := yaml.Marshal(pcrs)
if err != nil {
return err
@ -151,7 +134,7 @@ func printPCRsYAML(w io.Writer, pcrs Measurements) error {
return nil
}
func printPCRsJSON(w io.Writer, pcrs map[uint32][]byte) error {
func printPCRsJSON(w io.Writer, pcrs measurements.M) error {
pcrJSON, err := json.MarshalIndent(pcrs, "", " ")
if err != nil {
return err
@ -162,7 +145,7 @@ func printPCRsJSON(w io.Writer, pcrs map[uint32][]byte) error {
// exportToFile writes pcrs to a file, formatted to be valid Go code.
// Validity of the PCR map is not checked, and should be handled by the caller.
func exportToFile(path string, pcrs map[uint32][]byte, fs *afero.Afero) error {
func exportToFile(path string, pcrs measurements.M, fs *afero.Afero) error {
goCode := `package pcrs
var pcrs = map[uint32][]byte{%s

View File

@ -13,6 +13,7 @@ import (
"fmt"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm"
@ -31,12 +32,12 @@ func TestMain(m *testing.M) {
func TestExportToFile(t *testing.T) {
testCases := map[string]struct {
pcrs map[uint32][]byte
pcrs measurements.M
fs *afero.Afero
wantErr bool
}{
"file not writeable": {
pcrs: map[uint32][]byte{
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
@ -45,7 +46,7 @@ func TestExportToFile(t *testing.T) {
wantErr: true,
},
"file writeable": {
pcrs: map[uint32][]byte{
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
@ -105,7 +106,7 @@ func TestValidatePCRAttDoc(t *testing.T) {
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{
Pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
},
},
@ -122,8 +123,8 @@ func TestValidatePCRAttDoc(t *testing.T) {
{
Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
Pcrs: measurements.M{
0: measurements.PCRWithAllBytes(0xAA),
},
},
},
@ -163,11 +164,11 @@ func mustMarshalAttDoc(t *testing.T, attDoc vtpm.AttestationDocument) []byte {
func TestPrintPCRs(t *testing.T) {
testCases := map[string]struct {
pcrs map[uint32][]byte
pcrs measurements.M
format string
}{
"json": {
pcrs: map[uint32][]byte{
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
@ -175,7 +176,7 @@ func TestPrintPCRs(t *testing.T) {
format: "json",
},
"empty format": {
pcrs: map[uint32][]byte{
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},
@ -183,7 +184,7 @@ func TestPrintPCRs(t *testing.T) {
format: "",
},
"yaml": {
pcrs: map[uint32][]byte{
pcrs: measurements.M{
0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3},

View File

@ -15,6 +15,7 @@ import (
"strings"
"github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/role"
@ -197,7 +198,7 @@ func (s *Server) exportPCRs(w http.ResponseWriter, r *http.Request) {
}
// unmarshal the request body into a map of PCRs
var pcrs map[uint32][]byte
var pcrs measurements.M
if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil {
log.With(zap.Error(err)).Errorf("Failed to read request body")
http.Error(w, err.Error(), http.StatusInternalServerError)

View File

@ -17,6 +17,7 @@ import (
"testing"
"github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert"
@ -306,12 +307,12 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "192.0.100.1:1234",
connect: defaultConnect,
method: http.MethodPost,
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
},
"incorrect method": {
remoteAddr: "192.0.100.1:1234",
connect: defaultConnect,
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
method: http.MethodGet,
wantErr: true,
},
@ -320,7 +321,7 @@ func TestExportPCRs(t *testing.T) {
connect: &stubConnect{
getNetworkErr: errors.New("error"),
},
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
method: http.MethodPost,
wantErr: true,
},
@ -335,7 +336,7 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "localhost",
connect: defaultConnect,
method: http.MethodPost,
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
wantErr: true,
},
}

View File

@ -12,9 +12,10 @@ import (
"encoding/json"
"fmt"
"github.com/aws/aws-sdk-go-v2/config"
awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/google/go-tpm/tpm2"
@ -28,7 +29,7 @@ type Validator struct {
}
// NewValidator create a new Validator structure and returns it.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
v := &Validator{}
v.Validator = vtpm.NewValidator(
pcrs,
@ -88,7 +89,7 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
}
func getEC2Client(ctx context.Context, region string) (awsMetadataAPI, error) {
client, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
client, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region))
if err != nil {
return nil, err
}

View File

@ -19,6 +19,7 @@ import (
"fmt"
"math/big"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
internalCrypto "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/oid"
@ -41,7 +42,7 @@ type Validator struct {
}
// NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
return &Validator{
Validator: vtpm.NewValidator(
pcrs,

View File

@ -17,6 +17,7 @@ import (
"testing"
"time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/edgelesssys/constellation/v2/internal/crypto"
tpmclient "github.com/google/go-tpm-tools/client"
@ -188,7 +189,7 @@ func TestGetAttestationCert(t *testing.T) {
}
require.NoError(err)
validator := NewValidator(map[uint32][]byte{}, []uint32{}, nil)
validator := NewValidator(measurements.M{}, []uint32{}, nil)
cert, err := x509.ParseCertificate(rootCert.Raw)
require.NoError(err)
roots := x509.NewCertPool()

View File

@ -14,6 +14,7 @@ import (
"errors"
"fmt"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
certutil "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/oid"
@ -32,7 +33,7 @@ type Validator struct {
}
// NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
rootPool := x509.NewCertPool()
rootPool.AddCert(ameRoot)
v := &Validator{roots: rootPool}

View File

@ -18,6 +18,7 @@ import (
"time"
compute "cloud.google.com/go/compute/apiv1"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/google/go-tpm-tools/proto/attest"
@ -34,7 +35,7 @@ type Validator struct {
}
// NewValidator initializes a new GCP validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
return &Validator{
Validator: vtpm.NewValidator(
pcrs,

View File

@ -4,7 +4,7 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package config
package measurements
import (
"bytes"
@ -18,52 +18,60 @@ import (
"net/url"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/sigstore"
"gopkg.in/yaml.v2"
)
// Measurements are Platform Configuration Register (PCR) values.
type Measurements map[uint32][]byte
// M are Platform Configuration Register (PCR) values that make up the Measurements.
type M map[uint32][]byte
var (
zero = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
// gcpPCRs are the PCR values for a GCP Constellation node that are initially set in a generated config file.
gcpPCRs = Measurements{
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
11: zero,
12: zero,
13: zero,
uint32(vtpm.PCRIndexClusterID): zero,
}
// PCRWithAllBytes returns a PCR value where all 32 bytes are set to b.
func PCRWithAllBytes(b byte) []byte {
return bytes.Repeat([]byte{b}, 32)
}
// azurePCRs are the PCR values for an Azure Constellation node that are initially set in a generated config file.
azurePCRs = Measurements{
11: zero,
12: zero,
13: zero,
uint32(vtpm.PCRIndexClusterID): zero,
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider) M {
switch provider {
case cloudprovider.AWS:
return M{
11: PCRWithAllBytes(0x00),
12: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.Azure:
return M{
11: PCRWithAllBytes(0x00),
12: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.GCP:
return M{
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
11: PCRWithAllBytes(0x00),
12: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.QEMU:
return M{
11: PCRWithAllBytes(0x00),
12: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
default:
return nil
}
// awsPCRs are the PCR values for an AWS Nitro Constellation node that are initially set in a generated config file.
awsPCRs = Measurements{
11: zero,
12: zero,
13: zero,
uint32(vtpm.PCRIndexClusterID): zero,
}
qemuPCRs = Measurements{
11: zero,
12: zero,
13: zero,
uint32(vtpm.PCRIndexClusterID): zero,
}
)
}
// FetchAndVerify fetches measurement and signature files via provided URLs,
// using client for download. The publicKey is used to verify the measurements.
// The hash of the fetched measurements is returned.
func (m *Measurements) FetchAndVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL, signatureURL *url.URL, publicKey []byte) (string, error) {
func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL, signatureURL *url.URL, publicKey []byte) (string, error) {
measurements, err := getFromURL(ctx, client, measurementsURL)
if err != nil {
return "", fmt.Errorf("failed to fetch measurements: %w", err)
@ -86,15 +94,29 @@ func (m *Measurements) FetchAndVerify(ctx context.Context, client *http.Client,
// CopyFrom copies over all values from other. Overwriting existing values,
// but keeping not specified values untouched.
func (m Measurements) CopyFrom(other Measurements) {
func (m M) CopyFrom(other M) {
for idx := range other {
m[idx] = other[idx]
}
}
// EqualTo tests whether the provided other Measurements are equal to these
// measurements.
func (m M) EqualTo(other M) bool {
if len(m) != len(other) {
return false
}
for k, v := range m {
if !bytes.Equal(v, other[k]) {
return false
}
}
return true
}
// MarshalYAML overwrites the default behaviour of writing out []byte not as
// single bytes, but as a single base64 encoded string.
func (m Measurements) MarshalYAML() (any, error) {
func (m M) MarshalYAML() (any, error) {
base64Map := make(map[uint32]string)
for key, value := range m {
@ -106,14 +128,14 @@ func (m Measurements) MarshalYAML() (any, error) {
// UnmarshalYAML overwrites the default behaviour of reading []byte not as
// single bytes, but as a single base64 encoded string.
func (m *Measurements) UnmarshalYAML(unmarshal func(any) error) error {
func (m *M) UnmarshalYAML(unmarshal func(any) error) error {
base64Map := make(map[uint32]string)
err := unmarshal(base64Map)
if err != nil {
return err
}
*m = make(Measurements)
*m = make(M)
for key, value := range base64Map {
measurement, err := base64.StdEncoding.DecodeString(value)
if err != nil {

View File

@ -4,7 +4,7 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package config
package measurements
import (
"context"
@ -21,11 +21,11 @@ import (
func TestMarshalYAML(t *testing.T) {
testCases := map[string]struct {
measurements Measurements
measurements M
wantBase64Map map[uint32]string
}{
"valid measurements": {
measurements: Measurements{
measurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
@ -35,7 +35,7 @@ func TestMarshalYAML(t *testing.T) {
},
},
"omit bytes": {
measurements: Measurements{
measurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
@ -63,7 +63,7 @@ func TestUnmarshalYAML(t *testing.T) {
testCases := map[string]struct {
inputBase64Map map[uint32]string
forceUnmarshalError bool
wantMeasurements Measurements
wantMeasurements M
wantErr bool
}{
"valid measurements": {
@ -71,7 +71,7 @@ func TestUnmarshalYAML(t *testing.T) {
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=",
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
},
wantMeasurements: Measurements{
wantMeasurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
},
@ -81,7 +81,7 @@ func TestUnmarshalYAML(t *testing.T) {
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: Measurements{
wantMeasurements: M{
2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
},
@ -91,7 +91,7 @@ func TestUnmarshalYAML(t *testing.T) {
2: "This is not base64",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
wantMeasurements: Measurements{
wantMeasurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
@ -103,7 +103,7 @@ func TestUnmarshalYAML(t *testing.T) {
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
},
forceUnmarshalError: true,
wantMeasurements: Measurements{
wantMeasurements: M{
2: []byte{},
3: []byte{1, 2, 3, 4},
},
@ -116,7 +116,7 @@ func TestUnmarshalYAML(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
var m Measurements
var m M
err := m.UnmarshalYAML(func(i any) error {
if base64Map, ok := i.(map[uint32]string); ok {
for key, value := range tc.inputBase64Map {
@ -141,55 +141,55 @@ func TestUnmarshalYAML(t *testing.T) {
func TestMeasurementsCopyFrom(t *testing.T) {
testCases := map[string]struct {
current Measurements
newMeasurements Measurements
wantMeasurements Measurements
current M
newMeasurements M
wantMeasurements M
}{
"add to empty": {
current: Measurements{},
newMeasurements: Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
current: M{},
newMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
wantMeasurements: Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
wantMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
},
"keep existing": {
current: Measurements{
4: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
5: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
current: M{
4: PCRWithAllBytes(0x01),
5: PCRWithAllBytes(0x02),
},
newMeasurements: Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
newMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
wantMeasurements: Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
4: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
5: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
wantMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
4: PCRWithAllBytes(0x01),
5: PCRWithAllBytes(0x02),
},
},
"overwrite existing": {
current: Measurements{
2: []byte{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4},
3: []byte{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5},
current: M{
2: PCRWithAllBytes(0x04),
3: PCRWithAllBytes(0x05),
},
newMeasurements: Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
newMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
wantMeasurements: Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
wantMeasurements: M{
1: PCRWithAllBytes(0x00),
2: PCRWithAllBytes(0x01),
3: PCRWithAllBytes(0x02),
},
},
}
@ -230,7 +230,7 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
signature string
signatureStatus int
publicKey []byte
wantMeasurements Measurements
wantMeasurements M
wantSHA string
wantError bool
}{
@ -240,8 +240,8 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
wantMeasurements: Measurements{
0: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
wantMeasurements: M{
0: PCRWithAllBytes(0x00),
},
wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87",
},
@ -308,7 +308,7 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
}
})
m := Measurements{}
m := M{}
hash, err := m.FetchAndVerify(context.Background(), client, measurementsURL, signatureURL, tc.publicKey)
if tc.wantError {
@ -321,3 +321,83 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
})
}
}
func TestPCRWithAllBytes(t *testing.T) {
testCases := map[string]struct {
b byte
wantPCR []byte
}{
"0x00": {
b: 0x00,
wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
"0x01": {
b: 0x01,
wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
},
"0xFF": {
b: 0xFF,
wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
pcr := PCRWithAllBytes(tc.b)
assert.Equal(tc.wantPCR, pcr)
})
}
}
func TestEqualTo(t *testing.T) {
testCases := map[string]struct {
given M
other M
wantEqual bool
}{
"same values": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
wantEqual: true,
},
"different number of elements": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0x00),
},
wantEqual: false,
},
"different values": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0xFF),
1: PCRWithAllBytes(0x00),
},
wantEqual: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
if tc.wantEqual {
assert.True(tc.given.EqualTo(tc.other))
} else {
assert.False(tc.given.EqualTo(tc.other))
}
})
}
}

View File

@ -9,6 +9,7 @@ package qemu
import (
"crypto"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/google/go-tpm/tpm2"
@ -21,7 +22,7 @@ type Validator struct {
}
// NewValidator initializes a new QEMU validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
return &Validator{
Validator: vtpm.NewValidator(
pcrs,

View File

@ -14,19 +14,25 @@ import (
"errors"
"fmt"
"io/fs"
"os"
"regexp"
"strings"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
en_translations "github.com/go-playground/validator/v10/translations/en"
"go.uber.org/multierr"
)
// Measurements is a required alias since docgen is not able to work with
// types in other packages.
type Measurements = measurements.M
const (
// Version1 is the first version number for Constellation config file.
Version1 = "v1"
@ -164,7 +170,7 @@ type AzureConfig struct {
// Application client ID of the Active Directory app registration.
AppClientID string `yaml:"appClientID" validate:"uuid"`
// description: |
// Client secret value of the Active Directory app registration credentials.
// Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable.
ClientSecretValue string `yaml:"clientSecretValue" validate:"required"`
// description: |
// Machine image used to create Constellation nodes.
@ -277,7 +283,7 @@ func Default() *Config {
StateDiskType: "gp3",
IAMProfileControlPlane: "",
IAMProfileWorkerNodes: "",
Measurements: copyPCRMap(awsPCRs),
Measurements: measurements.DefaultsFor(cloudprovider.AWS),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
},
Azure: &AzureConfig{
@ -289,7 +295,7 @@ func Default() *Config {
Image: DefaultImageAzure,
InstanceType: "Standard_DC4as_v5",
StateDiskType: "Premium_LRS",
Measurements: copyPCRMap(azurePCRs),
Measurements: measurements.DefaultsFor(cloudprovider.Azure),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
IDKeyDigest: "57486a447ec0f1958002a22a06b7673b9fd27d11e1c6527498056054c5fa92d23c50f9de44072760fe2b6fb89740b696",
EnforceIDKeyDigest: func() *bool { b := true; return &b }(),
@ -304,7 +310,7 @@ func Default() *Config {
InstanceType: "n2d-standard-4",
StateDiskType: "pd-ssd",
ServiceAccountKeyPath: "",
Measurements: copyPCRMap(gcpPCRs),
Measurements: measurements.DefaultsFor(cloudprovider.GCP),
EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15},
},
QEMU: &QEMUConfig{
@ -314,7 +320,7 @@ func Default() *Config {
MetadataAPIImage: versions.QEMUMetadataImage,
LibvirtURI: "",
LibvirtContainerImage: versions.LibvirtImage,
Measurements: copyPCRMap(qemuPCRs),
Measurements: measurements.DefaultsFor(cloudprovider.QEMU),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
NVRAM: "production",
},
@ -323,198 +329,38 @@ func Default() *Config {
}
}
func validateK8sVersion(fl validator.FieldLevel) bool {
return versions.IsSupportedK8sVersion(fl.Field().String())
// FromFile returns config file with `name` read from `fileHandler` by parsing
// it as YAML. You should prefer config.New to read env vars and validate
// config in a consistent manner.
func FromFile(fileHandler file.Handler, name string) (*Config, error) {
var conf Config
if err := fileHandler.ReadYAMLStrict(name, &conf); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name)
}
return nil, fmt.Errorf("could not load config from file %s: %w", name, err)
}
return &conf, nil
}
func validateAWSInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.AWS)
}
func validateAzureInstanceType(fl validator.FieldLevel) bool {
azureConfig := fl.Parent().Interface().(AzureConfig)
var acceptNonCVM bool
if azureConfig.ConfidentialVM != nil {
// This is the inverse of the config value (acceptNonCVMs is true if confidentialVM is false).
// We could make the validator the other way around, but this should be an explicit bypass rather than checking if CVMs are "allowed".
acceptNonCVM = !*azureConfig.ConfidentialVM
}
return validInstanceTypeForProvider(fl.Field().String(), acceptNonCVM, cloudprovider.Azure)
}
func validateGCPInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
}
// validateProvider checks if zero or more than one providers are defined in the config.
func validateProvider(sl validator.StructLevel) {
provider := sl.Current().Interface().(ProviderConfig)
providerCount := 0
if provider.AWS != nil {
providerCount++
}
if provider.Azure != nil {
providerCount++
}
if provider.GCP != nil {
providerCount++
}
if provider.QEMU != nil {
providerCount++
}
if providerCount < 1 {
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
} else if providerCount > 1 {
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
}
}
// Validate checks the config values and returns validation error messages.
// The function only returns an error if the validation itself fails.
func (c *Config) Validate() ([]string, error) {
trans := ut.New(en.New()).GetFallback()
validate := validator.New()
if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
// New creates a new config by:
// 1. Reading config file via provided fileHandler from file with name.
// 2. Read secrets from environment variables.
// 3. Validate config.
func New(fileHandler file.Handler, name string) (*Config, error) {
// Read config file
c, err := FromFile(fileHandler, name)
if err != nil {
return nil, err
}
// Register AWS, Azure & GCP InstanceType validation error types
if err := validate.RegisterTranslation("aws_instance_type", trans, registerTranslateAWSInstanceTypeError, translateAWSInstanceTypeError); err != nil {
return nil, err
// Read secrets from env-vars.
clientSecretValue := os.Getenv(constants.EnvVarAzureClientSecretValue)
if clientSecretValue != "" && c.Provider.Azure != nil {
c.Provider.Azure.ClientSecretValue = clientSecretValue
}
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
return nil, err
}
if err := validate.RegisterTranslation("gcp_instance_type", trans, registerTranslateGCPInstanceTypeError, translateGCPInstanceTypeError); err != nil {
return nil, err
}
// Register Provider validation error types
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
return nil, err
}
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
return nil, err
}
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
return nil, err
}
// register custom validator with label aws_instance_type to validate the AWS instance type from config input.
if err := validate.RegisterValidation("aws_instance_type", validateAWSInstanceType); err != nil {
return nil, err
}
// register custom validator with label azure_instance_type to validate the Azure instance type from config input.
if err := validate.RegisterValidation("azure_instance_type", validateAzureInstanceType); err != nil {
return nil, err
}
// register custom validator with label gcp_instance_type to validate the GCP instance type from config input.
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
return nil, err
}
// Register provider validation
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
err := validate.Struct(c)
if err == nil {
return nil, nil
}
var errs validator.ValidationErrors
if !errors.As(err, &errs) {
return nil, err
}
var msgs []string
for _, e := range errs {
msgs = append(msgs, e.Translate(trans))
}
return msgs, nil
}
// Validation translation functions for Azure & GCP instance type errors.
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
}
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
var t string
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
} else {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
}
return t
}
func registerTranslateAWSInstanceTypeError(ut ut.Translator) error {
return ut.Add("aws_instance_type", fmt.Sprintf("{0} must be an instance from one of the following families types with size xlarge or higher: %v", instancetypes.AWSSupportedInstanceFamilies), true)
}
func translateAWSInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("aws_instance_type", fe.Field())
return t
}
func registerTranslateGCPInstanceTypeError(ut ut.Translator) error {
return ut.Add("gcp_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.GCPInstanceTypes), true)
}
func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("gcp_instance_type", fe.Field())
return t
}
// Validation translation functions for Provider errors.
func registerNoProviderError(ut ut.Translator) error {
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
}
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("no_provider", fe.Field())
return t
}
func registerMoreThanOneProviderError(ut ut.Translator) error {
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
}
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
definedProviders := make([]string, 0)
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
if c.Provider.AWS != nil {
definedProviders = append(definedProviders, "AWS")
}
if c.Provider.Azure != nil {
definedProviders = append(definedProviders, "Azure")
}
if c.Provider.GCP != nil {
definedProviders = append(definedProviders, "GCP")
}
if c.Provider.QEMU != nil {
definedProviders = append(definedProviders, "QEMU")
}
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
return t
return c, c.Validate()
}
// HasProvider checks whether the config contains the provider.
@ -584,6 +430,19 @@ func (c *Config) RemoveProviderExcept(provider cloudprovider.Provider) {
}
}
// IsAzureNonCVM checks whether the chosen provider is azure and confidential VMs are disabled.
func (c *Config) IsAzureNonCVM() bool {
return c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM
}
// IsDebugCluster checks whether the cluster is configured as a debug cluster.
func (c *Config) IsDebugCluster() bool {
if c.DebugCluster != nil && *c.DebugCluster {
return true
}
return false
}
// IsDebugImage checks whether image name looks like a release image, if not it is
// probably a debug image. In the end we do not if bootstrapper or debugd
// was put inside an image just by looking at its name.
@ -618,110 +477,80 @@ func (c *Config) GetProvider() cloudprovider.Provider {
return cloudprovider.Unknown
}
// IsAzureNonCVM checks whether the chosen provider is azure and confidential VMs are disabled.
func (c *Config) IsAzureNonCVM() bool {
return c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM
}
// EnforcesIDKeyDigest checks whether ID Key Digest should be enforced for respective cloud provider.
func (c *Config) EnforcesIDKeyDigest() bool {
return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest
}
// FromFile returns config file with `name` read from `fileHandler` by parsing
// it as YAML.
func FromFile(fileHandler file.Handler, name string) (*Config, error) {
var conf Config
if err := fileHandler.ReadYAMLStrict(name, &conf); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name)
}
return nil, fmt.Errorf("could not load config from file %s: %w", name, err)
// Validate checks the config values and returns validation errors.
func (c *Config) Validate() error {
trans := ut.New(en.New()).GetFallback()
validate := validator.New()
if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
return err
}
return &conf, nil
}
func copyPCRMap(m map[uint32][]byte) map[uint32][]byte {
res := make(Measurements)
res.CopyFrom(m)
return res
}
func validInstanceTypeForProvider(insType string, acceptNonCVM bool, provider cloudprovider.Provider) bool {
switch provider {
case cloudprovider.AWS:
return checkIfAWSInstanceTypeIsValid(insType)
case cloudprovider.Azure:
if acceptNonCVM {
for _, instanceType := range instancetypes.AzureTrustedLaunchInstanceTypes {
if insType == instanceType {
return true
}
}
} else {
for _, instanceType := range instancetypes.AzureCVMInstanceTypes {
if insType == instanceType {
return true
}
}
}
return false
case cloudprovider.GCP:
for _, instanceType := range instancetypes.GCPInstanceTypes {
if insType == instanceType {
return true
}
}
return false
default:
return false
}
}
// checkIfAWSInstanceTypeIsValid checks if an AWS instance type passed as user input is in one of the instance families supporting NitroTPM.
func checkIfAWSInstanceTypeIsValid(userInput string) bool {
// Check if user or code does anything weird and tries to pass multiple strings as one
if strings.Contains(userInput, " ") {
return false
}
if strings.Contains(userInput, ",") {
return false
}
if strings.Contains(userInput, ";") {
return false
}
splitInstanceType := strings.Split(userInput, ".")
if len(splitInstanceType) != 2 {
return false
}
userDefinedFamily := splitInstanceType[0]
userDefinedSize := splitInstanceType[1]
// Check if instace type has at least 4 vCPUs (= contains "xlarge" in its name)
hasEnoughVCPUs := strings.Contains(userDefinedSize, "xlarge")
if !hasEnoughVCPUs {
return false
}
// Now check if the user input is a supported family
// Note that we cannot directly use the family split from the Graviton check above, as some instances are directly specified by their full name and not just the family in general
for _, supportedFamily := range instancetypes.AWSSupportedInstanceFamilies {
supportedFamilyLowercase := strings.ToLower(supportedFamily)
if userDefinedFamily == supportedFamilyLowercase {
return true
}
}
return false
}
// IsDebugCluster checks whether the cluster is configured as a debug cluster.
func (c *Config) IsDebugCluster() bool {
if c.DebugCluster != nil && *c.DebugCluster {
return true
}
return false
// Register AWS, Azure & GCP InstanceType validation error types
if err := validate.RegisterTranslation("aws_instance_type", trans, registerTranslateAWSInstanceTypeError, translateAWSInstanceTypeError); err != nil {
return err
}
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
return err
}
if err := validate.RegisterTranslation("gcp_instance_type", trans, registerTranslateGCPInstanceTypeError, translateGCPInstanceTypeError); err != nil {
return err
}
// Register Provider validation error types
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
return err
}
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
return err
}
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
return err
}
// register custom validator with label aws_instance_type to validate the AWS instance type from config input.
if err := validate.RegisterValidation("aws_instance_type", validateAWSInstanceType); err != nil {
return err
}
// register custom validator with label azure_instance_type to validate the Azure instance type from config input.
if err := validate.RegisterValidation("azure_instance_type", validateAzureInstanceType); err != nil {
return err
}
// register custom validator with label gcp_instance_type to validate the GCP instance type from config input.
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
return err
}
// Register provider validation
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
err := validate.Struct(c)
if err == nil {
return nil
}
var errs validator.ValidationErrors
if !errors.As(err, &errs) {
return err
}
var validationErrors error
for _, e := range errs {
validationErrors = multierr.Append(
validationErrors,
errors.New(e.Translate(trans)),
)
}
return validationErrors
}

View File

@ -242,8 +242,8 @@ func init() {
AzureConfigDoc.Fields[6].Name = "clientSecretValue"
AzureConfigDoc.Fields[6].Type = "string"
AzureConfigDoc.Fields[6].Note = ""
AzureConfigDoc.Fields[6].Description = "Client secret value of the Active Directory app registration credentials."
AzureConfigDoc.Fields[6].Comments[encoder.LineComment] = "Client secret value of the Active Directory app registration credentials."
AzureConfigDoc.Fields[6].Description = "Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable."
AzureConfigDoc.Fields[6].Comments[encoder.LineComment] = "Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable."
AzureConfigDoc.Fields[7].Name = "image"
AzureConfigDoc.Fields[7].Type = "string"
AzureConfigDoc.Fields[7].Note = ""

View File

@ -10,6 +10,7 @@ import (
"reflect"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/constants"
@ -21,6 +22,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"go.uber.org/multierr"
)
func TestMain(m *testing.M) {
@ -108,33 +110,49 @@ func TestFromFile(t *testing.T) {
}
}
func TestValidate(t *testing.T) {
const defaultMsgCount = 20 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
func TestNewWithDefaultOptions(t *testing.T) {
testCases := map[string]struct {
cnf *Config
wantMsgCount int
confToWrite *Config
envToSet map[string]string
wantErr bool
wantClientSecretValue string
}{
"default config is valid": {
cnf: Default(),
wantMsgCount: defaultMsgCount,
},
"config with 1 error": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
return cnf
"set env works": {
confToWrite: func() *Config { // valid config with all, but clientSecretValue
c := Default()
c.RemoveProviderExcept(cloudprovider.Azure)
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
c.Provider.Azure.Location = "westus"
c.Provider.Azure.ResourceGroup = "test"
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Image = "/communityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/images/constellation/versions/2.2.0"
return c
}(),
wantMsgCount: defaultMsgCount + 1,
envToSet: map[string]string{
constants.EnvVarAzureClientSecretValue: "some-secret",
},
wantClientSecretValue: "some-secret",
},
"config with 2 errors": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
cnf.StateDiskSizeGB = -1
return cnf
"set env overwrites": {
confToWrite: func() *Config {
c := Default()
c.RemoveProviderExcept(cloudprovider.Azure)
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
c.Provider.Azure.Location = "westus"
c.Provider.Azure.ResourceGroup = "test"
c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well.
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Image = "/communityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/images/constellation/versions/2.2.0"
return c
}(),
wantMsgCount: defaultMsgCount + 2,
envToSet: map[string]string{
constants.EnvVarAzureClientSecretValue: "some-secret",
},
wantClientSecretValue: "some-secret",
},
}
@ -143,9 +161,126 @@ func TestValidate(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
msgs, err := tc.cnf.Validate()
// Setup
fileHandler := file.NewHandler(afero.NewMemMapFs())
err := fileHandler.WriteYAML(constants.ConfigFilename, tc.confToWrite)
require.NoError(err)
assert.Len(msgs, tc.wantMsgCount)
for envKey, envValue := range tc.envToSet {
t.Setenv(envKey, envValue)
}
// Test
c, err := New(fileHandler, constants.ConfigFilename)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(c.Provider.Azure.ClientSecretValue, tc.wantClientSecretValue)
})
}
}
func TestValidate(t *testing.T) {
const defaultErrCount = 20 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
const azErrCount = 8
const gcpErrCount = 5
testCases := map[string]struct {
cnf *Config
wantErr bool
wantErrCount int
}{
"default config is not valid": {
cnf: Default(),
wantErr: true,
wantErrCount: defaultErrCount,
},
"v0 is one error": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount + 1,
},
"v0 and negative state disk are two errors": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
cnf.StateDiskSizeGB = -1
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount + 2,
},
"default Azure config is not valid": {
cnf: func() *Config {
cnf := Default()
az := cnf.Provider.Azure
cnf.Provider = ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
wantErr: true,
wantErrCount: azErrCount,
},
"Azure config with all required fields is valid": {
cnf: func() *Config {
cnf := Default()
az := cnf.Provider.Azure
az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab"
az.TenantID = "01234567-0123-0123-0123-0123456789ab"
az.Location = "test-location"
az.UserAssignedIdentity = "test-identity"
az.Image = "some/image/location"
az.ResourceGroup = "test-resource-group"
az.AppClientID = "01234567-0123-0123-0123-0123456789ab"
az.ClientSecretValue = "test-client-secret"
cnf.Provider = ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
},
"default GCP config is not valid": {
cnf: func() *Config {
cnf := Default()
gcp := cnf.Provider.GCP
cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
wantErr: true,
wantErrCount: gcpErrCount,
},
"GCP config with all required fields is valid": {
cnf: func() *Config {
cnf := Default()
gcp := cnf.Provider.GCP
gcp.Region = "test-region"
gcp.Project = "test-project"
gcp.Image = "some/image/location"
gcp.Zone = "test-zone"
gcp.ServiceAccountKeyPath = "test-key-path"
cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := tc.cnf.Validate()
if tc.wantErr {
assert.Error(err)
assert.Len(multierr.Errors(err), tc.wantErrCount)
return
}
assert.NoError(err)
})
}
}
@ -260,10 +395,10 @@ func TestConfigGeneratedDocsFresh(t *testing.T) {
func TestConfig_UpdateMeasurements(t *testing.T) {
assert := assert.New(t)
newMeasurements := Measurements{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
newMeasurements := measurements.M{
1: measurements.PCRWithAllBytes(0x00),
2: measurements.PCRWithAllBytes(0x01),
3: measurements.PCRWithAllBytes(0x02),
}
{ // AWS

View File

@ -0,0 +1,212 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package config
import (
"fmt"
"strings"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/versions"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
)
func validateK8sVersion(fl validator.FieldLevel) bool {
return versions.IsSupportedK8sVersion(fl.Field().String())
}
func validateAWSInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.AWS)
}
func validateAzureInstanceType(fl validator.FieldLevel) bool {
azureConfig := fl.Parent().Interface().(AzureConfig)
var acceptNonCVM bool
if azureConfig.ConfidentialVM != nil {
// This is the inverse of the config value (acceptNonCVMs is true if confidentialVM is false).
// We could make the validator the other way around, but this should be an explicit bypass rather than checking if CVMs are "allowed".
acceptNonCVM = !*azureConfig.ConfidentialVM
}
return validInstanceTypeForProvider(fl.Field().String(), acceptNonCVM, cloudprovider.Azure)
}
func validateGCPInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
}
// validateProvider checks if zero or more than one providers are defined in the config.
func validateProvider(sl validator.StructLevel) {
provider := sl.Current().Interface().(ProviderConfig)
providerCount := 0
if provider.AWS != nil {
providerCount++
}
if provider.Azure != nil {
providerCount++
}
if provider.GCP != nil {
providerCount++
}
if provider.QEMU != nil {
providerCount++
}
if providerCount < 1 {
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
} else if providerCount > 1 {
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
}
}
func registerTranslateAWSInstanceTypeError(ut ut.Translator) error {
return ut.Add("aws_instance_type", fmt.Sprintf("{0} must be an instance from one of the following families types with size xlarge or higher: %v", instancetypes.AWSSupportedInstanceFamilies), true)
}
func translateAWSInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("aws_instance_type", fe.Field())
return t
}
func registerTranslateGCPInstanceTypeError(ut ut.Translator) error {
return ut.Add("gcp_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.GCPInstanceTypes), true)
}
func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("gcp_instance_type", fe.Field())
return t
}
// Validation translation functions for Provider errors.
func registerNoProviderError(ut ut.Translator) error {
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
}
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("no_provider", fe.Field())
return t
}
func registerMoreThanOneProviderError(ut ut.Translator) error {
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
}
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
definedProviders := make([]string, 0)
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
if c.Provider.AWS != nil {
definedProviders = append(definedProviders, "AWS")
}
if c.Provider.Azure != nil {
definedProviders = append(definedProviders, "Azure")
}
if c.Provider.GCP != nil {
definedProviders = append(definedProviders, "GCP")
}
if c.Provider.QEMU != nil {
definedProviders = append(definedProviders, "QEMU")
}
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
return t
}
func validInstanceTypeForProvider(insType string, acceptNonCVM bool, provider cloudprovider.Provider) bool {
switch provider {
case cloudprovider.AWS:
return checkIfAWSInstanceTypeIsValid(insType)
case cloudprovider.Azure:
if acceptNonCVM {
for _, instanceType := range instancetypes.AzureTrustedLaunchInstanceTypes {
if insType == instanceType {
return true
}
}
} else {
for _, instanceType := range instancetypes.AzureCVMInstanceTypes {
if insType == instanceType {
return true
}
}
}
return false
case cloudprovider.GCP:
for _, instanceType := range instancetypes.GCPInstanceTypes {
if insType == instanceType {
return true
}
}
return false
default:
return false
}
}
// checkIfAWSInstanceTypeIsValid checks if an AWS instance type passed as user input is in one of the instance families supporting NitroTPM.
func checkIfAWSInstanceTypeIsValid(userInput string) bool {
// Check if user or code does anything weird and tries to pass multiple strings as one
if strings.Contains(userInput, " ") {
return false
}
if strings.Contains(userInput, ",") {
return false
}
if strings.Contains(userInput, ";") {
return false
}
splitInstanceType := strings.Split(userInput, ".")
if len(splitInstanceType) != 2 {
return false
}
userDefinedFamily := splitInstanceType[0]
userDefinedSize := splitInstanceType[1]
// Check if instace type has at least 4 vCPUs (= contains "xlarge" in its name)
hasEnoughVCPUs := strings.Contains(userDefinedSize, "xlarge")
if !hasEnoughVCPUs {
return false
}
// Now check if the user input is a supported family
// Note that we cannot directly use the family split from the Graviton check above, as some instances are directly specified by their full name and not just the family in general
for _, supportedFamily := range instancetypes.AWSSupportedInstanceFamilies {
supportedFamilyLowercase := strings.ToLower(supportedFamily)
if userDefinedFamily == supportedFamilyLowercase {
return true
}
}
return false
}
// Validation translation functions for Azure & GCP instance type errors.
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
}
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
var t string
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
} else {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
}
return t
}

View File

@ -113,6 +113,11 @@ const (
MinControllerCount = 1
// MinWorkerCount is the minimum number of worker nodes.
MinWorkerCount = 1
// EnvVarPrefix is expected prefix for environment variables used to overwrite config parameters.
EnvVarPrefix = "CONSTELL_"
// EnvVarAzureClientSecretValue is environment variable to overwrite
// provider.azure.clientSecretValue .
EnvVarAzureClientSecretValue = EnvVarPrefix + "AZURE_CLIENT_SECRET_VALUE"
//
// Kubernetes.