/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package cmd import ( "bytes" "context" "fmt" "strings" "testing" "time" "github.com/edgelesssys/constellation/v2/bootstrapper/initproto" "github.com/edgelesssys/constellation/v2/cli/internal/cmd/pathprefix" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/gcpshared" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constellation" "github.com/edgelesssys/constellation/v2/internal/constellation/helm" "github.com/edgelesssys/constellation/v2/internal/constellation/state" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/kms/uri" "github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/semver" "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" k8serrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/client-go/tools/clientcmd" k8sclientapi "k8s.io/client-go/tools/clientcmd/api" ) func TestInitArgumentValidation(t *testing.T) { assert := assert.New(t) cmd := NewInitCmd() assert.NoError(cmd.ValidateArgs(nil)) assert.Error(cmd.ValidateArgs([]string{"something"})) assert.Error(cmd.ValidateArgs([]string{"sth", "sth"})) } // preInitStateFile returns a state file satisfying the pre-init state file // constraints. func preInitStateFile(csp cloudprovider.Provider) *state.State { s := defaultStateFile(csp) s.ClusterValues = state.ClusterValues{} return s } func TestInitialize(t *testing.T) { respKubeconfig := k8sclientapi.Config{ Clusters: map[string]*k8sclientapi.Cluster{ "cluster": { Server: "https://192.0.2.1:6443", }, }, } respKubeconfigBytes, err := clientcmd.Write(respKubeconfig) require.NoError(t, err) gcpServiceAccKey := &gcpshared.ServiceAccountKey{ Type: "service_account", ProjectID: "project_id", PrivateKeyID: "key_id", PrivateKey: "key", ClientEmail: "client_email", ClientID: "client_id", AuthURI: "auth_uri", TokenURI: "token_uri", AuthProviderX509CertURL: "cert", ClientX509CertURL: "client_cert", } testInitOutput := constellation.InitOutput{ Kubeconfig: respKubeconfigBytes, OwnerID: "ownerID", ClusterID: "clusterID", } serviceAccPath := "/test/service-account.json" testCases := map[string]struct { provider cloudprovider.Provider stateFile *state.State configMutator func(*config.Config) serviceAccKey *gcpshared.ServiceAccountKey initOutput constellation.InitOutput initErr error retriable bool masterSecretShouldExist bool wantErr bool }{ "initialize some gcp instances": { provider: cloudprovider.GCP, stateFile: preInitStateFile(cloudprovider.GCP), configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, serviceAccKey: gcpServiceAccKey, initOutput: testInitOutput, }, "initialize some azure instances": { provider: cloudprovider.Azure, stateFile: preInitStateFile(cloudprovider.Azure), initOutput: testInitOutput, }, "initialize some qemu instances": { provider: cloudprovider.QEMU, stateFile: preInitStateFile(cloudprovider.QEMU), initOutput: testInitOutput, }, "non retriable error": { provider: cloudprovider.QEMU, stateFile: preInitStateFile(cloudprovider.QEMU), initErr: &constellation.NonRetriableInitError{Err: assert.AnError}, retriable: false, masterSecretShouldExist: true, wantErr: true, }, "non retriable error with failed log collection": { provider: cloudprovider.QEMU, stateFile: preInitStateFile(cloudprovider.QEMU), initErr: &constellation.NonRetriableInitError{Err: assert.AnError, LogCollectionErr: assert.AnError}, retriable: false, masterSecretShouldExist: true, wantErr: true, }, "invalid state file": { provider: cloudprovider.GCP, stateFile: &state.State{Version: "invalid"}, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, serviceAccKey: gcpServiceAccKey, initOutput: testInitOutput, retriable: true, wantErr: true, }, "empty state file": { provider: cloudprovider.GCP, stateFile: &state.State{}, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, serviceAccKey: gcpServiceAccKey, initOutput: testInitOutput, retriable: true, wantErr: true, }, "no state file": { provider: cloudprovider.GCP, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, serviceAccKey: gcpServiceAccKey, initOutput: testInitOutput, retriable: true, wantErr: true, }, "init call fails": { provider: cloudprovider.GCP, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, stateFile: preInitStateFile(cloudprovider.GCP), serviceAccKey: gcpServiceAccKey, initErr: &constellation.NonRetriableInitError{Err: assert.AnError}, retriable: false, masterSecretShouldExist: true, wantErr: true, }, "k8s version without v works": { provider: cloudprovider.Azure, stateFile: preInitStateFile(cloudprovider.Azure), initOutput: testInitOutput, configMutator: func(c *config.Config) { res, err := versions.NewValidK8sVersion(strings.TrimPrefix(string(versions.Default), "v"), true) require.NoError(t, err) c.KubernetesVersion = res }, }, "outdated k8s patch version doesn't work": { provider: cloudprovider.Azure, stateFile: preInitStateFile(cloudprovider.Azure), initOutput: testInitOutput, configMutator: func(c *config.Config) { v, err := semver.New(versions.SupportedK8sVersions()[0]) require.NoError(t, err) outdatedPatchVer := semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String() c.KubernetesVersion = versions.ValidK8sVersion(outdatedPatchVer) }, wantErr: true, retriable: true, // doesn't need to show retriable error message }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) // Command cmd := NewInitCmd() var out bytes.Buffer cmd.SetOut(&out) var errOut bytes.Buffer cmd.SetErr(&errOut) // File system preparation fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) config := defaultConfigWithExpectedMeasurements(t, config.Default(), tc.provider) if tc.configMutator != nil { tc.configMutator(config) } require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config, file.OptNone)) if tc.stateFile != nil { require.NoError(tc.stateFile.WriteToFile(fileHandler, constants.StateFilename)) } if tc.serviceAccKey != nil { require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone)) } ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() cmd.SetContext(ctx) i := &applyCmd{ fileHandler: fileHandler, flags: applyFlags{ rootFlags: rootFlags{force: true}, skipPhases: newPhases(skipInfrastructurePhase), }, log: logger.NewTest(t), spinner: &nopSpinner{}, merger: &stubMerger{}, applier: &stubConstellApplier{ masterSecret: uri.MasterSecret{ Key: bytes.Repeat([]byte{0x01}, 32), Salt: bytes.Repeat([]byte{0x02}, 32), }, measurementSalt: bytes.Repeat([]byte{0x03}, 32), initErr: tc.initErr, initOutput: tc.initOutput, stubKubernetesUpgrader: &stubKubernetesUpgrader{ // On init, no attestation config exists yet getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""), }, helmApplier: &stubHelmApplier{}, }, } err := i.apply(cmd, stubAttestationFetcher{}, "test") if tc.wantErr { assert.Error(err) fmt.Println(err) if !tc.retriable { assert.Contains(errOut.String(), "This error is not recoverable") } else { assert.Empty(errOut.String()) } if !tc.masterSecretShouldExist { _, err = fileHandler.Stat(constants.MasterSecretFilename) assert.Error(err) } return } require.NoError(err) // assert.Contains(out.String(), base64.StdEncoding.EncodeToString([]byte("ownerID"))) assert.Contains(out.String(), "clusterID") var secret uri.MasterSecret assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret)) assert.NotEmpty(secret.Key) assert.NotEmpty(secret.Salt) }) } } type stubHelmApplier struct { err error } func (s stubHelmApplier) PrepareHelmCharts( _ helm.Options, _ *state.State, _ string, _ uri.MasterSecret, _ *config.OpenStackConfig, ) (helm.Applier, bool, error) { return stubRunner{}, false, s.err } type stubRunner struct { applyErr error saveChartsErr error } func (s stubRunner) Apply(_ context.Context) error { return s.applyErr } func (s stubRunner) SaveCharts(_ string, _ file.Handler) error { return s.saveChartsErr } func TestWriteOutput(t *testing.T) { assert := assert.New(t) require := require.New(t) clusterEndpoint := "cluster-endpoint" expectedKubeconfig := k8sclientapi.Config{ Clusters: map[string]*k8sclientapi.Cluster{ "cluster": { Server: fmt.Sprintf("https://%s:6443", clusterEndpoint), }, }, } expectedKubeconfigBytes, err := clientcmd.Write(expectedKubeconfig) require.NoError(err) respKubeconfig := k8sclientapi.Config{ Clusters: map[string]*k8sclientapi.Cluster{ "cluster": { Server: "https://cluster-endpoint:6443", }, }, } respKubeconfigBytes, err := clientcmd.Write(respKubeconfig) require.NoError(err) resp := &initproto.InitResponse{ Kind: &initproto.InitResponse_InitSuccess{ InitSuccess: &initproto.InitSuccessResponse{ OwnerId: []byte("ownerID"), ClusterId: []byte("clusterID"), Kubeconfig: respKubeconfigBytes, }, }, } ownerID := string(resp.GetInitSuccess().GetOwnerId()) clusterID := string(resp.GetInitSuccess().GetClusterId()) initOutput := constellation.InitOutput{ OwnerID: "ownerID", ClusterID: "clusterID", Kubeconfig: respKubeconfigBytes, } measurementSalt := []byte{0x41} expectedStateFile := &state.State{ Version: state.Version1, ClusterValues: state.ClusterValues{ ClusterID: clusterID, OwnerID: ownerID, MeasurementSalt: measurementSalt, }, Infrastructure: state.Infrastructure{ APIServerCertSANs: []string{}, InitSecret: []byte{}, ClusterEndpoint: clusterEndpoint, }, } var out bytes.Buffer testFs := afero.NewMemMapFs() fileHandler := file.NewHandler(testFs) stateFile := state.New().SetInfrastructure(state.Infrastructure{ ClusterEndpoint: clusterEndpoint, }) i := &applyCmd{ fileHandler: fileHandler, spinner: &nopSpinner{}, merger: &stubMerger{}, log: logger.NewTest(t), applier: constellation.NewApplier(logger.NewTest(t), &nopSpinner{}, nil), } err = i.writeInitOutput(stateFile, initOutput, false, &out, measurementSalt) require.NoError(err) assert.Contains(out.String(), clusterID) assert.Contains(out.String(), constants.AdminConfFilename) afs := afero.Afero{Fs: testFs} adminConf, err := afs.ReadFile(constants.AdminConfFilename) assert.NoError(err) assert.Contains(string(adminConf), clusterEndpoint) assert.Equal(string(expectedKubeconfigBytes), string(adminConf)) fh := file.NewHandler(afs) readStateFile, err := state.ReadFromFile(fh, constants.StateFilename) assert.NoError(err) assert.Equal(expectedStateFile, readStateFile) out.Reset() require.NoError(afs.Remove(constants.AdminConfFilename)) // test custom workspace i.flags.pathPrefixer = pathprefix.New("/some/path") err = i.writeInitOutput(stateFile, initOutput, true, &out, measurementSalt) require.NoError(err) assert.Contains(out.String(), clusterID) assert.Contains(out.String(), i.flags.pathPrefixer.PrefixPrintablePath(constants.AdminConfFilename)) out.Reset() // File is written to current working dir, we simply pass the workspace for generating readable user output require.NoError(afs.Remove(constants.AdminConfFilename)) i.flags.pathPrefixer = pathprefix.PathPrefixer{} // test config merging err = i.writeInitOutput(stateFile, initOutput, true, &out, measurementSalt) require.NoError(err) assert.Contains(out.String(), clusterID) assert.Contains(out.String(), constants.AdminConfFilename) assert.Contains(out.String(), "Constellation kubeconfig merged with default config") assert.Contains(out.String(), "You can now connect to your cluster") out.Reset() require.NoError(afs.Remove(constants.AdminConfFilename)) // test config merging with env vars set i.merger = &stubMerger{envVar: "/some/path/to/kubeconfig"} err = i.writeInitOutput(stateFile, initOutput, true, &out, measurementSalt) require.NoError(err) assert.Contains(out.String(), clusterID) assert.Contains(out.String(), constants.AdminConfFilename) assert.Contains(out.String(), "Constellation kubeconfig merged with default config") assert.Contains(out.String(), "Warning: KUBECONFIG environment variable is set") } func TestGenerateMasterSecret(t *testing.T) { testCases := map[string]struct { createFileFunc func(handler file.Handler) error fs func() afero.Fs wantErr bool }{ "file already exists": { fs: afero.NewMemMapFs, createFileFunc: func(handler file.Handler) error { return handler.WriteJSON( constants.MasterSecretFilename, uri.MasterSecret{Key: []byte("constellation-master-secret"), Salt: []byte("constellation-32Byte-length-salt")}, file.OptNone, ) }, wantErr: true, }, "file does not exist": { createFileFunc: func(handler file.Handler) error { return nil }, fs: afero.NewMemMapFs, wantErr: false, }, "file not writeable": { createFileFunc: func(handler file.Handler) error { return nil }, fs: func() afero.Fs { return afero.NewReadOnlyFs(afero.NewMemMapFs()) }, wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { assert := assert.New(t) require := require.New(t) fileHandler := file.NewHandler(tc.fs()) require.NoError(tc.createFileFunc(fileHandler)) var out bytes.Buffer i := &applyCmd{ fileHandler: fileHandler, log: logger.NewTest(t), applier: constellation.NewApplier(logger.NewTest(t), &nopSpinner{}, nil), } secret, err := i.generateAndPersistMasterSecret(&out) if tc.wantErr { assert.Error(err) } else { assert.NoError(err) require.Contains(out.String(), constants.MasterSecretFilename) var masterSecret uri.MasterSecret require.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &masterSecret)) assert.Equal(masterSecret.Key, secret.Key) assert.Equal(masterSecret.Salt, secret.Salt) } }) } } type stubMerger struct { envVar string mergeErr error } func (m *stubMerger) mergeConfigs(string, file.Handler) error { return m.mergeErr } func (m *stubMerger) kubeconfigEnvVar() string { return m.envVar } func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, csp cloudprovider.Provider) *config.Config { t.Helper() conf.RemoveProviderAndAttestationExcept(csp) conf.Image = constants.BinaryVersion().String() conf.Name = "kubernetes" var zone, instanceType, diskType string switch csp { case cloudprovider.AWS: conf.Provider.AWS.Region = "test-region-2" conf.Provider.AWS.Zone = "test-zone-2c" conf.Provider.AWS.IAMProfileControlPlane = "test-iam-profile" conf.Provider.AWS.IAMProfileWorkerNodes = "test-iam-profile" conf.Attestation.AWSSEVSNP.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.AWSSEVSNP.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.AWSSEVSNP.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength) zone = "test-zone-2c" instanceType = "c6a.xlarge" diskType = "gp3" case cloudprovider.Azure: conf.Provider.Azure.SubscriptionID = "01234567-0123-0123-0123-0123456789ab" conf.Provider.Azure.TenantID = "01234567-0123-0123-0123-0123456789ab" conf.Provider.Azure.Location = "test-location" conf.Provider.Azure.UserAssignedIdentity = "test-identity" conf.Provider.Azure.ResourceGroup = "test-resource-group" conf.Attestation.AzureSEVSNP.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.AzureSEVSNP.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.AzureSEVSNP.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength) instanceType = "Standard_DC4as_v5" diskType = "StandardSSD_LRS" case cloudprovider.GCP: conf.Provider.GCP.Region = "test-region" conf.Provider.GCP.Project = "test-project" conf.Provider.GCP.Zone = "test-zone" conf.Provider.GCP.ServiceAccountKeyPath = "test-key-path" conf.Attestation.GCPSEVES.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.GCPSEVES.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.GCPSEVES.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength) zone = "europe-west3-b" instanceType = "n2d-standard-4" diskType = "pd-ssd" case cloudprovider.QEMU: conf.Attestation.QEMUVTPM.Measurements[4] = measurements.WithAllBytes(0x44, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.QEMUVTPM.Measurements[9] = measurements.WithAllBytes(0x11, measurements.Enforce, measurements.PCRMeasurementLength) conf.Attestation.QEMUVTPM.Measurements[12] = measurements.WithAllBytes(0xcc, measurements.Enforce, measurements.PCRMeasurementLength) } for groupName, group := range conf.NodeGroups { group.Zone = zone group.InstanceType = instanceType group.StateDiskType = diskType conf.NodeGroups[groupName] = group } return conf }