license: refactor license check to be agnostic of input (#2659)

* license: refactor license check to be agnostic of input

* license: remove unused code

* cli: only check license file in enterprise version

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>

* bazel: fix enterprise CLI build

* bazel: add keep directive

* Update internal/constellation/apply.go

Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>

* license: check for return value

---------

Signed-off-by: Moritz Sanft <58110325+msanft@users.noreply.github.com>
Co-authored-by: Daniel Weiße <66256922+daniel-weisse@users.noreply.github.com>
This commit is contained in:
Moritz Sanft 2023-12-01 08:37:52 +01:00 committed by GitHub
parent 381c546c88
commit 4d6a7fa759
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 201 additions and 124 deletions

View File

@ -25,6 +25,9 @@ go_library(
"iamdestroy.go",
"iamupgradeapply.go",
"init.go",
# keep
"license_enterprise.go",
"license_oss.go",
"log.go",
"maapatch.go",
"mini.go",
@ -68,6 +71,7 @@ go_library(
"//internal/config/instancetypes",
"//internal/config/migration",
"//internal/constants",
"//internal/constellation",
"//internal/crypto",
"//internal/featureset",
"//internal/file",
@ -77,6 +81,7 @@ go_library(
"//internal/helm",
"//internal/kms/uri",
"//internal/kubecmd",
# keep
"//internal/license",
"//internal/logger",
"//internal/maa",
@ -167,7 +172,6 @@ go_test(
"//internal/helm",
"//internal/kms/uri",
"//internal/kubecmd",
"//internal/license",
"//internal/logger",
"//internal/semver",
"//internal/state",

View File

@ -27,11 +27,11 @@ import (
"github.com/edgelesssys/constellation/v2/internal/compatibility"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/constellation"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
"github.com/edgelesssys/constellation/v2/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/kubecmd"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/state"
"github.com/edgelesssys/constellation/v2/internal/versions"
"github.com/spf13/afero"
@ -244,6 +244,8 @@ func runApply(cmd *cobra.Command, _ []string) error {
)
}
applier := constellation.NewApplier(log)
apply := &applyCmd{
fileHandler: fileHandler,
flags: flags,
@ -254,13 +256,14 @@ func runApply(cmd *cobra.Command, _ []string) error {
newDialer: newDialer,
newKubeUpgrader: newKubeUpgrader,
newInfraApplier: newInfraApplier,
applier: applier,
}
ctx, cancel := context.WithTimeout(cmd.Context(), time.Hour)
defer cancel()
cmd.SetContext(ctx)
return apply.apply(cmd, attestationconfigapi.NewFetcher(), license.NewClient(), upgradeDir)
return apply.apply(cmd, attestationconfigapi.NewFetcher(), upgradeDir)
}
type applyCmd struct {
@ -272,12 +275,18 @@ type applyCmd struct {
merger configMerger
applier applier
newHelmClient func(kubeConfigPath string, log debugLog) (helmApplier, error)
newDialer func(validator atls.Validator) *dialer.Dialer
newKubeUpgrader func(out io.Writer, kubeConfigPath string, log debugLog) (kubernetesUpgrader, error)
newInfraApplier func(context.Context) (cloudApplier, func(), error)
}
type applier interface {
CheckLicense(ctx context.Context, csp cloudprovider.Provider, licenseID string) (int, error)
}
/*
apply updates a Constellation cluster by applying a user's config.
The control flow is as follows:
@ -339,7 +348,7 @@ The control flow is as follows:
Can be skipped Upgrade NodeVersion object K8s/Image
if we ran Init RP (Image and K8s update) Phase
if we ran Init RPC (Image and K8s update) Phase
@ -347,8 +356,7 @@ The control flow is as follows:
*/
func (a *applyCmd) apply(
cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher,
quotaChecker license.QuotaChecker, upgradeDir string,
cmd *cobra.Command, configFetcher attestationconfigapi.Fetcher, upgradeDir string,
) error {
// Validate inputs
conf, stateFile, err := a.validateInputs(cmd, configFetcher)
@ -357,12 +365,7 @@ func (a *applyCmd) apply(
}
// Check license
a.log.Debugf("Running license check")
checker := license.NewChecker(quotaChecker, a.fileHandler)
if err := checker.CheckLicense(cmd.Context(), conf.GetProvider(), conf.Provider, cmd.Printf); err != nil {
cmd.PrintErrf("License check failed: %s", err)
}
a.log.Debugf("Checked license")
a.checkLicenseFile(cmd, conf.GetProvider())
// Now start actually running the apply command

View File

@ -487,3 +487,9 @@ func newPhases(phases ...skipPhase) skipPhases {
skipPhases.add(phases...)
return skipPhases
}
type stubConstellApplier struct{}
func (s *stubConstellApplier) CheckLicense(context.Context, cloudprovider.Provider, string) (int, error) {
return 0, nil
}

View File

@ -233,9 +233,11 @@ func TestCreate(t *testing.T) {
newInfraApplier: func(_ context.Context) (cloudApplier, func(), error) {
return tc.creator, func() {}, tc.getCreatorErr
},
applier: &stubConstellApplier{},
}
err := a.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "create")
err := a.apply(cmd, stubAttestationFetcher{}, "create")
if tc.wantErr {
assert.Error(err)

View File

@ -35,7 +35,6 @@ import (
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/v2/internal/helm"
"github.com/edgelesssys/constellation/v2/internal/kms/uri"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/edgelesssys/constellation/v2/internal/semver"
"github.com/edgelesssys/constellation/v2/internal/state"
@ -280,9 +279,10 @@ func TestInitialize(t *testing.T) {
getClusterAttestationConfigErr: k8serrors.NewNotFound(schema.GroupResource{}, ""),
}, nil
},
applier: &stubConstellApplier{},
}
err := i.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "test")
err := i.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr {
assert.Error(err)
@ -782,14 +782,6 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
return conf
}
type stubLicenseClient struct{}
func (c *stubLicenseClient) QuotaCheck(_ context.Context, _ license.QuotaCheckRequest) (license.QuotaCheckResponse, error) {
return license.QuotaCheckResponse{
Quota: 25,
}, nil
}
type stubInitClient struct {
res io.Reader
err error

View File

@ -0,0 +1,56 @@
//go:build enterprise
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"errors"
"io/fs"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/license"
"github.com/spf13/cobra"
)
// checkLicenseFile reads the local license file and checks it's quota
// with the license server. If no license file is present or if errors
// occur during the check, the user is informed and the community license
// is used. It is a no-op in the open source version of Constellation.
func (a *applyCmd) checkLicenseFile(cmd *cobra.Command, csp cloudprovider.Provider) {
var licenseID string
a.log.Debugf("Running license check")
readBytes, err := a.fileHandler.Read(constants.LicenseFilename)
if errors.Is(err, fs.ErrNotExist) {
cmd.Printf("Using community license.\n")
licenseID = license.CommunityLicense
} else if err != nil {
cmd.Printf("Error: %v\nContinuing with community license.\n", err)
licenseID = license.CommunityLicense
} else {
cmd.Printf("Constellation license found!\n")
licenseID, err = license.FromBytes(readBytes)
if err != nil {
cmd.Printf("Error: %v\nContinuing with community license.\n", err)
licenseID = license.CommunityLicense
}
}
quota, err := a.applier.CheckLicense(cmd.Context(), csp, licenseID)
if err != nil {
cmd.Printf("Unable to contact license server.\n")
cmd.Printf("Please keep your vCPU quota in mind.\n")
} else if licenseID == license.CommunityLicense {
cmd.Printf("For details, see https://docs.edgeless.systems/constellation/overview/license\n")
} else {
cmd.Printf("Please keep your vCPU quota (%d) in mind.\n", quota)
}
a.log.Debugf("Checked license")
}

View File

@ -0,0 +1,20 @@
//go:build !enterprise
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package cmd
import (
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/spf13/cobra"
)
// checkLicenseFile reads the local license file and checks it's quota
// with the license server. If no license file is present or if errors
// occur during the check, the user is informed and the community license
// is used. It is a no-op in the open source version of Constellation.
func (a *applyCmd) checkLicenseFile(*cobra.Command, cloudprovider.Provider) {}

View File

@ -254,8 +254,9 @@ func TestUpgradeApply(t *testing.T) {
newInfraApplier: func(ctx context.Context) (cloudApplier, func(), error) {
return tc.terraformUpgrader, func() {}, nil
},
applier: &stubConstellApplier{},
}
err := upgrader.apply(cmd, stubAttestationFetcher{}, &stubLicenseClient{}, "test")
err := upgrader.apply(cmd, stubAttestationFetcher{}, "test")
if tc.wantErr {
assert.Error(err)
return

View File

@ -2,7 +2,14 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "constellation",
srcs = ["constellation.go"],
srcs = [
"apply.go",
"constellation.go",
],
importpath = "github.com/edgelesssys/constellation/v2/internal/constellation",
visibility = ["//:__subpackages__"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/license",
],
)

View File

@ -0,0 +1,47 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package constellation
import (
"context"
"fmt"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/license"
)
// An Applier handles applying a specific configuration to a Constellation cluster.
// In Particular, this involves Initialization and Upgrading of the cluster.
type Applier struct {
log debugLog
licenseChecker *license.Checker
}
type debugLog interface {
Debugf(format string, args ...any)
}
// NewApplier creates a new Applier.
func NewApplier(log debugLog) *Applier {
return &Applier{
log: log,
licenseChecker: license.NewChecker(license.NewClient()),
}
}
// CheckLicense checks the given Constellation license with the license server
// and returns the allowed quota for the license.
func (a *Applier) CheckLicense(ctx context.Context, csp cloudprovider.Provider, licenseID string) (int, error) {
a.log.Debugf("Contacting license server for license '%s'", licenseID)
quotaResp, err := a.licenseChecker.CheckLicense(ctx, csp, licenseID)
if err != nil {
return 0, fmt.Errorf("checking license: %w", err)
}
a.log.Debugf("Got response from license server for license '%s'", licenseID)
return quotaResp.Quota, nil
}

View File

@ -15,10 +15,8 @@ go_library(
visibility = ["//:__subpackages__"],
deps = [
"//internal/cloud/cloudprovider",
"//internal/config",
# keep
"//internal/constants",
"//internal/file",
],
)
@ -30,9 +28,6 @@ go_test(
],
embed = [":license"],
deps = [
"//internal/constants",
"//internal/file",
"@com_github_spf13_afero//:afero",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],

View File

@ -10,53 +10,25 @@ package license
import (
"context"
"errors"
"io/fs"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
)
type Checker struct {
quotaChecker QuotaChecker
fileHandler file.Handler
}
func NewChecker(quotaChecker QuotaChecker, fileHandler file.Handler) *Checker {
func NewChecker(quotaChecker QuotaChecker) *Checker {
return &Checker{
quotaChecker: quotaChecker,
fileHandler: fileHandler,
}
}
// CheckLicense tries to read the license file and contact license server
// to fetch quota information.
// If no license file is found, community license is assumed.
func (c *Checker) CheckLicense(ctx context.Context, provider cloudprovider.Provider, providerCfg config.ProviderConfig, printer func(string, ...any)) error {
licenseID, err := FromFile(c.fileHandler, constants.LicenseFilename)
if errors.Is(err, fs.ErrNotExist) {
printer("Using community license.\n")
licenseID = CommunityLicense
} else if err != nil {
printer("Error: %v\nContinuing with community license.\n", err)
licenseID = CommunityLicense
} else {
printer("Constellation license found!\n")
}
quotaResp, err := c.quotaChecker.QuotaCheck(ctx, QuotaCheckRequest{
// CheckLicense contacts the license server to fetch quota information for the given license.
func (c *Checker) CheckLicense(ctx context.Context, provider cloudprovider.Provider, licenseID string) (QuotaCheckResponse, error) {
return c.quotaChecker.QuotaCheck(ctx, QuotaCheckRequest{
License: licenseID,
Action: Init,
Provider: provider.String(),
})
if err != nil {
printer("Unable to contact license server.\n")
printer("Please keep your vCPU quota in mind.\n")
} else if licenseID == CommunityLicense {
printer("For details, see https://docs.edgeless.systems/constellation/overview/license\n")
} else {
printer("Please keep your vCPU quota (%d) in mind.\n", quotaResp.Quota)
}
return nil
}

View File

@ -12,19 +12,17 @@ import (
"context"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/config"
"github.com/edgelesssys/constellation/v2/internal/file"
)
// Checker checks the Constellation license.
type Checker struct{}
// NewChecker creates a new Checker.
func NewChecker(_ QuotaChecker, _ file.Handler) *Checker {
func NewChecker(QuotaChecker) *Checker {
return &Checker{}
}
// CheckLicense is a no-op for open source version of Constellation.
func (c *Checker) CheckLicense(_ context.Context, _ cloudprovider.Provider, _ config.ProviderConfig, _ func(string, ...any)) error {
return nil
func (c *Checker) CheckLicense(context.Context, cloudprovider.Provider, string) (QuotaCheckResponse, error) {
return QuotaCheckResponse{}, nil
}

View File

@ -9,20 +9,13 @@ package license
import (
"encoding/base64"
"fmt"
"github.com/edgelesssys/constellation/v2/internal/file"
)
// FromFile reads the license from fileHandler at path and returns it as a string.
func FromFile(fileHandler file.Handler, path string) (string, error) {
readBytes, err := fileHandler.Read(path)
if err != nil {
return "", fmt.Errorf("unable to read from '%s': %w", path, err)
}
maxSize := base64.StdEncoding.DecodedLen(len(readBytes))
// FromBytes reads the given license bytes and returns it as a string.
func FromBytes(license []byte) (string, error) {
maxSize := base64.StdEncoding.DecodedLen(len(license))
decodedLicense := make([]byte, maxSize)
n, err := base64.StdEncoding.Decode(decodedLicense, readBytes)
n, err := base64.StdEncoding.Decode(decodedLicense, license)
if err != nil {
return "", fmt.Errorf("unable to base64 decode license file: %w", err)
}
@ -30,6 +23,5 @@ func FromFile(fileHandler file.Handler, path string) (string, error) {
return "", fmt.Errorf("license file corrupt: wrong length")
}
decodedLicense = decodedLicense[:n]
return string(decodedLicense), nil
}

View File

@ -9,72 +9,54 @@ package license
import (
"testing"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/file"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFromFile(t *testing.T) {
func TestFromBytes(t *testing.T) {
testCases := map[string]struct {
licenseFileBytes []byte
licenseFilePath string
dontCreate bool
wantLicense string
wantError bool
licenseBytes []byte
wantLicense string
wantErr bool
}{
"community license": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"),
licenseFilePath: constants.LicenseFilename,
wantLicense: "00000000-0000-0000-0000-000000000000",
licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"),
wantLicense: CommunityLicense,
},
"license file corrupt: too short": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="),
licenseFilePath: constants.LicenseFilename,
wantError: true,
"too short": {
licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="),
wantErr: true,
},
"license file corrupt: too short by 1 character": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDA="),
licenseFilePath: constants.LicenseFilename,
wantError: true,
"too long": {
licenseBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="),
wantErr: true,
},
"license file corrupt: too long by 1 character": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="),
licenseFilePath: constants.LicenseFilename,
wantError: true,
"not base64": {
licenseBytes: []byte("not base64"),
wantErr: true,
},
"license file corrupt: not base64": {
licenseFileBytes: []byte("I am a license file."),
licenseFilePath: constants.LicenseFilename,
wantError: true,
"empty": {
licenseBytes: []byte(""),
wantErr: true,
},
"license file missing": {
licenseFilePath: constants.LicenseFilename,
dontCreate: true,
wantError: true,
"nil": {
licenseBytes: nil,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
assert := assert.New(t)
testFS := file.NewHandler(afero.NewMemMapFs())
if !tc.dontCreate {
err := testFS.Write(tc.licenseFilePath, tc.licenseFileBytes)
out, err := FromBytes(tc.licenseBytes)
if tc.wantErr {
require.Error(err)
} else {
require.NoError(err)
}
license, err := FromFile(testFS, tc.licenseFilePath)
if tc.wantError {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantLicense, license)
assert.Equal(tc.wantLicense, out)
})
}
}