license: refactor license check to be agnostic of input (#2659)

* license: refactor license check to be agnostic of input

* license: remove unused code

* cli: only check license file in enterprise version

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* bazel: fix enterprise CLI build

* bazel: add keep directive

* Update internal/constellation/apply.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* license: check for return value

---------

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
This commit is contained in:
Moritz Sanft 2023-12-01 08:37:52 +01:00 committed by GitHub
parent 381c546c88
commit 4d6a7fa759
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 201 additions and 124 deletions

View file

@ -25,6 +25,9 @@ go_library(
"iamdestroy.go", "iamdestroy.go",
"iamupgradeapply.go", "iamupgradeapply.go",
"init.go", "init.go",
# keep
"license_enterprise.go",
"license_oss.go",
"log.go", "log.go",
"maapatch.go", "maapatch.go",
"mini.go", "mini.go",
@ -68,6 +71,7 @@ go_library(
"//internal/config/instancetypes", "//internal/config/instancetypes",
"//internal/config/migration", "//internal/config/migration",
"//internal/constants", "//internal/constants",
"//internal/constellation",
"//internal/crypto", "//internal/crypto",
"//internal/featureset", "//internal/featureset",
"//internal/file", "//internal/file",
@ -77,6 +81,7 @@ go_library(
"//internal/helm", "//internal/helm",
"//internal/kms/uri", "//internal/kms/uri",
"//internal/kubecmd", "//internal/kubecmd",
# keep
"//internal/license", "//internal/license",
"//internal/logger", "//internal/logger",
"//internal/maa", "//internal/maa",
@ -167,7 +172,6 @@ go_test(
"//internal/helm", "//internal/helm",
"//internal/kms/uri", "//internal/kms/uri",
"//internal/kubecmd", "//internal/kubecmd",
"//internal/license",
"//internal/logger", "//internal/logger",
"//internal/semver", "//internal/semver",
"//internal/state", "//internal/state",

View file

@ -27,11 +27,11 @@ import (
"github.com/edgelesssys/constellation/v2/internal/compatibility" "github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/constellation"
"github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/helm" "github.com/edgelesssys/constellation/v2/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/kubecmd" "github.com/edgelesssys/constellation/v2/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/state" "github.com/edgelesssys/constellation/v2/internal/state"
"github.com/edgelesssys/constellation/v2/internal/versions" "github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero" "github.com/spf13/afero"
@ -244,6 +244,8 @@ func runApply(cmd *cobra.Command, _ []string) error {
) )
} }
applier := constellation.NewApplier(log)
apply := &applyCmd{ apply := &applyCmd{
fileHandler: fileHandler, fileHandler: fileHandler,
flags: flags, flags: flags,
@ -254,13 +256,14 @@ func runApply(cmd *cobra.Command, _ []string) error {
newDialer: newDialer, newDialer: newDialer,
newKubeUpgrader: newKubeUpgrader, newKubeUpgrader: newKubeUpgrader,
newInfraApplier: newInfraApplier, newInfraApplier: newInfraApplier,
applier: applier,
} }
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour) ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
return apply.apply(cmd, attestationconfigapi.NewFetcher(), license.NewClient(), upgradeDir) return apply.apply(cmd, attestationconfigapi.NewFetcher(), upgradeDir)
} }
type applyCmd struct { type applyCmd struct {
@ -272,12 +275,18 @@ type applyCmd struct {
merger configMerger merger configMerger
applier applier
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error) newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error)
newDialer func(validator atls.Validator) *dialer.Dialer newDialer func(validator atls.Validator) *dialer.Dialer
newKubeUpgrader func(out io.Writer, kubeConfigPath string, log debugLog) (kubernetesUpgrader, error) newKubeUpgrader func(out io.Writer, kubeConfigPath string, log debugLog) (kubernetesUpgrader, error)
newInfraApplier func(context.Context) (cloudApplier, func(), error) newInfraApplier func(context.Context) (cloudApplier, func(), error)
} }
type applier interface {
CheckLicense(ctx context.Context, csp cloudprovider.Provider, licenseID string) (int, error)
}
/* /*
apply updates a Constellation cluster by applying a user's config. apply updates a Constellation cluster by applying a user's config.
The control flow is as follows: The control flow is as follows:
@ -339,7 +348,7 @@ The control flow is as follows:
Can be skipped Upgrade NodeVersion object K8s/Image Can be skipped Upgrade NodeVersion object K8s/Image
if we ran Init RP (Image and K8s update) Phase if we ran Init RPC (Image and K8s update) Phase
@ -347,8 +356,7 @@ The control flow is as follows:
*/ */
func (a *applyCmd) apply( func (a *applyCmd) apply(
cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, upgradeDir string,
quotaChecker license.QuotaChecker, upgradeDir string,
) error { ) error {
// Validate inputs // Validate inputs
conf, stateFile, err := a.validateInputs(cmd, configFetcher) conf, stateFile, err := a.validateInputs(cmd, configFetcher)
@ -357,12 +365,7 @@ func (a *applyCmd) apply(
} }
// Check license // Check license
a.log.Debugf("Running license check") a.checkLicenseFile(cmd, conf.GetProvider())
checker := license.NewChecker(quotaChecker, a.fileHandler)
if err := checker.CheckLicense(cmd.Context(), conf.GetProvider(), conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %s", err)
}
a.log.Debugf("Checked license")
// Now start actually running the apply command // Now start actually running the apply command

View file

@ -487,3 +487,9 @@ func newPhases(phases ...skipPhase) skipPhases {
skipPhases.add(phases...) skipPhases.add(phases...)
return skipPhases return skipPhases
} }
type stubConstellApplier struct{}
func (s *stubConstellApplier) CheckLicense(context.Context, cloudprovider.Provider, string) (int, error) {
return 0, nil
}

View file

@ -233,9 +233,11 @@ func TestCreate(t *testing.T) {
newInfraApplier: func(_ context.Context) (cloudApplier, func(), error) { newInfraApplier: func(_ context.Context) (cloudApplier, func(), error) {
return tc.creator, func() {}, tc.getCreatorErr return tc.creator, func() {}, tc.getCreatorErr
}, },
applier: &stubConstellApplier{},
} }
err := a.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "create") err := a.apply(cmd, stubAttestationFetcher{}, "create")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)

View file

@ -35,7 +35,6 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer" "github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/v2/internal/helm" "github.com/edgelesssys/constellation/v2/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/kms/uri" "github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/logger" "github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/semver" "github.com/edgelesssys/constellation/v2/internal/semver"
"github.com/edgelesssys/constellation/v2/internal/state" "github.com/edgelesssys/constellation/v2/internal/state"
@ -280,9 +279,10 @@ func TestInitialize(t *testing.T) {
getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""), getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""),
}, nil }, nil
}, },
applier: &stubConstellApplier{},
} }
err := i.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "test") err := i.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -782,14 +782,6 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
return conf return conf
} }
type stubLicenseClient struct{}
func (c *stubLicenseClient) QuotaCheck(_ context.Context, _ license.QuotaCheckRequest) (license.QuotaCheckResponse, error) {
return license.QuotaCheckResponse{
Quota: 25,
}, nil
}
type stubInitClient struct { type stubInitClient struct {
res io.Reader res io.Reader
err error err error

View file

@ -0,0 +1,56 @@
//go:build enterprise
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"io/fs"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/spf13/cobra"
)
// checkLicenseFile reads the local license file and checks it's quota
// with the license server. If no license file is present or if errors
// occur during the check, the user is informed and the community license
// is used. It is a no-op in the open source version of Constellation.
func (a *applyCmd) checkLicenseFile(cmd *cobra.Command, csp cloudprovider.Provider) {
var licenseID string
a.log.Debugf("Running license check")
readBytes, err := a.fileHandler.Read(constants.LicenseFilename)
if errors.Is(err, fs.ErrNotExist) {
cmd.Printf("Using community license.\n")
licenseID = license.CommunityLicense
} else if err != nil {
cmd.Printf("Error: %v\nContinuing with community license.\n", err)
licenseID = license.CommunityLicense
} else {
cmd.Printf("Constellation license found!\n")
licenseID, err = license.FromBytes(readBytes)
if err != nil {
cmd.Printf("Error: %v\nContinuing with community license.\n", err)
licenseID = license.CommunityLicense
}
}
quota, err := a.applier.CheckLicense(cmd.Context(), csp, licenseID)
if err != nil {
cmd.Printf("Unable to contact license server.\n")
cmd.Printf("Please keep your vCPU quota in mind.\n")
} else if licenseID == license.CommunityLicense {
cmd.Printf("For details, see https://docs.edgeless.systems/constellation/overview/license\n")
} else {
cmd.Printf("Please keep your vCPU quota (%d) in mind.\n", quota)
}
a.log.Debugf("Checked license")
}

View file

@ -0,0 +1,20 @@
//go:build !enterprise
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/spf13/cobra"
)
// checkLicenseFile reads the local license file and checks it's quota
// with the license server. If no license file is present or if errors
// occur during the check, the user is informed and the community license
// is used. It is a no-op in the open source version of Constellation.
func (a *applyCmd) checkLicenseFile(*cobra.Command, cloudprovider.Provider) {}

View file

@ -254,8 +254,9 @@ func TestUpgradeApply(t *testing.T) {
newInfraApplier: func(ctx context.Context) (cloudApplier, func(), error) { newInfraApplier: func(ctx context.Context) (cloudApplier, func(), error) {
return tc.terraformUpgrader, func() {}, nil return tc.terraformUpgrader, func() {}, nil
}, },
applier: &stubConstellApplier{},
} }
err := upgrader.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "test") err := upgrader.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
return return

View file

@ -2,7 +2,14 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library( go_library(
name = "constellation", name = "constellation",
srcs = ["constellation.go"], srcs = [
"apply.go",
"constellation.go",
],
importpath = "github.com/edgelesssys/constellation/v2/internal/constellation", importpath = "github.com/edgelesssys/constellation/v2/internal/constellation",
visibility = ["//:__subpackages__"], visibility = ["//:__subpackages__"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/license",
],
) )

View file

@ -0,0 +1,47 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package constellation
import (
"context"
"fmt"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/license"
)
// An Applier handles applying a specific configuration to a Constellation cluster.
// In Particular, this involves Initialization and Upgrading of the cluster.
type Applier struct {
log debugLog
licenseChecker *license.Checker
}
type debugLog interface {
Debugf(format string, args ...any)
}
// NewApplier creates a new Applier.
func NewApplier(log debugLog) *Applier {
return &Applier{
log: log,
licenseChecker: license.NewChecker(license.NewClient()),
}
}
// CheckLicense checks the given Constellation license with the license server
// and returns the allowed quota for the license.
func (a *Applier) CheckLicense(ctx context.Context, csp cloudprovider.Provider, licenseID string) (int, error) {
a.log.Debugf("Contacting license server for license '%s'", licenseID)
quotaResp, err := a.licenseChecker.CheckLicense(ctx, csp, licenseID)
if err != nil {
return 0, fmt.Errorf("checking license: %w", err)
}
a.log.Debugf("Got response from license server for license '%s'", licenseID)
return quotaResp.Quota, nil
}

View file

@ -15,10 +15,8 @@ go_library(
visibility = ["//:__subpackages__"], visibility = ["//:__subpackages__"],
deps = [ deps = [
"//internal/cloud/cloudprovider", "//internal/cloud/cloudprovider",
"//internal/config",
# keep # keep
"//internal/constants", "//internal/constants",
"//internal/file",
], ],
) )
@ -30,9 +28,6 @@ go_test(
], ],
embed = [":license"], embed = [":license"],
deps = [ deps = [
"//internal/constants",
"//internal/file",
"@com_github_spf13_afero//:afero",
"@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require", "@com_github_stretchr_testify//require",
], ],

View file

@ -10,53 +10,25 @@ package license
import ( import (
"context" "context"
"errors"
"io/fs"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
) )
type Checker struct { type Checker struct {
quotaChecker QuotaChecker quotaChecker QuotaChecker
fileHandler file.Handler
} }
func NewChecker(quotaChecker QuotaChecker, fileHandler file.Handler) *Checker { func NewChecker(quotaChecker QuotaChecker) *Checker {
return &Checker{ return &Checker{
quotaChecker: quotaChecker, quotaChecker: quotaChecker,
fileHandler: fileHandler,
} }
} }
// CheckLicense tries to read the license file and contact license server // CheckLicense contacts the license server to fetch quota information for the given license.
// to fetch quota information. func (c *Checker) CheckLicense(ctx context.Context, provider cloudprovider.Provider, licenseID string) (QuotaCheckResponse, error) {
// If no license file is found, community license is assumed. return c.quotaChecker.QuotaCheck(ctx, QuotaCheckRequest{
func (c *Checker) CheckLicense(ctx context.Context, provider cloudprovider.Provider, providerCfg config.ProviderConfig, printer func(string, ...any)) error {
licenseID, err := FromFile(c.fileHandler, constants.LicenseFilename)
if errors.Is(err, fs.ErrNotExist) {
printer("Using community license.\n")
licenseID = CommunityLicense
} else if err != nil {
printer("Error: %v\nContinuing with community license.\n", err)
licenseID = CommunityLicense
} else {
printer("Constellation license found!\n")
}
quotaResp, err := c.quotaChecker.QuotaCheck(ctx, QuotaCheckRequest{
License: licenseID, License: licenseID,
Action: Init, Action: Init,
Provider: provider.String(), Provider: provider.String(),
}) })
if err != nil {
printer("Unable to contact license server.\n")
printer("Please keep your vCPU quota in mind.\n")
} else if licenseID == CommunityLicense {
printer("For details, see https://docs.edgeless.systems/constellation/overview/license\n")
} else {
printer("Please keep your vCPU quota (%d) in mind.\n", quotaResp.Quota)
}
return nil
} }

View file

@ -12,19 +12,17 @@ import (
"context" "context"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider" "github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
) )
// Checker checks the Constellation license. // Checker checks the Constellation license.
type Checker struct{} type Checker struct{}
// NewChecker creates a new Checker. // NewChecker creates a new Checker.
func NewChecker(_ QuotaChecker, _ file.Handler) *Checker { func NewChecker(QuotaChecker) *Checker {
return &Checker{} return &Checker{}
} }
// CheckLicense is a no-op for open source version of Constellation. // CheckLicense is a no-op for open source version of Constellation.
func (c *Checker) CheckLicense(_ context.Context, _ cloudprovider.Provider, _ config.ProviderConfig, _ func(string, ...any)) error { func (c *Checker) CheckLicense(context.Context, cloudprovider.Provider, string) (QuotaCheckResponse, error) {
return nil return QuotaCheckResponse{}, nil
} }

View file

@ -9,20 +9,13 @@ package license
import ( import (
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/edgelesssys/constellation/v2/internal/file"
) )
// FromFile reads the license from fileHandler at path and returns it as a string. // FromBytes reads the given license bytes and returns it as a string.
func FromFile(fileHandler file.Handler, path string) (string, error) { func FromBytes(license []byte) (string, error) {
readBytes, err := fileHandler.Read(path) maxSize := base64.StdEncoding.DecodedLen(len(license))
if err != nil {
return "", fmt.Errorf("unable to read from '%s': %w", path, err)
}
maxSize := base64.StdEncoding.DecodedLen(len(readBytes))
decodedLicense := make([]byte, maxSize) decodedLicense := make([]byte, maxSize)
n, err := base64.StdEncoding.Decode(decodedLicense, readBytes) n, err := base64.StdEncoding.Decode(decodedLicense, license)
if err != nil { if err != nil {
return "", fmt.Errorf("unable to base64 decode license file: %w", err) return "", fmt.Errorf("unable to base64 decode license file: %w", err)
} }
@ -30,6 +23,5 @@ func FromFile(fileHandler file.Handler, path string) (string, error) {
return "", fmt.Errorf("license file corrupt: wrong length") return "", fmt.Errorf("license file corrupt: wrong length")
} }
decodedLicense = decodedLicense[:n] decodedLicense = decodedLicense[:n]
return string(decodedLicense), nil return string(decodedLicense), nil
} }

View file

@ -9,72 +9,54 @@ package license
import ( import (
"testing" "testing"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestFromFile(t *testing.T) { func TestFromBytes(t *testing.T) {
testCases := map[string]struct { testCases := map[string]struct {
licenseFileBytes []byte licenseBytes []byte
licenseFilePath string
dontCreate bool
wantLicense string wantLicense string
wantError bool wantErr bool
}{ }{
"community license": { "community license": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"), licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"),
licenseFilePath: constants.LicenseFilename, wantLicense: CommunityLicense,
wantLicense: "00000000-0000-0000-0000-000000000000",
}, },
"license file corrupt: too short": { "too short": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="), licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="),
licenseFilePath: constants.LicenseFilename, wantErr: true,
wantError: true,
}, },
"license file corrupt: too short by 1 character": { "too long": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDA="), licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="),
licenseFilePath: constants.LicenseFilename, wantErr: true,
wantError: true,
}, },
"license file corrupt: too long by 1 character": { "not base64": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="), licenseBytes: []byte("not base64"),
licenseFilePath: constants.LicenseFilename, wantErr: true,
wantError: true,
}, },
"license file corrupt: not base64": { "empty": {
licenseFileBytes: []byte("I am a license file."), licenseBytes: []byte(""),
licenseFilePath: constants.LicenseFilename, wantErr: true,
wantError: true,
}, },
"license file missing": { "nil": {
licenseFilePath: constants.LicenseFilename, licenseBytes: nil,
dontCreate: true, wantErr: true,
wantError: true,
}, },
} }
for name, tc := range testCases { for name, tc := range testCases {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t) require := require.New(t)
assert := assert.New(t)
testFS := file.NewHandler(afero.NewMemMapFs()) out, err := FromBytes(tc.licenseBytes)
if tc.wantErr {
if !tc.dontCreate { require.Error(err)
err := testFS.Write(tc.licenseFilePath, tc.licenseFileBytes) } else {
require.NoError(err) require.NoError(err)
} }
assert.Equal(tc.wantLicense, out)
license, err := FromFile(testFS, tc.licenseFilePath)
if tc.wantError {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantLicense, license)
}) })
} }
} }