diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index 8494be45f..6b1e8faa6 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -26,6 +26,7 @@ import ( "github.com/edgelesssys/constellation/internal/file" "github.com/edgelesssys/constellation/internal/grpc/dialer" grpcRetry "github.com/edgelesssys/constellation/internal/grpc/retry" + "github.com/edgelesssys/constellation/internal/license" "github.com/edgelesssys/constellation/internal/retry" "github.com/edgelesssys/constellation/internal/state" kms "github.com/edgelesssys/constellation/kms/setup" @@ -57,12 +58,12 @@ func runInitialize(cmd *cobra.Command, args []string) error { return dialer.New(nil, validator.V(cmd), &net.Dialer{}) } helmLoader := &helm.ChartLoader{} - return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader) + return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient()) } // initialize initializes a Constellation. func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer, - serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, + serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, licenseClient licenseClient, ) error { flags, err := evalFlagArgs(cmd, fileHandler) if err != nil { @@ -84,6 +85,23 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator return fmt.Errorf("reading and validating config: %w", err) } + licenseID, err := license.FromFile(fileHandler, constants.LicenseFilename) + if err != nil { + cmd.Println("Unable to find license file. Assuming community license.") + licenseID = license.CommunityLicense + } + quotaResp, err := licenseClient.CheckQuota(cmd.Context(), license.CheckQuotaRequest{ + License: licenseID, + Action: license.Init, + }) + if err != nil { + cmd.Println("Unable to contact license server.") + cmd.Println("Please keep your vCPU quota in mind.") + cmd.Printf("For community installation the vCPU quota is: %d.\n", license.CommunityQuota) + } + cmd.Printf("Constellation license found: %s\n", licenseID) + cmd.Printf("Please keep your vCPU quota (%d) in mind.\n", quotaResp.Quota) + var sshUsers []*ssh.UserKey for _, user := range config.SSHUsers { sshUsers = append(sshUsers, &ssh.UserKey{ @@ -401,3 +419,7 @@ func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]str type grpcDialer interface { Dial(ctx context.Context, target string) (*grpc.ClientConn, error) } + +type licenseClient interface { + CheckQuota(ctx context.Context, checkRequest license.CheckQuotaRequest) (license.CheckQuotaResponse, error) +} diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 13cc5eefc..f58ea9b08 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -23,6 +23,7 @@ import ( "github.com/edgelesssys/constellation/internal/grpc/atlscredentials" "github.com/edgelesssys/constellation/internal/grpc/dialer" "github.com/edgelesssys/constellation/internal/grpc/testdialer" + "github.com/edgelesssys/constellation/internal/license" "github.com/edgelesssys/constellation/internal/oid" "github.com/edgelesssys/constellation/internal/state" "github.com/spf13/afero" @@ -173,7 +174,7 @@ func TestInitialize(t *testing.T) { defer cancel() cmd.SetContext(ctx) - err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler, &tc.helmLoader) + err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{}) if tc.wantErr { assert.Error(err) @@ -452,7 +453,7 @@ func TestAttestation(t *testing.T) { defer cancel() cmd.SetContext(ctx) - err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}) + err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}) assert.Error(err) // make sure the error is actually a TLS handshake error assert.Contains(err.Error(), "transport: authentication handshake failed") @@ -536,3 +537,11 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, csp cloudprovider.Provi config.RemoveProviderExcept(csp) return config } + +type stubLicenseClient struct{} + +func (c *stubLicenseClient) CheckQuota(ctx context.Context, checkRequest license.CheckQuotaRequest) (license.CheckQuotaResponse, error) { + return license.CheckQuotaResponse{ + Quota: license.CommunityQuota, + }, nil +} diff --git a/internal/constants/constants.go b/internal/constants/constants.go index e0992e66e..b2a6f93da 100644 --- a/internal/constants/constants.go +++ b/internal/constants/constants.go @@ -50,6 +50,7 @@ const ( StateFilename = "constellation-state.json" ClusterIDsFileName = "constellation-id.json" ConfigFilename = "constellation-conf.yaml" + LicenseFilename = "constellation.license" DebugdConfigFilename = "cdbg-conf.yaml" AdminConfFilename = "constellation-admin.conf" MasterSecretFilename = "constellation-mastersecret.json" diff --git a/internal/license/file.go b/internal/license/file.go new file mode 100644 index 000000000..00da18a86 --- /dev/null +++ b/internal/license/file.go @@ -0,0 +1,28 @@ +package license + +import ( + "encoding/base64" + "fmt" + + "github.com/edgelesssys/constellation/internal/file" +) + +func FromFile(fileHandler file.Handler, path string) (string, error) { + readBytes, err := fileHandler.Read(path) + if err != nil { + return "", fmt.Errorf("unable to read from '%s': %w", path, err) + } + + maxSize := base64.StdEncoding.DecodedLen(len(readBytes)) + decodedLicense := make([]byte, maxSize) + n, err := base64.StdEncoding.Decode(decodedLicense, readBytes) + if err != nil { + return "", fmt.Errorf("unable to base64 decode license file: %w", err) + } + if n != 36 { // length of UUID + return "", fmt.Errorf("license file corrupt: wrong length") + } + decodedLicense = decodedLicense[:n] + + return string(decodedLicense), nil +} diff --git a/internal/license/file_test.go b/internal/license/file_test.go new file mode 100644 index 000000000..1ffb9f6c7 --- /dev/null +++ b/internal/license/file_test.go @@ -0,0 +1,74 @@ +package license + +import ( + "testing" + + "github.com/edgelesssys/constellation/internal/constants" + "github.com/edgelesssys/constellation/internal/file" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFromFile(t *testing.T) { + testCases := map[string]struct { + licenseFileBytes []byte + licenseFilePath string + dontCreate bool + wantLicense string + wantError bool + }{ + "community license": { + licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"), + licenseFilePath: constants.LicenseFilename, + wantLicense: "00000000-0000-0000-0000-000000000000", + }, + "license file corrupt: too short": { + licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="), + licenseFilePath: constants.LicenseFilename, + wantError: true, + }, + "license file corrupt: too short by 1 character": { + licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDA="), + licenseFilePath: constants.LicenseFilename, + wantError: true, + }, + "license file corrupt: too long by 1 character": { + licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="), + licenseFilePath: constants.LicenseFilename, + wantError: true, + }, + "license file corrupt: not base64": { + licenseFileBytes: []byte("I am a license file."), + licenseFilePath: constants.LicenseFilename, + wantError: true, + }, + "license file missing": { + licenseFilePath: constants.LicenseFilename, + dontCreate: true, + wantError: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + testFS := file.NewHandler(afero.NewMemMapFs()) + + if !tc.dontCreate { + err := testFS.Write(tc.licenseFilePath, tc.licenseFileBytes) + require.NoError(err) + } + + license, err := FromFile(testFS, tc.licenseFilePath) + if tc.wantError { + assert.Error(err) + return + } + assert.NoError(err) + assert.Equal(tc.wantLicense, license) + }) + } +} diff --git a/internal/license/license.go b/internal/license/license.go new file mode 100644 index 000000000..f8d357e8b --- /dev/null +++ b/internal/license/license.go @@ -0,0 +1,91 @@ +package license + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const ( + // CommunityLicense is used by everyone who has not bought an enterprise license. + CommunityLicense = "00000000-0000-0000-0000-000000000000" + // CommunityQuota is the vCPU quota allowed for community installations of Constellation. + CommunityQuota = 8 + apiHost = "license.confidential.cloud" + licensePath = "api/v1/license" +) + +type Action string + +const ( + Init Action = "init" + test Action = "test" +) + +// Client interacts with the ES license server. +type Client struct { + httpClient *http.Client +} + +// NewClient creates a new client to interact with ES license server. +func NewClient() *Client { + return &Client{ + httpClient: http.DefaultClient, + } +} + +// CheckQuotaRequest is JSON request to license server to check quota for a given license and action. +type CheckQuotaRequest struct { + Action Action `json:"action"` + License string `json:"license"` +} + +// CheckQuotaResponse is JSON response by license server. +type CheckQuotaResponse struct { + Quota int `json:"quota"` +} + +// CheckQuota for a given license and action, passed via CheckQuotaRequest. +func (c *Client) CheckQuota(ctx context.Context, checkRequest CheckQuotaRequest) (CheckQuotaResponse, error) { + reqBody, err := json.Marshal(checkRequest) + if err != nil { + return CheckQuotaResponse{}, fmt.Errorf("unable to marshal input: %w", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, licenseURL().String(), bytes.NewBuffer(reqBody)) + if err != nil { + return CheckQuotaResponse{}, fmt.Errorf("unable to create request: %w", err) + } + resp, err := c.httpClient.Do(req) + if err != nil { + return CheckQuotaResponse{}, fmt.Errorf("unable to do request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return CheckQuotaResponse{}, fmt.Errorf("http error %d", resp.StatusCode) + } + + responseContentType := resp.Header.Get("Content-Type") + if responseContentType != "application/json" { + return CheckQuotaResponse{}, fmt.Errorf("expected server JSON response but got '%s'", responseContentType) + } + + var parsedResponse CheckQuotaResponse + err = json.NewDecoder(resp.Body).Decode(&parsedResponse) + if err != nil { + return CheckQuotaResponse{}, fmt.Errorf("unable to parse response: %w", err) + } + + return parsedResponse, nil +} + +func licenseURL() *url.URL { + return &url.URL{ + Scheme: "https", + Host: apiHost, + Path: licensePath, + } +} diff --git a/internal/license/license_integration_test.go b/internal/license/license_integration_test.go new file mode 100644 index 000000000..815a9628e --- /dev/null +++ b/internal/license/license_integration_test.go @@ -0,0 +1,64 @@ +//go:build integration + +package license + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCheckQuotaIntegration(t *testing.T) { + testCases := map[string]struct { + license string + action string + wantQuota int + wantError bool + }{ + "ES license has quota 256": { + license: "***REMOVED***", + action: test, + wantQuota: 256, + }, + "OSS license has quota 8": { + license: CommunityLicense, + action: test, + wantQuota: 8, + }, + "OSS license missing action": { + license: CommunityLicense, + action: "", + wantQuota: 8, + }, + "Empty license assumes community": { + license: "", + action: test, + wantQuota: 8, + }, + "Empty request": { + license: "", + action: "", + wantQuota: 8, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + client := NewClient() + + resp, err := client.CheckQuota(CheckQuotaRequest{ + Action: tc.action, + License: tc.license, + }) + + if tc.wantError { + assert.Error(err) + return + } + assert.NoError(err) + assert.Equal(tc.wantQuota, resp.Quota) + }) + } +} diff --git a/internal/license/license_test.go b/internal/license/license_test.go new file mode 100644 index 000000000..9d3d37fc9 --- /dev/null +++ b/internal/license/license_test.go @@ -0,0 +1,94 @@ +package license + +import ( + "bytes" + "context" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +// roundTripFunc . +type roundTripFunc func(req *http.Request) *http.Response + +// RoundTrip . +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req), nil +} + +// newTestClient returns *http.Client with Transport replaced to avoid making real calls. +func newTestClient(fn roundTripFunc) *Client { + return &Client{ + httpClient: &http.Client{ + Transport: fn, + }, + } +} + +func TestCheckQuota(t *testing.T) { + testCases := map[string]struct { + license string + serverResponse string + serverResponseCode int + serverResponseContent string + wantQuota int + wantError bool + }{ + "success": { + license: "***REMOVED***", + serverResponse: "{\"quota\":256}", + serverResponseCode: http.StatusOK, + serverResponseContent: "application/json", + wantQuota: 256, + }, + "404": { + serverResponseCode: http.StatusNotFound, + wantError: true, + }, + "HTML not JSON": { + serverResponseCode: http.StatusOK, + serverResponseContent: "text/html", + wantError: true, + }, + "promise JSON but actually HTML": { + serverResponseCode: http.StatusOK, + serverResponse: "", + serverResponseContent: "application/json", + wantError: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + client := newTestClient(func(req *http.Request) *http.Response { + r := &http.Response{ + StatusCode: tc.serverResponseCode, + Body: io.NopCloser(bytes.NewBufferString(tc.serverResponse)), + Header: make(http.Header), + } + r.Header.Set("Content-Type", tc.serverResponseContent) + return r + }) + + resp, err := client.CheckQuota(context.Background(), CheckQuotaRequest{ + Action: test, + License: tc.license, + }) + + if tc.wantError { + assert.Error(err) + return + } + assert.NoError(err) + assert.Equal(tc.wantQuota, resp.Quota) + }) + } +} + +func Test_licenseURL(t *testing.T) { + assert.Equal(t, "https://license.confidential.cloud/api/v1/license", licenseURL().String()) +}