From d95ddd01d3928fae163778c4353ff37b76420e1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Fri, 30 Jun 2023 16:46:05 +0200 Subject: [PATCH] helm: fix upgrade command unintentionally skipping all service upgrades (#1992) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix usage of errors.As in upgrade command implementation * Use struct pointers when working with custom errors --------- Signed-off-by: Daniel Weiße --- cli/internal/cmd/upgradeapply.go | 9 +++------ cli/internal/helm/client.go | 16 +--------------- cli/internal/helm/client_test.go | 2 +- cli/internal/kubernetes/upgrade.go | 2 +- cli/internal/kubernetes/upgrade_test.go | 2 +- internal/compatibility/compatibility.go | 4 ++-- internal/config/config.go | 10 ++++++---- internal/config/config_test.go | 2 +- internal/staticupload/staticupload.go | 17 +++++++++++------ internal/staticupload/staticupload_test.go | 18 +++++++++++------- 10 files changed, 38 insertions(+), 44 deletions(-) diff --git a/cli/internal/cmd/upgradeapply.go b/cli/internal/cmd/upgradeapply.go index 029bdecb1..c370015b0 100644 --- a/cli/internal/cmd/upgradeapply.go +++ b/cli/internal/cmd/upgradeapply.go @@ -116,13 +116,10 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, fileHandler file.Hand } if conf.GetProvider() == cloudprovider.Azure || conf.GetProvider() == cloudprovider.GCP || conf.GetProvider() == cloudprovider.AWS { + var upgradeErr *compatibility.InvalidUpgradeError err = u.handleServiceUpgrade(cmd, conf, flags) - upgradeErr := &compatibility.InvalidUpgradeError{} - noUpgradeRequiredError := &helm.NoUpgradeRequiredError{} switch { - case errors.As(err, upgradeErr): - cmd.PrintErrln(err) - case errors.As(err, noUpgradeRequiredError): + case errors.As(err, &upgradeErr): cmd.PrintErrln(err) case err != nil: return fmt.Errorf("upgrading services: %w", err) @@ -132,7 +129,7 @@ func (u *upgradeApplyCmd) upgradeApply(cmd *cobra.Command, fileHandler file.Hand switch { case errors.Is(err, kubernetes.ErrInProgress): cmd.PrintErrln("Skipping image and Kubernetes upgrades. Another upgrade is in progress.") - case errors.As(err, upgradeErr): + case errors.As(err, &upgradeErr): cmd.PrintErrln(err) case err != nil: return fmt.Errorf("upgrading NodeVersion: %w", err) diff --git a/cli/internal/helm/client.go b/cli/internal/helm/client.go index d1380101e..84dff51dc 100644 --- a/cli/internal/helm/client.go +++ b/cli/internal/helm/client.go @@ -85,10 +85,6 @@ func (c *Client) shouldUpgrade(releaseName, newVersion string, force bool) error // This may break for cert-manager or cilium if we decide to upgrade more than one minor version at a time. // Leaving it as is since it is not clear to me what kind of sanity check we could do. - if currentVersion == newVersion { - return NoUpgradeRequiredError{} - } - if !force { if err := compatibility.IsValidUpgrade(currentVersion, newVersion); err != nil { return err @@ -105,20 +101,12 @@ func (c *Client) shouldUpgrade(releaseName, newVersion string, force bool) error return nil } -// NoUpgradeRequiredError is returned if the current version is the same as the target version. -type NoUpgradeRequiredError struct{} - -func (e NoUpgradeRequiredError) Error() string { - return "no upgrade required since current version is the same as the target version" -} - // Upgrade runs a helm-upgrade on all deployments that are managed via Helm. // If the CLI receives an interrupt signal it will cancel the context. // Canceling the context will prompt helm to abort and roll back the ongoing upgrade. func (c *Client) Upgrade(ctx context.Context, config *config.Config, timeout time.Duration, allowDestructive, force bool, upgradeID string) error { upgradeErrs := []error{} upgradeReleases := []*chart.Chart{} - invalidUpgrade := &compatibility.InvalidUpgradeError{} for _, info := range []chartInfo{ciliumInfo, certManagerInfo, constellationOperatorsInfo, constellationServicesInfo} { chart, err := loadChartsDir(helmFS, info.path) @@ -136,13 +124,11 @@ func (c *Client) Upgrade(ctx context.Context, config *config.Config, timeout tim upgradeVersion = chart.Metadata.Version } + var invalidUpgrade *compatibility.InvalidUpgradeError err = c.shouldUpgrade(info.releaseName, upgradeVersion, force) - noUpgradeRequired := &NoUpgradeRequiredError{} switch { case errors.As(err, &invalidUpgrade): upgradeErrs = append(upgradeErrs, fmt.Errorf("skipping %s upgrade: %w", info.releaseName, err)) - case errors.As(err, &noUpgradeRequired): - upgradeErrs = append(upgradeErrs, fmt.Errorf("skipping %s upgrade: %w", info.releaseName, err)) case err != nil: return fmt.Errorf("should upgrade %s: %w", info.releaseName, err) case err == nil: diff --git a/cli/internal/helm/client_test.go b/cli/internal/helm/client_test.go index cb38d81e3..f04808347 100644 --- a/cli/internal/helm/client_test.go +++ b/cli/internal/helm/client_test.go @@ -32,7 +32,7 @@ func TestShouldUpgrade(t *testing.T) { "not a valid upgrade": { version: "1.0.0", assertCorrectError: func(t *testing.T, err error) bool { - target := &compatibility.InvalidUpgradeError{} + var target *compatibility.InvalidUpgradeError return assert.ErrorAs(t, err, &target) }, wantError: true, diff --git a/cli/internal/kubernetes/upgrade.go b/cli/internal/kubernetes/upgrade.go index e893a9fbc..ec68165c3 100644 --- a/cli/internal/kubernetes/upgrade.go +++ b/cli/internal/kubernetes/upgrade.go @@ -217,7 +217,7 @@ func (u *Upgrader) UpgradeNodeVersion(ctx context.Context, conf *config.Config, } upgradeErrs := []error{} - upgradeErr := &compatibility.InvalidUpgradeError{} + var upgradeErr *compatibility.InvalidUpgradeError err = u.updateImage(&nodeVersion, imageReference, imageVersion.Version, force) switch { diff --git a/cli/internal/kubernetes/upgrade_test.go b/cli/internal/kubernetes/upgrade_test.go index 842ddcbf9..31d82dba8 100644 --- a/cli/internal/kubernetes/upgrade_test.go +++ b/cli/internal/kubernetes/upgrade_test.go @@ -187,7 +187,7 @@ func TestUpgradeNodeVersion(t *testing.T) { wantUpdate: true, wantErr: true, assertCorrectError: func(t *testing.T, err error) bool { - upgradeErr := &compatibility.InvalidUpgradeError{} + var upgradeErr *compatibility.InvalidUpgradeError return assert.ErrorAs(t, err, &upgradeErr) }, }, diff --git a/internal/compatibility/compatibility.go b/internal/compatibility/compatibility.go index a22e4debe..e089a7ddd 100644 --- a/internal/compatibility/compatibility.go +++ b/internal/compatibility/compatibility.go @@ -41,12 +41,12 @@ func NewInvalidUpgradeError(from string, to string, innerErr error) *InvalidUpgr } // Unwrap returns the inner error, which is nil in this case. -func (e InvalidUpgradeError) Unwrap() error { +func (e *InvalidUpgradeError) Unwrap() error { return e.innerErr } // Error returns the String representation of this error. -func (e InvalidUpgradeError) Error() string { +func (e *InvalidUpgradeError) Error() string { return fmt.Sprintf("upgrading from %s to %s is not a valid upgrade: %s", e.from, e.to, e.innerErr) } diff --git a/internal/config/config.go b/internal/config/config.go index bbcde1d5d..553ee0143 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -52,6 +52,8 @@ const ( Version3 = "v3" defaultName = "constell" + + appRegistrationErrStr = "Azure app registrations are not supported since v2.9. Migrate to using a user assigned managed identity by following the migration guide: https://docs.edgeless.systems/constellation/reference/migration.\nPlease remove it from your config and from the Kubernetes secret in your running cluster. Ensure that the UAMI has all required permissions." ) // Config defines configuration used by CLI. @@ -395,7 +397,7 @@ func fromFile(fileHandler file.Handler, name string) (*Config, error) { return nil, fmt.Errorf("unable to find %s - use `constellation config generate` to generate it first", name) } if isAppClientIDError(err) { - return nil, UnsupportedAppRegistrationError{} + return nil, &UnsupportedAppRegistrationError{} } return nil, fmt.Errorf("could not load config from file %s: %w", name, err) } @@ -417,8 +419,8 @@ func isAppClientIDError(err error) bool { // UnsupportedAppRegistrationError is returned when the config contains configuration related to now unsupported app registrations. type UnsupportedAppRegistrationError struct{} -func (e UnsupportedAppRegistrationError) Error() string { - return "Azure app registrations are not supported since v2.9. Migrate to using a user assigned managed identity by following the migration guide: https://docs.edgeless.systems/constellation/reference/migration.\nPlease remove it from your config and from the Kubernetes secret in your running cluster. Ensure that the UAMI has all required permissions." +func (e *UnsupportedAppRegistrationError) Error() string { + return appRegistrationErrStr } // New creates a new config by: @@ -442,7 +444,7 @@ func New(fileHandler file.Handler, name string, fetcher attestationconfigapi.Fet // Read secrets from env-vars. clientSecretValue := os.Getenv(constants.EnvVarAzureClientSecretValue) if clientSecretValue != "" && c.Provider.Azure != nil { - fmt.Fprintf(os.Stderr, "WARNING: the environment variable %s is no longer used %s", constants.EnvVarAzureClientSecretValue, UnsupportedAppRegistrationError{}.Error()) + fmt.Fprintf(os.Stderr, "WARNING: the environment variable %s is no longer used %s", constants.EnvVarAzureClientSecretValue, appRegistrationErrStr) } openstackPassword := os.Getenv(constants.EnvVarOpenStackPassword) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 4c07fd747..dfce5365d 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -159,7 +159,7 @@ func TestReadConfigFile(t *testing.T) { return m }(), configName: constants.ConfigFilename, - wantedErrType: UnsupportedAppRegistrationError{}, + wantedErrType: &UnsupportedAppRegistrationError{}, }, } for name, tc := range testCases { diff --git a/internal/staticupload/staticupload.go b/internal/staticupload/staticupload.go index 3276da613..ea5f09931 100644 --- a/internal/staticupload/staticupload.go +++ b/internal/staticupload/staticupload.go @@ -83,13 +83,18 @@ type InvalidationError struct { inner error } +// NewInvalidationError creates a new InvalidationError. +func NewInvalidationError(err error) *InvalidationError { + return &InvalidationError{inner: err} +} + // Error returns the error message. -func (e InvalidationError) Error() string { +func (e *InvalidationError) Error() string { return fmt.Sprintf("invalidating CDN cache: %v", e.inner) } // Unwrap returns the inner error. -func (e InvalidationError) Unwrap() error { +func (e *InvalidationError) Unwrap() error { return e.inner } @@ -172,7 +177,7 @@ func (c *Client) invalidate(ctx context.Context, keys []string) error { // https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/Invalidation.html#InvalidationLimits func (c *Client) invalidateCacheForKeys(ctx context.Context, keys []string) (string, error) { if len(keys) > 3000 { - return "", InvalidationError{inner: fmt.Errorf("too many keys to invalidate: %d", len(keys))} + return "", NewInvalidationError(fmt.Errorf("too many keys to invalidate: %d", len(keys))) } for i, key := range keys { @@ -193,10 +198,10 @@ func (c *Client) invalidateCacheForKeys(ctx context.Context, keys []string) (str } invalidation, err := c.cdnClient.CreateInvalidation(ctx, in) if err != nil { - return "", InvalidationError{inner: fmt.Errorf("creating invalidation: %w", err)} + return "", NewInvalidationError(fmt.Errorf("creating invalidation: %w", err)) } if invalidation.Invalidation == nil || invalidation.Invalidation.Id == nil { - return "", InvalidationError{inner: fmt.Errorf("invalidation ID is not set")} + return "", NewInvalidationError(fmt.Errorf("invalidation ID is not set")) } return *invalidation.Invalidation.Id, nil } @@ -214,7 +219,7 @@ func (c *Client) waitForInvalidations(ctx context.Context) error { Id: &invalidationID, } if err := waiter.Wait(ctx, waitIn, c.cacheInvalidationWaitTimeout); err != nil { - return InvalidationError{inner: fmt.Errorf("waiting for invalidation to complete: %w", err)} + return NewInvalidationError(fmt.Errorf("waiting for invalidation to complete: %w", err)) } } c.invalidationIDs = nil diff --git a/internal/staticupload/staticupload_test.go b/internal/staticupload/staticupload_test.go index 907f98be9..8cccb6a86 100644 --- a/internal/staticupload/staticupload_test.go +++ b/internal/staticupload/staticupload_test.go @@ -108,12 +108,13 @@ func TestUpload(t *testing.T) { } _, err := client.Upload(context.Background(), tc.in) + var invalidationErr *InvalidationError if tc.wantCacheInvalidationErr { - assert.ErrorAs(err, &InvalidationError{}) + assert.ErrorAs(err, &invalidationErr) return } if tc.wantErr { - assert.False(errors.As(err, &InvalidationError{})) + assert.False(errors.As(err, &invalidationErr)) assert.Error(err) return } @@ -218,12 +219,13 @@ func TestDeleteObject(t *testing.T) { } _, err := client.DeleteObject(context.Background(), newObjectInput(tc.nilInput, tc.nilKey)) + var invalidationErr *InvalidationError if tc.wantCacheInvalidationErr { - assert.ErrorAs(err, &InvalidationError{}) + assert.ErrorAs(err, &invalidationErr) return } if tc.wantErr { - assert.False(errors.As(err, &InvalidationError{})) + assert.False(errors.As(err, &invalidationErr)) assert.Error(err) return } @@ -255,12 +257,13 @@ func TestDeleteObject(t *testing.T) { } _, err := client.DeleteObjects(context.Background(), newObjectsInput(tc.nilInput, tc.nilKey)) + var invalidationErr *InvalidationError if tc.wantCacheInvalidationErr { - assert.ErrorAs(err, &InvalidationError{}) + assert.ErrorAs(err, &invalidationErr) return } if tc.wantErr { - assert.False(errors.As(err, &InvalidationError{})) + assert.False(errors.As(err, &invalidationErr)) assert.Error(err) return } @@ -396,7 +399,8 @@ func TestFlush(t *testing.T) { err := client.Flush(context.Background()) if tc.wantCacheInvalidationErr { - assert.ErrorAs(err, &InvalidationError{}) + var invalidationErr *InvalidationError + assert.ErrorAs(err, &invalidationErr) return }