mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-11 15:39:33 -05:00
AB#2299 License check in CLI during init (#366)
* license server interaction * logic to read from license file * print license information during init Signed-off-by: Fabian Kammel <fk@edgeless.systems> Co-authored-by: Moritz Eckert <m1gh7ym0@gmail.com>
This commit is contained in:
parent
170a8bf5e0
commit
82eb9f4544
@ -26,6 +26,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
grpcRetry "github.com/edgelesssys/constellation/internal/grpc/retry"
|
||||
"github.com/edgelesssys/constellation/internal/license"
|
||||
"github.com/edgelesssys/constellation/internal/retry"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
kms "github.com/edgelesssys/constellation/kms/setup"
|
||||
@ -57,12 +58,12 @@ func runInitialize(cmd *cobra.Command, args []string) error {
|
||||
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
|
||||
}
|
||||
helmLoader := &helm.ChartLoader{}
|
||||
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader)
|
||||
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient())
|
||||
}
|
||||
|
||||
// initialize initializes a Constellation.
|
||||
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
|
||||
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader,
|
||||
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, licenseClient licenseClient,
|
||||
) error {
|
||||
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||
if err != nil {
|
||||
@ -84,6 +85,23 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
return fmt.Errorf("reading and validating config: %w", err)
|
||||
}
|
||||
|
||||
licenseID, err := license.FromFile(fileHandler, constants.LicenseFilename)
|
||||
if err != nil {
|
||||
cmd.Println("Unable to find license file. Assuming community license.")
|
||||
licenseID = license.CommunityLicense
|
||||
}
|
||||
quotaResp, err := licenseClient.CheckQuota(cmd.Context(), license.CheckQuotaRequest{
|
||||
License: licenseID,
|
||||
Action: license.Init,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.Println("Unable to contact license server.")
|
||||
cmd.Println("Please keep your vCPU quota in mind.")
|
||||
cmd.Printf("For community installation the vCPU quota is: %d.\n", license.CommunityQuota)
|
||||
}
|
||||
cmd.Printf("Constellation license found: %s\n", licenseID)
|
||||
cmd.Printf("Please keep your vCPU quota (%d) in mind.\n", quotaResp.Quota)
|
||||
|
||||
var sshUsers []*ssh.UserKey
|
||||
for _, user := range config.SSHUsers {
|
||||
sshUsers = append(sshUsers, &ssh.UserKey{
|
||||
@ -401,3 +419,7 @@ func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]str
|
||||
type grpcDialer interface {
|
||||
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
|
||||
}
|
||||
|
||||
type licenseClient interface {
|
||||
CheckQuota(ctx context.Context, checkRequest license.CheckQuotaRequest) (license.CheckQuotaResponse, error)
|
||||
}
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/dialer"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/license"
|
||||
"github.com/edgelesssys/constellation/internal/oid"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
@ -173,7 +174,7 @@ func TestInitialize(t *testing.T) {
|
||||
defer cancel()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler, &tc.helmLoader)
|
||||
err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{})
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
@ -452,7 +453,7 @@ func TestAttestation(t *testing.T) {
|
||||
defer cancel()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{})
|
||||
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
|
||||
assert.Error(err)
|
||||
// make sure the error is actually a TLS handshake error
|
||||
assert.Contains(err.Error(), "transport: authentication handshake failed")
|
||||
@ -536,3 +537,11 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, csp cloudprovider.Provi
|
||||
config.RemoveProviderExcept(csp)
|
||||
return config
|
||||
}
|
||||
|
||||
type stubLicenseClient struct{}
|
||||
|
||||
func (c *stubLicenseClient) CheckQuota(ctx context.Context, checkRequest license.CheckQuotaRequest) (license.CheckQuotaResponse, error) {
|
||||
return license.CheckQuotaResponse{
|
||||
Quota: license.CommunityQuota,
|
||||
}, nil
|
||||
}
|
||||
|
@ -50,6 +50,7 @@ const (
|
||||
StateFilename = "constellation-state.json"
|
||||
ClusterIDsFileName = "constellation-id.json"
|
||||
ConfigFilename = "constellation-conf.yaml"
|
||||
LicenseFilename = "constellation.license"
|
||||
DebugdConfigFilename = "cdbg-conf.yaml"
|
||||
AdminConfFilename = "constellation-admin.conf"
|
||||
MasterSecretFilename = "constellation-mastersecret.json"
|
||||
|
28
internal/license/file.go
Normal file
28
internal/license/file.go
Normal file
@ -0,0 +1,28 @@
|
||||
package license
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
)
|
||||
|
||||
func FromFile(fileHandler file.Handler, path string) (string, error) {
|
||||
readBytes, err := fileHandler.Read(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to read from '%s': %w", path, err)
|
||||
}
|
||||
|
||||
maxSize := base64.StdEncoding.DecodedLen(len(readBytes))
|
||||
decodedLicense := make([]byte, maxSize)
|
||||
n, err := base64.StdEncoding.Decode(decodedLicense, readBytes)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to base64 decode license file: %w", err)
|
||||
}
|
||||
if n != 36 { // length of UUID
|
||||
return "", fmt.Errorf("license file corrupt: wrong length")
|
||||
}
|
||||
decodedLicense = decodedLicense[:n]
|
||||
|
||||
return string(decodedLicense), nil
|
||||
}
|
74
internal/license/file_test.go
Normal file
74
internal/license/file_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package license
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFromFile(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
licenseFileBytes []byte
|
||||
licenseFilePath string
|
||||
dontCreate bool
|
||||
wantLicense string
|
||||
wantError bool
|
||||
}{
|
||||
"community license": {
|
||||
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"),
|
||||
licenseFilePath: constants.LicenseFilename,
|
||||
wantLicense: "00000000-0000-0000-0000-000000000000",
|
||||
},
|
||||
"license file corrupt: too short": {
|
||||
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="),
|
||||
licenseFilePath: constants.LicenseFilename,
|
||||
wantError: true,
|
||||
},
|
||||
"license file corrupt: too short by 1 character": {
|
||||
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDA="),
|
||||
licenseFilePath: constants.LicenseFilename,
|
||||
wantError: true,
|
||||
},
|
||||
"license file corrupt: too long by 1 character": {
|
||||
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="),
|
||||
licenseFilePath: constants.LicenseFilename,
|
||||
wantError: true,
|
||||
},
|
||||
"license file corrupt: not base64": {
|
||||
licenseFileBytes: []byte("I am a license file."),
|
||||
licenseFilePath: constants.LicenseFilename,
|
||||
wantError: true,
|
||||
},
|
||||
"license file missing": {
|
||||
licenseFilePath: constants.LicenseFilename,
|
||||
dontCreate: true,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
testFS := file.NewHandler(afero.NewMemMapFs())
|
||||
|
||||
if !tc.dontCreate {
|
||||
err := testFS.Write(tc.licenseFilePath, tc.licenseFileBytes)
|
||||
require.NoError(err)
|
||||
}
|
||||
|
||||
license, err := FromFile(testFS, tc.licenseFilePath)
|
||||
if tc.wantError {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantLicense, license)
|
||||
})
|
||||
}
|
||||
}
|
91
internal/license/license.go
Normal file
91
internal/license/license.go
Normal file
@ -0,0 +1,91 @@
|
||||
package license
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
const (
|
||||
// CommunityLicense is used by everyone who has not bought an enterprise license.
|
||||
CommunityLicense = "00000000-0000-0000-0000-000000000000"
|
||||
// CommunityQuota is the vCPU quota allowed for community installations of Constellation.
|
||||
CommunityQuota = 8
|
||||
apiHost = "license.confidential.cloud"
|
||||
licensePath = "api/v1/license"
|
||||
)
|
||||
|
||||
type Action string
|
||||
|
||||
const (
|
||||
Init Action = "init"
|
||||
test Action = "test"
|
||||
)
|
||||
|
||||
// Client interacts with the ES license server.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewClient creates a new client to interact with ES license server.
|
||||
func NewClient() *Client {
|
||||
return &Client{
|
||||
httpClient: http.DefaultClient,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckQuotaRequest is JSON request to license server to check quota for a given license and action.
|
||||
type CheckQuotaRequest struct {
|
||||
Action Action `json:"action"`
|
||||
License string `json:"license"`
|
||||
}
|
||||
|
||||
// CheckQuotaResponse is JSON response by license server.
|
||||
type CheckQuotaResponse struct {
|
||||
Quota int `json:"quota"`
|
||||
}
|
||||
|
||||
// CheckQuota for a given license and action, passed via CheckQuotaRequest.
|
||||
func (c *Client) CheckQuota(ctx context.Context, checkRequest CheckQuotaRequest) (CheckQuotaResponse, error) {
|
||||
reqBody, err := json.Marshal(checkRequest)
|
||||
if err != nil {
|
||||
return CheckQuotaResponse{}, fmt.Errorf("unable to marshal input: %w", err)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, licenseURL().String(), bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return CheckQuotaResponse{}, fmt.Errorf("unable to create request: %w", err)
|
||||
}
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return CheckQuotaResponse{}, fmt.Errorf("unable to do request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return CheckQuotaResponse{}, fmt.Errorf("http error %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
responseContentType := resp.Header.Get("Content-Type")
|
||||
if responseContentType != "application/json" {
|
||||
return CheckQuotaResponse{}, fmt.Errorf("expected server JSON response but got '%s'", responseContentType)
|
||||
}
|
||||
|
||||
var parsedResponse CheckQuotaResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&parsedResponse)
|
||||
if err != nil {
|
||||
return CheckQuotaResponse{}, fmt.Errorf("unable to parse response: %w", err)
|
||||
}
|
||||
|
||||
return parsedResponse, nil
|
||||
}
|
||||
|
||||
func licenseURL() *url.URL {
|
||||
return &url.URL{
|
||||
Scheme: "https",
|
||||
Host: apiHost,
|
||||
Path: licensePath,
|
||||
}
|
||||
}
|
64
internal/license/license_integration_test.go
Normal file
64
internal/license/license_integration_test.go
Normal file
@ -0,0 +1,64 @@
|
||||
//go:build integration
|
||||
|
||||
package license
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCheckQuotaIntegration(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
license string
|
||||
action string
|
||||
wantQuota int
|
||||
wantError bool
|
||||
}{
|
||||
"ES license has quota 256": {
|
||||
license: "***REMOVED***",
|
||||
action: test,
|
||||
wantQuota: 256,
|
||||
},
|
||||
"OSS license has quota 8": {
|
||||
license: CommunityLicense,
|
||||
action: test,
|
||||
wantQuota: 8,
|
||||
},
|
||||
"OSS license missing action": {
|
||||
license: CommunityLicense,
|
||||
action: "",
|
||||
wantQuota: 8,
|
||||
},
|
||||
"Empty license assumes community": {
|
||||
license: "",
|
||||
action: test,
|
||||
wantQuota: 8,
|
||||
},
|
||||
"Empty request": {
|
||||
license: "",
|
||||
action: "",
|
||||
wantQuota: 8,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := NewClient()
|
||||
|
||||
resp, err := client.CheckQuota(CheckQuotaRequest{
|
||||
Action: tc.action,
|
||||
License: tc.license,
|
||||
})
|
||||
|
||||
if tc.wantError {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantQuota, resp.Quota)
|
||||
})
|
||||
}
|
||||
}
|
94
internal/license/license_test.go
Normal file
94
internal/license/license_test.go
Normal file
@ -0,0 +1,94 @@
|
||||
package license
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// roundTripFunc .
|
||||
type roundTripFunc func(req *http.Request) *http.Response
|
||||
|
||||
// RoundTrip .
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req), nil
|
||||
}
|
||||
|
||||
// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
|
||||
func newTestClient(fn roundTripFunc) *Client {
|
||||
return &Client{
|
||||
httpClient: &http.Client{
|
||||
Transport: fn,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckQuota(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
license string
|
||||
serverResponse string
|
||||
serverResponseCode int
|
||||
serverResponseContent string
|
||||
wantQuota int
|
||||
wantError bool
|
||||
}{
|
||||
"success": {
|
||||
license: "***REMOVED***",
|
||||
serverResponse: "{\"quota\":256}",
|
||||
serverResponseCode: http.StatusOK,
|
||||
serverResponseContent: "application/json",
|
||||
wantQuota: 256,
|
||||
},
|
||||
"404": {
|
||||
serverResponseCode: http.StatusNotFound,
|
||||
wantError: true,
|
||||
},
|
||||
"HTML not JSON": {
|
||||
serverResponseCode: http.StatusOK,
|
||||
serverResponseContent: "text/html",
|
||||
wantError: true,
|
||||
},
|
||||
"promise JSON but actually HTML": {
|
||||
serverResponseCode: http.StatusOK,
|
||||
serverResponse: "<html><head></head></html>",
|
||||
serverResponseContent: "application/json",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := newTestClient(func(req *http.Request) *http.Response {
|
||||
r := &http.Response{
|
||||
StatusCode: tc.serverResponseCode,
|
||||
Body: io.NopCloser(bytes.NewBufferString(tc.serverResponse)),
|
||||
Header: make(http.Header),
|
||||
}
|
||||
r.Header.Set("Content-Type", tc.serverResponseContent)
|
||||
return r
|
||||
})
|
||||
|
||||
resp, err := client.CheckQuota(context.Background(), CheckQuotaRequest{
|
||||
Action: test,
|
||||
License: tc.license,
|
||||
})
|
||||
|
||||
if tc.wantError {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantQuota, resp.Quota)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_licenseURL(t *testing.T) {
|
||||
assert.Equal(t, "https://license.confidential.cloud/api/v1/license", licenseURL().String())
|
||||
}
|
Loading…
Reference in New Issue
Block a user