AB#2512 Config secrets via env var & config refactoring (#544)

* refactor measurements to use consistent types and less byte pushing
* refactor: only rely on a single multierr dependency
* extend config creation with envar support
* document changes
Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
Fabian Kammel 2022-11-15 15:40:49 +01:00 committed by GitHub
parent 80a801629e
commit bb76a4e4c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
42 changed files with 932 additions and 791 deletions

View file

@ -7,13 +7,13 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
@ -58,7 +58,7 @@ func NewUpgrader(outWriter io.Writer) (*Upgrader, error) {
}
// Upgrade upgrades the cluster to the given measurements and image.
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error {
func (u *Upgrader) Upgrade(ctx context.Context, image string, measurements measurements.M) error {
if err := u.updateMeasurements(ctx, measurements); err != nil {
return fmt.Errorf("updating measurements: %w", err)
}
@ -97,36 +97,25 @@ func (u *Upgrader) GetCurrentImage(ctx context.Context) (*unstructured.Unstructu
return imageStruct, imageDefinition, nil
}
func (u *Upgrader) updateMeasurements(ctx context.Context, measurements map[uint32][]byte) error {
func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measurements.M) error {
existingConf, err := u.measurementsUpdater.getCurrent(ctx, constants.JoinConfigMap)
if err != nil {
return fmt.Errorf("retrieving current measurements: %w", err)
}
var currentMeasurements map[uint32][]byte
var currentMeasurements measurements.M
if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), &currentMeasurements); err != nil {
return fmt.Errorf("retrieving current measurements: %w", err)
}
if len(currentMeasurements) == len(measurements) {
changed := false
for k, v := range currentMeasurements {
if !bytes.Equal(v, measurements[k]) {
// measurements have changed
changed = true
break
}
}
if !changed {
// measurements are the same, nothing to be done
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
return nil
}
if currentMeasurements.EqualTo(newMeasurements) {
fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade")
return nil
}
// backup of previous measurements
existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename]
measurementsJSON, err := json.Marshal(measurements)
measurementsJSON, err := json.Marshal(newMeasurements)
if err != nil {
return fmt.Errorf("marshaling measurements: %w", err)
}

View file

@ -13,6 +13,7 @@ import (
"errors"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -24,7 +25,7 @@ func TestUpdateMeasurements(t *testing.T) {
someErr := errors.New("error")
testCases := map[string]struct {
updater *stubMeasurementsUpdater
newMeasurements map[uint32][]byte
newMeasurements measurements.M
wantUpdate bool
wantErr bool
}{
@ -36,7 +37,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
},
newMeasurements: map[uint32][]byte{
newMeasurements: measurements.M{
0: []byte("1"),
},
wantUpdate: true,
@ -49,7 +50,7 @@ func TestUpdateMeasurements(t *testing.T) {
},
},
},
newMeasurements: map[uint32][]byte{
newMeasurements: measurements.M{
0: []byte("1"),
},
},

View file

@ -18,6 +18,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -28,7 +29,7 @@ import (
// Validator validates Platform Configuration Registers (PCRs).
type Validator struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
pcrs measurements.M
enforcedPCRs []uint32
idkeydigest []byte
enforceIDKeyDigest bool
@ -147,7 +148,7 @@ func (v *Validator) V(cmd *cobra.Command) atls.Validator {
}
// PCRS returns the validator's PCR map.
func (v *Validator) PCRS() map[uint32][]byte {
func (v *Validator) PCRS() measurements.M {
return v.pcrs
}
@ -169,7 +170,7 @@ func (v *Validator) updateValidator(cmd *cobra.Command) {
}
}
func (v *Validator) checkPCRs(pcrs map[uint32][]byte, enforcedPCRs []uint32) error {
func (v *Validator) checkPCRs(pcrs measurements.M, enforcedPCRs []uint32) error {
if len(pcrs) == 0 {
return errors.New("no PCR values provided")
}

View file

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package cloudcmd
import (
"bytes"
"crypto/sha256"
"encoding/base64"
"testing"
@ -15,6 +16,7 @@ import (
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
@ -24,21 +26,19 @@ import (
)
func TestNewValidator(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
one := []byte("11111111111111111111111111111111")
testPCRs := map[uint32][]byte{
0: zero,
1: one,
2: zero,
3: one,
4: zero,
5: zero,
testPCRs := measurements.M{
0: measurements.PCRWithAllBytes(0x00),
1: measurements.PCRWithAllBytes(0xFF),
2: measurements.PCRWithAllBytes(0x00),
3: measurements.PCRWithAllBytes(0xFF),
4: measurements.PCRWithAllBytes(0x00),
5: measurements.PCRWithAllBytes(0x00),
}
testCases := map[string]struct {
provider cloudprovider.Provider
config *config.Config
pcrs map[uint32][]byte
pcrs measurements.M
enforceIDKeyDigest bool
idKeyDigest string
azureCVM bool
@ -64,13 +64,15 @@ func TestNewValidator(t *testing.T) {
},
"no pcrs provided": {
provider: cloudprovider.Azure,
pcrs: map[uint32][]byte{},
pcrs: measurements.M{},
wantErr: true,
},
"invalid pcr length": {
provider: cloudprovider.GCP,
pcrs: map[uint32][]byte{0: []byte("0000000000000000000000000000000")},
wantErr: true,
pcrs: measurements.M{
0: bytes.Repeat([]byte{0x00}, 31),
},
wantErr: true,
},
"unknown provider": {
provider: cloudprovider.Unknown,
@ -99,16 +101,13 @@ func TestNewValidator(t *testing.T) {
conf := &config.Config{Provider: config.ProviderConfig{}}
if tc.provider == cloudprovider.GCP {
measurements := config.Measurements(tc.pcrs)
conf.Provider.GCP = &config.GCPConfig{Measurements: measurements}
conf.Provider.GCP = &config.GCPConfig{Measurements: tc.pcrs}
}
if tc.provider == cloudprovider.Azure {
measurements := config.Measurements(tc.pcrs)
conf.Provider.Azure = &config.AzureConfig{Measurements: measurements, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
conf.Provider.Azure = &config.AzureConfig{Measurements: tc.pcrs, EnforceIDKeyDigest: &tc.enforceIDKeyDigest, IDKeyDigest: tc.idKeyDigest, ConfidentialVM: &tc.azureCVM}
}
if tc.provider == cloudprovider.QEMU {
measurements := config.Measurements(tc.pcrs)
conf.Provider.QEMU = &config.QEMUConfig{Measurements: measurements}
conf.Provider.QEMU = &config.QEMUConfig{Measurements: tc.pcrs}
}
validators, err := NewValidator(tc.provider, conf)
@ -125,29 +124,27 @@ func TestNewValidator(t *testing.T) {
}
func TestValidatorV(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
newTestPCRs := func() map[uint32][]byte {
return map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
newTestPCRs := func() measurements.M {
return measurements.M{
0: measurements.PCRWithAllBytes(0x00),
1: measurements.PCRWithAllBytes(0x00),
2: measurements.PCRWithAllBytes(0x00),
3: measurements.PCRWithAllBytes(0x00),
4: measurements.PCRWithAllBytes(0x00),
5: measurements.PCRWithAllBytes(0x00),
6: measurements.PCRWithAllBytes(0x00),
7: measurements.PCRWithAllBytes(0x00),
8: measurements.PCRWithAllBytes(0x00),
9: measurements.PCRWithAllBytes(0x00),
10: measurements.PCRWithAllBytes(0x00),
11: measurements.PCRWithAllBytes(0x00),
12: measurements.PCRWithAllBytes(0x00),
}
}
testCases := map[string]struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
pcrs measurements.M
wantVs atls.Validator
azureCVM bool
}{
@ -224,7 +221,7 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
testCases := map[string]struct {
provider cloudprovider.Provider
pcrs map[uint32][]byte
pcrs measurements.M
ownerID string
clusterID string
wantErr bool
@ -321,14 +318,14 @@ func TestValidatorUpdateInitPCRs(t *testing.T) {
}
func TestUpdatePCR(t *testing.T) {
emptyMap := map[uint32][]byte{}
defaultMap := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
emptyMap := measurements.M{}
defaultMap := measurements.M{
0: measurements.PCRWithAllBytes(0xAA),
1: measurements.PCRWithAllBytes(0xBB),
}
testCases := map[string]struct {
pcrMap map[uint32][]byte
pcrMap measurements.M
pcrIndex uint32
encoded string
wantEntries int
@ -389,7 +386,7 @@ func TestUpdatePCR(t *testing.T) {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
pcrs := make(map[uint32][]byte)
pcrs := make(measurements.M)
for k, v := range tc.pcrMap {
pcrs[k] = v
}

View file

@ -14,6 +14,7 @@ import (
"strings"
"time"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -39,7 +40,7 @@ func newConfigFetchMeasurementsCmd() *cobra.Command {
type fetchMeasurementsFlags struct {
measurementsURL *url.URL
signatureURL *url.URL
config string
configPath string
}
func runConfigFetchMeasurements(cmd *cobra.Command, args []string) error {
@ -57,9 +58,9 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
return err
}
conf, err := config.FromFile(fileHandler, flags.config)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return err
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
if conf.IsDebugImage() {
@ -72,7 +73,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
var fetchedMeasurements config.Measurements
var fetchedMeasurements measurements.M
hash, err := fetchedMeasurements.FetchAndVerify(ctx, client, flags.measurementsURL, flags.signatureURL, []byte(constants.CosignPublicKey))
if err != nil {
return err
@ -84,7 +85,7 @@ func configFetchMeasurements(cmd *cobra.Command, verifier rekorVerifier, fileHan
}
conf.UpdateMeasurements(fetchedMeasurements)
if err := fileHandler.WriteYAML(flags.config, conf, file.OptOverwrite); err != nil {
if err := fileHandler.WriteYAML(flags.configPath, conf, file.OptOverwrite); err != nil {
return err
}
@ -123,7 +124,7 @@ func parseFetchMeasurementsFlags(cmd *cobra.Command) (*fetchMeasurementsFlags, e
return &fetchMeasurementsFlags{
measurementsURL: measurementsURL,
signatureURL: measurementsSignatureURL,
config: config,
configPath: config,
}, nil
}

View file

@ -40,7 +40,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
wantFlags: &fetchMeasurementsFlags{
measurementsURL: nil,
signatureURL: nil,
config: constants.ConfigFilename,
configPath: constants.ConfigFilename,
},
},
"url": {
@ -49,7 +49,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
wantFlags: &fetchMeasurementsFlags{
measurementsURL: urlMustParse("https://some.other.url/with/path"),
signatureURL: urlMustParse("https://some.other.url/with/path.sig"),
config: constants.ConfigFilename,
configPath: constants.ConfigFilename,
},
},
"broken url": {
@ -59,7 +59,7 @@ func TestParseFetchMeasurementsFlags(t *testing.T) {
"config": {
configFlag: "someOtherConfig.yaml",
wantFlags: &fetchMeasurementsFlags{
config: "someOtherConfig.yaml",
configPath: "someOtherConfig.yaml",
},
},
}
@ -212,8 +212,7 @@ func TestConfigFetchMeasurements(t *testing.T) {
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
fileHandler := file.NewHandler(afero.NewMemMapFs())
gcpConfig := config.Default()
gcpConfig.RemoveProviderExcept(cloudprovider.GCP)
gcpConfig := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
gcpConfig.Provider.GCP.Image = "projects/constellation-images/global/images/constellation-coreos-1658216163"
err := fileHandler.WriteYAML(constants.ConfigFilename, gcpConfig, file.OptMkdirAll)

View file

@ -0,0 +1,28 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"io"
"go.uber.org/multierr"
)
func displayConfigValidationErrors(errWriter io.Writer, configError error) error {
errs := multierr.Errors(configError)
if errs != nil {
fmt.Fprintln(errWriter, "Problems validating config file:")
for _, err := range errs {
fmt.Fprintln(errWriter, "\t"+err.Error())
}
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
return errors.New("invalid configuration")
}
return nil
}

View file

@ -13,6 +13,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
@ -59,27 +60,27 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
var printedAWarning bool
if config.IsDebugImage() {
if conf.IsDebugImage() {
cmd.PrintErrln("Configured image doesn't look like a released production image. Double check image before deploying to production.")
printedAWarning = true
}
if config.IsDebugCluster() {
if conf.IsDebugCluster() {
cmd.PrintErrln("WARNING: Creating a debug cluster. This cluster is not secure and should only be used for debugging purposes.")
cmd.PrintErrln("DO NOT USE THIS CLUSTER IN PRODUCTION.")
printedAWarning = true
}
if config.IsAzureNonCVM() {
if conf.IsAzureNonCVM() {
cmd.PrintErrln("Disabling Confidential VMs is insecure. Use only for evaluation purposes.")
printedAWarning = true
if config.EnforcesIDKeyDigest() {
if conf.EnforcesIDKeyDigest() {
cmd.PrintErrln("Your config asks for enforcing the idkeydigest. This is only available on Confidential VMs. It will not be enforced.")
}
}
@ -89,20 +90,20 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
cmd.PrintErrln("")
}
provider := config.GetProvider()
provider := conf.GetProvider()
var instanceType string
switch provider {
case cloudprovider.AWS:
instanceType = config.Provider.AWS.InstanceType
instanceType = conf.Provider.AWS.InstanceType
if len(flags.name) > 10 {
return fmt.Errorf("cluster name on AWS must not be longer than 10 characters")
}
case cloudprovider.Azure:
instanceType = config.Provider.Azure.InstanceType
instanceType = conf.Provider.Azure.InstanceType
case cloudprovider.GCP:
instanceType = config.Provider.GCP.InstanceType
instanceType = conf.Provider.GCP.InstanceType
case cloudprovider.QEMU:
cpus := config.Provider.QEMU.VCPUs
cpus := conf.Provider.QEMU.VCPUs
instanceType = fmt.Sprintf("%d-vCPU", cpus)
}
@ -122,7 +123,7 @@ func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler,
}
spinner.Start("Creating", false)
idFile, err := creator.Create(cmd.Context(), provider, config, flags.name, instanceType, flags.controllerCount, flags.workerCount)
idFile, err := creator.Create(cmd.Context(), provider, conf, flags.name, instanceType, flags.controllerCount, flags.workerCount)
spinner.Stop()
if err != nil {
return err

View file

@ -78,9 +78,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
var idFile clusterid.File
@ -88,7 +88,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("reading cluster ID file: %w", err)
}
k8sVersion, err := versions.NewValidK8sVersion(config.KubernetesVersion)
k8sVersion, err := versions.NewValidK8sVersion(conf.KubernetesVersion)
if err != nil {
return fmt.Errorf("validating kubernetes version: %w", err)
}
@ -96,18 +96,18 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
cmd.PrintErrf("Warning: Constellation with Kubernetes %v is still in preview. Use only for evaluation purposes.\n", k8sVersion)
}
provider := config.GetProvider()
provider := conf.GetProvider()
checker := license.NewChecker(quotaChecker, fileHandler)
if err := checker.CheckLicense(cmd.Context(), provider, config.Provider, cmd.Printf); err != nil {
if err := checker.CheckLicense(cmd.Context(), provider, conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %v", err)
}
validator, err := cloudcmd.NewValidator(provider, config)
validator, err := cloudcmd.NewValidator(provider, conf)
if err != nil {
return err
}
serviceAccURI, err := getMarshaledServiceAccountURI(provider, config, fileHandler)
serviceAccURI, err := getMarshaledServiceAccountURI(provider, conf, fileHandler)
if err != nil {
return err
}
@ -117,7 +117,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err)
}
helmLoader := helm.New(provider, k8sVersion)
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, config), getEnforceIDKeyDigest(provider, config))
helmDeployments, err := helmLoader.Load(provider, flags.conformance, masterSecret.Key, masterSecret.Salt, getEnforcedPCRs(provider, conf), getEnforceIDKeyDigest(provider, conf))
if err != nil {
return fmt.Errorf("loading Helm charts: %w", err)
}
@ -131,10 +131,10 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
KeyEncryptionKeyId: "",
UseExistingKek: false,
CloudServiceAccountUri: serviceAccURI,
KubernetesVersion: config.KubernetesVersion,
KubernetesVersion: conf.KubernetesVersion,
HelmDeployments: helmDeployments,
EnforcedPcrs: getEnforcedPCRs(provider, config),
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, config),
EnforcedPcrs: getEnforcedPCRs(provider, conf),
EnforceIdkeydigest: getEnforceIDKeyDigest(provider, conf),
ConformanceMode: flags.conformance,
}
resp, err := initCall(cmd.Context(), newDialer(validator), idFile.IP, req)

View file

@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared"
"github.com/edgelesssys/constellation/v2/internal/config"
@ -360,11 +361,11 @@ func TestAttestation(t *testing.T) {
issuer := &testIssuer{
Getter: oid.QEMU{},
pcrs: map[uint32][]byte{
0: []byte("ffffffffffffffffffffffffffffffff"),
1: []byte("ffffffffffffffffffffffffffffffff"),
2: []byte("ffffffffffffffffffffffffffffffff"),
3: []byte("ffffffffffffffffffffffffffffffff"),
pcrs: measurements.M{
0: measurements.PCRWithAllBytes(0xFF),
1: measurements.PCRWithAllBytes(0xFF),
2: measurements.PCRWithAllBytes(0xFF),
3: measurements.PCRWithAllBytes(0xFF),
},
}
serverCreds := atlscredentials.New(issuer, nil)
@ -389,13 +390,13 @@ func TestAttestation(t *testing.T) {
cfg := config.Default()
cfg.RemoveProviderExcept(cloudprovider.QEMU)
cfg.Provider.QEMU.Image = "some/image/location"
cfg.Provider.QEMU.Measurements[0] = []byte("00000000000000000000000000000000")
cfg.Provider.QEMU.Measurements[1] = []byte("11111111111111111111111111111111")
cfg.Provider.QEMU.Measurements[2] = []byte("22222222222222222222222222222222")
cfg.Provider.QEMU.Measurements[3] = []byte("33333333333333333333333333333333")
cfg.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444")
cfg.Provider.QEMU.Measurements[8] = []byte("88888888888888888888888888888888")
cfg.Provider.QEMU.Measurements[9] = []byte("99999999999999999999999999999999")
cfg.Provider.QEMU.Measurements[0] = measurements.PCRWithAllBytes(0x00)
cfg.Provider.QEMU.Measurements[1] = measurements.PCRWithAllBytes(0x11)
cfg.Provider.QEMU.Measurements[2] = measurements.PCRWithAllBytes(0x22)
cfg.Provider.QEMU.Measurements[3] = measurements.PCRWithAllBytes(0x33)
cfg.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
cfg.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x88)
cfg.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x99)
require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, cfg, file.OptNone))
ctx := context.Background()
@ -411,13 +412,13 @@ func TestAttestation(t *testing.T) {
type testValidator struct {
oid.Getter
pcrs map[uint32][]byte
pcrs measurements.M
}
func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var attestation struct {
UserData []byte
PCRs map[uint32][]byte
PCRs measurements.M
}
if err := json.Unmarshal(attDoc, &attestation); err != nil {
return nil, err
@ -433,14 +434,14 @@ func (v *testValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
type testIssuer struct {
oid.Getter
pcrs map[uint32][]byte
pcrs measurements.M
}
func (i *testIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(
struct {
UserData []byte
PCRs map[uint32][]byte
PCRs measurements.M
}{
UserData: userData,
PCRs: i.pcrs,
@ -472,23 +473,23 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.AppClientID = "01234567-0123-0123-0123-0123456789ab"
conf.Provider.Azure.ClientSecretValue = "test-client-secret"
conf.Provider.Azure.Measurements[4] = []byte("44444444444444444444444444444444")
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
conf.Provider.Azure.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.Azure.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.Azure.Measurements[9] = measurements.PCRWithAllBytes(0x11)
case cloudprovider.GCP:
conf.Provider.GCP.Region = "test-region"
conf.Provider.GCP.Project = "test-project"
conf.Provider.GCP.Image = "some/image/location"
conf.Provider.GCP.Zone = "test-zone"
conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path"
conf.Provider.GCP.Measurements[4] = []byte("44444444444444444444444444444444")
conf.Provider.GCP.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.GCP.Measurements[9] = []byte("11111111111111111111111111111111")
conf.Provider.GCP.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.GCP.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.GCP.Measurements[9] = measurements.PCRWithAllBytes(0x11)
case cloudprovider.QEMU:
conf.Provider.QEMU.Image = "some/image/location"
conf.Provider.QEMU.Measurements[4] = []byte("44444444444444444444444444444444")
conf.Provider.QEMU.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.QEMU.Measurements[9] = []byte("11111111111111111111111111111111")
conf.Provider.QEMU.Measurements[4] = measurements.PCRWithAllBytes(0x44)
conf.Provider.QEMU.Measurements[8] = measurements.PCRWithAllBytes(0x00)
conf.Provider.QEMU.Measurements[9] = measurements.PCRWithAllBytes(0x11)
}
conf.RemoveProviderExcept(csp)

View file

@ -163,14 +163,14 @@ func prepareConfig(cmd *cobra.Command, fileHandler file.Handler) (*config.Config
// check for existing config
if configPath != "" {
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, configPath)
conf, err := config.New(fileHandler, configPath)
if err != nil {
return nil, err
return nil, displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
if config.GetProvider() != cloudprovider.QEMU {
if conf.GetProvider() != cloudprovider.QEMU {
return nil, errors.New("invalid provider for MiniConstellation cluster")
}
return config, nil
return conf, nil
}
if err := cmd.Flags().Set("config", constants.ConfigFilename); err != nil {
return nil, err

View file

@ -1,46 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
)
func readConfig(errWriter io.Writer, fileHandler file.Handler, name string) (*config.Config, error) {
cnf, err := config.FromFile(fileHandler, name)
if err != nil {
return nil, err
}
if err := validateConfig(errWriter, cnf); err != nil {
return nil, err
}
return cnf, nil
}
func validateConfig(errWriter io.Writer, cnf *config.Config) error {
msgs, err := cnf.Validate()
if err != nil {
return fmt.Errorf("performing config validation: %w", err)
}
if len(msgs) > 0 {
fmt.Fprintln(errWriter, "Invalid fields in config file:")
for _, m := range msgs {
fmt.Fprintln(errWriter, "\t"+m)
}
fmt.Fprintln(errWriter, "Fix the invalid entries or generate a new configuration using `constellation config generate`")
return errors.New("invalid configuration")
}
return nil
}

View file

@ -1,126 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateConfig(t *testing.T) {
testCases := map[string]struct {
cnf *config.Config
provider cloudprovider.Provider
wantOutput bool
wantErr bool
}{
"default config is not valid": {
cnf: config.Default(),
wantOutput: true,
wantErr: true,
},
"default Azure config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
az := cnf.Provider.Azure
cnf.Provider = config.ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
provider: cloudprovider.Azure,
wantOutput: true,
wantErr: true,
},
"Azure config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure),
provider: cloudprovider.Azure,
},
"default GCP config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
gcp := cnf.Provider.GCP
cnf.Provider = config.ProviderConfig{}
cnf.Provider.GCP = gcp
return cnf
}(),
provider: cloudprovider.GCP,
wantOutput: true,
wantErr: true,
},
"GCP config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP),
provider: cloudprovider.GCP,
},
"default QEMU config is not valid": {
cnf: func() *config.Config {
cnf := config.Default()
qemu := cnf.Provider.QEMU
cnf.Provider = config.ProviderConfig{}
cnf.Provider.QEMU = qemu
return cnf
}(),
provider: cloudprovider.QEMU,
wantOutput: true,
wantErr: true,
},
"QEMU config with all required fields is valid": {
cnf: defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.QEMU),
provider: cloudprovider.QEMU,
},
"config with an error": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Version = "v0"
return cnf
}(),
wantOutput: true,
wantErr: true,
},
"config without provider is not ok": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider = config.ProviderConfig{}
return cnf
}(),
wantErr: true,
},
"config without required provider": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider.Azure = nil
return cnf
}(),
provider: cloudprovider.Azure,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
out := &bytes.Buffer{}
err := validateConfig(out, tc.cnf)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantOutput, out.Len() > 0)
})
}
}

View file

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
"github.com/edgelesssys/constellation/v2/internal/attestation"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -66,16 +67,16 @@ func recover(
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
provider := config.GetProvider()
provider := conf.GetProvider()
if provider == cloudprovider.Azure {
interval = 20 * time.Second // Azure LB takes a while to remove unhealthy instances
}
validator, err := cloudcmd.NewValidator(provider, config)
validator, err := cloudcmd.NewValidator(provider, conf)
if err != nil {
return err
}

View file

@ -10,6 +10,7 @@ import (
"context"
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
@ -43,17 +44,17 @@ func upgradeExecute(cmd *cobra.Command, upgrader cloudUpgrader, fileHandler file
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, configPath)
conf, err := config.New(fileHandler, configPath)
if err != nil {
return err
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
// TODO: validate upgrade config? Should be basic things like checking image is not an empty string
// More sophisticated validation, like making sure we don't downgrade the cluster, should be done by `constellation upgrade plan`
return upgrader.Upgrade(cmd.Context(), config.Upgrade.Image, config.Upgrade.Measurements)
return upgrader.Upgrade(cmd.Context(), conf.Upgrade.Image, conf.Upgrade.Measurements)
}
type cloudUpgrader interface {
Upgrade(ctx context.Context, image string, measurements map[uint32][]byte) error
Upgrade(ctx context.Context, image string, measurements measurements.M) error
}

View file

@ -11,6 +11,8 @@ import (
"errors"
"testing"
"github.com/edgelesssys/constellation/v2/internal/attestation/measurements"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -41,7 +43,8 @@ func TestUpgradeExecute(t *testing.T) {
cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
handler := file.NewHandler(afero.NewMemMapFs())
require.NoError(handler.WriteYAML(constants.ConfigFilename, config.Default()))
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg))
err := upgradeExecute(cmd, tc.upgrader, handler)
if tc.wantErr {
@ -57,6 +60,6 @@ type stubUpgrader struct {
err error
}
func (u stubUpgrader) Upgrade(context.Context, string, map[uint32][]byte) error {
func (u stubUpgrader) Upgrade(context.Context, string, measurements.M) error {
return u.err
}

View file

@ -73,13 +73,13 @@ func runUpgradePlan(cmd *cobra.Command, args []string) error {
func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
fileHandler file.Handler, client *http.Client, rekor rekorVerifier, flags upgradePlanFlags,
) error {
config, err := config.FromFile(fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return err
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
// get current image version of the cluster
csp := config.GetProvider()
csp := conf.GetProvider()
version, err := getCurrentImageVersion(cmd.Context(), planner, csp)
if err != nil {
@ -108,7 +108,7 @@ func upgradePlan(cmd *cobra.Command, planner upgradePlanner,
return upgradePlanInteractive(
&nopWriteCloser{cmd.OutOrStdout()},
io.NopCloser(cmd.InOrStdin()),
flags.configPath, config, fileHandler,
flags.configPath, conf, fileHandler,
compatibleImages,
)
}

View file

@ -451,8 +451,8 @@ func TestUpgradePlan(t *testing.T) {
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cfg := config.Default()
cfg.RemoveProviderExcept(tc.csp)
cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.csp)
require.NoError(fileHandler.WriteYAML(tc.flags.configPath, cfg))
cmd := newUpgradePlanCmd()

View file

@ -19,6 +19,7 @@ import (
"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
"github.com/edgelesssys/constellation/v2/internal/atls"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/crypto"
"github.com/edgelesssys/constellation/v2/internal/file"
@ -59,13 +60,13 @@ func verify(cmd *cobra.Command, fileHandler file.Handler, verifyClient verifyCli
return err
}
config, err := readConfig(cmd.ErrOrStderr(), fileHandler, flags.configPath)
conf, err := config.New(fileHandler, flags.configPath)
if err != nil {
return fmt.Errorf("reading and validating config: %w", err)
return displayConfigValidationErrors(cmd.ErrOrStderr(), err)
}
provider := config.GetProvider()
validators, err := cloudcmd.NewValidator(provider, config)
provider := conf.GetProvider()
validators, err := cloudcmd.NewValidator(provider, conf)
if err != nil {
return err
}