mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-26 07:16:08 -05:00
AB#2512 Config secrets via env var & config refactoring (#544)
* refactor measurements to use consistent types and less byte pushing * refactor: only rely on a single multierr dependency * extend config creation with envar support * document changes Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
parent
80a801629e
commit
bb76a4e4c8
@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Environment variable `CONSTELL_AZURE_CLIENT_SECRET_VALUE` as an alternative way to provide the configuration value `provider.azure.clientSecretValue`.
|
||||
|
||||
### Changed
|
||||
<!-- For changes in existing functionality. -->
|
||||
|
@ -22,10 +22,10 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/multierr"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
testclock "k8s.io/utils/clock/testing"
|
||||
)
|
||||
@ -713,10 +713,10 @@ func (w *tarGzWriter) Bytes() []byte {
|
||||
|
||||
func (w *tarGzWriter) Close() (result error) {
|
||||
if err := w.tarWriter.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
result = multierr.Append(result, err)
|
||||
}
|
||||
if err := w.gzWriter.Close(); err != nil {
|
||||
result = multierror.Append(result, err)
|
||||
result = multierr.Append(result, err)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
@ -7,13 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package cloudcmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
corev1 "k8s.io/api/core/v1"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
@ -58,7 +58,7 @@ func NewUpgrader(outWriter io.Writer) (*Upgrader, error) {
|
||||
}
|
||||
|
||||
// Upgrade upgrades the cluster to the given measurements and image.
|
||||
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error {
|
||||
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements measurements.M) error {
|
||||
if err := u.updateMeasurements(ctx, measurements); err != nil {
|
||||
return fmt.Errorf("updating measurements: %w", err)
|
||||
}
|
||||
@ -97,36 +97,25 @@ func (u *Upgrader) GetCurrentImage(ctx context.Context) (*unstructured.Unstructu
|
||||
return imageStruct, imageDefinition, nil
|
||||
}
|
||||
|
||||
func (u *Upgrader) updateMeasurements(ctx context.Context, measurements map[uint32][]byte) error {
|
||||
func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measurements.M) error {
|
||||
existingConf, err := u.measurementsUpdater.getCurrent(ctx, constants.JoinConfigMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("retrieving current measurements: %w", err)
|
||||
}
|
||||
|
||||
var currentMeasurements map[uint32][]byte
|
||||
var currentMeasurements measurements.M
|
||||
if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), ¤tMeasurements); err != nil {
|
||||
return fmt.Errorf("retrieving current measurements: %w", err)
|
||||
}
|
||||
if len(currentMeasurements) == len(measurements) {
|
||||
changed := false
|
||||
for k, v := range currentMeasurements {
|
||||
if !bytes.Equal(v, measurements[k]) {
|
||||
// measurements have changed
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
// measurements are the same, nothing to be done
|
||||
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
|
||||
return nil
|
||||
}
|
||||
if currentMeasurements.EqualTo(newMeasurements) {
|
||||
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
|
||||
return nil
|
||||
}
|
||||
|
||||
// backup of previous measurements
|
||||
existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename]
|
||||
|
||||
measurementsJSON, err := json.Marshal(measurements)
|
||||
measurementsJSON, err := json.Marshal(newMeasurements)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling measurements: %w", err)
|
||||
}
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -24,7 +25,7 @@ func TestUpdateMeasurements(t *testing.T) {
|
||||
someErr := errors.New("error")
|
||||
testCases := map[string]struct {
|
||||
updater *stubMeasurementsUpdater
|
||||
newMeasurements map[uint32][]byte
|
||||
newMeasurements measurements.M
|
||||
wantUpdate bool
|
||||
wantErr bool
|
||||
}{
|
||||
@ -36,7 +37,7 @@ func TestUpdateMeasurements(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
newMeasurements: map[uint32][]byte{
|
||||
newMeasurements: measurements.M{
|
||||
0: []byte("1"),
|
||||
},
|
||||
wantUpdate: true,
|
||||
@ -49,7 +50,7 @@ func TestUpdateMeasurements(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
newMeasurements: map[uint32][]byte{
|
||||
newMeasurements: measurements.M{
|
||||
0: []byte("1"),
|
||||
},
|
||||
},
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
@ -28,7 +29,7 @@ import (
|
||||
// Validator validates Platform Configuration Registers (PCRs).
|
||||
type Validator struct {
|
||||
provider cloudprovider.Provider
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
enforcedPCRs []uint32
|
||||
idkeydigest []byte
|
||||
enforceIDKeyDigest bool
|
||||
@ -147,7 +148,7 @@ func (v *Validator) V(cmd *cobra.Command) atls.Validator {
|
||||
}
|
||||
|
||||
// PCRS returns the validator's PCR map.
|
||||
func (v *Validator) PCRS() map[uint32][]byte {
|
||||
func (v *Validator) PCRS() measurements.M {
|
||||
return v.pcrs
|
||||
}
|
||||
|
||||
@ -169,7 +170,7 @@ func (v *Validator) updateValidator(cmd *cobra.Command) {
|
||||
}
|
||||
}
|
||||
|
||||
func (v *Validator) checkPCRs(pcrs map[uint32][]byte, enforcedPCRs []uint32) error {
|
||||
func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error {
|
||||
if len(pcrs) == 0 {
|
||||
return errors.New("no PCR values provided")
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
|
||||
package cloudcmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"testing"
|
||||
@ -15,6 +16,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
@ -24,21 +26,19 @@ import (
|
||||
)
|
||||
|
||||
func TestNewValidator(t *testing.T) {
|
||||
zero := []byte("00000000000000000000000000000000")
|
||||
one := []byte("11111111111111111111111111111111")
|
||||
testPCRs := map[uint32][]byte{
|
||||
0: zero,
|
||||
1: one,
|
||||
2: zero,
|
||||
3: one,
|
||||
4: zero,
|
||||
5: zero,
|
||||
testPCRs := measurements.M{
|
||||
0: measurements.PCRWithAllBytes(0x00),
|
||||
1: measurements.PCRWithAllBytes(0xFF),
|
||||
2: measurements.PCRWithAllBytes(0x00),
|
||||
3: measurements.PCRWithAllBytes(0xFF),
|
||||
4: measurements.PCRWithAllBytes(0x00),
|
||||
5: measurements.PCRWithAllBytes(0x00),
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
provider cloudprovider.Provider
|
||||
config *config.Config
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
enforceIDKeyDigest bool
|
||||
idKeyDigest string
|
||||
azureCVM bool
|
||||
@ -64,13 +64,15 @@ func TestNewValidator(t *testing.T) {
|
||||
},
|
||||
"no pcrs provided": {
|
||||
provider: cloudprovider.Azure,
|
||||
pcrs: map[uint32][]byte{},
|
||||
pcrs: measurements.M{},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid pcr length": {
|
||||
provider: cloudprovider.GCP,
|
||||
pcrs: map[uint32][]byte{0: []byte("0000000000000000000000000000000")},
|
||||
wantErr: true,
|
||||
pcrs: measurements.M{
|
||||
0: bytes.Repeat([]byte{0x00}, 31),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"unknown provider": {
|
||||
provider: cloudprovider.Unknown,
|
||||
@ -99,16 +101,13 @@ func TestNewValidator(t *testing.T) {
|
||||
|
||||
conf := &config.Config{Provider: config.ProviderConfig{}}
|
||||
if tc.provider == cloudprovider.GCP {
|
||||
measurements := config.Measurements(tc.pcrs)
|
||||
conf.Provider.GCP = &config.GCPConfig{Measurements: measurements}
|
||||
conf.Provider.GCP = &config.GCPConfig{Measurements: tc.pcrs}
|
||||
}
|
||||
if tc.provider == cloudprovider.Azure {
|
||||
measurements := config.Measurements(tc.pcrs)
|
||||
conf.Provider.Azure = &config.AzureConfig{Measurements: measurements, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
|
||||
conf.Provider.Azure = &config.AzureConfig{Measurements: tc.pcrs, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
|
||||
}
|
||||
if tc.provider == cloudprovider.QEMU {
|
||||
measurements := config.Measurements(tc.pcrs)
|
||||
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
|
||||
conf.Provider.QEMU = &config.QEMUConfig{Measurements: tc.pcrs}
|
||||
}
|
||||
|
||||
validators, err := NewValidator(tc.provider, conf)
|
||||
@ -125,29 +124,27 @@ func TestNewValidator(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestValidatorV(t *testing.T) {
|
||||
zero := []byte("00000000000000000000000000000000")
|
||||
|
||||
newTestPCRs := func() map[uint32][]byte {
|
||||
return map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
6: zero,
|
||||
7: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
10: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
newTestPCRs := func() measurements.M {
|
||||
return measurements.M{
|
||||
0: measurements.PCRWithAllBytes(0x00),
|
||||
1: measurements.PCRWithAllBytes(0x00),
|
||||
2: measurements.PCRWithAllBytes(0x00),
|
||||
3: measurements.PCRWithAllBytes(0x00),
|
||||
4: measurements.PCRWithAllBytes(0x00),
|
||||
5: measurements.PCRWithAllBytes(0x00),
|
||||
6: measurements.PCRWithAllBytes(0x00),
|
||||
7: measurements.PCRWithAllBytes(0x00),
|
||||
8: measurements.PCRWithAllBytes(0x00),
|
||||
9: measurements.PCRWithAllBytes(0x00),
|
||||
10: measurements.PCRWithAllBytes(0x00),
|
||||
11: measurements.PCRWithAllBytes(0x00),
|
||||
12: measurements.PCRWithAllBytes(0x00),
|
||||
}
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
provider cloudprovider.Provider
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
wantVs atls.Validator
|
||||
azureCVM bool
|
||||
}{
|
||||
@ -224,7 +221,7 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
|
||||
|
||||
testCases := map[string]struct {
|
||||
provider cloudprovider.Provider
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
ownerID string
|
||||
clusterID string
|
||||
wantErr bool
|
||||
@ -321,14 +318,14 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdatePCR(t *testing.T) {
|
||||
emptyMap := map[uint32][]byte{}
|
||||
defaultMap := map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
|
||||
emptyMap := measurements.M{}
|
||||
defaultMap := measurements.M{
|
||||
0: measurements.PCRWithAllBytes(0xAA),
|
||||
1: measurements.PCRWithAllBytes(0xBB),
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
pcrMap map[uint32][]byte
|
||||
pcrMap measurements.M
|
||||
pcrIndex uint32
|
||||
encoded string
|
||||
wantEntries int
|
||||
@ -389,7 +386,7 @@ func TestUpdatePCR(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
pcrs := make(map[uint32][]byte)
|
||||
pcrs := make(measurements.M)
|
||||
for k, v := range tc.pcrMap {
|
||||
pcrs[k] = v
|
||||
}
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
@ -39,7 +40,7 @@ func newConfigFetchMeasurementsCmd() *cobra.Command {
|
||||
type fetchMeasurementsFlags struct {
|
||||
measurementsURL *url.URL
|
||||
signatureURL *url.URL
|
||||
config string
|
||||
configPath string
|
||||
}
|
||||
|
||||
func runConfigFetchMeasurements(cmd *cobra.Command, args []string) error {
|
||||
@ -57,9 +58,9 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
|
||||
return err
|
||||
}
|
||||
|
||||
conf, err := config.FromFile(fileHandler, flags.config)
|
||||
conf, err := config.New(fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
|
||||
if conf.IsDebugImage() {
|
||||
@ -72,7 +73,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
var fetchedMeasurements config.Measurements
|
||||
var fetchedMeasurements measurements.M
|
||||
hash, err := fetchedMeasurements.FetchAndVerify(ctx, client, flags.measurementsURL, flags.signatureURL, []byte(constants.CosignPublicKey))
|
||||
if err != nil {
|
||||
return err
|
||||
@ -84,7 +85,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
|
||||
}
|
||||
|
||||
conf.UpdateMeasurements(fetchedMeasurements)
|
||||
if err := fileHandler.WriteYAML(flags.config, conf, file.OptOverwrite); err != nil {
|
||||
if err := fileHandler.WriteYAML(flags.configPath, conf, file.OptOverwrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -123,7 +124,7 @@ func parseFetchMeasurementsFlags(cmd *cobra.Command) (*fetchMeasurementsFlags, e
|
||||
return &fetchMeasurementsFlags{
|
||||
measurementsURL: measurementsURL,
|
||||
signatureURL: measurementsSignatureURL,
|
||||
config: config,
|
||||
configPath: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -40,7 +40,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
|
||||
wantFlags: &fetchMeasurementsFlags{
|
||||
measurementsURL: nil,
|
||||
signatureURL: nil,
|
||||
config: constants.ConfigFilename,
|
||||
configPath: constants.ConfigFilename,
|
||||
},
|
||||
},
|
||||
"url": {
|
||||
@ -49,7 +49,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
|
||||
wantFlags: &fetchMeasurementsFlags{
|
||||
measurementsURL: urlMustParse("https://some.other.url/with/path"),
|
||||
signatureURL: urlMustParse("https://some.other.url/with/path.sig"),
|
||||
config: constants.ConfigFilename,
|
||||
configPath: constants.ConfigFilename,
|
||||
},
|
||||
},
|
||||
"broken url": {
|
||||
@ -59,7 +59,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
|
||||
"config": {
|
||||
configFlag: "someOtherConfig.yaml",
|
||||
wantFlags: &fetchMeasurementsFlags{
|
||||
config: "someOtherConfig.yaml",
|
||||
configPath: "someOtherConfig.yaml",
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -212,8 +212,7 @@ func TestConfigFetchMeasurements(t *testing.T) {
|
||||
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
|
||||
gcpConfig := config.Default()
|
||||
gcpConfig.RemoveProviderExcept(cloudprovider.GCP)
|
||||
gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
|
||||
gcpConfig.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-coreos-1658216163"
|
||||
|
||||
err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll)
|
||||
|
28
cli/internal/cmd/configvalidation.go
Normal file
28
cli/internal/cmd/configvalidation.go
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
func displayConfigValidationErrors(errWriter io.Writer, configError error) error {
|
||||
errs := multierr.Errors(configError)
|
||||
if errs != nil {
|
||||
fmt.Fprintln(errWriter, "Problems validating config file:")
|
||||
for _, err := range errs {
|
||||
fmt.Fprintln(errWriter, "\t"+err.Error())
|
||||
}
|
||||
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
|
||||
return errors.New("invalid configuration")
|
||||
}
|
||||
return nil
|
||||
}
|
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
@ -59,27 +60,27 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
|
||||
conf, err := config.New(fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading and validating config: %w", err)
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
|
||||
var printedAWarning bool
|
||||
if config.IsDebugImage() {
|
||||
if conf.IsDebugImage() {
|
||||
cmd.PrintErrln("Configured image doesn't look like a released production image. Double check image before deploying to production.")
|
||||
printedAWarning = true
|
||||
}
|
||||
|
||||
if config.IsDebugCluster() {
|
||||
if conf.IsDebugCluster() {
|
||||
cmd.PrintErrln("WARNING: Creating a debug cluster. This cluster is not secure and should only be used for debugging purposes.")
|
||||
cmd.PrintErrln("DO NOT USE THIS CLUSTER IN PRODUCTION.")
|
||||
printedAWarning = true
|
||||
}
|
||||
|
||||
if config.IsAzureNonCVM() {
|
||||
if conf.IsAzureNonCVM() {
|
||||
cmd.PrintErrln("Disabling Confidential VMs is insecure. Use only for evaluation purposes.")
|
||||
printedAWarning = true
|
||||
if config.EnforcesIDKeyDigest() {
|
||||
if conf.EnforcesIDKeyDigest() {
|
||||
cmd.PrintErrln("Your config asks for enforcing the idkeydigest. This is only available on Confidential VMs. It will not be enforced.")
|
||||
}
|
||||
}
|
||||
@ -89,20 +90,20 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
|
||||
cmd.PrintErrln("")
|
||||
}
|
||||
|
||||
provider := config.GetProvider()
|
||||
provider := conf.GetProvider()
|
||||
var instanceType string
|
||||
switch provider {
|
||||
case cloudprovider.AWS:
|
||||
instanceType = config.Provider.AWS.InstanceType
|
||||
instanceType = conf.Provider.AWS.InstanceType
|
||||
if len(flags.name) > 10 {
|
||||
return fmt.Errorf("cluster name on AWS must not be longer than 10 characters")
|
||||
}
|
||||
case cloudprovider.Azure:
|
||||
instanceType = config.Provider.Azure.InstanceType
|
||||
instanceType = conf.Provider.Azure.InstanceType
|
||||
case cloudprovider.GCP:
|
||||
instanceType = config.Provider.GCP.InstanceType
|
||||
instanceType = conf.Provider.GCP.InstanceType
|
||||
case cloudprovider.QEMU:
|
||||
cpus := config.Provider.QEMU.VCPUs
|
||||
cpus := conf.Provider.QEMU.VCPUs
|
||||
instanceType = fmt.Sprintf("%d-vCPU", cpus)
|
||||
}
|
||||
|
||||
@ -122,7 +123,7 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
|
||||
}
|
||||
|
||||
spinner.Start("Creating", false)
|
||||
idFile, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount)
|
||||
idFile, err := creator.Create(cmd.Context(), provider, conf, flags.name, instanceType, flags.controllerCount, flags.workerCount)
|
||||
spinner.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -78,9 +78,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
|
||||
conf, err := config.New(fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading and validating config: %w", err)
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
|
||||
var idFile clusterid.File
|
||||
@ -88,7 +88,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
return fmt.Errorf("reading cluster ID file: %w", err)
|
||||
}
|
||||
|
||||
k8sVersion, err := versions.NewValidK8sVersion(config.KubernetesVersion)
|
||||
k8sVersion, err := versions.NewValidK8sVersion(conf.KubernetesVersion)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validating kubernetes version: %w", err)
|
||||
}
|
||||
@ -96,18 +96,18 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion)
|
||||
}
|
||||
|
||||
provider := config.GetProvider()
|
||||
provider := conf.GetProvider()
|
||||
checker := license.NewChecker(quotaChecker, fileHandler)
|
||||
if err := checker.CheckLicense(cmd.Context(), provider, config.Provider, cmd.Printf); err != nil {
|
||||
if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil {
|
||||
cmd.PrintErrf("License check failed: %v", err)
|
||||
}
|
||||
|
||||
validator, err := cloudcmd.NewValidator(provider, config)
|
||||
validator, err := cloudcmd.NewValidator(provider, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serviceAccURI, err := getMarshaledServiceAccountURI(provider, config, fileHandler)
|
||||
serviceAccURI, err := getMarshaledServiceAccountURI(provider, conf, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -117,7 +117,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err)
|
||||
}
|
||||
helmLoader := helm.New(provider, k8sVersion)
|
||||
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, config), getEnforceIDKeyDigest(provider, config))
|
||||
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, conf), getEnforceIDKeyDigest(provider, conf))
|
||||
if err != nil {
|
||||
return fmt.Errorf("loading Helm charts: %w", err)
|
||||
}
|
||||
@ -131,10 +131,10 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
KeyEncryptionKeyId: "",
|
||||
UseExistingKek: false,
|
||||
CloudServiceAccountUri: serviceAccURI,
|
||||
KubernetesVersion: config.KubernetesVersion,
|
||||
KubernetesVersion: conf.KubernetesVersion,
|
||||
HelmDeployments: helmDeployments,
|
||||
EnforcedPcrs: getEnforcedPCRs(provider, config),
|
||||
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, config),
|
||||
EnforcedPcrs: getEnforcedPCRs(provider, conf),
|
||||
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, conf),
|
||||
ConformanceMode: flags.conformance,
|
||||
}
|
||||
resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req)
|
||||
|
@ -21,6 +21,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
@ -360,11 +361,11 @@ func TestAttestation(t *testing.T) {
|
||||
|
||||
issuer := &testIssuer{
|
||||
Getter: oid.QEMU{},
|
||||
pcrs: map[uint32][]byte{
|
||||
0: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
1: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
2: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
3: []byte("ffffffffffffffffffffffffffffffff"),
|
||||
pcrs: measurements.M{
|
||||
0: measurements.PCRWithAllBytes(0xFF),
|
||||
1: measurements.PCRWithAllBytes(0xFF),
|
||||
2: measurements.PCRWithAllBytes(0xFF),
|
||||
3: measurements.PCRWithAllBytes(0xFF),
|
||||
},
|
||||
}
|
||||
serverCreds := atlscredentials.New(issuer, nil)
|
||||
@ -389,13 +390,13 @@ func TestAttestation(t *testing.T) {
|
||||
cfg := config.Default()
|
||||
cfg.RemoveProviderExcept(cloudprovider.QEMU)
|
||||
cfg.Provider.QEMU.Image = "some/image/location"
|
||||
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000")
|
||||
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111")
|
||||
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222")
|
||||
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333")
|
||||
cfg.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444")
|
||||
cfg.Provider.QEMU.Measurements[8] = []byte("88888888888888888888888888888888")
|
||||
cfg.Provider.QEMU.Measurements[9] = []byte("99999999999999999999999999999999")
|
||||
cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00)
|
||||
cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11)
|
||||
cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22)
|
||||
cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33)
|
||||
cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
|
||||
cfg.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x88)
|
||||
cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99)
|
||||
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
|
||||
|
||||
ctx := context.Background()
|
||||
@ -411,13 +412,13 @@ func TestAttestation(t *testing.T) {
|
||||
|
||||
type testValidator struct {
|
||||
oid.Getter
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
}
|
||||
|
||||
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||||
var attestation struct {
|
||||
UserData []byte
|
||||
PCRs map[uint32][]byte
|
||||
PCRs measurements.M
|
||||
}
|
||||
if err := json.Unmarshal(attDoc, &attestation); err != nil {
|
||||
return nil, err
|
||||
@ -433,14 +434,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
|
||||
|
||||
type testIssuer struct {
|
||||
oid.Getter
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
}
|
||||
|
||||
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
|
||||
return json.Marshal(
|
||||
struct {
|
||||
UserData []byte
|
||||
PCRs map[uint32][]byte
|
||||
PCRs measurements.M
|
||||
}{
|
||||
UserData: userData,
|
||||
PCRs: i.pcrs,
|
||||
@ -472,23 +473,23 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
|
||||
conf.Provider.Azure.ResourceGroup = "test-resource-group"
|
||||
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
|
||||
conf.Provider.Azure.ClientSecretValue = "test-client-secret"
|
||||
conf.Provider.Azure.Measurements[4] = []byte("44444444444444444444444444444444")
|
||||
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
|
||||
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
|
||||
conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44)
|
||||
conf.Provider.Azure.Measurements[8] = measurements.PCRWithAllBytes(0x00)
|
||||
conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11)
|
||||
case cloudprovider.GCP:
|
||||
conf.Provider.GCP.Region = "test-region"
|
||||
conf.Provider.GCP.Project = "test-project"
|
||||
conf.Provider.GCP.Image = "some/image/location"
|
||||
conf.Provider.GCP.Zone = "test-zone"
|
||||
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
|
||||
conf.Provider.GCP.Measurements[4] = []byte("44444444444444444444444444444444")
|
||||
conf.Provider.GCP.Measurements[8] = []byte("00000000000000000000000000000000")
|
||||
conf.Provider.GCP.Measurements[9] = []byte("11111111111111111111111111111111")
|
||||
conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44)
|
||||
conf.Provider.GCP.Measurements[8] = measurements.PCRWithAllBytes(0x00)
|
||||
conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11)
|
||||
case cloudprovider.QEMU:
|
||||
conf.Provider.QEMU.Image = "some/image/location"
|
||||
conf.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444")
|
||||
conf.Provider.QEMU.Measurements[8] = []byte("00000000000000000000000000000000")
|
||||
conf.Provider.QEMU.Measurements[9] = []byte("11111111111111111111111111111111")
|
||||
conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
|
||||
conf.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x00)
|
||||
conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11)
|
||||
}
|
||||
|
||||
conf.RemoveProviderExcept(csp)
|
||||
|
@ -163,14 +163,14 @@ func prepareConfig(cmd *cobra.Command, fileHandler file.Handler) (*config.Config
|
||||
|
||||
// check for existing config
|
||||
if configPath != "" {
|
||||
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, configPath)
|
||||
conf, err := config.New(fileHandler, configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
if config.GetProvider() != cloudprovider.QEMU {
|
||||
if conf.GetProvider() != cloudprovider.QEMU {
|
||||
return nil, errors.New("invalid provider for MiniConstellation cluster")
|
||||
}
|
||||
return config, nil
|
||||
return conf, nil
|
||||
}
|
||||
if err := cmd.Flags().Set("config", constants.ConfigFilename); err != nil {
|
||||
return nil, err
|
||||
|
@ -1,46 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
)
|
||||
|
||||
func readConfig(errWriter io.Writer, fileHandler file.Handler, name string) (*config.Config, error) {
|
||||
cnf, err := config.FromFile(fileHandler, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateConfig(errWriter, cnf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return cnf, nil
|
||||
}
|
||||
|
||||
func validateConfig(errWriter io.Writer, cnf *config.Config) error {
|
||||
msgs, err := cnf.Validate()
|
||||
if err != nil {
|
||||
return fmt.Errorf("performing config validation: %w", err)
|
||||
}
|
||||
|
||||
if len(msgs) > 0 {
|
||||
fmt.Fprintln(errWriter, "Invalid fields in config file:")
|
||||
for _, m := range msgs {
|
||||
fmt.Fprintln(errWriter, "\t"+m)
|
||||
}
|
||||
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
|
||||
return errors.New("invalid configuration")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -1,126 +0,0 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
cnf *config.Config
|
||||
provider cloudprovider.Provider
|
||||
wantOutput bool
|
||||
wantErr bool
|
||||
}{
|
||||
"default config is not valid": {
|
||||
cnf: config.Default(),
|
||||
wantOutput: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"default Azure config is not valid": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
az := cnf.Provider.Azure
|
||||
cnf.Provider = config.ProviderConfig{}
|
||||
cnf.Provider.Azure = az
|
||||
return cnf
|
||||
}(),
|
||||
provider: cloudprovider.Azure,
|
||||
wantOutput: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"Azure config with all required fields is valid": {
|
||||
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure),
|
||||
provider: cloudprovider.Azure,
|
||||
},
|
||||
"default GCP config is not valid": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
gcp := cnf.Provider.GCP
|
||||
cnf.Provider = config.ProviderConfig{}
|
||||
cnf.Provider.GCP = gcp
|
||||
return cnf
|
||||
}(),
|
||||
provider: cloudprovider.GCP,
|
||||
wantOutput: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"GCP config with all required fields is valid": {
|
||||
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP),
|
||||
provider: cloudprovider.GCP,
|
||||
},
|
||||
"default QEMU config is not valid": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
qemu := cnf.Provider.QEMU
|
||||
cnf.Provider = config.ProviderConfig{}
|
||||
cnf.Provider.QEMU = qemu
|
||||
return cnf
|
||||
}(),
|
||||
provider: cloudprovider.QEMU,
|
||||
wantOutput: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"QEMU config with all required fields is valid": {
|
||||
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.QEMU),
|
||||
provider: cloudprovider.QEMU,
|
||||
},
|
||||
"config with an error": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
cnf.Version = "v0"
|
||||
return cnf
|
||||
}(),
|
||||
wantOutput: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"config without provider is not ok": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
cnf.Provider = config.ProviderConfig{}
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
"config without required provider": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
cnf.Provider.Azure = nil
|
||||
return cnf
|
||||
}(),
|
||||
provider: cloudprovider.Azure,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
out := &bytes.Buffer{}
|
||||
|
||||
err := validateConfig(out, tc.cnf)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
assert.Equal(tc.wantOutput, out.Len() > 0)
|
||||
})
|
||||
}
|
||||
}
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
@ -66,16 +67,16 @@ func recover(
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
|
||||
conf, err := config.New(fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading and validating config: %w", err)
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
provider := config.GetProvider()
|
||||
provider := conf.GetProvider()
|
||||
if provider == cloudprovider.Azure {
|
||||
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
|
||||
}
|
||||
|
||||
validator, err := cloudcmd.NewValidator(provider, config)
|
||||
validator, err := cloudcmd.NewValidator(provider, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
@ -43,17 +44,17 @@ func upgradeExecute(cmd *cobra.Command, upgrader cloudUpgrader, fileHandler file
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := config.FromFile(fileHandler, configPath)
|
||||
conf, err := config.New(fileHandler, configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
|
||||
// TODO: validate upgrade config? Should be basic things like checking image is not an empty string
|
||||
// More sophisticated validation, like making sure we don't downgrade the cluster, should be done by `constellation upgrade plan`
|
||||
|
||||
return upgrader.Upgrade(cmd.Context(), config.Upgrade.Image, config.Upgrade.Measurements)
|
||||
return upgrader.Upgrade(cmd.Context(), conf.Upgrade.Image, conf.Upgrade.Measurements)
|
||||
}
|
||||
|
||||
type cloudUpgrader interface {
|
||||
Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error
|
||||
Upgrade(ctx context.Context, image string, measurements measurements.M) error
|
||||
}
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
@ -41,7 +43,8 @@ func TestUpgradeExecute(t *testing.T) {
|
||||
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
|
||||
|
||||
handler := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(handler.WriteYAML(constants.ConfigFilename, config.Default()))
|
||||
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
|
||||
require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg))
|
||||
|
||||
err := upgradeExecute(cmd, tc.upgrader, handler)
|
||||
if tc.wantErr {
|
||||
@ -57,6 +60,6 @@ type stubUpgrader struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (u stubUpgrader) Upgrade(context.Context, string, map[uint32][]byte) error {
|
||||
func (u stubUpgrader) Upgrade(context.Context, string, measurements.M) error {
|
||||
return u.err
|
||||
}
|
||||
|
@ -73,13 +73,13 @@ func runUpgradePlan(cmd *cobra.Command, args []string) error {
|
||||
func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
|
||||
fileHandler file.Handler, client *http.Client, rekor rekorVerifier, flags upgradePlanFlags,
|
||||
) error {
|
||||
config, err := config.FromFile(fileHandler, flags.configPath)
|
||||
conf, err := config.New(fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
|
||||
// get current image version of the cluster
|
||||
csp := config.GetProvider()
|
||||
csp := conf.GetProvider()
|
||||
|
||||
version, err := getCurrentImageVersion(cmd.Context(), planner, csp)
|
||||
if err != nil {
|
||||
@ -108,7 +108,7 @@ func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
|
||||
return upgradePlanInteractive(
|
||||
&nopWriteCloser{cmd.OutOrStdout()},
|
||||
io.NopCloser(cmd.InOrStdin()),
|
||||
flags.configPath, config, fileHandler,
|
||||
flags.configPath, conf, fileHandler,
|
||||
compatibleImages,
|
||||
)
|
||||
}
|
||||
|
@ -451,8 +451,8 @@ func TestUpgradePlan(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
cfg := config.Default()
|
||||
cfg.RemoveProviderExcept(tc.csp)
|
||||
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.csp)
|
||||
|
||||
require.NoError(fileHandler.WriteYAML(tc.flags.configPath, cfg))
|
||||
|
||||
cmd := newUpgradePlanCmd()
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
|
||||
"github.com/edgelesssys/constellation/v2/internal/atls"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
@ -59,13 +60,13 @@ func verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyCli
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
|
||||
conf, err := config.New(fileHandler, flags.configPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading and validating config: %w", err)
|
||||
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
|
||||
}
|
||||
|
||||
provider := config.GetProvider()
|
||||
validators, err := cloudcmd.NewValidator(provider, config)
|
||||
provider := conf.GetProvider()
|
||||
validators, err := cloudcmd.NewValidator(provider, conf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -67,6 +67,12 @@ If you don't have a cloud subscription, check out [MiniConstellation](first-step
|
||||
|
||||
Fill the values produced by the script into your configuration file.
|
||||
|
||||
:::tip
|
||||
|
||||
Alternatively, you can leave `clientSecretValue` empty and provide the secret via the `CONSTELL_AZURE_CLIENT_SECRET_VALUE` environment variable.
|
||||
|
||||
:::
|
||||
|
||||
By default, Constellation uses `Standard_DC4as_v5` CVMs (4 vCPUs, 16 GB RAM) to create your cluster. Optionally, you can switch to a different VM type by modifying **instanceType** in the configuration file. For CVMs, any VM type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported.
|
||||
|
||||
Run `constellation config instance-types` to get the list of all supported options.
|
||||
@ -112,6 +118,12 @@ If you don't have a cloud subscription, check out [MiniConstellation](first-step
|
||||
|
||||
Set the configuration value to the secret value.
|
||||
|
||||
:::tip
|
||||
|
||||
Alternatively, you can leave `clientSecretValue` empty and provide the secret via the `CONSTELL_AZURE_CLIENT_SECRET_VALUE` environment variable.
|
||||
|
||||
:::
|
||||
|
||||
* **instanceType**: The VM type you want to use for your Constellation nodes.
|
||||
|
||||
For CVMs, any type with a minimum of 4 vCPUs from the [DCasv5 & DCadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/dcasv5-dcadsv5-series) or [ECasv5 & ECadsv5](https://docs.microsoft.com/en-us/azure/virtual-machines/ecasv5-ecadsv5-series) families is supported. It defaults to `Standard_DC4as_v5` (4 vCPUs, 16 GB RAM).
|
||||
|
2
go.mod
2
go.mod
@ -66,7 +66,6 @@ require (
|
||||
github.com/google/tink/go v1.7.0
|
||||
github.com/googleapis/gax-go/v2 v2.7.0
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/hashicorp/hc-install v0.4.0
|
||||
github.com/hashicorp/terraform-exec v0.17.3
|
||||
@ -117,6 +116,7 @@ require (
|
||||
github.com/golang-jwt/jwt/v4 v4.4.2 // indirect
|
||||
github.com/google/logger v1.1.1 // indirect
|
||||
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.1 // indirect
|
||||
golang.org/x/text v0.4.0 // indirect
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/hack/image-measurement/server"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
@ -282,7 +282,7 @@ func (l *libvirtInstance) deleteLibvirtInstance() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *libvirtInstance) obtainMeasurements() (measurements config.Measurements, err error) {
|
||||
func (l *libvirtInstance) obtainMeasurements() (measurements measurements.M, err error) {
|
||||
// sanity check
|
||||
if err := l.deleteLibvirtInstance(); err != nil {
|
||||
return nil, err
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -20,7 +21,7 @@ import (
|
||||
type Server struct {
|
||||
log *logger.Logger
|
||||
server http.Server
|
||||
measurements map[uint32][]byte
|
||||
measurements measurements.M
|
||||
done chan<- struct{}
|
||||
}
|
||||
|
||||
@ -72,7 +73,7 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// unmarshal the request body into a map of PCRs
|
||||
var pcrs map[uint32][]byte
|
||||
var pcrs measurements.M
|
||||
if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to read request body")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
@ -89,6 +90,6 @@ func (s *Server) logPCRs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// GetMeasurements returns the static measurements for QEMU environment.
|
||||
func (s *Server) GetMeasurements() map[uint32][]byte {
|
||||
func (s *Server) GetMeasurements() measurements.M {
|
||||
return s.measurements
|
||||
}
|
||||
|
@ -8,7 +8,6 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
@ -20,6 +19,7 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
@ -68,23 +68,6 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
// Measurements contains all PCR values.
|
||||
type Measurements map[uint32][]byte
|
||||
|
||||
var _ yaml.Marshaler = Measurements{}
|
||||
|
||||
// MarshalYAML forces that measurements are written as base64. Default would
|
||||
// be to print list of bytes.
|
||||
func (m Measurements) MarshalYAML() (any, error) {
|
||||
base64Map := make(map[uint32]string)
|
||||
|
||||
for key, value := range m {
|
||||
base64Map[key] = base64.StdEncoding.EncodeToString(value[:])
|
||||
}
|
||||
|
||||
return base64Map, nil
|
||||
}
|
||||
|
||||
// getAttestation connects to the Constellation verification service and returns its attestation document.
|
||||
func getAttestation(ctx context.Context, addr string) ([]byte, error) {
|
||||
conn, err := grpc.DialContext(
|
||||
@ -109,7 +92,7 @@ func getAttestation(ctx context.Context, addr string) ([]byte, error) {
|
||||
}
|
||||
|
||||
// validatePCRAttDoc parses and validates PCRs of an attestation document.
|
||||
func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) {
|
||||
func validatePCRAttDoc(attDocRaw []byte) (measurements.M, error) {
|
||||
attDoc := vtpm.AttestationDocument{}
|
||||
if err := json.Unmarshal(attDocRaw, &attDoc); err != nil {
|
||||
return nil, err
|
||||
@ -131,7 +114,7 @@ func validatePCRAttDoc(attDocRaw []byte) (map[uint32][]byte, error) {
|
||||
|
||||
// printPCRs formates and prints PCRs to the given writer.
|
||||
// format can be one of 'json' or 'yaml'. If it doesnt match defaults to 'json'.
|
||||
func printPCRs(w io.Writer, pcrs map[uint32][]byte, format string) error {
|
||||
func printPCRs(w io.Writer, pcrs measurements.M, format string) error {
|
||||
switch format {
|
||||
case "json":
|
||||
return printPCRsJSON(w, pcrs)
|
||||
@ -142,7 +125,7 @@ func printPCRs(w io.Writer, pcrs map[uint32][]byte, format string) error {
|
||||
}
|
||||
}
|
||||
|
||||
func printPCRsYAML(w io.Writer, pcrs Measurements) error {
|
||||
func printPCRsYAML(w io.Writer, pcrs measurements.M) error {
|
||||
pcrYAML, err := yaml.Marshal(pcrs)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -151,7 +134,7 @@ func printPCRsYAML(w io.Writer, pcrs Measurements) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func printPCRsJSON(w io.Writer, pcrs map[uint32][]byte) error {
|
||||
func printPCRsJSON(w io.Writer, pcrs measurements.M) error {
|
||||
pcrJSON, err := json.MarshalIndent(pcrs, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
@ -162,7 +145,7 @@ func printPCRsJSON(w io.Writer, pcrs map[uint32][]byte) error {
|
||||
|
||||
// exportToFile writes pcrs to a file, formatted to be valid Go code.
|
||||
// Validity of the PCR map is not checked, and should be handled by the caller.
|
||||
func exportToFile(path string, pcrs map[uint32][]byte, fs *afero.Afero) error {
|
||||
func exportToFile(path string, pcrs measurements.M, fs *afero.Afero) error {
|
||||
goCode := `package pcrs
|
||||
|
||||
var pcrs = map[uint32][]byte{%s
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
"github.com/google/go-tpm-tools/proto/tpm"
|
||||
@ -31,12 +32,12 @@ func TestMain(m *testing.M) {
|
||||
|
||||
func TestExportToFile(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
fs *afero.Afero
|
||||
wantErr bool
|
||||
}{
|
||||
"file not writeable": {
|
||||
pcrs: map[uint32][]byte{
|
||||
pcrs: measurements.M{
|
||||
0: {0x1, 0x2, 0x3},
|
||||
1: {0x1, 0x2, 0x3},
|
||||
2: {0x1, 0x2, 0x3},
|
||||
@ -45,7 +46,7 @@ func TestExportToFile(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
"file writeable": {
|
||||
pcrs: map[uint32][]byte{
|
||||
pcrs: measurements.M{
|
||||
0: {0x1, 0x2, 0x3},
|
||||
1: {0x1, 0x2, 0x3},
|
||||
2: {0x1, 0x2, 0x3},
|
||||
@ -105,7 +106,7 @@ func TestValidatePCRAttDoc(t *testing.T) {
|
||||
{
|
||||
Pcrs: &tpm.PCRs{
|
||||
Hash: tpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
Pcrs: measurements.M{
|
||||
0: {0x1, 0x2, 0x3},
|
||||
},
|
||||
},
|
||||
@ -122,8 +123,8 @@ func TestValidatePCRAttDoc(t *testing.T) {
|
||||
{
|
||||
Pcrs: &tpm.PCRs{
|
||||
Hash: tpm.HashAlgo_SHA256,
|
||||
Pcrs: map[uint32][]byte{
|
||||
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
|
||||
Pcrs: measurements.M{
|
||||
0: measurements.PCRWithAllBytes(0xAA),
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -163,11 +164,11 @@ func mustMarshalAttDoc(t *testing.T, attDoc vtpm.AttestationDocument) []byte {
|
||||
|
||||
func TestPrintPCRs(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
pcrs map[uint32][]byte
|
||||
pcrs measurements.M
|
||||
format string
|
||||
}{
|
||||
"json": {
|
||||
pcrs: map[uint32][]byte{
|
||||
pcrs: measurements.M{
|
||||
0: {0x1, 0x2, 0x3},
|
||||
1: {0x1, 0x2, 0x3},
|
||||
2: {0x1, 0x2, 0x3},
|
||||
@ -175,7 +176,7 @@ func TestPrintPCRs(t *testing.T) {
|
||||
format: "json",
|
||||
},
|
||||
"empty format": {
|
||||
pcrs: map[uint32][]byte{
|
||||
pcrs: measurements.M{
|
||||
0: {0x1, 0x2, 0x3},
|
||||
1: {0x1, 0x2, 0x3},
|
||||
2: {0x1, 0x2, 0x3},
|
||||
@ -183,7 +184,7 @@ func TestPrintPCRs(t *testing.T) {
|
||||
format: "",
|
||||
},
|
||||
"yaml": {
|
||||
pcrs: map[uint32][]byte{
|
||||
pcrs: measurements.M{
|
||||
0: {0x1, 0x2, 0x3},
|
||||
1: {0x1, 0x2, 0x3},
|
||||
2: {0x1, 0x2, 0x3},
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/edgelesssys/constellation/v2/internal/role"
|
||||
@ -197,7 +198,7 @@ func (s *Server) exportPCRs(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// unmarshal the request body into a map of PCRs
|
||||
var pcrs map[uint32][]byte
|
||||
var pcrs measurements.M
|
||||
if err := json.NewDecoder(r.Body).Decode(&pcrs); err != nil {
|
||||
log.With(zap.Error(err)).Errorf("Failed to read request body")
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/hack/qemu-metadata-api/virtwrapper"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -306,12 +307,12 @@ func TestExportPCRs(t *testing.T) {
|
||||
remoteAddr: "192.0.100.1:1234",
|
||||
connect: defaultConnect,
|
||||
method: http.MethodPost,
|
||||
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
},
|
||||
"incorrect method": {
|
||||
remoteAddr: "192.0.100.1:1234",
|
||||
connect: defaultConnect,
|
||||
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
method: http.MethodGet,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -320,7 +321,7 @@ func TestExportPCRs(t *testing.T) {
|
||||
connect: &stubConnect{
|
||||
getNetworkErr: errors.New("error"),
|
||||
},
|
||||
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
method: http.MethodPost,
|
||||
wantErr: true,
|
||||
},
|
||||
@ -335,7 +336,7 @@ func TestExportPCRs(t *testing.T) {
|
||||
remoteAddr: "localhost",
|
||||
connect: defaultConnect,
|
||||
method: http.MethodPost,
|
||||
message: mustMarshal(t, map[uint32][]byte{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
message: mustMarshal(t, measurements.M{0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")}),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
@ -12,9 +12,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/config"
|
||||
awsConfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ec2"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
"github.com/google/go-tpm/tpm2"
|
||||
@ -28,7 +29,7 @@ type Validator struct {
|
||||
}
|
||||
|
||||
// NewValidator create a new Validator structure and returns it.
|
||||
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
v := &Validator{}
|
||||
v.Validator = vtpm.NewValidator(
|
||||
pcrs,
|
||||
@ -88,7 +89,7 @@ func (v *Validator) tpmEnabled(attestation vtpm.AttestationDocument) error {
|
||||
}
|
||||
|
||||
func getEC2Client(ctx context.Context, region string) (awsMetadataAPI, error) {
|
||||
client, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
|
||||
client, err := awsConfig.LoadDefaultConfig(ctx, awsConfig.WithRegion(region))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
internalCrypto "github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
@ -41,7 +42,7 @@ type Validator struct {
|
||||
}
|
||||
|
||||
// NewValidator initializes a new Azure validator with the provided PCR values.
|
||||
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
|
||||
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, idKeyDigest []byte, enforceIDKeyDigest bool, log vtpm.AttestationLogger) *Validator {
|
||||
return &Validator{
|
||||
Validator: vtpm.NewValidator(
|
||||
pcrs,
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
|
||||
"github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
tpmclient "github.com/google/go-tpm-tools/client"
|
||||
@ -188,7 +189,7 @@ func TestGetAttestationCert(t *testing.T) {
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
validator := NewValidator(map[uint32][]byte{}, []uint32{}, nil)
|
||||
validator := NewValidator(measurements.M{}, []uint32{}, nil)
|
||||
cert, err := x509.ParseCertificate(rootCert.Raw)
|
||||
require.NoError(err)
|
||||
roots := x509.NewCertPool()
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
certutil "github.com/edgelesssys/constellation/v2/internal/crypto"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
@ -32,7 +33,7 @@ type Validator struct {
|
||||
}
|
||||
|
||||
// NewValidator initializes a new Azure validator with the provided PCR values.
|
||||
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
rootPool := x509.NewCertPool()
|
||||
rootPool.AddCert(ameRoot)
|
||||
v := &Validator{roots: rootPool}
|
||||
|
@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
compute "cloud.google.com/go/compute/apiv1"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
"github.com/google/go-tpm-tools/proto/attest"
|
||||
@ -34,7 +35,7 @@ type Validator struct {
|
||||
}
|
||||
|
||||
// NewValidator initializes a new GCP validator with the provided PCR values.
|
||||
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
return &Validator{
|
||||
Validator: vtpm.NewValidator(
|
||||
pcrs,
|
||||
|
@ -4,7 +4,7 @@ Copyright (c) Edgeless Systems GmbH
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package config
|
||||
package measurements
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@ -18,52 +18,60 @@ import (
|
||||
"net/url"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/sigstore"
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
// Measurements are Platform Configuration Register (PCR) values.
|
||||
type Measurements map[uint32][]byte
|
||||
// M are Platform Configuration Register (PCR) values that make up the Measurements.
|
||||
type M map[uint32][]byte
|
||||
|
||||
var (
|
||||
zero = []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}
|
||||
// gcpPCRs are the PCR values for a GCP Constellation node that are initially set in a generated config file.
|
||||
gcpPCRs = Measurements{
|
||||
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
|
||||
11: zero,
|
||||
12: zero,
|
||||
13: zero,
|
||||
uint32(vtpm.PCRIndexClusterID): zero,
|
||||
}
|
||||
// PCRWithAllBytes returns a PCR value where all 32 bytes are set to b.
|
||||
func PCRWithAllBytes(b byte) []byte {
|
||||
return bytes.Repeat([]byte{b}, 32)
|
||||
}
|
||||
|
||||
// azurePCRs are the PCR values for an Azure Constellation node that are initially set in a generated config file.
|
||||
azurePCRs = Measurements{
|
||||
11: zero,
|
||||
12: zero,
|
||||
13: zero,
|
||||
uint32(vtpm.PCRIndexClusterID): zero,
|
||||
// DefaultsFor provides the default measurements for given cloud provider.
|
||||
func DefaultsFor(provider cloudprovider.Provider) M {
|
||||
switch provider {
|
||||
case cloudprovider.AWS:
|
||||
return M{
|
||||
11: PCRWithAllBytes(0x00),
|
||||
12: PCRWithAllBytes(0x00),
|
||||
13: PCRWithAllBytes(0x00),
|
||||
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
|
||||
}
|
||||
case cloudprovider.Azure:
|
||||
return M{
|
||||
11: PCRWithAllBytes(0x00),
|
||||
12: PCRWithAllBytes(0x00),
|
||||
13: PCRWithAllBytes(0x00),
|
||||
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
|
||||
}
|
||||
case cloudprovider.GCP:
|
||||
return M{
|
||||
0: {0x0F, 0x35, 0xC2, 0x14, 0x60, 0x8D, 0x93, 0xC7, 0xA6, 0xE6, 0x8A, 0xE7, 0x35, 0x9B, 0x4A, 0x8B, 0xE5, 0xA0, 0xE9, 0x9E, 0xEA, 0x91, 0x07, 0xEC, 0xE4, 0x27, 0xC4, 0xDE, 0xA4, 0xE4, 0x39, 0xCF},
|
||||
11: PCRWithAllBytes(0x00),
|
||||
12: PCRWithAllBytes(0x00),
|
||||
13: PCRWithAllBytes(0x00),
|
||||
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
|
||||
}
|
||||
case cloudprovider.QEMU:
|
||||
return M{
|
||||
11: PCRWithAllBytes(0x00),
|
||||
12: PCRWithAllBytes(0x00),
|
||||
13: PCRWithAllBytes(0x00),
|
||||
uint32(vtpm.PCRIndexClusterID): PCRWithAllBytes(0x00),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
// awsPCRs are the PCR values for an AWS Nitro Constellation node that are initially set in a generated config file.
|
||||
awsPCRs = Measurements{
|
||||
11: zero,
|
||||
12: zero,
|
||||
13: zero,
|
||||
uint32(vtpm.PCRIndexClusterID): zero,
|
||||
}
|
||||
|
||||
qemuPCRs = Measurements{
|
||||
11: zero,
|
||||
12: zero,
|
||||
13: zero,
|
||||
uint32(vtpm.PCRIndexClusterID): zero,
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// FetchAndVerify fetches measurement and signature files via provided URLs,
|
||||
// using client for download. The publicKey is used to verify the measurements.
|
||||
// The hash of the fetched measurements is returned.
|
||||
func (m *Measurements) FetchAndVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL, signatureURL *url.URL, publicKey []byte) (string, error) {
|
||||
func (m *M) FetchAndVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL, signatureURL *url.URL, publicKey []byte) (string, error) {
|
||||
measurements, err := getFromURL(ctx, client, measurementsURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch measurements: %w", err)
|
||||
@ -86,15 +94,29 @@ func (m *Measurements) FetchAndVerify(ctx context.Context, client *http.Client,
|
||||
|
||||
// CopyFrom copies over all values from other. Overwriting existing values,
|
||||
// but keeping not specified values untouched.
|
||||
func (m Measurements) CopyFrom(other Measurements) {
|
||||
func (m M) CopyFrom(other M) {
|
||||
for idx := range other {
|
||||
m[idx] = other[idx]
|
||||
}
|
||||
}
|
||||
|
||||
// EqualTo tests whether the provided other Measurements are equal to these
|
||||
// measurements.
|
||||
func (m M) EqualTo(other M) bool {
|
||||
if len(m) != len(other) {
|
||||
return false
|
||||
}
|
||||
for k, v := range m {
|
||||
if !bytes.Equal(v, other[k]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// MarshalYAML overwrites the default behaviour of writing out []byte not as
|
||||
// single bytes, but as a single base64 encoded string.
|
||||
func (m Measurements) MarshalYAML() (any, error) {
|
||||
func (m M) MarshalYAML() (any, error) {
|
||||
base64Map := make(map[uint32]string)
|
||||
|
||||
for key, value := range m {
|
||||
@ -106,14 +128,14 @@ func (m Measurements) MarshalYAML() (any, error) {
|
||||
|
||||
// UnmarshalYAML overwrites the default behaviour of reading []byte not as
|
||||
// single bytes, but as a single base64 encoded string.
|
||||
func (m *Measurements) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
func (m *M) UnmarshalYAML(unmarshal func(any) error) error {
|
||||
base64Map := make(map[uint32]string)
|
||||
err := unmarshal(base64Map)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*m = make(Measurements)
|
||||
*m = make(M)
|
||||
for key, value := range base64Map {
|
||||
measurement, err := base64.StdEncoding.DecodeString(value)
|
||||
if err != nil {
|
@ -4,7 +4,7 @@ Copyright (c) Edgeless Systems GmbH
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package config
|
||||
package measurements
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -21,11 +21,11 @@ import (
|
||||
|
||||
func TestMarshalYAML(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
measurements Measurements
|
||||
measurements M
|
||||
wantBase64Map map[uint32]string
|
||||
}{
|
||||
"valid measurements": {
|
||||
measurements: Measurements{
|
||||
measurements: M{
|
||||
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
|
||||
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
|
||||
},
|
||||
@ -35,7 +35,7 @@ func TestMarshalYAML(t *testing.T) {
|
||||
},
|
||||
},
|
||||
"omit bytes": {
|
||||
measurements: Measurements{
|
||||
measurements: M{
|
||||
2: []byte{},
|
||||
3: []byte{1, 2, 3, 4},
|
||||
},
|
||||
@ -63,7 +63,7 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
inputBase64Map map[uint32]string
|
||||
forceUnmarshalError bool
|
||||
wantMeasurements Measurements
|
||||
wantMeasurements M
|
||||
wantErr bool
|
||||
}{
|
||||
"valid measurements": {
|
||||
@ -71,7 +71,7 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
2: "/V3p3zUOO8RBCsBrv+XM3rk/U7nvUSOfdSzmnbxgDzU=",
|
||||
3: "1aRJbSHeyaUljdsZxv61O7TTwEY/5gfySI3fTxAG754=",
|
||||
},
|
||||
wantMeasurements: Measurements{
|
||||
wantMeasurements: M{
|
||||
2: []byte{253, 93, 233, 223, 53, 14, 59, 196, 65, 10, 192, 107, 191, 229, 204, 222, 185, 63, 83, 185, 239, 81, 35, 159, 117, 44, 230, 157, 188, 96, 15, 53},
|
||||
3: []byte{213, 164, 73, 109, 33, 222, 201, 165, 37, 141, 219, 25, 198, 254, 181, 59, 180, 211, 192, 70, 63, 230, 7, 242, 72, 141, 223, 79, 16, 6, 239, 158},
|
||||
},
|
||||
@ -81,7 +81,7 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
2: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
|
||||
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
|
||||
},
|
||||
wantMeasurements: Measurements{
|
||||
wantMeasurements: M{
|
||||
2: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
3: []byte{1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
},
|
||||
@ -91,7 +91,7 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
2: "This is not base64",
|
||||
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
|
||||
},
|
||||
wantMeasurements: Measurements{
|
||||
wantMeasurements: M{
|
||||
2: []byte{},
|
||||
3: []byte{1, 2, 3, 4},
|
||||
},
|
||||
@ -103,7 +103,7 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
3: "AQIDBAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=",
|
||||
},
|
||||
forceUnmarshalError: true,
|
||||
wantMeasurements: Measurements{
|
||||
wantMeasurements: M{
|
||||
2: []byte{},
|
||||
3: []byte{1, 2, 3, 4},
|
||||
},
|
||||
@ -116,7 +116,7 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
var m Measurements
|
||||
var m M
|
||||
err := m.UnmarshalYAML(func(i any) error {
|
||||
if base64Map, ok := i.(map[uint32]string); ok {
|
||||
for key, value := range tc.inputBase64Map {
|
||||
@ -141,55 +141,55 @@ func TestUnmarshalYAML(t *testing.T) {
|
||||
|
||||
func TestMeasurementsCopyFrom(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
current Measurements
|
||||
newMeasurements Measurements
|
||||
wantMeasurements Measurements
|
||||
current M
|
||||
newMeasurements M
|
||||
wantMeasurements M
|
||||
}{
|
||||
"add to empty": {
|
||||
current: Measurements{},
|
||||
newMeasurements: Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
current: M{},
|
||||
newMeasurements: M{
|
||||
1: PCRWithAllBytes(0x00),
|
||||
2: PCRWithAllBytes(0x01),
|
||||
3: PCRWithAllBytes(0x02),
|
||||
},
|
||||
wantMeasurements: Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
wantMeasurements: M{
|
||||
1: PCRWithAllBytes(0x00),
|
||||
2: PCRWithAllBytes(0x01),
|
||||
3: PCRWithAllBytes(0x02),
|
||||
},
|
||||
},
|
||||
"keep existing": {
|
||||
current: Measurements{
|
||||
4: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
5: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
current: M{
|
||||
4: PCRWithAllBytes(0x01),
|
||||
5: PCRWithAllBytes(0x02),
|
||||
},
|
||||
newMeasurements: Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
newMeasurements: M{
|
||||
1: PCRWithAllBytes(0x00),
|
||||
2: PCRWithAllBytes(0x01),
|
||||
3: PCRWithAllBytes(0x02),
|
||||
},
|
||||
wantMeasurements: Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
4: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
5: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
wantMeasurements: M{
|
||||
1: PCRWithAllBytes(0x00),
|
||||
2: PCRWithAllBytes(0x01),
|
||||
3: PCRWithAllBytes(0x02),
|
||||
4: PCRWithAllBytes(0x01),
|
||||
5: PCRWithAllBytes(0x02),
|
||||
},
|
||||
},
|
||||
"overwrite existing": {
|
||||
current: Measurements{
|
||||
2: []byte{4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4},
|
||||
3: []byte{5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5},
|
||||
current: M{
|
||||
2: PCRWithAllBytes(0x04),
|
||||
3: PCRWithAllBytes(0x05),
|
||||
},
|
||||
newMeasurements: Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
newMeasurements: M{
|
||||
1: PCRWithAllBytes(0x00),
|
||||
2: PCRWithAllBytes(0x01),
|
||||
3: PCRWithAllBytes(0x02),
|
||||
},
|
||||
wantMeasurements: Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
wantMeasurements: M{
|
||||
1: PCRWithAllBytes(0x00),
|
||||
2: PCRWithAllBytes(0x01),
|
||||
3: PCRWithAllBytes(0x02),
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -230,7 +230,7 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
|
||||
signature string
|
||||
signatureStatus int
|
||||
publicKey []byte
|
||||
wantMeasurements Measurements
|
||||
wantMeasurements M
|
||||
wantSHA string
|
||||
wantError bool
|
||||
}{
|
||||
@ -240,8 +240,8 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
|
||||
signature: "MEUCIBs1g2/n0FsgPfJ+0uLD5TaunGhxwDcQcUGBroejKvg3AiEAzZtcLU9O6IiVhxB8tBS+ty6MXoPNwL8WRWMzyr35eKI=",
|
||||
signatureStatus: http.StatusOK,
|
||||
publicKey: []byte("-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEUs5fDUIz9aiwrfr8BK4VjN7jE6sl\ngz7UuXsOin8+dB0SGrbNHy7TJToa2fAiIKPVLTOfvY75DqRAtffhO1fpBA==\n-----END PUBLIC KEY-----"),
|
||||
wantMeasurements: Measurements{
|
||||
0: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
wantMeasurements: M{
|
||||
0: PCRWithAllBytes(0x00),
|
||||
},
|
||||
wantSHA: "4cd9d6ed8d9322150dff7738994c5e2fabff35f3bae6f5c993412d13249a5e87",
|
||||
},
|
||||
@ -308,7 +308,7 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
m := Measurements{}
|
||||
m := M{}
|
||||
hash, err := m.FetchAndVerify(context.Background(), client, measurementsURL, signatureURL, tc.publicKey)
|
||||
|
||||
if tc.wantError {
|
||||
@ -321,3 +321,83 @@ func TestMeasurementsFetchAndVerify(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPCRWithAllBytes(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
b byte
|
||||
wantPCR []byte
|
||||
}{
|
||||
"0x00": {
|
||||
b: 0x00,
|
||||
wantPCR: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
},
|
||||
"0x01": {
|
||||
b: 0x01,
|
||||
wantPCR: []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01},
|
||||
},
|
||||
"0xFF": {
|
||||
b: 0xFF,
|
||||
wantPCR: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
pcr := PCRWithAllBytes(tc.b)
|
||||
assert.Equal(tc.wantPCR, pcr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEqualTo(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
given M
|
||||
other M
|
||||
wantEqual bool
|
||||
}{
|
||||
"same values": {
|
||||
given: M{
|
||||
0: PCRWithAllBytes(0x00),
|
||||
1: PCRWithAllBytes(0xFF),
|
||||
},
|
||||
other: M{
|
||||
0: PCRWithAllBytes(0x00),
|
||||
1: PCRWithAllBytes(0xFF),
|
||||
},
|
||||
wantEqual: true,
|
||||
},
|
||||
"different number of elements": {
|
||||
given: M{
|
||||
0: PCRWithAllBytes(0x00),
|
||||
1: PCRWithAllBytes(0xFF),
|
||||
},
|
||||
other: M{
|
||||
0: PCRWithAllBytes(0x00),
|
||||
},
|
||||
wantEqual: false,
|
||||
},
|
||||
"different values": {
|
||||
given: M{
|
||||
0: PCRWithAllBytes(0x00),
|
||||
1: PCRWithAllBytes(0xFF),
|
||||
},
|
||||
other: M{
|
||||
0: PCRWithAllBytes(0xFF),
|
||||
1: PCRWithAllBytes(0x00),
|
||||
},
|
||||
wantEqual: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
if tc.wantEqual {
|
||||
assert.True(tc.given.EqualTo(tc.other))
|
||||
} else {
|
||||
assert.False(tc.given.EqualTo(tc.other))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ package qemu
|
||||
import (
|
||||
"crypto"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/v2/internal/oid"
|
||||
"github.com/google/go-tpm/tpm2"
|
||||
@ -21,7 +22,7 @@ type Validator struct {
|
||||
}
|
||||
|
||||
// NewValidator initializes a new QEMU validator with the provided PCR values.
|
||||
func NewValidator(pcrs map[uint32][]byte, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
func NewValidator(pcrs measurements.M, enforcedPCRs []uint32, log vtpm.AttestationLogger) *Validator {
|
||||
return &Validator{
|
||||
Validator: vtpm.NewValidator(
|
||||
pcrs,
|
||||
|
@ -14,19 +14,25 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/file"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
"github.com/go-playground/locales/en"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
en_translations "github.com/go-playground/validator/v10/translations/en"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
// Measurements is a required alias since docgen is not able to work with
|
||||
// types in other packages.
|
||||
type Measurements = measurements.M
|
||||
|
||||
const (
|
||||
// Version1 is the first version number for Constellation config file.
|
||||
Version1 = "v1"
|
||||
@ -164,7 +170,7 @@ type AzureConfig struct {
|
||||
// Application client ID of the Active Directory app registration.
|
||||
AppClientID string `yaml:"appClientID" validate:"uuid"`
|
||||
// description: |
|
||||
// Client secret value of the Active Directory app registration credentials.
|
||||
// Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable.
|
||||
ClientSecretValue string `yaml:"clientSecretValue" validate:"required"`
|
||||
// description: |
|
||||
// Machine image used to create Constellation nodes.
|
||||
@ -277,7 +283,7 @@ func Default() *Config {
|
||||
StateDiskType: "gp3",
|
||||
IAMProfileControlPlane: "",
|
||||
IAMProfileWorkerNodes: "",
|
||||
Measurements: copyPCRMap(awsPCRs),
|
||||
Measurements: measurements.DefaultsFor(cloudprovider.AWS),
|
||||
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
|
||||
},
|
||||
Azure: &AzureConfig{
|
||||
@ -289,7 +295,7 @@ func Default() *Config {
|
||||
Image: DefaultImageAzure,
|
||||
InstanceType: "Standard_DC4as_v5",
|
||||
StateDiskType: "Premium_LRS",
|
||||
Measurements: copyPCRMap(azurePCRs),
|
||||
Measurements: measurements.DefaultsFor(cloudprovider.Azure),
|
||||
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
|
||||
IDKeyDigest: "57486a447ec0f1958002a22a06b7673b9fd27d11e1c6527498056054c5fa92d23c50f9de44072760fe2b6fb89740b696",
|
||||
EnforceIDKeyDigest: func() *bool { b := true; return &b }(),
|
||||
@ -304,7 +310,7 @@ func Default() *Config {
|
||||
InstanceType: "n2d-standard-4",
|
||||
StateDiskType: "pd-ssd",
|
||||
ServiceAccountKeyPath: "",
|
||||
Measurements: copyPCRMap(gcpPCRs),
|
||||
Measurements: measurements.DefaultsFor(cloudprovider.GCP),
|
||||
EnforcedMeasurements: []uint32{0, 4, 8, 9, 11, 12, 13, 15},
|
||||
},
|
||||
QEMU: &QEMUConfig{
|
||||
@ -314,7 +320,7 @@ func Default() *Config {
|
||||
MetadataAPIImage: versions.QEMUMetadataImage,
|
||||
LibvirtURI: "",
|
||||
LibvirtContainerImage: versions.LibvirtImage,
|
||||
Measurements: copyPCRMap(qemuPCRs),
|
||||
Measurements: measurements.DefaultsFor(cloudprovider.QEMU),
|
||||
EnforcedMeasurements: []uint32{4, 8, 9, 11, 12, 13, 15},
|
||||
NVRAM: "production",
|
||||
},
|
||||
@ -323,198 +329,38 @@ func Default() *Config {
|
||||
}
|
||||
}
|
||||
|
||||
func validateK8sVersion(fl validator.FieldLevel) bool {
|
||||
return versions.IsSupportedK8sVersion(fl.Field().String())
|
||||
// FromFile returns config file with `name` read from `fileHandler` by parsing
|
||||
// it as YAML. You should prefer config.New to read env vars and validate
|
||||
// config in a consistent manner.
|
||||
func FromFile(fileHandler file.Handler, name string) (*Config, error) {
|
||||
var conf Config
|
||||
if err := fileHandler.ReadYAMLStrict(name, &conf); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name)
|
||||
}
|
||||
return nil, fmt.Errorf("could not load config from file %s: %w", name, err)
|
||||
}
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
func validateAWSInstanceType(fl validator.FieldLevel) bool {
|
||||
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.AWS)
|
||||
}
|
||||
|
||||
func validateAzureInstanceType(fl validator.FieldLevel) bool {
|
||||
azureConfig := fl.Parent().Interface().(AzureConfig)
|
||||
var acceptNonCVM bool
|
||||
if azureConfig.ConfidentialVM != nil {
|
||||
// This is the inverse of the config value (acceptNonCVMs is true if confidentialVM is false).
|
||||
// We could make the validator the other way around, but this should be an explicit bypass rather than checking if CVMs are "allowed".
|
||||
acceptNonCVM = !*azureConfig.ConfidentialVM
|
||||
}
|
||||
return validInstanceTypeForProvider(fl.Field().String(), acceptNonCVM, cloudprovider.Azure)
|
||||
}
|
||||
|
||||
func validateGCPInstanceType(fl validator.FieldLevel) bool {
|
||||
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
|
||||
}
|
||||
|
||||
// validateProvider checks if zero or more than one providers are defined in the config.
|
||||
func validateProvider(sl validator.StructLevel) {
|
||||
provider := sl.Current().Interface().(ProviderConfig)
|
||||
providerCount := 0
|
||||
|
||||
if provider.AWS != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.Azure != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.GCP != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.QEMU != nil {
|
||||
providerCount++
|
||||
}
|
||||
|
||||
if providerCount < 1 {
|
||||
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
|
||||
} else if providerCount > 1 {
|
||||
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks the config values and returns validation error messages.
|
||||
// The function only returns an error if the validation itself fails.
|
||||
func (c *Config) Validate() ([]string, error) {
|
||||
trans := ut.New(en.New()).GetFallback()
|
||||
validate := validator.New()
|
||||
if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
|
||||
// New creates a new config by:
|
||||
// 1. Reading config file via provided fileHandler from file with name.
|
||||
// 2. Read secrets from environment variables.
|
||||
// 3. Validate config.
|
||||
func New(fileHandler file.Handler, name string) (*Config, error) {
|
||||
// Read config file
|
||||
c, err := FromFile(fileHandler, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register AWS, Azure & GCP InstanceType validation error types
|
||||
if err := validate.RegisterTranslation("aws_instance_type", trans, registerTranslateAWSInstanceTypeError, translateAWSInstanceTypeError); err != nil {
|
||||
return nil, err
|
||||
// Read secrets from env-vars.
|
||||
clientSecretValue := os.Getenv(constants.EnvVarAzureClientSecretValue)
|
||||
if clientSecretValue != "" && c.Provider.Azure != nil {
|
||||
c.Provider.Azure.ClientSecretValue = clientSecretValue
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("gcp_instance_type", trans, registerTranslateGCPInstanceTypeError, translateGCPInstanceTypeError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register Provider validation error types
|
||||
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
|
||||
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register custom validator with label aws_instance_type to validate the AWS instance type from config input.
|
||||
if err := validate.RegisterValidation("aws_instance_type", validateAWSInstanceType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register custom validator with label azure_instance_type to validate the Azure instance type from config input.
|
||||
if err := validate.RegisterValidation("azure_instance_type", validateAzureInstanceType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// register custom validator with label gcp_instance_type to validate the GCP instance type from config input.
|
||||
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Register provider validation
|
||||
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
|
||||
|
||||
err := validate.Struct(c)
|
||||
if err == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var errs validator.ValidationErrors
|
||||
if !errors.As(err, &errs) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var msgs []string
|
||||
for _, e := range errs {
|
||||
msgs = append(msgs, e.Translate(trans))
|
||||
}
|
||||
return msgs, nil
|
||||
}
|
||||
|
||||
// Validation translation functions for Azure & GCP instance type errors.
|
||||
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
|
||||
}
|
||||
|
||||
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
|
||||
var t string
|
||||
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
|
||||
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
|
||||
} else {
|
||||
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func registerTranslateAWSInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("aws_instance_type", fmt.Sprintf("{0} must be an instance from one of the following families types with size xlarge or higher: %v", instancetypes.AWSSupportedInstanceFamilies), true)
|
||||
}
|
||||
|
||||
func translateAWSInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("aws_instance_type", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func registerTranslateGCPInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("gcp_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.GCPInstanceTypes), true)
|
||||
}
|
||||
|
||||
func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("gcp_instance_type", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Validation translation functions for Provider errors.
|
||||
func registerNoProviderError(ut ut.Translator) error {
|
||||
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
|
||||
}
|
||||
|
||||
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("no_provider", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func registerMoreThanOneProviderError(ut ut.Translator) error {
|
||||
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
|
||||
}
|
||||
|
||||
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||
definedProviders := make([]string, 0)
|
||||
|
||||
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
|
||||
if c.Provider.AWS != nil {
|
||||
definedProviders = append(definedProviders, "AWS")
|
||||
}
|
||||
if c.Provider.Azure != nil {
|
||||
definedProviders = append(definedProviders, "Azure")
|
||||
}
|
||||
if c.Provider.GCP != nil {
|
||||
definedProviders = append(definedProviders, "GCP")
|
||||
}
|
||||
if c.Provider.QEMU != nil {
|
||||
definedProviders = append(definedProviders, "QEMU")
|
||||
}
|
||||
|
||||
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
|
||||
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
|
||||
|
||||
return t
|
||||
return c, c.Validate()
|
||||
}
|
||||
|
||||
// HasProvider checks whether the config contains the provider.
|
||||
@ -584,6 +430,19 @@ func (c *Config) RemoveProviderExcept(provider cloudprovider.Provider) {
|
||||
}
|
||||
}
|
||||
|
||||
// IsAzureNonCVM checks whether the chosen provider is azure and confidential VMs are disabled.
|
||||
func (c *Config) IsAzureNonCVM() bool {
|
||||
return c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM
|
||||
}
|
||||
|
||||
// IsDebugCluster checks whether the cluster is configured as a debug cluster.
|
||||
func (c *Config) IsDebugCluster() bool {
|
||||
if c.DebugCluster != nil && *c.DebugCluster {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDebugImage checks whether image name looks like a release image, if not it is
|
||||
// probably a debug image. In the end we do not if bootstrapper or debugd
|
||||
// was put inside an image just by looking at its name.
|
||||
@ -618,110 +477,80 @@ func (c *Config) GetProvider() cloudprovider.Provider {
|
||||
return cloudprovider.Unknown
|
||||
}
|
||||
|
||||
// IsAzureNonCVM checks whether the chosen provider is azure and confidential VMs are disabled.
|
||||
func (c *Config) IsAzureNonCVM() bool {
|
||||
return c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM
|
||||
}
|
||||
|
||||
// EnforcesIDKeyDigest checks whether ID Key Digest should be enforced for respective cloud provider.
|
||||
func (c *Config) EnforcesIDKeyDigest() bool {
|
||||
return c.Provider.Azure != nil && c.Provider.Azure.EnforceIDKeyDigest != nil && *c.Provider.Azure.EnforceIDKeyDigest
|
||||
}
|
||||
|
||||
// FromFile returns config file with `name` read from `fileHandler` by parsing
|
||||
// it as YAML.
|
||||
func FromFile(fileHandler file.Handler, name string) (*Config, error) {
|
||||
var conf Config
|
||||
if err := fileHandler.ReadYAMLStrict(name, &conf); err != nil {
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name)
|
||||
}
|
||||
return nil, fmt.Errorf("could not load config from file %s: %w", name, err)
|
||||
// Validate checks the config values and returns validation errors.
|
||||
func (c *Config) Validate() error {
|
||||
trans := ut.New(en.New()).GetFallback()
|
||||
validate := validator.New()
|
||||
if err := en_translations.RegisterDefaultTranslations(validate, trans); err != nil {
|
||||
return err
|
||||
}
|
||||
return &conf, nil
|
||||
}
|
||||
|
||||
func copyPCRMap(m map[uint32][]byte) map[uint32][]byte {
|
||||
res := make(Measurements)
|
||||
res.CopyFrom(m)
|
||||
return res
|
||||
}
|
||||
|
||||
func validInstanceTypeForProvider(insType string, acceptNonCVM bool, provider cloudprovider.Provider) bool {
|
||||
switch provider {
|
||||
case cloudprovider.AWS:
|
||||
return checkIfAWSInstanceTypeIsValid(insType)
|
||||
case cloudprovider.Azure:
|
||||
if acceptNonCVM {
|
||||
for _, instanceType := range instancetypes.AzureTrustedLaunchInstanceTypes {
|
||||
if insType == instanceType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, instanceType := range instancetypes.AzureCVMInstanceTypes {
|
||||
if insType == instanceType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
case cloudprovider.GCP:
|
||||
for _, instanceType := range instancetypes.GCPInstanceTypes {
|
||||
if insType == instanceType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// checkIfAWSInstanceTypeIsValid checks if an AWS instance type passed as user input is in one of the instance families supporting NitroTPM.
|
||||
func checkIfAWSInstanceTypeIsValid(userInput string) bool {
|
||||
// Check if user or code does anything weird and tries to pass multiple strings as one
|
||||
if strings.Contains(userInput, " ") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(userInput, ",") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(userInput, ";") {
|
||||
return false
|
||||
}
|
||||
|
||||
splitInstanceType := strings.Split(userInput, ".")
|
||||
|
||||
if len(splitInstanceType) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
userDefinedFamily := splitInstanceType[0]
|
||||
userDefinedSize := splitInstanceType[1]
|
||||
|
||||
// Check if instace type has at least 4 vCPUs (= contains "xlarge" in its name)
|
||||
hasEnoughVCPUs := strings.Contains(userDefinedSize, "xlarge")
|
||||
if !hasEnoughVCPUs {
|
||||
return false
|
||||
}
|
||||
|
||||
// Now check if the user input is a supported family
|
||||
// Note that we cannot directly use the family split from the Graviton check above, as some instances are directly specified by their full name and not just the family in general
|
||||
for _, supportedFamily := range instancetypes.AWSSupportedInstanceFamilies {
|
||||
supportedFamilyLowercase := strings.ToLower(supportedFamily)
|
||||
if userDefinedFamily == supportedFamilyLowercase {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsDebugCluster checks whether the cluster is configured as a debug cluster.
|
||||
func (c *Config) IsDebugCluster() bool {
|
||||
if c.DebugCluster != nil && *c.DebugCluster {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
// Register AWS, Azure & GCP InstanceType validation error types
|
||||
if err := validate.RegisterTranslation("aws_instance_type", trans, registerTranslateAWSInstanceTypeError, translateAWSInstanceTypeError); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("azure_instance_type", trans, registerTranslateAzureInstanceTypeError, c.translateAzureInstanceTypeError); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("gcp_instance_type", trans, registerTranslateGCPInstanceTypeError, translateGCPInstanceTypeError); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register Provider validation error types
|
||||
if err := validate.RegisterTranslation("no_provider", trans, registerNoProviderError, translateNoProviderError); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validate.RegisterTranslation("more_than_one_provider", trans, registerMoreThanOneProviderError, c.translateMoreThanOneProviderError); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// register custom validator with label supported_k8s_version to validate version based on available versionConfigs.
|
||||
if err := validate.RegisterValidation("supported_k8s_version", validateK8sVersion); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// register custom validator with label aws_instance_type to validate the AWS instance type from config input.
|
||||
if err := validate.RegisterValidation("aws_instance_type", validateAWSInstanceType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// register custom validator with label azure_instance_type to validate the Azure instance type from config input.
|
||||
if err := validate.RegisterValidation("azure_instance_type", validateAzureInstanceType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// register custom validator with label gcp_instance_type to validate the GCP instance type from config input.
|
||||
if err := validate.RegisterValidation("gcp_instance_type", validateGCPInstanceType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Register provider validation
|
||||
validate.RegisterStructValidation(validateProvider, ProviderConfig{})
|
||||
|
||||
err := validate.Struct(c)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var errs validator.ValidationErrors
|
||||
if !errors.As(err, &errs) {
|
||||
return err
|
||||
}
|
||||
|
||||
var validationErrors error
|
||||
for _, e := range errs {
|
||||
validationErrors = multierr.Append(
|
||||
validationErrors,
|
||||
errors.New(e.Translate(trans)),
|
||||
)
|
||||
}
|
||||
return validationErrors
|
||||
}
|
||||
|
@ -242,8 +242,8 @@ func init() {
|
||||
AzureConfigDoc.Fields[6].Name = "clientSecretValue"
|
||||
AzureConfigDoc.Fields[6].Type = "string"
|
||||
AzureConfigDoc.Fields[6].Note = ""
|
||||
AzureConfigDoc.Fields[6].Description = "Client secret value of the Active Directory app registration credentials."
|
||||
AzureConfigDoc.Fields[6].Comments[encoder.LineComment] = "Client secret value of the Active Directory app registration credentials."
|
||||
AzureConfigDoc.Fields[6].Description = "Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable."
|
||||
AzureConfigDoc.Fields[6].Comments[encoder.LineComment] = "Client secret value of the Active Directory app registration credentials. Alternatively leave empty and pass value via CONSTELL_AZURE_CLIENT_SECRET_VALUE environment variable."
|
||||
AzureConfigDoc.Fields[7].Name = "image"
|
||||
AzureConfigDoc.Fields[7].Type = "string"
|
||||
AzureConfigDoc.Fields[7].Note = ""
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
@ -21,6 +22,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"go.uber.org/multierr"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
@ -108,33 +110,49 @@ func TestFromFile(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
const defaultMsgCount = 20 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
|
||||
|
||||
func TestNewWithDefaultOptions(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
cnf *Config
|
||||
wantMsgCount int
|
||||
confToWrite *Config
|
||||
envToSet map[string]string
|
||||
wantErr bool
|
||||
wantClientSecretValue string
|
||||
}{
|
||||
"default config is valid": {
|
||||
cnf: Default(),
|
||||
wantMsgCount: defaultMsgCount,
|
||||
},
|
||||
"config with 1 error": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.Version = "v0"
|
||||
return cnf
|
||||
"set env works": {
|
||||
confToWrite: func() *Config { // valid config with all, but clientSecretValue
|
||||
c := Default()
|
||||
c.RemoveProviderExcept(cloudprovider.Azure)
|
||||
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
|
||||
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
|
||||
c.Provider.Azure.Location = "westus"
|
||||
c.Provider.Azure.ResourceGroup = "test"
|
||||
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
|
||||
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
|
||||
c.Provider.Azure.Image = "/communityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/images/constellation/versions/2.2.0"
|
||||
return c
|
||||
}(),
|
||||
wantMsgCount: defaultMsgCount + 1,
|
||||
envToSet: map[string]string{
|
||||
constants.EnvVarAzureClientSecretValue: "some-secret",
|
||||
},
|
||||
wantClientSecretValue: "some-secret",
|
||||
},
|
||||
"config with 2 errors": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.Version = "v0"
|
||||
cnf.StateDiskSizeGB = -1
|
||||
return cnf
|
||||
"set env overwrites": {
|
||||
confToWrite: func() *Config {
|
||||
c := Default()
|
||||
c.RemoveProviderExcept(cloudprovider.Azure)
|
||||
c.Provider.Azure.SubscriptionID = "f4278079-288c-4766-a98c-ab9d5dba01a5"
|
||||
c.Provider.Azure.TenantID = "d4ff9d63-6d6d-4042-8f6a-21e804add5aa"
|
||||
c.Provider.Azure.Location = "westus"
|
||||
c.Provider.Azure.ResourceGroup = "test"
|
||||
c.Provider.Azure.ClientSecretValue = "other-value" // < Note secret set in config, as well.
|
||||
c.Provider.Azure.UserAssignedIdentity = "/subscriptions/8b8bd01f-efd9-4113-9bd1-c82137c32da7/resourcegroups/constellation-identity/providers/Microsoft.ManagedIdentity/userAssignedIdentities/constellation-identity"
|
||||
c.Provider.Azure.AppClientID = "3ea4bdc1-1cc1-4237-ae78-0831eff3491e"
|
||||
c.Provider.Azure.Image = "/communityGalleries/ConstellationCVM-b3782fa0-0df7-4f2f-963e-fc7fc42663df/images/constellation/versions/2.2.0"
|
||||
return c
|
||||
}(),
|
||||
wantMsgCount: defaultMsgCount + 2,
|
||||
envToSet: map[string]string{
|
||||
constants.EnvVarAzureClientSecretValue: "some-secret",
|
||||
},
|
||||
wantClientSecretValue: "some-secret",
|
||||
},
|
||||
}
|
||||
|
||||
@ -143,9 +161,126 @@ func TestValidate(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
msgs, err := tc.cnf.Validate()
|
||||
// Setup
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
err := fileHandler.WriteYAML(constants.ConfigFilename, tc.confToWrite)
|
||||
require.NoError(err)
|
||||
assert.Len(msgs, tc.wantMsgCount)
|
||||
for envKey, envValue := range tc.envToSet {
|
||||
t.Setenv(envKey, envValue)
|
||||
}
|
||||
|
||||
// Test
|
||||
c, err := New(fileHandler, constants.ConfigFilename)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(c.Provider.Azure.ClientSecretValue, tc.wantClientSecretValue)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate(t *testing.T) {
|
||||
const defaultErrCount = 20 // expect this number of error messages by default because user-specific values are not set and multiple providers are defined by default
|
||||
const azErrCount = 8
|
||||
const gcpErrCount = 5
|
||||
|
||||
testCases := map[string]struct {
|
||||
cnf *Config
|
||||
wantErr bool
|
||||
wantErrCount int
|
||||
}{
|
||||
"default config is not valid": {
|
||||
cnf: Default(),
|
||||
wantErr: true,
|
||||
wantErrCount: defaultErrCount,
|
||||
},
|
||||
"v0 is one error": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.Version = "v0"
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
wantErrCount: defaultErrCount + 1,
|
||||
},
|
||||
"v0 and negative state disk are two errors": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
cnf.Version = "v0"
|
||||
cnf.StateDiskSizeGB = -1
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
wantErrCount: defaultErrCount + 2,
|
||||
},
|
||||
"default Azure config is not valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
az := cnf.Provider.Azure
|
||||
cnf.Provider = ProviderConfig{}
|
||||
cnf.Provider.Azure = az
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
wantErrCount: azErrCount,
|
||||
},
|
||||
"Azure config with all required fields is valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
az := cnf.Provider.Azure
|
||||
az.SubscriptionID = "01234567-0123-0123-0123-0123456789ab"
|
||||
az.TenantID = "01234567-0123-0123-0123-0123456789ab"
|
||||
az.Location = "test-location"
|
||||
az.UserAssignedIdentity = "test-identity"
|
||||
az.Image = "some/image/location"
|
||||
az.ResourceGroup = "test-resource-group"
|
||||
az.AppClientID = "01234567-0123-0123-0123-0123456789ab"
|
||||
az.ClientSecretValue = "test-client-secret"
|
||||
cnf.Provider = ProviderConfig{}
|
||||
cnf.Provider.Azure = az
|
||||
return cnf
|
||||
}(),
|
||||
},
|
||||
"default GCP config is not valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
gcp := cnf.Provider.GCP
|
||||
cnf.Provider = ProviderConfig{}
|
||||
cnf.Provider.GCP = gcp
|
||||
return cnf
|
||||
}(),
|
||||
wantErr: true,
|
||||
wantErrCount: gcpErrCount,
|
||||
},
|
||||
"GCP config with all required fields is valid": {
|
||||
cnf: func() *Config {
|
||||
cnf := Default()
|
||||
gcp := cnf.Provider.GCP
|
||||
gcp.Region = "test-region"
|
||||
gcp.Project = "test-project"
|
||||
gcp.Image = "some/image/location"
|
||||
gcp.Zone = "test-zone"
|
||||
gcp.ServiceAccountKeyPath = "test-key-path"
|
||||
cnf.Provider = ProviderConfig{}
|
||||
cnf.Provider.GCP = gcp
|
||||
return cnf
|
||||
}(),
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
err := tc.cnf.Validate()
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
assert.Len(multierr.Errors(err), tc.wantErrCount)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -260,10 +395,10 @@ func TestConfigGeneratedDocsFresh(t *testing.T) {
|
||||
|
||||
func TestConfig_UpdateMeasurements(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
newMeasurements := Measurements{
|
||||
1: []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
|
||||
2: []byte{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1},
|
||||
3: []byte{2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2},
|
||||
newMeasurements := measurements.M{
|
||||
1: measurements.PCRWithAllBytes(0x00),
|
||||
2: measurements.PCRWithAllBytes(0x01),
|
||||
3: measurements.PCRWithAllBytes(0x02),
|
||||
}
|
||||
|
||||
{ // AWS
|
||||
|
212
internal/config/validation.go
Normal file
212
internal/config/validation.go
Normal file
@ -0,0 +1,212 @@
|
||||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/v2/internal/config/instancetypes"
|
||||
"github.com/edgelesssys/constellation/v2/internal/versions"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func validateK8sVersion(fl validator.FieldLevel) bool {
|
||||
return versions.IsSupportedK8sVersion(fl.Field().String())
|
||||
}
|
||||
|
||||
func validateAWSInstanceType(fl validator.FieldLevel) bool {
|
||||
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.AWS)
|
||||
}
|
||||
|
||||
func validateAzureInstanceType(fl validator.FieldLevel) bool {
|
||||
azureConfig := fl.Parent().Interface().(AzureConfig)
|
||||
var acceptNonCVM bool
|
||||
if azureConfig.ConfidentialVM != nil {
|
||||
// This is the inverse of the config value (acceptNonCVMs is true if confidentialVM is false).
|
||||
// We could make the validator the other way around, but this should be an explicit bypass rather than checking if CVMs are "allowed".
|
||||
acceptNonCVM = !*azureConfig.ConfidentialVM
|
||||
}
|
||||
return validInstanceTypeForProvider(fl.Field().String(), acceptNonCVM, cloudprovider.Azure)
|
||||
}
|
||||
|
||||
func validateGCPInstanceType(fl validator.FieldLevel) bool {
|
||||
return validInstanceTypeForProvider(fl.Field().String(), false, cloudprovider.GCP)
|
||||
}
|
||||
|
||||
// validateProvider checks if zero or more than one providers are defined in the config.
|
||||
func validateProvider(sl validator.StructLevel) {
|
||||
provider := sl.Current().Interface().(ProviderConfig)
|
||||
providerCount := 0
|
||||
|
||||
if provider.AWS != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.Azure != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.GCP != nil {
|
||||
providerCount++
|
||||
}
|
||||
if provider.QEMU != nil {
|
||||
providerCount++
|
||||
}
|
||||
|
||||
if providerCount < 1 {
|
||||
sl.ReportError(provider, "Provider", "Provider", "no_provider", "")
|
||||
} else if providerCount > 1 {
|
||||
sl.ReportError(provider, "Provider", "Provider", "more_than_one_provider", "")
|
||||
}
|
||||
}
|
||||
|
||||
func registerTranslateAWSInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("aws_instance_type", fmt.Sprintf("{0} must be an instance from one of the following families types with size xlarge or higher: %v", instancetypes.AWSSupportedInstanceFamilies), true)
|
||||
}
|
||||
|
||||
func translateAWSInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("aws_instance_type", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func registerTranslateGCPInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("gcp_instance_type", fmt.Sprintf("{0} must be one of %v", instancetypes.GCPInstanceTypes), true)
|
||||
}
|
||||
|
||||
func translateGCPInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("gcp_instance_type", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Validation translation functions for Provider errors.
|
||||
func registerNoProviderError(ut ut.Translator) error {
|
||||
return ut.Add("no_provider", "{0}: No provider has been defined (requires either Azure, GCP or QEMU)", true)
|
||||
}
|
||||
|
||||
func translateNoProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("no_provider", fe.Field())
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func registerMoreThanOneProviderError(ut ut.Translator) error {
|
||||
return ut.Add("more_than_one_provider", "{0}: Only one provider can be defined ({1} are defined)", true)
|
||||
}
|
||||
|
||||
func (c *Config) translateMoreThanOneProviderError(ut ut.Translator, fe validator.FieldError) string {
|
||||
definedProviders := make([]string, 0)
|
||||
|
||||
// c.Provider should not be nil as Provider would need to be defined for the validation to fail in this place.
|
||||
if c.Provider.AWS != nil {
|
||||
definedProviders = append(definedProviders, "AWS")
|
||||
}
|
||||
if c.Provider.Azure != nil {
|
||||
definedProviders = append(definedProviders, "Azure")
|
||||
}
|
||||
if c.Provider.GCP != nil {
|
||||
definedProviders = append(definedProviders, "GCP")
|
||||
}
|
||||
if c.Provider.QEMU != nil {
|
||||
definedProviders = append(definedProviders, "QEMU")
|
||||
}
|
||||
|
||||
// Show single string if only one other provider is defined, show list with brackets if multiple are defined.
|
||||
t, _ := ut.T("more_than_one_provider", fe.Field(), strings.Join(definedProviders, ", "))
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func validInstanceTypeForProvider(insType string, acceptNonCVM bool, provider cloudprovider.Provider) bool {
|
||||
switch provider {
|
||||
case cloudprovider.AWS:
|
||||
return checkIfAWSInstanceTypeIsValid(insType)
|
||||
case cloudprovider.Azure:
|
||||
if acceptNonCVM {
|
||||
for _, instanceType := range instancetypes.AzureTrustedLaunchInstanceTypes {
|
||||
if insType == instanceType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, instanceType := range instancetypes.AzureCVMInstanceTypes {
|
||||
if insType == instanceType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
case cloudprovider.GCP:
|
||||
for _, instanceType := range instancetypes.GCPInstanceTypes {
|
||||
if insType == instanceType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// checkIfAWSInstanceTypeIsValid checks if an AWS instance type passed as user input is in one of the instance families supporting NitroTPM.
|
||||
func checkIfAWSInstanceTypeIsValid(userInput string) bool {
|
||||
// Check if user or code does anything weird and tries to pass multiple strings as one
|
||||
if strings.Contains(userInput, " ") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(userInput, ",") {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(userInput, ";") {
|
||||
return false
|
||||
}
|
||||
|
||||
splitInstanceType := strings.Split(userInput, ".")
|
||||
|
||||
if len(splitInstanceType) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
userDefinedFamily := splitInstanceType[0]
|
||||
userDefinedSize := splitInstanceType[1]
|
||||
|
||||
// Check if instace type has at least 4 vCPUs (= contains "xlarge" in its name)
|
||||
hasEnoughVCPUs := strings.Contains(userDefinedSize, "xlarge")
|
||||
if !hasEnoughVCPUs {
|
||||
return false
|
||||
}
|
||||
|
||||
// Now check if the user input is a supported family
|
||||
// Note that we cannot directly use the family split from the Graviton check above, as some instances are directly specified by their full name and not just the family in general
|
||||
for _, supportedFamily := range instancetypes.AWSSupportedInstanceFamilies {
|
||||
supportedFamilyLowercase := strings.ToLower(supportedFamily)
|
||||
if userDefinedFamily == supportedFamilyLowercase {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Validation translation functions for Azure & GCP instance type errors.
|
||||
func registerTranslateAzureInstanceTypeError(ut ut.Translator) error {
|
||||
return ut.Add("azure_instance_type", "{0} must be one of {1}", true)
|
||||
}
|
||||
|
||||
func (c *Config) translateAzureInstanceTypeError(ut ut.Translator, fe validator.FieldError) string {
|
||||
// Suggest trusted launch VMs if confidential VMs have been specifically disabled
|
||||
var t string
|
||||
if c.Provider.Azure != nil && c.Provider.Azure.ConfidentialVM != nil && !*c.Provider.Azure.ConfidentialVM {
|
||||
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureTrustedLaunchInstanceTypes))
|
||||
} else {
|
||||
t, _ = ut.T("azure_instance_type", fe.Field(), fmt.Sprintf("%v", instancetypes.AzureCVMInstanceTypes))
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
@ -113,6 +113,11 @@ const (
|
||||
MinControllerCount = 1
|
||||
// MinWorkerCount is the minimum number of worker nodes.
|
||||
MinWorkerCount = 1
|
||||
// EnvVarPrefix is expected prefix for environment variables used to overwrite config parameters.
|
||||
EnvVarPrefix = "CONSTELL_"
|
||||
// EnvVarAzureClientSecretValue is environment variable to overwrite
|
||||
// provider.azure.clientSecretValue .
|
||||
EnvVarAzureClientSecretValue = EnvVarPrefix + "AZURE_CLIENT_SECRET_VALUE"
|
||||
|
||||
//
|
||||
// Kubernetes.
|
||||
|
Loading…
x
Reference in New Issue
Block a user