AB#2512 Config secrets via env var & config refactoring (#544)

* refactor measurements to use consistent types and less byte pushing
* refactor: only rely on a single multierr dependency
* extend config creation with envar support
* document changes
Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
Fabian Kammel 2022-11-15 15:40:49 +01:00 committed by GitHub
parent 80a801629e
commit bb76a4e4c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 932 additions and 791 deletions

View file

@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
### Added ### Added
- Environment variable `CONSTELL_AZURE_CLIENT_SECRET_VALUE` as an alternative way to provide the configuration value `provider.azure.clientSecretValue`.
### Changed ### Changed
<!-- For changes in existing functionality. --> <!-- For changes in existing functionality. -->

View file

@ -22,10 +22,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/hashicorp/go-multierror"
"github.com/spf13/afero" "github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/multierr"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
testclock "k8s.io/utils/clock/testing" testclock "k8s.io/utils/clock/testing"
) )
@ -713,10 +713,10 @@ func (w *tarGzWriter) Bytes() []byte {
func (w *tarGzWriter) Close() (result error) { func (w *tarGzWriter) Close() (result error) {
if err := w.tarWriter.Close(); err != nil { if err := w.tarWriter.Close(); err != nil {
result = multierror.Append(result, err) result = multierr.Append(result, err)
} }
if err := w.gzWriter.Close(); err != nil { if err := w.gzWriter.Close(); err != nil {
result = multierror.Append(result, err) result = multierr.Append(result, err)
} }
return result return result
} }

View file

@ -7,13 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd package cloudcmd
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -58,7 +58,7 @@ func NewUpgrader(outWriter io.Writer) (*Upgrader, error) {
} }
// Upgrade upgrades the cluster to the given measurements and image. // Upgrade upgrades the cluster to the given measurements and image.
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error { func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements measurements.M) error {
if err := u.updateMeasurements(ctx, measurements); err != nil { if err := u.updateMeasurements(ctx, measurements); err != nil {
return fmt.Errorf("updating measurements: %w", err) return fmt.Errorf("updating measurements: %w", err)
} }
@ -97,36 +97,25 @@ func (u *Upgrader) GetCurrentImage(ctx context.Context) (*unstructured.Unstructu
return imageStruct, imageDefinition, nil return imageStruct, imageDefinition, nil
} }
func (u *Upgrader) updateMeasurements(ctx context.Context, measurements map[uint32][]byte) error { func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measurements.M) error {
existingConf, err := u.measurementsUpdater.getCurrent(ctx, constants.JoinConfigMap) existingConf, err := u.measurementsUpdater.getCurrent(ctx, constants.JoinConfigMap)
if err != nil { if err != nil {
return fmt.Errorf("retrieving current measurements: %w", err) return fmt.Errorf("retrieving current measurements: %w", err)
} }
var currentMeasurements map[uint32][]byte var currentMeasurements measurements.M
if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), &currentMeasurements); err != nil { if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), &currentMeasurements); err != nil {
return fmt.Errorf("retrieving current measurements: %w", err) return fmt.Errorf("retrieving current measurements: %w", err)
} }
if len(currentMeasurements) == len(measurements) { if currentMeasurements.EqualTo(newMeasurements) {
changed := false
for k, v := range currentMeasurements {
if !bytes.Equal(v, measurements[k]) {
// measurements have changed
changed = true
break
}
}
if !changed {
// measurements are the same, nothing to be done
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade") fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
return nil return nil
} }
}
// backup of previous measurements // backup of previous measurements
existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename] existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename]
measurementsJSON, err := json.Marshal(measurements) measurementsJSON, err := json.Marshal(newMeasurements)
if err != nil { if err != nil {
return fmt.Errorf("marshaling measurements: %w", err) return fmt.Errorf("marshaling measurements: %w", err)
} }

View file

@ -13,6 +13,7 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func TestUpdateMeasurements(t *testing.T) {
someErr := errors.New("error") someErr := errors.New("error")
testCases := map[string]struct { testCases := map[string]struct {
updater *stubMeasurementsUpdater updater *stubMeasurementsUpdater
newMeasurements map[uint32][]byte newMeasurements measurements.M
wantUpdate bool wantUpdate bool
wantErr bool wantErr bool
}{ }{
@ -36,7 +37,7 @@ func TestUpdateMeasurements(t *testing.T) {
}, },
}, },
}, },
newMeasurements: map[uint32][]byte{ newMeasurements: measurements.M{
0: []byte("1"), 0: []byte("1"),
}, },
wantUpdate: true, wantUpdate: true,
@ -49,7 +50,7 @@ func TestUpdateMeasurements(t *testing.T) {
}, },
}, },
}, },
newMeasurements: map[uint32][]byte{ newMeasurements: measurements.M{
0: []byte("1"), 0: []byte("1"),
}, },
}, },

View file

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -28,7 +29,7 @@ import (
// Validator validates Platform Configuration Registers (PCRs). // Validator validates Platform Configuration Registers (PCRs).
type Validator struct { type Validator struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs map[uint32][]byte pcrs measurements.M
enforcedPCRs []uint32 enforcedPCRs []uint32
idkeydigest []byte idkeydigest []byte
enforceIDKeyDigest bool enforceIDKeyDigest bool
@ -147,7 +148,7 @@ func (v *Validator) V(cmd *cobra.Command) atls.Validator {
} }
// PCRS returns the validator's PCR map. // PCRS returns the validator's PCR map.
func (v *Validator) PCRS() map[uint32][]byte { func (v *Validator) PCRS() measurements.M {
return v.pcrs return v.pcrs
} }
@ -169,7 +170,7 @@ func (v *Validator) updateValidator(cmd *cobra.Command) {
} }
} }
func (v *Validator) checkPCRs(pcrs map[uint32][]byte, enforcedPCRs []uint32) error { func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error {
if len(pcrs) == 0 { if len(pcrs) == 0 {
return errors.New("no PCR values provided") return errors.New("no PCR values provided")
} }

View file

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd package cloudcmd
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"testing" "testing"
@ -15,6 +16,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch" "github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp" "github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu" "github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -24,21 +26,19 @@ import (
) )
func TestNewValidator(t *testing.T) { func TestNewValidator(t *testing.T) {
zero := []byte("00000000000000000000000000000000") testPCRs := measurements.M{
one := []byte("11111111111111111111111111111111") 0: measurements.PCRWithAllBytes(0x00),
testPCRs := map[uint32][]byte{ 1: measurements.PCRWithAllBytes(0xFF),
0: zero, 2: measurements.PCRWithAllBytes(0x00),
1: one, 3: measurements.PCRWithAllBytes(0xFF),
2: zero, 4: measurements.PCRWithAllBytes(0x00),
3: one, 5: measurements.PCRWithAllBytes(0x00),
4: zero,
5: zero,
} }
testCases := map[string]struct { testCases := map[string]struct {
provider cloudprovider.Provider provider cloudprovider.Provider
config *config.Config config *config.Config
pcrs map[uint32][]byte pcrs measurements.M
enforceIDKeyDigest bool enforceIDKeyDigest bool
idKeyDigest string idKeyDigest string
azureCVM bool azureCVM bool
@ -64,12 +64,14 @@ func TestNewValidator(t *testing.T) {
}, },
"no pcrs provided": { "no pcrs provided": {
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
pcrs: map[uint32][]byte{}, pcrs: measurements.M{},
wantErr: true, wantErr: true,
}, },
"invalid pcr length": { "invalid pcr length": {
provider: cloudprovider.GCP, provider: cloudprovider.GCP,
pcrs: map[uint32][]byte{0: []byte("0000000000000000000000000000000")}, pcrs: measurements.M{
0: bytes.Repeat([]byte{0x00}, 31),
},
wantErr: true, wantErr: true,
}, },
"unknown provider": { "unknown provider": {
@ -99,16 +101,13 @@ func TestNewValidator(t *testing.T) {
conf := &config.Config{Provider: config.ProviderConfig{}} conf := &config.Config{Provider: config.ProviderConfig{}}
if tc.provider == cloudprovider.GCP { if tc.provider == cloudprovider.GCP {
measurements := config.Measurements(tc.pcrs) conf.Provider.GCP = &config.GCPConfig{Measurements: tc.pcrs}
conf.Provider.GCP = &config.GCPConfig{Measurements: measurements}
} }
if tc.provider == cloudprovider.Azure { if tc.provider == cloudprovider.Azure {
measurements := config.Measurements(tc.pcrs) conf.Provider.Azure = &config.AzureConfig{Measurements: tc.pcrs, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
conf.Provider.Azure = &config.AzureConfig{Measurements: measurements, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
} }
if tc.provider == cloudprovider.QEMU { if tc.provider == cloudprovider.QEMU {
measurements := config.Measurements(tc.pcrs) conf.Provider.QEMU = &config.QEMUConfig{Measurements: tc.pcrs}
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
} }
validators, err := NewValidator(tc.provider, conf) validators, err := NewValidator(tc.provider, conf)
@ -125,29 +124,27 @@ func TestNewValidator(t *testing.T) {
} }
func TestValidatorV(t *testing.T) { func TestValidatorV(t *testing.T) {
zero := []byte("00000000000000000000000000000000") newTestPCRs := func() measurements.M {
return measurements.M{
newTestPCRs := func() map[uint32][]byte { 0: measurements.PCRWithAllBytes(0x00),
return map[uint32][]byte{ 1: measurements.PCRWithAllBytes(0x00),
0: zero, 2: measurements.PCRWithAllBytes(0x00),
1: zero, 3: measurements.PCRWithAllBytes(0x00),
2: zero, 4: measurements.PCRWithAllBytes(0x00),
3: zero, 5: measurements.PCRWithAllBytes(0x00),
4: zero, 6: measurements.PCRWithAllBytes(0x00),
5: zero, 7: measurements.PCRWithAllBytes(0x00),
6: zero, 8: measurements.PCRWithAllBytes(0x00),
7: zero, 9: measurements.PCRWithAllBytes(0x00),
8: zero, 10: measurements.PCRWithAllBytes(0x00),
9: zero, 11: measurements.PCRWithAllBytes(0x00),
10: zero, 12: measurements.PCRWithAllBytes(0x00),
11: zero,
12: zero,
} }
} }
testCases := map[string]struct { testCases := map[string]struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs map[uint32][]byte pcrs measurements.M
wantVs atls.Validator wantVs atls.Validator
azureCVM bool azureCVM bool
}{ }{
@ -224,7 +221,7 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
provider cloudprovider.Provider provider cloudprovider.Provider
pcrs map[uint32][]byte pcrs measurements.M
ownerID string ownerID string
clusterID string clusterID string
wantErr bool wantErr bool
@ -321,14 +318,14 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
} }
func TestUpdatePCR(t *testing.T) { func TestUpdatePCR(t *testing.T) {
emptyMap := map[uint32][]byte{} emptyMap := measurements.M{}
defaultMap := map[uint32][]byte{ defaultMap := measurements.M{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), 0: measurements.PCRWithAllBytes(0xAA),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"), 1: measurements.PCRWithAllBytes(0xBB),
} }
testCases := map[string]struct { testCases := map[string]struct {
pcrMap map[uint32][]byte pcrMap measurements.M
pcrIndex uint32 pcrIndex uint32
encoded string encoded string
wantEntries int wantEntries int
@ -389,7 +386,7 @@ func TestUpdatePCR(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
pcrs := make(map[uint32][]byte) pcrs := make(measurements.M)
for k, v := range tc.pcrMap { for k, v := range tc.pcrMap {
pcrs[k] = v pcrs[k] = v
} }

View file

@ -14,6 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
@ -39,7 +40,7 @@ func newConfigFetchMeasurementsCmd() *cobra.Command {
type fetchMeasurementsFlags struct { type fetchMeasurementsFlags struct {
measurementsURL *url.URL measurementsURL *url.URL
signatureURL *url.URL signatureURL *url.URL
config string configPath string
} }
func runConfigFetchMeasurements(cmd *cobra.Command, args []string) error { func runConfigFetchMeasurements(cmd *cobra.Command, args []string) error {
@ -57,9 +58,9 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
return err return err
} }
conf, err := config.FromFile(fileHandler, flags.config) conf, err := config.New(fileHandler, flags.configPath)
if err != nil { if err != nil {
return err return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
if conf.IsDebugImage() { if conf.IsDebugImage() {
@ -72,7 +73,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
var fetchedMeasurements config.Measurements var fetchedMeasurements measurements.M
hash, err := fetchedMeasurements.FetchAndVerify(ctx, client, flags.measurementsURL, flags.signatureURL, []byte(constants.CosignPublicKey)) hash, err := fetchedMeasurements.FetchAndVerify(ctx, client, flags.measurementsURL, flags.signatureURL, []byte(constants.CosignPublicKey))
if err != nil { if err != nil {
return err return err
@ -84,7 +85,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
} }
conf.UpdateMeasurements(fetchedMeasurements) conf.UpdateMeasurements(fetchedMeasurements)
if err := fileHandler.WriteYAML(flags.config, conf, file.OptOverwrite); err != nil { if err := fileHandler.WriteYAML(flags.configPath, conf, file.OptOverwrite); err != nil {
return err return err
} }
@ -123,7 +124,7 @@ func parseFetchMeasurementsFlags(cmd *cobra.Command) (*fetchMeasurementsFlags, e
return &fetchMeasurementsFlags{ return &fetchMeasurementsFlags{
measurementsURL: measurementsURL, measurementsURL: measurementsURL,
signatureURL: measurementsSignatureURL, signatureURL: measurementsSignatureURL,
config: config, configPath: config,
}, nil }, nil
} }

View file

@ -40,7 +40,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
wantFlags: &fetchMeasurementsFlags{ wantFlags: &fetchMeasurementsFlags{
measurementsURL: nil, measurementsURL: nil,
signatureURL: nil, signatureURL: nil,
config: constants.ConfigFilename, configPath: constants.ConfigFilename,
}, },
}, },
"url": { "url": {
@ -49,7 +49,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
wantFlags: &fetchMeasurementsFlags{ wantFlags: &fetchMeasurementsFlags{
measurementsURL: urlMustParse("https://some.other.url/with/path"), measurementsURL: urlMustParse("https://some.other.url/with/path"),
signatureURL: urlMustParse("https://some.other.url/with/path.sig"), signatureURL: urlMustParse("https://some.other.url/with/path.sig"),
config: constants.ConfigFilename, configPath: constants.ConfigFilename,
}, },
}, },
"broken url": { "broken url": {
@ -59,7 +59,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
"config": { "config": {
configFlag: "someOtherConfig.yaml", configFlag: "someOtherConfig.yaml",
wantFlags: &fetchMeasurementsFlags{ wantFlags: &fetchMeasurementsFlags{
config: "someOtherConfig.yaml", configPath: "someOtherConfig.yaml",
}, },
}, },
} }
@ -212,8 +212,7 @@ func TestConfigFetchMeasurements(t *testing.T) {
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
gcpConfig := config.Default() gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
gcpConfig.RemoveProviderExcept(cloudprovider.GCP)
gcpConfig.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-coreos-1658216163" gcpConfig.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-coreos-1658216163"
err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll) err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll)

View file

@ -0,0 +1,28 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"io"
"go.uber.org/multierr"
)
func displayConfigValidationErrors(errWriter io.Writer, configError error) error {
errs := multierr.Errors(configError)
if errs != nil {
fmt.Fprintln(errWriter, "Problems validating config file:")
for _, err := range errs {
fmt.Fprintln(errWriter, "\t"+err.Error())
}
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
return errors.New("invalid configuration")
}
return nil
}

View file

@ -13,6 +13,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -59,27 +60,27 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
return err return err
} }
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath) conf, err := config.New(fileHandler, flags.configPath)
if err != nil { if err != nil {
return fmt.Errorf("reading and validating config: %w", err) return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
var printedAWarning bool var printedAWarning bool
if config.IsDebugImage() { if conf.IsDebugImage() {
cmd.PrintErrln("Configured image doesn't look like a released production image. Double check image before deploying to production.") cmd.PrintErrln("Configured image doesn't look like a released production image. Double check image before deploying to production.")
printedAWarning = true printedAWarning = true
} }
if config.IsDebugCluster() { if conf.IsDebugCluster() {
cmd.PrintErrln("WARNING: Creating a debug cluster. This cluster is not secure and should only be used for debugging purposes.") cmd.PrintErrln("WARNING: Creating a debug cluster. This cluster is not secure and should only be used for debugging purposes.")
cmd.PrintErrln("DO NOT USE THIS CLUSTER IN PRODUCTION.") cmd.PrintErrln("DO NOT USE THIS CLUSTER IN PRODUCTION.")
printedAWarning = true printedAWarning = true
} }
if config.IsAzureNonCVM() { if conf.IsAzureNonCVM() {
cmd.PrintErrln("Disabling Confidential VMs is insecure. Use only for evaluation purposes.") cmd.PrintErrln("Disabling Confidential VMs is insecure. Use only for evaluation purposes.")
printedAWarning = true printedAWarning = true
if config.EnforcesIDKeyDigest() { if conf.EnforcesIDKeyDigest() {
cmd.PrintErrln("Your config asks for enforcing the idkeydigest. This is only available on Confidential VMs. It will not be enforced.") cmd.PrintErrln("Your config asks for enforcing the idkeydigest. This is only available on Confidential VMs. It will not be enforced.")
} }
} }
@ -89,20 +90,20 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
cmd.PrintErrln("") cmd.PrintErrln("")
} }
provider := config.GetProvider() provider := conf.GetProvider()
var instanceType string var instanceType string
switch provider { switch provider {
case cloudprovider.AWS: case cloudprovider.AWS:
instanceType = config.Provider.AWS.InstanceType instanceType = conf.Provider.AWS.InstanceType
if len(flags.name) > 10 { if len(flags.name) > 10 {
return fmt.Errorf("cluster name on AWS must not be longer than 10 characters") return fmt.Errorf("cluster name on AWS must not be longer than 10 characters")
} }
case cloudprovider.Azure: case cloudprovider.Azure:
instanceType = config.Provider.Azure.InstanceType instanceType = conf.Provider.Azure.InstanceType
case cloudprovider.GCP: case cloudprovider.GCP:
instanceType = config.Provider.GCP.InstanceType instanceType = conf.Provider.GCP.InstanceType
case cloudprovider.QEMU: case cloudprovider.QEMU:
cpus := config.Provider.QEMU.VCPUs cpus := conf.Provider.QEMU.VCPUs
instanceType = fmt.Sprintf("%d-vCPU", cpus) instanceType = fmt.Sprintf("%d-vCPU", cpus)
} }
@ -122,7 +123,7 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
} }
spinner.Start("Creating", false) spinner.Start("Creating", false)
idFile, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount) idFile, err := creator.Create(cmd.Context(), provider, conf, flags.name, instanceType, flags.controllerCount, flags.workerCount)
spinner.Stop() spinner.Stop()
if err != nil { if err != nil {
return err return err

View file

@ -78,9 +78,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return err return err
} }
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath) conf, err := config.New(fileHandler, flags.configPath)
if err != nil { if err != nil {
return fmt.Errorf("reading and validating config: %w", err) return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
var idFile clusterid.File var idFile clusterid.File
@ -88,7 +88,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("reading cluster ID file: %w", err) return fmt.Errorf("reading cluster ID file: %w", err)
} }
k8sVersion, err := versions.NewValidK8sVersion(config.KubernetesVersion) k8sVersion, err := versions.NewValidK8sVersion(conf.KubernetesVersion)
if err != nil { if err != nil {
return fmt.Errorf("validating kubernetes version: %w", err) return fmt.Errorf("validating kubernetes version: %w", err)
} }
@ -96,18 +96,18 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion) cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion)
} }
provider := config.GetProvider() provider := conf.GetProvider()
checker := license.NewChecker(quotaChecker, fileHandler) checker := license.NewChecker(quotaChecker, fileHandler)
if err := checker.CheckLicense(cmd.Context(), provider, config.Provider, cmd.Printf); err != nil { if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err) cmd.PrintErrf("License check failed: %v", err)
} }
validator, err := cloudcmd.NewValidator(provider, config) validator, err := cloudcmd.NewValidator(provider, conf)
if err != nil { if err != nil {
return err return err
} }
serviceAccURI, err := getMarshaledServiceAccountURI(provider, config, fileHandler) serviceAccURI, err := getMarshaledServiceAccountURI(provider, conf, fileHandler)
if err != nil { if err != nil {
return err return err
} }
@ -117,7 +117,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err) return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err)
} }
helmLoader := helm.New(provider, k8sVersion) helmLoader := helm.New(provider, k8sVersion)
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, config), getEnforceIDKeyDigest(provider, config)) helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, conf), getEnforceIDKeyDigest(provider, conf))
if err != nil { if err != nil {
return fmt.Errorf("loading Helm charts: %w", err) return fmt.Errorf("loading Helm charts: %w", err)
} }
@ -131,10 +131,10 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
KeyEncryptionKeyId: "", KeyEncryptionKeyId: "",
UseExistingKek: false, UseExistingKek: false,
CloudServiceAccountUri: serviceAccURI, CloudServiceAccountUri: serviceAccURI,
KubernetesVersion: config.KubernetesVersion, KubernetesVersion: conf.KubernetesVersion,
HelmDeployments: helmDeployments, HelmDeployments: helmDeployments,
EnforcedPcrs: getEnforcedPCRs(provider, config), EnforcedPcrs: getEnforcedPCRs(provider, conf),
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, config), EnforceIdkeydigest: getEnforceIDKeyDigest(provider, conf),
ConformanceMode: flags.conformance, ConformanceMode: flags.conformance,
} }
resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req) resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req)

View file

@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid" "github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
@ -360,11 +361,11 @@ func TestAttestation(t *testing.T) {
issuer := &testIssuer{ issuer := &testIssuer{
Getter: oid.QEMU{}, Getter: oid.QEMU{},
pcrs: map[uint32][]byte{ pcrs: measurements.M{
0: []byte("ffffffffffffffffffffffffffffffff"), 0: measurements.PCRWithAllBytes(0xFF),
1: []byte("ffffffffffffffffffffffffffffffff"), 1: measurements.PCRWithAllBytes(0xFF),
2: []byte("ffffffffffffffffffffffffffffffff"), 2: measurements.PCRWithAllBytes(0xFF),
3: []byte("ffffffffffffffffffffffffffffffff"), 3: measurements.PCRWithAllBytes(0xFF),
}, },
} }
serverCreds := atlscredentials.New(issuer, nil) serverCreds := atlscredentials.New(issuer, nil)
@ -389,13 +390,13 @@ func TestAttestation(t *testing.T) {
cfg := config.Default() cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.QEMU) cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Image = "some/image/location" cfg.Provider.QEMU.Image = "some/image/location"
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000") cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00)
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111") cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11)
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222") cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22)
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333") cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33)
cfg.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444") cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
cfg.Provider.QEMU.Measurements[8] = []byte("88888888888888888888888888888888") cfg.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x88)
cfg.Provider.QEMU.Measurements[9] = []byte("99999999999999999999999999999999") cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone)) require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background() ctx := context.Background()
@ -411,13 +412,13 @@ func TestAttestation(t *testing.T) {
type testValidator struct { type testValidator struct {
oid.Getter oid.Getter
pcrs map[uint32][]byte pcrs measurements.M
} }
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var attestation struct { var attestation struct {
UserData []byte UserData []byte
PCRs map[uint32][]byte PCRs measurements.M
} }
if err := json.Unmarshal(attDoc, &attestation); err != nil { if err := json.Unmarshal(attDoc, &attestation); err != nil {
return nil, err return nil, err
@ -433,14 +434,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
type testIssuer struct { type testIssuer struct {
oid.Getter oid.Getter
pcrs map[uint32][]byte pcrs measurements.M
} }
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal( return json.Marshal(
struct { struct {
UserData []byte UserData []byte
PCRs map[uint32][]byte PCRs measurements.M
}{ }{
UserData: userData, UserData: userData,
PCRs: i.pcrs, PCRs: i.pcrs,
@ -472,23 +473,23 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.ResourceGroup = "test-resource-group" conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab" conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
conf.Provider.Azure.ClientSecretValue = "test-client-secret" conf.Provider.Azure.ClientSecretValue = "test-client-secret"
conf.Provider.Azure.Measurements[4] = []byte("44444444444444444444444444444444") conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000") conf.Provider.Azure.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111") conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11)
case cloudprovider.GCP: case cloudprovider.GCP:
conf.Provider.GCP.Region = "test-region" conf.Provider.GCP.Region = "test-region"
conf.Provider.GCP.Project = "test-project" conf.Provider.GCP.Project = "test-project"
conf.Provider.GCP.Image = "some/image/location" conf.Provider.GCP.Image = "some/image/location"
conf.Provider.GCP.Zone = "test-zone" conf.Provider.GCP.Zone = "test-zone"
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path" conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
conf.Provider.GCP.Measurements[4] = []byte("44444444444444444444444444444444") conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.GCP.Measurements[8] = []byte("00000000000000000000000000000000") conf.Provider.GCP.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.GCP.Measurements[9] = []byte("11111111111111111111111111111111") conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11)
case cloudprovider.QEMU: case cloudprovider.QEMU:
conf.Provider.QEMU.Image = "some/image/location" conf.Provider.QEMU.Image = "some/image/location"
conf.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444") conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.QEMU.Measurements[8] = []byte("00000000000000000000000000000000") conf.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.QEMU.Measurements[9] = []byte("11111111111111111111111111111111") conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11)
} }
conf.RemoveProviderExcept(csp) conf.RemoveProviderExcept(csp)

View file

@ -163,14 +163,14 @@ func prepareConfig(cmd *cobra.Command, fileHandler file.Handler) (*config.Config
// check for existing config // check for existing config
if configPath != "" { if configPath != "" {
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, configPath) conf, err := config.New(fileHandler, configPath)
if err != nil { if err != nil {
return nil, err return nil, displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
if config.GetProvider() != cloudprovider.QEMU { if conf.GetProvider() != cloudprovider.QEMU {
return nil, errors.New("invalid provider for MiniConstellation cluster") return nil, errors.New("invalid provider for MiniConstellation cluster")
} }
return config, nil return conf, nil
} }
if err := cmd.Flags().Set("config", constants.ConfigFilename); err != nil { if err := cmd.Flags().Set("config", constants.ConfigFilename); err != nil {
return nil, err return nil, err

View file

@ -1,46 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
)
func readConfig(errWriter io.Writer, fileHandler file.Handler, name string) (*config.Config, error) {
cnf, err := config.FromFile(fileHandler, name)
if err != nil {
return nil, err
}
if err := validateConfig(errWriter, cnf); err != nil {
return nil, err
}
return cnf, nil
}
func validateConfig(errWriter io.Writer, cnf *config.Config) error {
msgs, err := cnf.Validate()
if err != nil {
return fmt.Errorf("performing config validation: %w", err)
}
if len(msgs) > 0 {
fmt.Fprintln(errWriter, "Invalid fields in config file:")
for _, m := range msgs {
fmt.Fprintln(errWriter, "\t"+m)
}
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
return errors.New("invalid configuration")
}
return nil
}

View file

@ -1,126 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateConfig(t *testing.T) {
testCases := map[string]struct {
cnf *config.Config
provider cloudprovider.Provider
wantOutput bool
wantErr bool
}{
"default config is not valid": {
cnf: config.Default(),
wantOutput: true,
wantErr: true,
},
"default Azure config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
az := cnf.Provider.Azure
cnf.Provider = config.ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
provider: cloudprovider.Azure,
wantOutput: true,
wantErr: true,
},
"Azure config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure),
provider: cloudprovider.Azure,
},
"default GCP config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
gcp := cnf.Provider.GCP
cnf.Provider = config.ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
provider: cloudprovider.GCP,
wantOutput: true,
wantErr: true,
},
"GCP config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP),
provider: cloudprovider.GCP,
},
"default QEMU config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
qemu := cnf.Provider.QEMU
cnf.Provider = config.ProviderConfig{}
cnf.Provider.QEMU = qemu
return cnf
}(),
provider: cloudprovider.QEMU,
wantOutput: true,
wantErr: true,
},
"QEMU config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.QEMU),
provider: cloudprovider.QEMU,
},
"config with an error": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Version = "v0"
return cnf
}(),
wantOutput: true,
wantErr: true,
},
"config without provider is not ok": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider = config.ProviderConfig{}
return cnf
}(),
wantErr: true,
},
"config without required provider": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider.Azure = nil
return cnf
}(),
provider: cloudprovider.Azure,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
out := &bytes.Buffer{}
err := validateConfig(out, tc.cnf)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantOutput, out.Len() > 0)
})
}
}

View file

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto" "github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/attestation" "github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
@ -66,16 +67,16 @@ func recover(
return err return err
} }
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath) conf, err := config.New(fileHandler, flags.configPath)
if err != nil { if err != nil {
return fmt.Errorf("reading and validating config: %w", err) return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
provider := config.GetProvider() provider := conf.GetProvider()
if provider == cloudprovider.Azure { if provider == cloudprovider.Azure {
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
} }
validator, err := cloudcmd.NewValidator(provider, config) validator, err := cloudcmd.NewValidator(provider, conf)
if err != nil { if err != nil {
return err return err
} }

View file

@ -10,6 +10,7 @@ import (
"context" "context"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -43,17 +44,17 @@ func upgradeExecute(cmd *cobra.Command, upgrader cloudUpgrader, fileHandler file
if err != nil { if err != nil {
return err return err
} }
config, err := config.FromFile(fileHandler, configPath) conf, err := config.New(fileHandler, configPath)
if err != nil { if err != nil {
return err return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
// TODO: validate upgrade config? Should be basic things like checking image is not an empty string // TODO: validate upgrade config? Should be basic things like checking image is not an empty string
// More sophisticated validation, like making sure we don't downgrade the cluster, should be done by `constellation upgrade plan` // More sophisticated validation, like making sure we don't downgrade the cluster, should be done by `constellation upgrade plan`
return upgrader.Upgrade(cmd.Context(), config.Upgrade.Image, config.Upgrade.Measurements) return upgrader.Upgrade(cmd.Context(), conf.Upgrade.Image, conf.Upgrade.Measurements)
} }
type cloudUpgrader interface { type cloudUpgrader interface {
Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error Upgrade(ctx context.Context, image string, measurements measurements.M) error
} }

View file

@ -11,6 +11,8 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
@ -41,7 +43,8 @@ func TestUpgradeExecute(t *testing.T) {
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
handler := file.NewHandler(afero.NewMemMapFs()) handler := file.NewHandler(afero.NewMemMapFs())
require.NoError(handler.WriteYAML(constants.ConfigFilename, config.Default())) cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg))
err := upgradeExecute(cmd, tc.upgrader, handler) err := upgradeExecute(cmd, tc.upgrader, handler)
if tc.wantErr { if tc.wantErr {
@ -57,6 +60,6 @@ type stubUpgrader struct {
err error err error
} }
func (u stubUpgrader) Upgrade(context.Context, string, map[uint32][]byte) error { func (u stubUpgrader) Upgrade(context.Context, string, measurements.M) error {
return u.err return u.err
} }

View file

@ -73,13 +73,13 @@ func runUpgradePlan(cmd *cobra.Command, args []string) error {
func upgradePlan(cmd *cobra.Command, planner upgradePlanner, func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
fileHandler file.Handler, client *http.Client, rekor rekorVerifier, flags upgradePlanFlags, fileHandler file.Handler, client *http.Client, rekor rekorVerifier, flags upgradePlanFlags,
) error { ) error {
config, err := config.FromFile(fileHandler, flags.configPath) conf, err := config.New(fileHandler, flags.configPath)
if err != nil { if err != nil {
return err return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
// get current image version of the cluster // get current image version of the cluster
csp := config.GetProvider() csp := conf.GetProvider()
version, err := getCurrentImageVersion(cmd.Context(), planner, csp) version, err := getCurrentImageVersion(cmd.Context(), planner, csp)
if err != nil { if err != nil {
@ -108,7 +108,7 @@ func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
return upgradePlanInteractive( return upgradePlanInteractive(
&nopWriteCloser{cmd.OutOrStdout()}, &nopWriteCloser{cmd.OutOrStdout()},
io.NopCloser(cmd.InOrStdin()), io.NopCloser(cmd.InOrStdin()),
flags.configPath, config, fileHandler, flags.configPath, conf, fileHandler,
compatibleImages, compatibleImages,
) )
} }

View file

@ -451,8 +451,8 @@ func TestUpgradePlan(t *testing.T) {
require := require.New(t) require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs()) fileHandler := file.NewHandler(afero.NewMemMapFs())
cfg := config.Default() cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.csp)
cfg.RemoveProviderExcept(tc.csp)
require.NoError(fileHandler.WriteYAML(tc.flags.configPath, cfg)) require.NoError(fileHandler.WriteYAML(tc.flags.configPath, cfg))
cmd := newUpgradePlanCmd() cmd := newUpgradePlanCmd()

View file

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid" "github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/internal/atls" "github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
@ -59,13 +60,13 @@ func verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyCli
return err return err
} }
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath) conf, err := config.New(fileHandler, flags.configPath)
if err != nil { if err != nil {
return fmt.Errorf("reading and validating config: %w", err) return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
} }
provider := config.GetProvider() provider := conf.GetProvider()
validators, err := cloudcmd.NewValidator(provider, config) validators, err := cloudcmd.NewValidator(provider, conf)
if err != nil { if err != nil {
return err return err
} }

View file

@ -67,6 +67,12 @@ If you don't have a cloud subscription, check out [MiniConstellation](first-step
Fill the values produced by the script into your configuration file. Fill the values produced by the script into your configuration file.
:::tip
Alternatively, you can leave `clientSecretValue` empty and provide the secret via the `CONSTELL_AZURE_CLIENT_SECRET_VALUE` environment variable.
:::
By default, Constellation uses `Standard_DC4as_v5` CVMs (4 vCPUs, 16 GB RAM) to create your cluster. Optionally, you can switch to a different VM type by modifying **instanceType** in the configuration file. For CVMs, any VM type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported. By default, Constellation uses `Standard_DC4as_v5` CVMs (4 vCPUs, 16 GB RAM) to create your cluster. Optionally, you can switch to a different VM type by modifying **instanceType** in the configuration file. For CVMs, any VM type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported.
Run `constellation config instance-types` to get the list of all supported options. Run `constellation config instance-types` to get the list of all supported options.
@ -112,6 +118,12 @@ If you don't have a cloud subscription, check out [MiniConstellation](first-step
Set the configuration value to the secret value. Set the configuration value to the secret value.
:::tip
Alternatively, you can leave `clientSecretValue` empty and provide the secret via the `CONSTELL_AZURE_CLIENT_SECRET_VALUE` environment variable.
:::
* **instanceType**: The VM type you want to use for your Constellation nodes. * **instanceType**: The VM type you want to use for your Constellation nodes.
For CVMs, any type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported. It defaults to `Standard_DC4as_v5` (4 vCPUs, 16 GB RAM). For CVMs, any type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported. It defaults to `Standard_DC4as_v5` (4 vCPUs, 16 GB RAM).

2
go.mod
View file

@ -66,7 +66,6 @@ require (
github.com/google/tink/go v1.7.0 github.com/google/tink/go v1.7.0
github.com/googleapis/gax-go/v2 v2.7.0 github.com/googleapis/gax-go/v2 v2.7.0
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-version v1.6.0 github.com/hashicorp/go-version v1.6.0
github.com/hashicorp/hc-install v0.4.0 github.com/hashicorp/hc-install v0.4.0
github.com/hashicorp/terraform-exec v0.17.3 github.com/hashicorp/terraform-exec v0.17.3
@ -117,6 +116,7 @@ require (
github.com/golang-jwt/jwt/v4 v4.4.2 // indirect github.com/golang-jwt/jwt/v4 v4.4.2 // indirect
github.com/google/logger v1.1.1 // indirect github.com/google/logger v1.1.1 // indirect
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect github.com/rogpeppe/go-internal v1.8.1 // indirect
golang.org/x/text v0.4.0 // indirect golang.org/x/text v0.4.0 // indirect

View file

@ -16,7 +16,7 @@ import (
"syscall" "syscall"
"github.com/edgelesssys/constellation/v2/hack/image-measurement/server" "github.com/edgelesssys/constellation/v2/hack/image-measurement/server"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/multierr" "go.uber.org/multierr"
"go.uber.org/zap" "go.uber.org/zap"
@ -282,7 +282,7 @@ func (l *libvirtInstance) deleteLibvirtInstance() error {
return err return err
} }
func (l *libvirtInstance) obtainMeasurements() (measurements config.Measurements, err error) { func (l *libvirtInstance) obtainMeasurements() (measurements measurements.M, err error) {
// sanity check // sanity check
if err := l.deleteLibvirtInstance(); err != nil { if err := l.deleteLibvirtInstance(); err != nil {
return nil, err return nil, err

View file

@ -12,6 +12,7 @@ import (
"net" "net"
"net/http" "net/http"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -20,7 +21,7 @@ import (
type Server struct { type Server struct {
log *logger.Logger log *logger.Logger
server http.Server server http.Server
measurements map[uint32][]byte measurements measurements.M
done chan<- struct{} done chan<- struct{}
} }
@ -72,7 +73,7 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) {
} }
// unmarshal the request body into a map of PCRs // unmarshal the request body into a map of PCRs
var pcrs map[uint32][]byte var pcrs measurements.M
if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil { if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil {
log.With(zap.Error(err)).Errorf("Failed to read request body") log.With(zap.Error(err)).Errorf("Failed to read request body")
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -89,6 +90,6 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) {
} }
// GetMeasurements returns the static measurements for QEMU environment. // GetMeasurements returns the static measurements for QEMU environment.
func (s *Server) GetMeasurements() map[uint32][]byte { func (s *Server) GetMeasurements() measurements.M {
return s.measurements return s.measurements
} }

View file

@ -8,7 +8,6 @@ package main
import ( import (
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"flag" "flag"
@ -20,6 +19,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
@ -68,23 +68,6 @@ func main() {
} }
} }
// Measurements contains all PCR values.
type Measurements map[uint32][]byte
var _ yaml.Marshaler = Measurements{}
// MarshalYAML forces that measurements are written as base64. Default would
// be to print list of bytes.
func (m Measurements) MarshalYAML() (any, error) {
base64Map := make(map[uint32]string)
for key, value := range m {
base64Map[key] = base64.StdEncoding.EncodeToString(value[:])
}
return base64Map, nil
}
// getAttestation connects to the Constellation verification service and returns its attestation document. // getAttestation connects to the Constellation verification service and returns its attestation document.
func getAttestation(ctx context.Context, addr string) ([]byte, error) { func getAttestation(ctx context.Context, addr string) ([]byte, error) {
conn, err := grpc.DialContext( conn, err := grpc.DialContext(
@ -109,7 +92,7 @@ func getAttestation(ctx context.Context, addr string) ([]byte, error) {
} }
// validatePCRAttDoc parses and validates PCRs of an attestation document. // validatePCRAttDoc parses and validates PCRs of an attestation document.
func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) { func validatePCRAttDoc(attDocRaw []byte) (measurements.M, error) {
attDoc := vtpm.AttestationDocument{} attDoc := vtpm.AttestationDocument{}
if err := json.Unmarshal(attDocRaw, &attDoc); err != nil { if err := json.Unmarshal(attDocRaw, &attDoc); err != nil {
return nil, err return nil, err
@ -131,7 +114,7 @@ func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) {
// printPCRs formates and prints PCRs to the given writer. // printPCRs formates and prints PCRs to the given writer.
// format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'. // format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'.
func printPCRs(w io.Writer, pcrs map[uint32][]byte, format string) error { func printPCRs(w io.Writer, pcrs measurements.M, format string) error {
switch format { switch format {
case "json": case "json":
return printPCRsJSON(w, pcrs) return printPCRsJSON(w, pcrs)
@ -142,7 +125,7 @@ func printPCRs(w io.Writer, pcrs map[uint32][]byte, format string) error {
} }
} }
func printPCRsYAML(w io.Writer, pcrs Measurements) error { func printPCRsYAML(w io.Writer, pcrs measurements.M) error {
pcrYAML, err := yaml.Marshal(pcrs) pcrYAML, err := yaml.Marshal(pcrs)
if err != nil { if err != nil {
return err return err
@ -151,7 +134,7 @@ func printPCRsYAML(w io.Writer, pcrs Measurements) error {
return nil return nil
} }
func printPCRsJSON(w io.Writer, pcrs map[uint32][]byte) error { func printPCRsJSON(w io.Writer, pcrs measurements.M) error {
pcrJSON, err := json.MarshalIndent(pcrs, "", " ") pcrJSON, err := json.MarshalIndent(pcrs, "", " ")
if err != nil { if err != nil {
return err return err
@ -162,7 +145,7 @@ func printPCRsJSON(w io.Writer, pcrs map[uint32][]byte) error {
// exportToFile writes pcrs to a file, formatted to be valid Go code. // exportToFile writes pcrs to a file, formatted to be valid Go code.
// Validity of the PCR map is not checked, and should be handled by the caller. // Validity of the PCR map is not checked, and should be handled by the caller.
func exportToFile(path string, pcrs map[uint32][]byte, fs *afero.Afero) error { func exportToFile(path string, pcrs measurements.M, fs *afero.Afero) error {
goCode := `package pcrs goCode := `package pcrs
var pcrs = map[uint32][]byte{%s var pcrs = map[uint32][]byte{%s

View file

@ -13,6 +13,7 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
"github.com/google/go-tpm-tools/proto/tpm" "github.com/google/go-tpm-tools/proto/tpm"
@ -31,12 +32,12 @@ func TestMain(m *testing.M) {
func TestExportToFile(t *testing.T) { func TestExportToFile(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
pcrs map[uint32][]byte pcrs measurements.M
fs *afero.Afero fs *afero.Afero
wantErr bool wantErr bool
}{ }{
"file not writeable": { "file not writeable": {
pcrs: map[uint32][]byte{ pcrs: measurements.M{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3}, 1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3}, 2: {0x1, 0x2, 0x3},
@ -45,7 +46,7 @@ func TestExportToFile(t *testing.T) {
wantErr: true, wantErr: true,
}, },
"file writeable": { "file writeable": {
pcrs: map[uint32][]byte{ pcrs: measurements.M{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3}, 1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3}, 2: {0x1, 0x2, 0x3},
@ -105,7 +106,7 @@ func TestValidatePCRAttDoc(t *testing.T) {
{ {
Pcrs: &tpm.PCRs{ Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256, Hash: tpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{ Pcrs: measurements.M{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
}, },
}, },
@ -122,8 +123,8 @@ func TestValidatePCRAttDoc(t *testing.T) {
{ {
Pcrs: &tpm.PCRs{ Pcrs: &tpm.PCRs{
Hash: tpm.HashAlgo_SHA256, Hash: tpm.HashAlgo_SHA256,
Pcrs: map[uint32][]byte{ Pcrs: measurements.M{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"), 0: measurements.PCRWithAllBytes(0xAA),
}, },
}, },
}, },
@ -163,11 +164,11 @@ func mustMarshalAttDoc(t *testing.T, attDoc vtpm.AttestationDocument) []byte {
func TestPrintPCRs(t *testing.T) { func TestPrintPCRs(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
pcrs map[uint32][]byte pcrs measurements.M
format string format string
}{ }{
"json": { "json": {
pcrs: map[uint32][]byte{ pcrs: measurements.M{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3}, 1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3}, 2: {0x1, 0x2, 0x3},
@ -175,7 +176,7 @@ func TestPrintPCRs(t *testing.T) {
format: "json", format: "json",
}, },
"empty format": { "empty format": {
pcrs: map[uint32][]byte{ pcrs: measurements.M{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3}, 1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3}, 2: {0x1, 0x2, 0x3},
@ -183,7 +184,7 @@ func TestPrintPCRs(t *testing.T) {
format: "", format: "",
}, },
"yaml": { "yaml": {
pcrs: map[uint32][]byte{ pcrs: measurements.M{
0: {0x1, 0x2, 0x3}, 0: {0x1, 0x2, 0x3},
1: {0x1, 0x2, 0x3}, 1: {0x1, 0x2, 0x3},
2: {0x1, 0x2, 0x3}, 2: {0x1, 0x2, 0x3},

View file

@ -15,6 +15,7 @@ import (
"strings" "strings"
"github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper" "github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/role" "github.com/edgelesssys/constellation/v2/internal/role"
@ -197,7 +198,7 @@ func (s *Server) exportPCRs(w http.ResponseWriter, r *http.Request) {
} }
// unmarshal the request body into a map of PCRs // unmarshal the request body into a map of PCRs
var pcrs map[uint32][]byte var pcrs measurements.M
if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil { if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil {
log.With(zap.Error(err)).Errorf("Failed to read request body") log.With(zap.Error(err)).Errorf("Failed to read request body")
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)

View file

@ -17,6 +17,7 @@ import (
"testing" "testing"
"github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper" "github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata" "github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -306,12 +307,12 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "192.0.100.1:1234", remoteAddr: "192.0.100.1:1234",
connect: defaultConnect, connect: defaultConnect,
method: http.MethodPost, method: http.MethodPost,
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
}, },
"incorrect method": { "incorrect method": {
remoteAddr: "192.0.100.1:1234", remoteAddr: "192.0.100.1:1234",
connect: defaultConnect, connect: defaultConnect,
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
method: http.MethodGet, method: http.MethodGet,
wantErr: true, wantErr: true,
}, },
@ -320,7 +321,7 @@ func TestExportPCRs(t *testing.T) {
connect: &stubConnect{ connect: &stubConnect{
getNetworkErr: errors.New("error"), getNetworkErr: errors.New("error"),
}, },
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
method: http.MethodPost, method: http.MethodPost,
wantErr: true, wantErr: true,
}, },
@ -335,7 +336,7 @@ func TestExportPCRs(t *testing.T) {
remoteAddr: "localhost", remoteAddr: "localhost",
connect: defaultConnect, connect: defaultConnect,
method: http.MethodPost, method: http.MethodPost,
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}), message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
wantErr: true, wantErr: true,
}, },
} }

View file

@ -12,9 +12,10 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/aws/aws-sdk-go-v2/config" awsConfig "github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
@ -28,7 +29,7 @@ type Validator struct {
} }
// NewValidator create a new Validator structure and returns it. // NewValidator create a new Validator structure and returns it.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
v := &Validator{} v := &Validator{}
v.Validator = vtpm.NewValidator( v.Validator = vtpm.NewValidator(
pcrs, pcrs,
@ -88,7 +89,7 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
} }
func getEC2Client(ctx context.Context, region string) (awsMetadataAPI, error) { func getEC2Client(ctx context.Context, region string) (awsMetadataAPI, error) {
client, err := config.LoadDefaultConfig(ctx, config.WithRegion(region)) client, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -19,6 +19,7 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
internalCrypto "github.com/edgelesssys/constellation/v2/internal/crypto" internalCrypto "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
@ -41,7 +42,7 @@ type Validator struct {
} }
// NewValidator initializes a new Azure validator with the provided PCR values. // NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,

View file

@ -17,6 +17,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator" "github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
"github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/crypto"
tpmclient "github.com/google/go-tpm-tools/client" tpmclient "github.com/google/go-tpm-tools/client"
@ -188,7 +189,7 @@ func TestGetAttestationCert(t *testing.T) {
} }
require.NoError(err) require.NoError(err)
validator := NewValidator(map[uint32][]byte{}, []uint32{}, nil) validator := NewValidator(measurements.M{}, []uint32{}, nil)
cert, err := x509.ParseCertificate(rootCert.Raw) cert, err := x509.ParseCertificate(rootCert.Raw)
require.NoError(err) require.NoError(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()

View file

@ -14,6 +14,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
certutil "github.com/edgelesssys/constellation/v2/internal/crypto" certutil "github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
@ -32,7 +33,7 @@ type Validator struct {
} }
// NewValidator initializes a new Azure validator with the provided PCR values. // NewValidator initializes a new Azure validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
rootPool := x509.NewCertPool() rootPool := x509.NewCertPool()
rootPool.AddCert(ameRoot) rootPool.AddCert(ameRoot)
v := &Validator{roots: rootPool} v := &Validator{roots: rootPool}

View file

@ -18,6 +18,7 @@ import (
"time" "time"
compute "cloud.google.com/go/compute/apiv1" compute "cloud.google.com/go/compute/apiv1"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/google/go-tpm-tools/proto/attest" "github.com/google/go-tpm-tools/proto/attest"
@ -34,7 +35,7 @@ type Validator struct {
} }
// NewValidator initializes a new GCP validator with the provided PCR values. // NewValidator initializes a new GCP validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,

View file

@ -4,7 +4,7 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only SPDX-License-Identifier: AGPL-3.0-only
*/ */
package config package measurements
import ( import (
"bytes" "bytes"
@ -18,52 +18,60 @@ import (
"net/url" "net/url"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/sigstore" "github.com/edgelesssys/constellation/v2/internal/sigstore"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// Measurements are Platform Configuration Register (PCR) values. // M are Platform Configuration Register (PCR) values that make up the Measurements.
type Measurements map[uint32][]byte type M map[uint32][]byte
var ( // PCRWithAllBytes returns a PCR value where all 32 bytes are set to b.
zero = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} func PCRWithAllBytes(b byte) []byte {
// gcpPCRs are the PCR values for a GCP Constellation node that are initially set in a generated config file. return bytes.Repeat([]byte{b}, 32)
gcpPCRs = Measurements{ }
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider) M {
switch provider {
case cloudprovider.AWS:
return M{
11: PCRWithAllBytes(0x00),
12: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.Azure:
return M{
11: PCRWithAllBytes(0x00),
12: PCRWithAllBytes(0x00),
13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
}
case cloudprovider.GCP:
return M{
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF}, 0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
11: zero, 11: PCRWithAllBytes(0x00),
12: zero, 12: PCRWithAllBytes(0x00),
13: zero, 13: PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): zero, uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
} }
case cloudprovider.QEMU:
// azurePCRs are the PCR values for an Azure Constellation node that are initially set in a generated config file. return M{
azurePCRs = Measurements{ 11: PCRWithAllBytes(0x00),
11: zero, 12: PCRWithAllBytes(0x00),
12: zero, 13: PCRWithAllBytes(0x00),
13: zero, uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
uint32(vtpm.PCRIndexClusterID): zero,
} }
default:
// awsPCRs are the PCR values for an AWS Nitro Constellation node that are initially set in a generated config file. return nil
awsPCRs = Measurements{
11: zero,
12: zero,
13: zero,
uint32(vtpm.PCRIndexClusterID): zero,
} }
}
qemuPCRs = Measurements{
11: zero,
12: zero,
13: zero,
uint32(vtpm.PCRIndexClusterID): zero,
}
)
// FetchAndVerify fetches measurement and signature files via provided URLs, // FetchAndVerify fetches measurement and signature files via provided URLs,
// using client for download. The publicKey is used to verify the measurements. // using client for download. The publicKey is used to verify the measurements.
// The hash of the fetched measurements is returned. // The hash of the fetched measurements is returned.
func (m *Measurements) FetchAndVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL, signatureURL *url.URL, publicKey []byte) (string, error) { func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL, signatureURL *url.URL, publicKey []byte) (string, error) {
measurements, err := getFromURL(ctx, client, measurementsURL) measurements, err := getFromURL(ctx, client, measurementsURL)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to fetch measurements: %w", err) return "", fmt.Errorf("failed to fetch measurements: %w", err)
@ -86,15 +94,29 @@ func (m *Measurements) FetchAndVerify(ctx context.Context, client *http.Client,
// CopyFrom copies over all values from other. Overwriting existing values, // CopyFrom copies over all values from other. Overwriting existing values,
// but keeping not specified values untouched. // but keeping not specified values untouched.
func (m Measurements) CopyFrom(other Measurements) { func (m M) CopyFrom(other M) {
for idx := range other { for idx := range other {
m[idx] = other[idx] m[idx] = other[idx]
} }
} }
// EqualTo tests whether the provided other Measurements are equal to these
// measurements.
func (m M) EqualTo(other M) bool {
if len(m) != len(other) {
return false
}
for k, v := range m {
if !bytes.Equal(v, other[k]) {
return false
}
}
return true
}
// MarshalYAML overwrites the default behaviour of writing out []byte not as // MarshalYAML overwrites the default behaviour of writing out []byte not as
// single bytes, but as a single base64 encoded string. // single bytes, but as a single base64 encoded string.
func (m Measurements) MarshalYAML() (any, error) { func (m M) MarshalYAML() (any, error) {
base64Map := make(map[uint32]string) base64Map := make(map[uint32]string)
for key, value := range m { for key, value := range m {
@ -106,14 +128,14 @@ func (m Measurements) MarshalYAML() (any, error) {
// UnmarshalYAML overwrites the default behaviour of reading []byte not as // UnmarshalYAML overwrites the default behaviour of reading []byte not as
// single bytes, but as a single base64 encoded string. // single bytes, but as a single base64 encoded string.
func (m *Measurements) UnmarshalYAML(unmarshal func(any) error) error { func (m *M) UnmarshalYAML(unmarshal func(any) error) error {
base64Map := make(map[uint32]string) base64Map := make(map[uint32]string)
err := unmarshal(base64Map) err := unmarshal(base64Map)
if err != nil { if err != nil {
return err return err
} }
*m = make(Measurements) *m = make(M)
for key, value := range base64Map { for key, value := range base64Map {
measurement, err := base64.StdEncoding.DecodeString(value) measurement, err := base64.StdEncoding.DecodeString(value)
if err != nil { if err != nil {

View file

@ -4,7 +4,7 @@ Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only SPDX-License-Identifier: AGPL-3.0-only
*/ */
package config package measurements
import ( import (
"context" "context"
@ -21,11 +21,11 @@ import (
func TestMarshalYAML(t *testing.T) { func TestMarshalYAML(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
measurements Measurements measurements M
wantBase64Map map[uint32]string wantBase64Map map[uint32]string
}{ }{
"valid measurements": { "valid measurements": {
measurements: Measurements{ measurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, 2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, 3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
}, },
@ -35,7 +35,7 @@ func TestMarshalYAML(t *testing.T) {
}, },
}, },
"omit bytes": { "omit bytes": {
measurements: Measurements{ measurements: M{
2: []byte{}, 2: []byte{},
3: []byte{1, 2, 3, 4}, 3: []byte{1, 2, 3, 4},
}, },
@ -63,7 +63,7 @@ func TestUnmarshalYAML(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
inputBase64Map map[uint32]string inputBase64Map map[uint32]string
forceUnmarshalError bool forceUnmarshalError bool
wantMeasurements Measurements wantMeasurements M
wantErr bool wantErr bool
}{ }{
"valid measurements": { "valid measurements": {
@ -71,7 +71,7 @@ func TestUnmarshalYAML(t *testing.T) {
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=", 2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=",
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=", 3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
}, },
wantMeasurements: Measurements{ wantMeasurements: M{
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53}, 2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158}, 3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
}, },
@ -81,7 +81,7 @@ func TestUnmarshalYAML(t *testing.T) {
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", 2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
}, },
wantMeasurements: Measurements{ wantMeasurements: M{
2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
}, },
@ -91,7 +91,7 @@ func TestUnmarshalYAML(t *testing.T) {
2: "This is not base64", 2: "This is not base64",
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
}, },
wantMeasurements: Measurements{ wantMeasurements: M{
2: []byte{}, 2: []byte{},
3: []byte{1, 2, 3, 4}, 3: []byte{1, 2, 3, 4},
}, },
@ -103,7 +103,7 @@ func TestUnmarshalYAML(t *testing.T) {
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", 3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
}, },
forceUnmarshalError: true, forceUnmarshalError: true,
wantMeasurements: Measurements{ wantMeasurements: M{
2: []byte{}, 2: []byte{},
3: []byte{1, 2, 3, 4}, 3: []byte{1, 2, 3, 4},
}, },
@ -116,7 +116,7 @@ func TestUnmarshalYAML(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
var m Measurements var m M
err := m.UnmarshalYAML(func(i any) error { err := m.UnmarshalYAML(func(i any) error {
if base64Map, ok := i.(map[uint32]string); ok { if base64Map, ok := i.(map[uint32]string); ok {
for key, value := range tc.inputBase64Map { for key, value := range tc.inputBase64Map {
@ -141,55 +141,55 @@ func TestUnmarshalYAML(t *testing.T) {
func TestMeasurementsCopyFrom(t *testing.T) { func TestMeasurementsCopyFrom(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
current Measurements current M
newMeasurements Measurements newMeasurements M
wantMeasurements Measurements wantMeasurements M
}{ }{
"add to empty": { "add to empty": {
current: Measurements{}, current: M{},
newMeasurements: Measurements{ newMeasurements: M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: PCRWithAllBytes(0x02),
}, },
wantMeasurements: Measurements{ wantMeasurements: M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: PCRWithAllBytes(0x02),
}, },
}, },
"keep existing": { "keep existing": {
current: Measurements{ current: M{
4: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 4: PCRWithAllBytes(0x01),
5: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 5: PCRWithAllBytes(0x02),
}, },
newMeasurements: Measurements{ newMeasurements: M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: PCRWithAllBytes(0x02),
}, },
wantMeasurements: Measurements{ wantMeasurements: M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: PCRWithAllBytes(0x02),
4: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 4: PCRWithAllBytes(0x01),
5: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 5: PCRWithAllBytes(0x02),
}, },
}, },
"overwrite existing": { "overwrite existing": {
current: Measurements{ current: M{
2: []byte{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4}, 2: PCRWithAllBytes(0x04),
3: []byte{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, 3: PCRWithAllBytes(0x05),
}, },
newMeasurements: Measurements{ newMeasurements: M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: PCRWithAllBytes(0x02),
}, },
wantMeasurements: Measurements{ wantMeasurements: M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: PCRWithAllBytes(0x02),
}, },
}, },
} }
@ -230,7 +230,7 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
signature string signature string
signatureStatus int signatureStatus int
publicKey []byte publicKey []byte
wantMeasurements Measurements wantMeasurements M
wantSHA string wantSHA string
wantError bool wantError bool
}{ }{
@ -240,8 +240,8 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=", signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
signatureStatus: http.StatusOK, signatureStatus: http.StatusOK,
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"), publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
wantMeasurements: Measurements{ wantMeasurements: M{
0: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 0: PCRWithAllBytes(0x00),
}, },
wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87", wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87",
}, },
@ -308,7 +308,7 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
} }
}) })
m := Measurements{} m := M{}
hash, err := m.FetchAndVerify(context.Background(), client, measurementsURL, signatureURL, tc.publicKey) hash, err := m.FetchAndVerify(context.Background(), client, measurementsURL, signatureURL, tc.publicKey)
if tc.wantError { if tc.wantError {
@ -321,3 +321,83 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
}) })
} }
} }
func TestPCRWithAllBytes(t *testing.T) {
testCases := map[string]struct {
b byte
wantPCR []byte
}{
"0x00": {
b: 0x00,
wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
},
"0x01": {
b: 0x01,
wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
},
"0xFF": {
b: 0xFF,
wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
pcr := PCRWithAllBytes(tc.b)
assert.Equal(tc.wantPCR, pcr)
})
}
}
func TestEqualTo(t *testing.T) {
testCases := map[string]struct {
given M
other M
wantEqual bool
}{
"same values": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
wantEqual: true,
},
"different number of elements": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0x00),
},
wantEqual: false,
},
"different values": {
given: M{
0: PCRWithAllBytes(0x00),
1: PCRWithAllBytes(0xFF),
},
other: M{
0: PCRWithAllBytes(0xFF),
1: PCRWithAllBytes(0x00),
},
wantEqual: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
if tc.wantEqual {
assert.True(tc.given.EqualTo(tc.other))
} else {
assert.False(tc.given.EqualTo(tc.other))
}
})
}
}

View file

@ -9,6 +9,7 @@ package qemu
import ( import (
"crypto" "crypto"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
"github.com/google/go-tpm/tpm2" "github.com/google/go-tpm/tpm2"
@ -21,7 +22,7 @@ type Validator struct {
} }
// NewValidator initializes a new QEMU validator with the provided PCR values. // NewValidator initializes a new QEMU validator with the provided PCR values.
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator { func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
return &Validator{ return &Validator{
Validator: vtpm.NewValidator( Validator: vtpm.NewValidator(
pcrs, pcrs,

View file

@ -14,19 +14,25 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
"os"
"regexp" "regexp"
"strings"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/go-playground/locales/en" "github.com/go-playground/locales/en"
ut "github.com/go-playground/universal-translator" ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
en_translations "github.com/go-playground/validator/v10/translations/en" en_translations "github.com/go-playground/validator/v10/translations/en"
"go.uber.org/multierr"
) )
// Measurements is a required alias since docgen is not able to work with
// types in other packages.
type Measurements = measurements.M
const ( const (
// Version1 is the first version number for Constellation config file. // Version1 is the first version number for Constellation config file.
Version1 = "v1" Version1 = "v1"
@ -164,7 +170,7 @@ type AzureConfig struct {
// Application client ID of the Active Directory app registration. // Application client ID of the Active Directory app registration.
AppClientID string `yaml:"appClientID" validate:"uuid"` AppClientID string `yaml:"appClientID" validate:"uuid"`
// description: | // description: |
// Client secret value of the Active Directory app registration credentials. // Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable.
ClientSecretValue string `yaml:"clientSecretValue" validate:"required"` ClientSecretValue string `yaml:"clientSecretValue" validate:"required"`
// description: | // description: |
// Machine image used to create Constellation nodes. // Machine image used to create Constellation nodes.
@ -277,7 +283,7 @@ func Default() *Config {
StateDiskType: "gp3", StateDiskType: "gp3",
IAMProfileControlPlane: "", IAMProfileControlPlane: "",
IAMProfileWorkerNodes: "", IAMProfileWorkerNodes: "",
Measurements: copyPCRMap(awsPCRs), Measurements: measurements.DefaultsFor(cloudprovider.AWS),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15}, EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
}, },
Azure: &AzureConfig{ Azure: &AzureConfig{
@ -289,7 +295,7 @@ func Default() *Config {
Image: DefaultImageAzure, Image: DefaultImageAzure,
InstanceType: "Standard_DC4as_v5", InstanceType: "Standard_DC4as_v5",
StateDiskType: "Premium_LRS", StateDiskType: "Premium_LRS",
Measurements: copyPCRMap(azurePCRs), Measurements: measurements.DefaultsFor(cloudprovider.Azure),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15}, EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
IDKeyDigest: "57486a447ec0f1958002a22a06b7673b9fd27d11e1c6527498056054c5fa92d23c50f9de44072760fe2b6fb89740b696", IDKeyDigest: "57486a447ec0f1958002a22a06b7673b9fd27d11e1c6527498056054c5fa92d23c50f9de44072760fe2b6fb89740b696",
EnforceIDKeyDigest: func() *bool { b := true; return &b }(), EnforceIDKeyDigest: func() *bool { b := true; return &b }(),
@ -304,7 +310,7 @@ func Default() *Config {
InstanceType: "n2d-standard-4", InstanceType: "n2d-standard-4",
StateDiskType: "pd-ssd", StateDiskType: "pd-ssd",
ServiceAccountKeyPath: "", ServiceAccountKeyPath: "",
Measurements: copyPCRMap(gcpPCRs), Measurements: measurements.DefaultsFor(cloudprovider.GCP),
EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15}, EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15},
}, },
QEMU: &QEMUConfig{ QEMU: &QEMUConfig{
@ -314,7 +320,7 @@ func Default() *Config {
MetadataAPIImage: versions.QEMUMetadataImage, MetadataAPIImage: versions.QEMUMetadataImage,
LibvirtURI: "", LibvirtURI: "",
LibvirtContainerImage: versions.LibvirtImage, LibvirtContainerImage: versions.LibvirtImage,
Measurements: copyPCRMap(qemuPCRs), Measurements: measurements.DefaultsFor(cloudprovider.QEMU),
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15}, EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
NVRAM: "production", NVRAM: "production",
}, },
@ -323,198 +329,38 @@ func Default() *Config {
} }
} }
func validateK8sVersion(fl validator.FieldLevel) bool { // FromFile returns config file with `name` read from `fileHandler` by parsing
return versions.IsSupportedK8sVersion(fl.Field().String()) // it as YAML. You should prefer config.New to read env vars and validate
// config in a consistent manner.
func FromFile(fileHandler file.Handler, name string) (*Config, error) {
var conf Config
if err := fileHandler.ReadYAMLStrict(name, &conf); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name)
}
return nil, fmt.Errorf("could not load config from file %s: %w", name, err)
}
return &conf, nil
} }
func validateAWSInstanceType(fl validator.FieldLevel) bool { // New creates a new config by:
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.AWS) // 1. Reading config file via provided fileHandler from file with name.
} // 2. Read secrets from environment variables.
// 3. Validate config.
func validateAzureInstanceType(fl validator.FieldLevel) bool { func New(fileHandler file.Handler, name string) (*Config, error) {
azureConfig := fl.Parent().Interface().(AzureConfig) // Read config file
var acceptNonCVM bool c, err := FromFile(fileHandler, name)
if azureConfig.ConfidentialVM != nil { if err != nil {
// This is the inverse of the config value (acceptNonCVMs is true if confidentialVM is false).
// We could make the validator the other way around, but this should be an explicit bypass rather than checking if CVMs are "allowed".
acceptNonCVM = !*azureConfig.ConfidentialVM
}
return validInstanceTypeForProvider(fl.Field().String(), acceptNonCVM, cloudprovider.Azure)
}
func validateGCPInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
}
// validateProvider checks if zero or more than one providers are defined in the config.
func validateProvider(sl validator.StructLevel) {
provider := sl.Current().Interface().(ProviderConfig)
providerCount := 0
if provider.AWS != nil {
providerCount++
}
if provider.Azure != nil {
providerCount++
}
if provider.GCP != nil {
providerCount++
}
if provider.QEMU != nil {
providerCount++
}
if providerCount < 1 {
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
} else if providerCount > 1 {
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
}
}
// Validate checks the config values and returns validation error messages.
// The function only returns an error if the validation itself fails.
func (c *Config) Validate() ([]string, error) {
trans := ut.New(en.New()).GetFallback()
validate := validator.New()
if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
return nil, err return nil, err
} }
// Register AWS, Azure & GCP InstanceType validation error types // Read secrets from env-vars.
if err := validate.RegisterTranslation("aws_instance_type", trans, registerTranslateAWSInstanceTypeError, translateAWSInstanceTypeError); err != nil { clientSecretValue := os.Getenv(constants.EnvVarAzureClientSecretValue)
return nil, err if clientSecretValue != "" && c.Provider.Azure != nil {
c.Provider.Azure.ClientSecretValue = clientSecretValue
} }
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil { return c, c.Validate()
return nil, err
}
if err := validate.RegisterTranslation("gcp_instance_type", trans, registerTranslateGCPInstanceTypeError, translateGCPInstanceTypeError); err != nil {
return nil, err
}
// Register Provider validation error types
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
return nil, err
}
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
return nil, err
}
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
return nil, err
}
// register custom validator with label aws_instance_type to validate the AWS instance type from config input.
if err := validate.RegisterValidation("aws_instance_type", validateAWSInstanceType); err != nil {
return nil, err
}
// register custom validator with label azure_instance_type to validate the Azure instance type from config input.
if err := validate.RegisterValidation("azure_instance_type", validateAzureInstanceType); err != nil {
return nil, err
}
// register custom validator with label gcp_instance_type to validate the GCP instance type from config input.
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
return nil, err
}
// Register provider validation
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
err := validate.Struct(c)
if err == nil {
return nil, nil
}
var errs validator.ValidationErrors
if !errors.As(err, &errs) {
return nil, err
}
var msgs []string
for _, e := range errs {
msgs = append(msgs, e.Translate(trans))
}
return msgs, nil
}
// Validation translation functions for Azure & GCP instance type errors.
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
}
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
var t string
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
} else {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
}
return t
}
func registerTranslateAWSInstanceTypeError(ut ut.Translator) error {
return ut.Add("aws_instance_type", fmt.Sprintf("{0} must be an instance from one of the following families types with size xlarge or higher: %v", instancetypes.AWSSupportedInstanceFamilies), true)
}
func translateAWSInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("aws_instance_type", fe.Field())
return t
}
func registerTranslateGCPInstanceTypeError(ut ut.Translator) error {
return ut.Add("gcp_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.GCPInstanceTypes), true)
}
func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("gcp_instance_type", fe.Field())
return t
}
// Validation translation functions for Provider errors.
func registerNoProviderError(ut ut.Translator) error {
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
}
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("no_provider", fe.Field())
return t
}
func registerMoreThanOneProviderError(ut ut.Translator) error {
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
}
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
definedProviders := make([]string, 0)
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
if c.Provider.AWS != nil {
definedProviders = append(definedProviders, "AWS")
}
if c.Provider.Azure != nil {
definedProviders = append(definedProviders, "Azure")
}
if c.Provider.GCP != nil {
definedProviders = append(definedProviders, "GCP")
}
if c.Provider.QEMU != nil {
definedProviders = append(definedProviders, "QEMU")
}
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
return t
} }
// HasProvider checks whether the config contains the provider. // HasProvider checks whether the config contains the provider.
@ -584,6 +430,19 @@ func (c *Config) RemoveProviderExcept(provider cloudprovider.Provider) {
} }
} }
// IsAzureNonCVM checks whether the chosen provider is azure and confidential VMs are disabled.
func (c *Config) IsAzureNonCVM() bool {
return c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM
}
// IsDebugCluster checks whether the cluster is configured as a debug cluster.
func (c *Config) IsDebugCluster() bool {
if c.DebugCluster != nil && *c.DebugCluster {
return true
}
return false
}
// IsDebugImage checks whether image name looks like a release image, if not it is // IsDebugImage checks whether image name looks like a release image, if not it is
// probably a debug image. In the end we do not if bootstrapper or debugd // probably a debug image. In the end we do not if bootstrapper or debugd
// was put inside an image just by looking at its name. // was put inside an image just by looking at its name.
@ -618,110 +477,80 @@ func (c *Config) GetProvider() cloudprovider.Provider {
return cloudprovider.Unknown return cloudprovider.Unknown
} }
// IsAzureNonCVM checks whether the chosen provider is azure and confidential VMs are disabled.
func (c *Config) IsAzureNonCVM() bool {
return c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM
}
// EnforcesIDKeyDigest checks whether ID Key Digest should be enforced for respective cloud provider. // EnforcesIDKeyDigest checks whether ID Key Digest should be enforced for respective cloud provider.
func (c *Config) EnforcesIDKeyDigest() bool { func (c *Config) EnforcesIDKeyDigest() bool {
return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest
} }
// FromFile returns config file with `name` read from `fileHandler` by parsing // Validate checks the config values and returns validation errors.
// it as YAML. func (c *Config) Validate() error {
func FromFile(fileHandler file.Handler, name string) (*Config, error) { trans := ut.New(en.New()).GetFallback()
var conf Config validate := validator.New()
if err := fileHandler.ReadYAMLStrict(name, &conf); err != nil { if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
if errors.Is(err, fs.ErrNotExist) { return err
return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name)
} }
return nil, fmt.Errorf("could not load config from file %s: %w", name, err)
// Register AWS, Azure & GCP InstanceType validation error types
if err := validate.RegisterTranslation("aws_instance_type", trans, registerTranslateAWSInstanceTypeError, translateAWSInstanceTypeError); err != nil {
return err
} }
return &conf, nil
} if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
return err
func copyPCRMap(m map[uint32][]byte) map[uint32][]byte { }
res := make(Measurements)
res.CopyFrom(m) if err := validate.RegisterTranslation("gcp_instance_type", trans, registerTranslateGCPInstanceTypeError, translateGCPInstanceTypeError); err != nil {
return res return err
} }
func validInstanceTypeForProvider(insType string, acceptNonCVM bool, provider cloudprovider.Provider) bool { // Register Provider validation error types
switch provider { if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
case cloudprovider.AWS: return err
return checkIfAWSInstanceTypeIsValid(insType) }
case cloudprovider.Azure:
if acceptNonCVM { if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
for _, instanceType := range instancetypes.AzureTrustedLaunchInstanceTypes { return err
if insType == instanceType { }
return true
} // register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
} if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
} else { return err
for _, instanceType := range instancetypes.AzureCVMInstanceTypes { }
if insType == instanceType {
return true // register custom validator with label aws_instance_type to validate the AWS instance type from config input.
} if err := validate.RegisterValidation("aws_instance_type", validateAWSInstanceType); err != nil {
} return err
} }
return false
case cloudprovider.GCP: // register custom validator with label azure_instance_type to validate the Azure instance type from config input.
for _, instanceType := range instancetypes.GCPInstanceTypes { if err := validate.RegisterValidation("azure_instance_type", validateAzureInstanceType); err != nil {
if insType == instanceType { return err
return true }
}
} // register custom validator with label gcp_instance_type to validate the GCP instance type from config input.
return false if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
default: return err
return false }
}
} // Register provider validation
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
// checkIfAWSInstanceTypeIsValid checks if an AWS instance type passed as user input is in one of the instance families supporting NitroTPM.
func checkIfAWSInstanceTypeIsValid(userInput string) bool { err := validate.Struct(c)
// Check if user or code does anything weird and tries to pass multiple strings as one if err == nil {
if strings.Contains(userInput, " ") { return nil
return false }
}
if strings.Contains(userInput, ",") { var errs validator.ValidationErrors
return false if !errors.As(err, &errs) {
} return err
if strings.Contains(userInput, ";") { }
return false
} var validationErrors error
for _, e := range errs {
splitInstanceType := strings.Split(userInput, ".") validationErrors = multierr.Append(
validationErrors,
if len(splitInstanceType) != 2 { errors.New(e.Translate(trans)),
return false )
} }
return validationErrors
userDefinedFamily := splitInstanceType[0]
userDefinedSize := splitInstanceType[1]
// Check if instace type has at least 4 vCPUs (= contains "xlarge" in its name)
hasEnoughVCPUs := strings.Contains(userDefinedSize, "xlarge")
if !hasEnoughVCPUs {
return false
}
// Now check if the user input is a supported family
// Note that we cannot directly use the family split from the Graviton check above, as some instances are directly specified by their full name and not just the family in general
for _, supportedFamily := range instancetypes.AWSSupportedInstanceFamilies {
supportedFamilyLowercase := strings.ToLower(supportedFamily)
if userDefinedFamily == supportedFamilyLowercase {
return true
}
}
return false
}
// IsDebugCluster checks whether the cluster is configured as a debug cluster.
func (c *Config) IsDebugCluster() bool {
if c.DebugCluster != nil && *c.DebugCluster {
return true
}
return false
} }

View file

@ -242,8 +242,8 @@ func init() {
AzureConfigDoc.Fields[6].Name = "clientSecretValue" AzureConfigDoc.Fields[6].Name = "clientSecretValue"
AzureConfigDoc.Fields[6].Type = "string" AzureConfigDoc.Fields[6].Type = "string"
AzureConfigDoc.Fields[6].Note = "" AzureConfigDoc.Fields[6].Note = ""
AzureConfigDoc.Fields[6].Description = "Client secret value of the Active Directory app registration credentials." AzureConfigDoc.Fields[6].Description = "Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable."
AzureConfigDoc.Fields[6].Comments[encoder.LineComment] = "Client secret value of the Active Directory app registration credentials." AzureConfigDoc.Fields[6].Comments[encoder.LineComment] = "Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable."
AzureConfigDoc.Fields[7].Name = "image" AzureConfigDoc.Fields[7].Name = "image"
AzureConfigDoc.Fields[7].Type = "string" AzureConfigDoc.Fields[7].Type = "string"
AzureConfigDoc.Fields[7].Note = "" AzureConfigDoc.Fields[7].Note = ""

View file

@ -10,6 +10,7 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes" "github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
@ -21,6 +22,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/goleak" "go.uber.org/goleak"
"go.uber.org/multierr"
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
@ -108,33 +110,49 @@ func TestFromFile(t *testing.T) {
} }
} }
func TestValidate(t *testing.T) { func TestNewWithDefaultOptions(t *testing.T) {
const defaultMsgCount = 20 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
testCases := map[string]struct { testCases := map[string]struct {
cnf *Config confToWrite *Config
wantMsgCount int envToSet map[string]string
wantErr bool
wantClientSecretValue string
}{ }{
"default config is valid": { "set env works": {
cnf: Default(), confToWrite: func() *Config { // valid config with all, but clientSecretValue
wantMsgCount: defaultMsgCount, c := Default()
}, c.RemoveProviderExcept(cloudprovider.Azure)
"config with 1 error": { c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
cnf: func() *Config { c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
cnf := Default() c.Provider.Azure.Location = "westus"
cnf.Version = "v0" c.Provider.Azure.ResourceGroup = "test"
return cnf c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Image = "/communityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/images/constellation/versions/2.2.0"
return c
}(), }(),
wantMsgCount: defaultMsgCount + 1, envToSet: map[string]string{
constants.EnvVarAzureClientSecretValue: "some-secret",
}, },
"config with 2 errors": { wantClientSecretValue: "some-secret",
cnf: func() *Config { },
cnf := Default() "set env overwrites": {
cnf.Version = "v0" confToWrite: func() *Config {
cnf.StateDiskSizeGB = -1 c := Default()
return cnf c.RemoveProviderExcept(cloudprovider.Azure)
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
c.Provider.Azure.Location = "westus"
c.Provider.Azure.ResourceGroup = "test"
c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well.
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
c.Provider.Azure.Image = "/communityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/images/constellation/versions/2.2.0"
return c
}(), }(),
wantMsgCount: defaultMsgCount + 2, envToSet: map[string]string{
constants.EnvVarAzureClientSecretValue: "some-secret",
},
wantClientSecretValue: "some-secret",
}, },
} }
@ -143,9 +161,126 @@ func TestValidate(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
require := require.New(t) require := require.New(t)
msgs, err := tc.cnf.Validate() // Setup
fileHandler := file.NewHandler(afero.NewMemMapFs())
err := fileHandler.WriteYAML(constants.ConfigFilename, tc.confToWrite)
require.NoError(err) require.NoError(err)
assert.Len(msgs, tc.wantMsgCount) for envKey, envValue := range tc.envToSet {
t.Setenv(envKey, envValue)
}
// Test
c, err := New(fileHandler, constants.ConfigFilename)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(c.Provider.Azure.ClientSecretValue, tc.wantClientSecretValue)
})
}
}
func TestValidate(t *testing.T) {
const defaultErrCount = 20 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
const azErrCount = 8
const gcpErrCount = 5
testCases := map[string]struct {
cnf *Config
wantErr bool
wantErrCount int
}{
"default config is not valid": {
cnf: Default(),
wantErr: true,
wantErrCount: defaultErrCount,
},
"v0 is one error": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount + 1,
},
"v0 and negative state disk are two errors": {
cnf: func() *Config {
cnf := Default()
cnf.Version = "v0"
cnf.StateDiskSizeGB = -1
return cnf
}(),
wantErr: true,
wantErrCount: defaultErrCount + 2,
},
"default Azure config is not valid": {
cnf: func() *Config {
cnf := Default()
az := cnf.Provider.Azure
cnf.Provider = ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
wantErr: true,
wantErrCount: azErrCount,
},
"Azure config with all required fields is valid": {
cnf: func() *Config {
cnf := Default()
az := cnf.Provider.Azure
az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab"
az.TenantID = "01234567-0123-0123-0123-0123456789ab"
az.Location = "test-location"
az.UserAssignedIdentity = "test-identity"
az.Image = "some/image/location"
az.ResourceGroup = "test-resource-group"
az.AppClientID = "01234567-0123-0123-0123-0123456789ab"
az.ClientSecretValue = "test-client-secret"
cnf.Provider = ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
},
"default GCP config is not valid": {
cnf: func() *Config {
cnf := Default()
gcp := cnf.Provider.GCP
cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
wantErr: true,
wantErrCount: gcpErrCount,
},
"GCP config with all required fields is valid": {
cnf: func() *Config {
cnf := Default()
gcp := cnf.Provider.GCP
gcp.Region = "test-region"
gcp.Project = "test-project"
gcp.Image = "some/image/location"
gcp.Zone = "test-zone"
gcp.ServiceAccountKeyPath = "test-key-path"
cnf.Provider = ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := tc.cnf.Validate()
if tc.wantErr {
assert.Error(err)
assert.Len(multierr.Errors(err), tc.wantErrCount)
return
}
assert.NoError(err)
}) })
} }
} }
@ -260,10 +395,10 @@ func TestConfigGeneratedDocsFresh(t *testing.T) {
func TestConfig_UpdateMeasurements(t *testing.T) { func TestConfig_UpdateMeasurements(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
newMeasurements := Measurements{ newMeasurements := measurements.M{
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1: measurements.PCRWithAllBytes(0x00),
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 2: measurements.PCRWithAllBytes(0x01),
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}, 3: measurements.PCRWithAllBytes(0x02),
} }
{ // AWS { // AWS

View file

@ -0,0 +1,212 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package config
import (
"fmt"
"strings"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
"github.com/edgelesssys/constellation/v2/internal/versions"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
)
func validateK8sVersion(fl validator.FieldLevel) bool {
return versions.IsSupportedK8sVersion(fl.Field().String())
}
func validateAWSInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.AWS)
}
func validateAzureInstanceType(fl validator.FieldLevel) bool {
azureConfig := fl.Parent().Interface().(AzureConfig)
var acceptNonCVM bool
if azureConfig.ConfidentialVM != nil {
// This is the inverse of the config value (acceptNonCVMs is true if confidentialVM is false).
// We could make the validator the other way around, but this should be an explicit bypass rather than checking if CVMs are "allowed".
acceptNonCVM = !*azureConfig.ConfidentialVM
}
return validInstanceTypeForProvider(fl.Field().String(), acceptNonCVM, cloudprovider.Azure)
}
func validateGCPInstanceType(fl validator.FieldLevel) bool {
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
}
// validateProvider checks if zero or more than one providers are defined in the config.
func validateProvider(sl validator.StructLevel) {
provider := sl.Current().Interface().(ProviderConfig)
providerCount := 0
if provider.AWS != nil {
providerCount++
}
if provider.Azure != nil {
providerCount++
}
if provider.GCP != nil {
providerCount++
}
if provider.QEMU != nil {
providerCount++
}
if providerCount < 1 {
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
} else if providerCount > 1 {
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
}
}
func registerTranslateAWSInstanceTypeError(ut ut.Translator) error {
return ut.Add("aws_instance_type", fmt.Sprintf("{0} must be an instance from one of the following families types with size xlarge or higher: %v", instancetypes.AWSSupportedInstanceFamilies), true)
}
func translateAWSInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("aws_instance_type", fe.Field())
return t
}
func registerTranslateGCPInstanceTypeError(ut ut.Translator) error {
return ut.Add("gcp_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.GCPInstanceTypes), true)
}
func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("gcp_instance_type", fe.Field())
return t
}
// Validation translation functions for Provider errors.
func registerNoProviderError(ut ut.Translator) error {
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
}
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("no_provider", fe.Field())
return t
}
func registerMoreThanOneProviderError(ut ut.Translator) error {
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
}
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
definedProviders := make([]string, 0)
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
if c.Provider.AWS != nil {
definedProviders = append(definedProviders, "AWS")
}
if c.Provider.Azure != nil {
definedProviders = append(definedProviders, "Azure")
}
if c.Provider.GCP != nil {
definedProviders = append(definedProviders, "GCP")
}
if c.Provider.QEMU != nil {
definedProviders = append(definedProviders, "QEMU")
}
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
return t
}
func validInstanceTypeForProvider(insType string, acceptNonCVM bool, provider cloudprovider.Provider) bool {
switch provider {
case cloudprovider.AWS:
return checkIfAWSInstanceTypeIsValid(insType)
case cloudprovider.Azure:
if acceptNonCVM {
for _, instanceType := range instancetypes.AzureTrustedLaunchInstanceTypes {
if insType == instanceType {
return true
}
}
} else {
for _, instanceType := range instancetypes.AzureCVMInstanceTypes {
if insType == instanceType {
return true
}
}
}
return false
case cloudprovider.GCP:
for _, instanceType := range instancetypes.GCPInstanceTypes {
if insType == instanceType {
return true
}
}
return false
default:
return false
}
}
// checkIfAWSInstanceTypeIsValid checks if an AWS instance type passed as user input is in one of the instance families supporting NitroTPM.
func checkIfAWSInstanceTypeIsValid(userInput string) bool {
// Check if user or code does anything weird and tries to pass multiple strings as one
if strings.Contains(userInput, " ") {
return false
}
if strings.Contains(userInput, ",") {
return false
}
if strings.Contains(userInput, ";") {
return false
}
splitInstanceType := strings.Split(userInput, ".")
if len(splitInstanceType) != 2 {
return false
}
userDefinedFamily := splitInstanceType[0]
userDefinedSize := splitInstanceType[1]
// Check if instace type has at least 4 vCPUs (= contains "xlarge" in its name)
hasEnoughVCPUs := strings.Contains(userDefinedSize, "xlarge")
if !hasEnoughVCPUs {
return false
}
// Now check if the user input is a supported family
// Note that we cannot directly use the family split from the Graviton check above, as some instances are directly specified by their full name and not just the family in general
for _, supportedFamily := range instancetypes.AWSSupportedInstanceFamilies {
supportedFamilyLowercase := strings.ToLower(supportedFamily)
if userDefinedFamily == supportedFamilyLowercase {
return true
}
}
return false
}
// Validation translation functions for Azure & GCP instance type errors.
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
}
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
var t string
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
} else {
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
}
return t
}

View file

@ -113,6 +113,11 @@ const (
MinControllerCount = 1 MinControllerCount = 1
// MinWorkerCount is the minimum number of worker nodes. // MinWorkerCount is the minimum number of worker nodes.
MinWorkerCount = 1 MinWorkerCount = 1
// EnvVarPrefix is expected prefix for environment variables used to overwrite config parameters.
EnvVarPrefix = "CONSTELL_"
// EnvVarAzureClientSecretValue is environment variable to overwrite
// provider.azure.clientSecretValue .
EnvVarAzureClientSecretValue = EnvVarPrefix + "AZURE_CLIENT_SECRET_VALUE"
// //
// Kubernetes. // Kubernetes.