From bb2b5e1bd17d233496fbda2195ba8df3910f8b13 Mon Sep 17 00:00:00 2001 From: Otto Bittner Date: Tue, 14 Mar 2023 18:34:58 +0100 Subject: [PATCH] cli: allow users to only upgrade measurements In case only measurements are upgrades a confirmation is required. Alternatively, the `yes` flag can be used. --- cli/internal/cloudcmd/upgrade.go | 139 +++++++++++++------------- cli/internal/cloudcmd/upgrade_test.go | 6 +- cli/internal/cmd/BUILD.bazel | 2 + cli/internal/cmd/upgradeapply.go | 39 +++++++- cli/internal/cmd/upgradeapply_test.go | 15 ++- 5 files changed, 128 insertions(+), 73 deletions(-) diff --git a/cli/internal/cloudcmd/upgrade.go b/cli/internal/cloudcmd/upgrade.go index 221b1f9c7..6566df901 100644 --- a/cli/internal/cloudcmd/upgrade.go +++ b/cli/internal/cloudcmd/upgrade.go @@ -139,7 +139,7 @@ func (u *Upgrader) UpgradeNodeVersion(ctx context.Context, conf *config.Config) return errors.Join(upgradeErrs...) } - if err := u.updateMeasurements(ctx, conf.GetMeasurements()); err != nil { + if err := u.UpdateMeasurements(ctx, conf.GetMeasurements()); err != nil { return fmt.Errorf("updating measurements: %w", err) } @@ -161,6 +161,76 @@ func (u *Upgrader) UpgradeNodeVersion(ctx context.Context, conf *config.Config) return errors.Join(upgradeErrs...) } +// KubernetesVersion returns the version of Kubernetes the Constellation is currently running on. +func (u *Upgrader) KubernetesVersion() (string, error) { + return u.stableInterface.kubernetesVersion() +} + +// CurrentImage returns the currently used image version of the cluster. +func (u *Upgrader) CurrentImage(ctx context.Context) (string, error) { + nodeVersion, err := u.getConstellationVersion(ctx) + if err != nil { + return "", fmt.Errorf("getting constellation-version: %w", err) + } + return nodeVersion.Spec.ImageVersion, nil +} + +// CurrentKubernetesVersion returns the currently used Kubernetes version. +func (u *Upgrader) CurrentKubernetesVersion(ctx context.Context) (string, error) { + nodeVersion, err := u.getConstellationVersion(ctx) + if err != nil { + return "", fmt.Errorf("getting constellation-version: %w", err) + } + return nodeVersion.Spec.KubernetesClusterVersion, nil +} + +// UpdateMeasurements fetches the cluster's measurements, compares them to a set of new measurements +// and updates the cluster's measurements if they are different from the new ones. +func (u *Upgrader) UpdateMeasurements(ctx context.Context, newMeasurements measurements.M) error { + currentMeasurements, existingConf, err := u.GetClusterMeasurements(ctx) + if err != nil { + return fmt.Errorf("getting cluster measurements: %w", err) + } + if currentMeasurements.EqualTo(newMeasurements) { + fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade") + return nil + } + + // backup of previous measurements + existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename] + + measurementsJSON, err := json.Marshal(newMeasurements) + if err != nil { + return fmt.Errorf("marshaling measurements: %w", err) + } + existingConf.Data[constants.MeasurementsFilename] = string(measurementsJSON) + u.log.Debugf("Triggering measurements config map update now") + if _, err = u.stableInterface.updateConfigMap(ctx, existingConf); err != nil { + return fmt.Errorf("setting new measurements: %w", err) + } + + fmt.Fprintln(u.outWriter, "Successfully updated the cluster's expected measurements") + return nil +} + +// GetClusterMeasurements fetches the join-config configmap from the cluster, extracts the measurements +// and returns both the full configmap and the measurements. +func (u *Upgrader) GetClusterMeasurements(ctx context.Context) (measurements.M, *corev1.ConfigMap, error) { + existingConf, err := u.stableInterface.getCurrentConfigMap(ctx, constants.JoinConfigMap) + if err != nil { + return measurements.M{}, &corev1.ConfigMap{}, fmt.Errorf("retrieving current measurements: %w", err) + } + if _, ok := existingConf.Data[constants.MeasurementsFilename]; !ok { + return measurements.M{}, &corev1.ConfigMap{}, errors.New("measurements missing from join-config") + } + var currentMeasurements measurements.M + if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), ¤tMeasurements); err != nil { + return measurements.M{}, &corev1.ConfigMap{}, fmt.Errorf("retrieving current measurements: %w", err) + } + + return currentMeasurements, existingConf, nil +} + // applyComponentsCM applies the k8s components ConfigMap to the cluster. func (u *Upgrader) applyComponentsCM(ctx context.Context, components *corev1.ConfigMap) error { _, err := u.stableInterface.createConfigMap(ctx, components) @@ -236,29 +306,6 @@ func (u *Upgrader) updateK8s(nodeVersion *updatev1alpha1.NodeVersion, newCluster return &configMap, nil } -// KubernetesVersion returns the version of Kubernetes the Constellation is currently running on. -func (u *Upgrader) KubernetesVersion() (string, error) { - return u.stableInterface.kubernetesVersion() -} - -// CurrentImage returns the currently used image version of the cluster. -func (u *Upgrader) CurrentImage(ctx context.Context) (string, error) { - nodeVersion, err := u.getConstellationVersion(ctx) - if err != nil { - return "", fmt.Errorf("getting constellation-version: %w", err) - } - return nodeVersion.Spec.ImageVersion, nil -} - -// CurrentKubernetesVersion returns the currently used Kubernetes version. -func (u *Upgrader) CurrentKubernetesVersion(ctx context.Context) (string, error) { - nodeVersion, err := u.getConstellationVersion(ctx) - if err != nil { - return "", fmt.Errorf("getting constellation-version: %w", err) - } - return nodeVersion.Spec.KubernetesClusterVersion, nil -} - // getFromConstellationVersion queries the constellation-version object for a given field. func (u *Upgrader) getConstellationVersion(ctx context.Context) (updatev1alpha1.NodeVersion, error) { raw, err := u.dynamicInterface.getCurrent(ctx, "constellation-version") @@ -273,50 +320,6 @@ func (u *Upgrader) getConstellationVersion(ctx context.Context) (updatev1alpha1. return nodeVersion, nil } -func (u *Upgrader) updateMeasurements(ctx context.Context, newMeasurements measurements.M) error { - existingConf, err := u.stableInterface.getCurrentConfigMap(ctx, constants.JoinConfigMap) - if err != nil { - return fmt.Errorf("retrieving current measurements: %w", err) - } - if _, ok := existingConf.Data[constants.MeasurementsFilename]; !ok { - return errors.New("measurements missing from join-config") - } - var currentMeasurements measurements.M - if err := json.Unmarshal([]byte(existingConf.Data[constants.MeasurementsFilename]), ¤tMeasurements); err != nil { - return fmt.Errorf("retrieving current measurements: %w", err) - } - if currentMeasurements.EqualTo(newMeasurements) { - fmt.Fprintln(u.outWriter, "Cluster is already using the chosen measurements, skipping measurements upgrade") - return nil - } - - // don't allow potential security downgrades by setting the warnOnly flag to true - for k, newM := range newMeasurements { - if currentM, ok := currentMeasurements[k]; ok && - currentM.ValidationOpt != measurements.WarnOnly && - newM.ValidationOpt == measurements.WarnOnly { - return fmt.Errorf("setting enforced measurement %d to warn only: not allowed", k) - } - } - - // backup of previous measurements - existingConf.Data["oldMeasurements"] = existingConf.Data[constants.MeasurementsFilename] - - measurementsJSON, err := json.Marshal(newMeasurements) - if err != nil { - return fmt.Errorf("marshaling measurements: %w", err) - } - existingConf.Data[constants.MeasurementsFilename] = string(measurementsJSON) - u.log.Debugf("Triggering measurements config map update now") - _, err = u.stableInterface.updateConfigMap(ctx, existingConf) - if err != nil { - return fmt.Errorf("setting new measurements: %w", err) - } - - fmt.Fprintln(u.outWriter, "Successfully updated the cluster's expected measurements") - return nil -} - // upgradeInProgress checks if an upgrade is in progress. // Returns true with errors as it's the "safer" response. If caller does not check err they at least won't update the cluster. func upgradeInProgress(nodeVersion updatev1alpha1.NodeVersion) bool { diff --git a/cli/internal/cloudcmd/upgrade_test.go b/cli/internal/cloudcmd/upgrade_test.go index bd6c19d60..7e136d10b 100644 --- a/cli/internal/cloudcmd/upgrade_test.go +++ b/cli/internal/cloudcmd/upgrade_test.go @@ -224,7 +224,7 @@ func TestUpdateMeasurements(t *testing.T) { 0: measurements.WithAllBytes(0xAA, measurements.Enforce), }, }, - "trying to set warnOnly to true results in error": { + "setting warnOnly to true is allowed": { updater: &stubStableClient{ configMaps: map[string]*corev1.ConfigMap{ constants.JoinConfigMap: newJoinConfigMap(`{"0":{"expected":"AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA","warnOnly":false}}`), @@ -233,7 +233,7 @@ func TestUpdateMeasurements(t *testing.T) { newMeasurements: measurements.M{ 0: measurements.WithAllBytes(0xAA, measurements.WarnOnly), }, - wantErr: true, + wantUpdate: true, }, "setting warnOnly to false is allowed": { updater: &stubStableClient{ @@ -271,7 +271,7 @@ func TestUpdateMeasurements(t *testing.T) { log: logger.NewTest(t), } - err := upgrader.updateMeasurements(context.Background(), tc.newMeasurements) + err := upgrader.UpdateMeasurements(context.Background(), tc.newMeasurements) if tc.wantErr { assert.Error(err) return diff --git a/cli/internal/cmd/BUILD.bazel b/cli/internal/cmd/BUILD.bazel index 0e5e6d865..c1a4822a6 100644 --- a/cli/internal/cmd/BUILD.bazel +++ b/cli/internal/cmd/BUILD.bazel @@ -72,6 +72,7 @@ go_library( "@com_github_siderolabs_talos_pkg_machinery//config/encoder", "@com_github_spf13_afero//:afero", "@com_github_spf13_cobra//:cobra", + "@io_k8s_api//core/v1:core", "@io_k8s_apimachinery//pkg/runtime", "@io_k8s_client_go//tools/clientcmd", "@io_k8s_client_go//tools/clientcmd/api/latest", @@ -135,6 +136,7 @@ go_test( "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", "@in_gopkg_yaml_v3//:yaml_v3", + "@io_k8s_api//core/v1:core", "@org_golang_google_grpc//:go_default_library", "@org_golang_google_grpc//codes", "@org_golang_google_grpc//status", diff --git a/cli/internal/cmd/upgradeapply.go b/cli/internal/cmd/upgradeapply.go index 8239be475..7c9d42aaf 100644 --- a/cli/internal/cmd/upgradeapply.go +++ b/cli/internal/cmd/upgradeapply.go @@ -14,11 +14,13 @@ import ( "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" "github.com/edgelesssys/constellation/v2/cli/internal/helm" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/compatibility" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/spf13/afero" "github.com/spf13/cobra" + corev1 "k8s.io/api/core/v1" ) func newUpgradeApplyCmd() *cobra.Command { @@ -31,7 +33,8 @@ func newUpgradeApplyCmd() *cobra.Command { } cmd.Flags().BoolP("yes", "y", false, "run upgrades without further confirmation\n"+ - "WARNING: might delete your resources in case you are using cert-manager in your cluster. Please read the docs.") + "WARNING: might delete your resources in case you are using cert-manager in your cluster. Please read the docs.\n"+ + "WARNING: might unintentionally overwrite measurements in the running cluster.") cmd.Flags().Duration("timeout", 3*time.Minute, "change helm upgrade timeout\n"+ "Might be useful for slow connections or big clusters.") if err := cmd.Flags().MarkHidden("timeout"); err != nil { @@ -96,6 +99,38 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, fileHandler file.Hand return fmt.Errorf("upgrading NodeVersion: %w", err) } + // If an image upgrade was just executed there won't be a diff. The function will return nil in that case. + if err := u.upgradeMeasurementsIfDiff(cmd, conf.GetMeasurements(), flags); err != nil { + return fmt.Errorf("upgrading measurements: %w", err) + } + + return nil +} + +// upgradeMeasurementsIfDiff checks if the locally configured measurements are different from the cluster's measurements. +// If so the function will ask the user to confirm (if --yes is not set) and upgrade the measurements only. +func (u *upgradeApplyCmd) upgradeMeasurementsIfDiff(cmd *cobra.Command, newMeasurements measurements.M, flags upgradeApplyFlags) error { + clusterMeasurements, _, err := u.upgrader.GetClusterMeasurements(cmd.Context()) + if err != nil { + return fmt.Errorf("getting cluster measurements: %w", err) + } + if clusterMeasurements.EqualTo(newMeasurements) { + return nil + } + + if !flags.yes { + ok, err := askToConfirm(cmd, "You are about to change your cluster's measurements. Are you sure you want to continue?") + if err != nil { + return fmt.Errorf("asking for confirmation: %w", err) + } + if !ok { + cmd.Println("Aborting upgrade.") + return nil + } + } + if err := u.upgrader.UpdateMeasurements(cmd.Context(), newMeasurements); err != nil { + return fmt.Errorf("updating measurements: %w", err) + } return nil } @@ -153,4 +188,6 @@ type upgradeApplyFlags struct { type cloudUpgrader interface { UpgradeNodeVersion(ctx context.Context, conf *config.Config) error UpgradeHelmServices(ctx context.Context, config *config.Config, timeout time.Duration, allowDestructive bool) error + UpdateMeasurements(ctx context.Context, newMeasurements measurements.M) error + GetClusterMeasurements(ctx context.Context) (measurements.M, *corev1.ConfigMap, error) } diff --git a/cli/internal/cmd/upgradeapply_test.go b/cli/internal/cmd/upgradeapply_test.go index 9ffd68996..67dbfb88a 100644 --- a/cli/internal/cmd/upgradeapply_test.go +++ b/cli/internal/cmd/upgradeapply_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd" + "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" @@ -21,6 +22,7 @@ import ( "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" ) func TestUpgradeApply(t *testing.T) { @@ -53,12 +55,15 @@ func TestUpgradeApply(t *testing.T) { cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually cmd.Flags().Bool("force", true, "") // register persistent flag manually + err := cmd.Flags().Set("yes", "true") + require.NoError(err) + handler := file.NewHandler(afero.NewMemMapFs()) cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure) require.NoError(handler.WriteYAML(constants.ConfigFilename, cfg)) upgrader := upgradeApplyCmd{upgrader: tc.upgrader, log: logger.NewTest(t)} - err := upgrader.upgradeApply(cmd, handler) + err = upgrader.upgradeApply(cmd, handler) if tc.wantErr { assert.Error(err) } else { @@ -80,3 +85,11 @@ func (u stubUpgrader) UpgradeNodeVersion(_ context.Context, _ *config.Config) er func (u stubUpgrader) UpgradeHelmServices(_ context.Context, _ *config.Config, _ time.Duration, _ bool) error { return u.helmErr } + +func (u stubUpgrader) UpdateMeasurements(_ context.Context, _ measurements.M) error { + return nil +} + +func (u stubUpgrader) GetClusterMeasurements(_ context.Context) (measurements.M, *corev1.ConfigMap, error) { + return measurements.M{}, &corev1.ConfigMap{}, nil +}