From 4d6a7fa7591636adbf2fc1fb9b0d551b4d4cf751 Mon Sep 17 00:00:00 2001 From: Moritz Sanft <58110325+msanft@users.noreply.github.com> Date: Fri, 1 Dec 2023 08:37:52 +0100 Subject: [PATCH] license: refactor license check to be agnostic of input (#2659) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * license: refactor license check to be agnostic of input * license: remove unused code * cli: only check license file in enterprise version Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> * bazel: fix enterprise CLI build * bazel: add keep directive * Update internal/constellation/apply.go Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> * license: check for return value --------- Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com> Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com> --- cli/internal/cmd/BUILD.bazel | 6 ++- cli/internal/cmd/apply.go | 25 +++++---- cli/internal/cmd/apply_test.go | 6 +++ cli/internal/cmd/create_test.go | 4 +- cli/internal/cmd/init_test.go | 12 +---- cli/internal/cmd/license_enterprise.go | 56 ++++++++++++++++++++ cli/internal/cmd/license_oss.go | 20 +++++++ cli/internal/cmd/upgradeapply_test.go | 3 +- internal/constellation/BUILD.bazel | 9 +++- internal/constellation/apply.go | 47 +++++++++++++++++ internal/license/BUILD.bazel | 5 -- internal/license/checker_enterprise.go | 36 ++----------- internal/license/checker_oss.go | 8 ++- internal/license/file.go | 16 ++---- internal/license/file_test.go | 72 ++++++++++---------------- 15 files changed, 201 insertions(+), 124 deletions(-) create mode 100644 cli/internal/cmd/license_enterprise.go create mode 100644 cli/internal/cmd/license_oss.go create mode 100644 internal/constellation/apply.go diff --git a/cli/internal/cmd/BUILD.bazel b/cli/internal/cmd/BUILD.bazel index 78491129d..e605f26a0 100644 --- a/cli/internal/cmd/BUILD.bazel +++ b/cli/internal/cmd/BUILD.bazel @@ -25,6 +25,9 @@ go_library( "iamdestroy.go", "iamupgradeapply.go", "init.go", + # keep + "license_enterprise.go", + "license_oss.go", "log.go", "maapatch.go", "mini.go", @@ -68,6 +71,7 @@ go_library( "//internal/config/instancetypes", "//internal/config/migration", "//internal/constants", + "//internal/constellation", "//internal/crypto", "//internal/featureset", "//internal/file", @@ -77,6 +81,7 @@ go_library( "//internal/helm", "//internal/kms/uri", "//internal/kubecmd", + # keep "//internal/license", "//internal/logger", "//internal/maa", @@ -167,7 +172,6 @@ go_test( "//internal/helm", "//internal/kms/uri", "//internal/kubecmd", - "//internal/license", "//internal/logger", "//internal/semver", "//internal/state", diff --git a/cli/internal/cmd/apply.go b/cli/internal/cmd/apply.go index 312fec392..40cea7c84 100644 --- a/cli/internal/cmd/apply.go +++ b/cli/internal/cmd/apply.go @@ -27,11 +27,11 @@ import ( "github.com/edgelesssys/constellation/v2/internal/compatibility" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" + "github.com/edgelesssys/constellation/v2/internal/constellation" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/helm" "github.com/edgelesssys/constellation/v2/internal/kubecmd" - "github.com/edgelesssys/constellation/v2/internal/license" "github.com/edgelesssys/constellation/v2/internal/state" "github.com/edgelesssys/constellation/v2/internal/versions" "github.com/spf13/afero" @@ -244,6 +244,8 @@ func runApply(cmd *cobra.Command, _ []string) error { ) } + applier := constellation.NewApplier(log) + apply := &applyCmd{ fileHandler: fileHandler, flags: flags, @@ -254,13 +256,14 @@ func runApply(cmd *cobra.Command, _ []string) error { newDialer: newDialer, newKubeUpgrader: newKubeUpgrader, newInfraApplier: newInfraApplier, + applier: applier, } ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour) defer cancel() cmd.SetContext(ctx) - return apply.apply(cmd, attestationconfigapi.NewFetcher(), license.NewClient(), upgradeDir) + return apply.apply(cmd, attestationconfigapi.NewFetcher(), upgradeDir) } type applyCmd struct { @@ -272,12 +275,18 @@ type applyCmd struct { merger configMerger + applier applier + newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error) newDialer func(validator atls.Validator) *dialer.Dialer newKubeUpgrader func(out io.Writer, kubeConfigPath string, log debugLog) (kubernetesUpgrader, error) newInfraApplier func(context.Context) (cloudApplier, func(), error) } +type applier interface { + CheckLicense(ctx context.Context, csp cloudprovider.Provider, licenseID string) (int, error) +} + /* apply updates a Constellation cluster by applying a user's config. The control flow is as follows: @@ -339,7 +348,7 @@ The control flow is as follows: │ ───┐ ┌─────────────▼────────────┐ │ Can be skipped │Upgrade NodeVersion object│ │K8s/Image - if we ran Init RP │ (Image and K8s update) │ │Phase + if we ran Init RPC │ (Image and K8s update) │ │Phase └─────────────┬────────────┘ │ │ ───┘ ┌─────────▼──────────┐ @@ -347,8 +356,7 @@ The control flow is as follows: └────────────────────┘ */ func (a *applyCmd) apply( - cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, - quotaChecker license.QuotaChecker, upgradeDir string, + cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, upgradeDir string, ) error { // Validate inputs conf, stateFile, err := a.validateInputs(cmd, configFetcher) @@ -357,12 +365,7 @@ func (a *applyCmd) apply( } // Check license - a.log.Debugf("Running license check") - checker := license.NewChecker(quotaChecker, a.fileHandler) - if err := checker.CheckLicense(cmd.Context(), conf.GetProvider(), conf.Provider, cmd.Printf); err != nil { - cmd.PrintErrf("License check failed: %s", err) - } - a.log.Debugf("Checked license") + a.checkLicenseFile(cmd, conf.GetProvider()) // Now start actually running the apply command diff --git a/cli/internal/cmd/apply_test.go b/cli/internal/cmd/apply_test.go index ab6b5b742..05269245d 100644 --- a/cli/internal/cmd/apply_test.go +++ b/cli/internal/cmd/apply_test.go @@ -487,3 +487,9 @@ func newPhases(phases ...skipPhase) skipPhases { skipPhases.add(phases...) return skipPhases } + +type stubConstellApplier struct{} + +func (s *stubConstellApplier) CheckLicense(context.Context, cloudprovider.Provider, string) (int, error) { + return 0, nil +} diff --git a/cli/internal/cmd/create_test.go b/cli/internal/cmd/create_test.go index 719682dc3..53376513f 100644 --- a/cli/internal/cmd/create_test.go +++ b/cli/internal/cmd/create_test.go @@ -233,9 +233,11 @@ func TestCreate(t *testing.T) { newInfraApplier: func(_ context.Context) (cloudApplier, func(), error) { return tc.creator, func() {}, tc.getCreatorErr }, + + applier: &stubConstellApplier{}, } - err := a.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "create") + err := a.apply(cmd, stubAttestationFetcher{}, "create") if tc.wantErr { assert.Error(err) diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 4e00afb96..ec158112e 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -35,7 +35,6 @@ import ( "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" "github.com/edgelesssys/constellation/v2/internal/helm" "github.com/edgelesssys/constellation/v2/internal/kms/uri" - "github.com/edgelesssys/constellation/v2/internal/license" "github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/semver" "github.com/edgelesssys/constellation/v2/internal/state" @@ -280,9 +279,10 @@ func TestInitialize(t *testing.T) { getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""), }, nil }, + applier: &stubConstellApplier{}, } - err := i.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "test") + err := i.apply(cmd, stubAttestationFetcher{}, "test") if tc.wantErr { assert.Error(err) @@ -782,14 +782,6 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs return conf } -type stubLicenseClient struct{} - -func (c *stubLicenseClient) QuotaCheck(_ context.Context, _ license.QuotaCheckRequest) (license.QuotaCheckResponse, error) { - return license.QuotaCheckResponse{ - Quota: 25, - }, nil -} - type stubInitClient struct { res io.Reader err error diff --git a/cli/internal/cmd/license_enterprise.go b/cli/internal/cmd/license_enterprise.go new file mode 100644 index 000000000..2bc1d8797 --- /dev/null +++ b/cli/internal/cmd/license_enterprise.go @@ -0,0 +1,56 @@ +//go:build enterprise + +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package cmd + +import ( + "errors" + "io/fs" + + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" + "github.com/edgelesssys/constellation/v2/internal/constants" + "github.com/edgelesssys/constellation/v2/internal/license" + "github.com/spf13/cobra" +) + +// checkLicenseFile reads the local license file and checks it's quota +// with the license server. If no license file is present or if errors +// occur during the check, the user is informed and the community license +// is used. It is a no-op in the open source version of Constellation. +func (a *applyCmd) checkLicenseFile(cmd *cobra.Command, csp cloudprovider.Provider) { + var licenseID string + a.log.Debugf("Running license check") + + readBytes, err := a.fileHandler.Read(constants.LicenseFilename) + if errors.Is(err, fs.ErrNotExist) { + cmd.Printf("Using community license.\n") + licenseID = license.CommunityLicense + } else if err != nil { + cmd.Printf("Error: %v\nContinuing with community license.\n", err) + licenseID = license.CommunityLicense + } else { + cmd.Printf("Constellation license found!\n") + licenseID, err = license.FromBytes(readBytes) + if err != nil { + cmd.Printf("Error: %v\nContinuing with community license.\n", err) + licenseID = license.CommunityLicense + } + } + + quota, err := a.applier.CheckLicense(cmd.Context(), csp, licenseID) + if err != nil { + cmd.Printf("Unable to contact license server.\n") + cmd.Printf("Please keep your vCPU quota in mind.\n") + } else if licenseID == license.CommunityLicense { + cmd.Printf("For details, see https://docs.edgeless.systems/constellation/overview/license\n") + } else { + cmd.Printf("Please keep your vCPU quota (%d) in mind.\n", quota) + } + + a.log.Debugf("Checked license") +} diff --git a/cli/internal/cmd/license_oss.go b/cli/internal/cmd/license_oss.go new file mode 100644 index 000000000..8fba56114 --- /dev/null +++ b/cli/internal/cmd/license_oss.go @@ -0,0 +1,20 @@ +//go:build !enterprise + +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package cmd + +import ( + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" + "github.com/spf13/cobra" +) + +// checkLicenseFile reads the local license file and checks it's quota +// with the license server. If no license file is present or if errors +// occur during the check, the user is informed and the community license +// is used. It is a no-op in the open source version of Constellation. +func (a *applyCmd) checkLicenseFile(*cobra.Command, cloudprovider.Provider) {} diff --git a/cli/internal/cmd/upgradeapply_test.go b/cli/internal/cmd/upgradeapply_test.go index d5536b296..82ac9706a 100644 --- a/cli/internal/cmd/upgradeapply_test.go +++ b/cli/internal/cmd/upgradeapply_test.go @@ -254,8 +254,9 @@ func TestUpgradeApply(t *testing.T) { newInfraApplier: func(ctx context.Context) (cloudApplier, func(), error) { return tc.terraformUpgrader, func() {}, nil }, + applier: &stubConstellApplier{}, } - err := upgrader.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "test") + err := upgrader.apply(cmd, stubAttestationFetcher{}, "test") if tc.wantErr { assert.Error(err) return diff --git a/internal/constellation/BUILD.bazel b/internal/constellation/BUILD.bazel index 4fc1e733e..53dacaf04 100644 --- a/internal/constellation/BUILD.bazel +++ b/internal/constellation/BUILD.bazel @@ -2,7 +2,14 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library") go_library( name = "constellation", - srcs = ["constellation.go"], + srcs = [ + "apply.go", + "constellation.go", + ], importpath = "github.com/edgelesssys/constellation/v2/internal/constellation", visibility = ["//:__subpackages__"], + deps = [ + "//internal/cloud/cloudprovider", + "//internal/license", + ], ) diff --git a/internal/constellation/apply.go b/internal/constellation/apply.go new file mode 100644 index 000000000..91e576e48 --- /dev/null +++ b/internal/constellation/apply.go @@ -0,0 +1,47 @@ +/* +Copyright (c) Edgeless Systems GmbH + +SPDX-License-Identifier: AGPL-3.0-only +*/ + +package constellation + +import ( + "context" + "fmt" + + "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" + "github.com/edgelesssys/constellation/v2/internal/license" +) + +// An Applier handles applying a specific configuration to a Constellation cluster. +// In Particular, this involves Initialization and Upgrading of the cluster. +type Applier struct { + log debugLog + licenseChecker *license.Checker +} + +type debugLog interface { + Debugf(format string, args ...any) +} + +// NewApplier creates a new Applier. +func NewApplier(log debugLog) *Applier { + return &Applier{ + log: log, + licenseChecker: license.NewChecker(license.NewClient()), + } +} + +// CheckLicense checks the given Constellation license with the license server +// and returns the allowed quota for the license. +func (a *Applier) CheckLicense(ctx context.Context, csp cloudprovider.Provider, licenseID string) (int, error) { + a.log.Debugf("Contacting license server for license '%s'", licenseID) + quotaResp, err := a.licenseChecker.CheckLicense(ctx, csp, licenseID) + if err != nil { + return 0, fmt.Errorf("checking license: %w", err) + } + a.log.Debugf("Got response from license server for license '%s'", licenseID) + + return quotaResp.Quota, nil +} diff --git a/internal/license/BUILD.bazel b/internal/license/BUILD.bazel index 84e64da1a..4cd4e56a8 100644 --- a/internal/license/BUILD.bazel +++ b/internal/license/BUILD.bazel @@ -15,10 +15,8 @@ go_library( visibility = ["//:__subpackages__"], deps = [ "//internal/cloud/cloudprovider", - "//internal/config", # keep "//internal/constants", - "//internal/file", ], ) @@ -30,9 +28,6 @@ go_test( ], embed = [":license"], deps = [ - "//internal/constants", - "//internal/file", - "@com_github_spf13_afero//:afero", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", ], diff --git a/internal/license/checker_enterprise.go b/internal/license/checker_enterprise.go index 52012a123..91b6fd753 100644 --- a/internal/license/checker_enterprise.go +++ b/internal/license/checker_enterprise.go @@ -10,53 +10,25 @@ package license import ( "context" - "errors" - "io/fs" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" - "github.com/edgelesssys/constellation/v2/internal/config" - "github.com/edgelesssys/constellation/v2/internal/constants" - "github.com/edgelesssys/constellation/v2/internal/file" ) type Checker struct { quotaChecker QuotaChecker - fileHandler file.Handler } -func NewChecker(quotaChecker QuotaChecker, fileHandler file.Handler) *Checker { +func NewChecker(quotaChecker QuotaChecker) *Checker { return &Checker{ quotaChecker: quotaChecker, - fileHandler: fileHandler, } } -// CheckLicense tries to read the license file and contact license server -// to fetch quota information. -// If no license file is found, community license is assumed. -func (c *Checker) CheckLicense(ctx context.Context, provider cloudprovider.Provider, providerCfg config.ProviderConfig, printer func(string, ...any)) error { - licenseID, err := FromFile(c.fileHandler, constants.LicenseFilename) - if errors.Is(err, fs.ErrNotExist) { - printer("Using community license.\n") - licenseID = CommunityLicense - } else if err != nil { - printer("Error: %v\nContinuing with community license.\n", err) - licenseID = CommunityLicense - } else { - printer("Constellation license found!\n") - } - quotaResp, err := c.quotaChecker.QuotaCheck(ctx, QuotaCheckRequest{ +// CheckLicense contacts the license server to fetch quota information for the given license. +func (c *Checker) CheckLicense(ctx context.Context, provider cloudprovider.Provider, licenseID string) (QuotaCheckResponse, error) { + return c.quotaChecker.QuotaCheck(ctx, QuotaCheckRequest{ License: licenseID, Action: Init, Provider: provider.String(), }) - if err != nil { - printer("Unable to contact license server.\n") - printer("Please keep your vCPU quota in mind.\n") - } else if licenseID == CommunityLicense { - printer("For details, see https://docs.edgeless.systems/constellation/overview/license\n") - } else { - printer("Please keep your vCPU quota (%d) in mind.\n", quotaResp.Quota) - } - return nil } diff --git a/internal/license/checker_oss.go b/internal/license/checker_oss.go index 1ea297930..a7b18327c 100644 --- a/internal/license/checker_oss.go +++ b/internal/license/checker_oss.go @@ -12,19 +12,17 @@ import ( "context" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" - "github.com/edgelesssys/constellation/v2/internal/config" - "github.com/edgelesssys/constellation/v2/internal/file" ) // Checker checks the Constellation license. type Checker struct{} // NewChecker creates a new Checker. -func NewChecker(_ QuotaChecker, _ file.Handler) *Checker { +func NewChecker(QuotaChecker) *Checker { return &Checker{} } // CheckLicense is a no-op for open source version of Constellation. -func (c *Checker) CheckLicense(_ context.Context, _ cloudprovider.Provider, _ config.ProviderConfig, _ func(string, ...any)) error { - return nil +func (c *Checker) CheckLicense(context.Context, cloudprovider.Provider, string) (QuotaCheckResponse, error) { + return QuotaCheckResponse{}, nil } diff --git a/internal/license/file.go b/internal/license/file.go index 97bd0b0c3..01f5afdff 100644 --- a/internal/license/file.go +++ b/internal/license/file.go @@ -9,20 +9,13 @@ package license import ( "encoding/base64" "fmt" - - "github.com/edgelesssys/constellation/v2/internal/file" ) -// FromFile reads the license from fileHandler at path and returns it as a string. -func FromFile(fileHandler file.Handler, path string) (string, error) { - readBytes, err := fileHandler.Read(path) - if err != nil { - return "", fmt.Errorf("unable to read from '%s': %w", path, err) - } - - maxSize := base64.StdEncoding.DecodedLen(len(readBytes)) +// FromBytes reads the given license bytes and returns it as a string. +func FromBytes(license []byte) (string, error) { + maxSize := base64.StdEncoding.DecodedLen(len(license)) decodedLicense := make([]byte, maxSize) - n, err := base64.StdEncoding.Decode(decodedLicense, readBytes) + n, err := base64.StdEncoding.Decode(decodedLicense, license) if err != nil { return "", fmt.Errorf("unable to base64 decode license file: %w", err) } @@ -30,6 +23,5 @@ func FromFile(fileHandler file.Handler, path string) (string, error) { return "", fmt.Errorf("license file corrupt: wrong length") } decodedLicense = decodedLicense[:n] - return string(decodedLicense), nil } diff --git a/internal/license/file_test.go b/internal/license/file_test.go index 47ca08386..84101dd72 100644 --- a/internal/license/file_test.go +++ b/internal/license/file_test.go @@ -9,72 +9,54 @@ package license import ( "testing" - "github.com/edgelesssys/constellation/v2/internal/constants" - "github.com/edgelesssys/constellation/v2/internal/file" - "github.com/spf13/afero" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestFromFile(t *testing.T) { +func TestFromBytes(t *testing.T) { testCases := map[string]struct { - licenseFileBytes []byte - licenseFilePath string - dontCreate bool - wantLicense string - wantError bool + licenseBytes []byte + wantLicense string + wantErr bool }{ "community license": { - licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"), - licenseFilePath: constants.LicenseFilename, - wantLicense: "00000000-0000-0000-0000-000000000000", + licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"), + wantLicense: CommunityLicense, }, - "license file corrupt: too short": { - licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="), - licenseFilePath: constants.LicenseFilename, - wantError: true, + "too short": { + licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="), + wantErr: true, }, - "license file corrupt: too short by 1 character": { - licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDA="), - licenseFilePath: constants.LicenseFilename, - wantError: true, + "too long": { + licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="), + wantErr: true, }, - "license file corrupt: too long by 1 character": { - licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="), - licenseFilePath: constants.LicenseFilename, - wantError: true, + "not base64": { + licenseBytes: []byte("not base64"), + wantErr: true, }, - "license file corrupt: not base64": { - licenseFileBytes: []byte("I am a license file."), - licenseFilePath: constants.LicenseFilename, - wantError: true, + "empty": { + licenseBytes: []byte(""), + wantErr: true, }, - "license file missing": { - licenseFilePath: constants.LicenseFilename, - dontCreate: true, - wantError: true, + "nil": { + licenseBytes: nil, + wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { - assert := assert.New(t) require := require.New(t) + assert := assert.New(t) - testFS := file.NewHandler(afero.NewMemMapFs()) - - if !tc.dontCreate { - err := testFS.Write(tc.licenseFilePath, tc.licenseFileBytes) + out, err := FromBytes(tc.licenseBytes) + if tc.wantErr { + require.Error(err) + } else { require.NoError(err) } - - license, err := FromFile(testFS, tc.licenseFilePath) - if tc.wantError { - assert.Error(err) - return - } - assert.NoError(err) - assert.Equal(tc.wantLicense, license) + assert.Equal(tc.wantLicense, out) }) } }