AB#2299 License check in CLI during init (#366)

* license server interaction
* logic to read from license file
* print license information during init
Signed-off-by: Fabian Kammel <fk@edgeless.systems>
Co-authored-by: Moritz Eckert <m1gh7ym0@gmail.com>
This commit is contained in:
Fabian Kammel 2022-08-16 16:06:38 +02:00 committed by GitHub
parent 170a8bf5e0
commit 82eb9f4544
8 changed files with 387 additions and 4 deletions

View File

@ -26,6 +26,7 @@ import (
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
grpcRetry "github.com/edgelesssys/constellation/internal/grpc/retry"
"github.com/edgelesssys/constellation/internal/license"
"github.com/edgelesssys/constellation/internal/retry"
"github.com/edgelesssys/constellation/internal/state"
kms "github.com/edgelesssys/constellation/kms/setup"
@ -57,12 +58,12 @@ func runInitialize(cmd *cobra.Command, args []string) error {
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
}
helmLoader := &helm.ChartLoader{}
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader)
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient())
}
// initialize initializes a Constellation.
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader,
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, licenseClient licenseClient,
) error {
flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil {
@ -84,6 +85,23 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return fmt.Errorf("reading and validating config: %w", err)
}
licenseID, err := license.FromFile(fileHandler, constants.LicenseFilename)
if err != nil {
cmd.Println("Unable to find license file. Assuming community license.")
licenseID = license.CommunityLicense
}
quotaResp, err := licenseClient.CheckQuota(cmd.Context(), license.CheckQuotaRequest{
License: licenseID,
Action: license.Init,
})
if err != nil {
cmd.Println("Unable to contact license server.")
cmd.Println("Please keep your vCPU quota in mind.")
cmd.Printf("For community installation the vCPU quota is: %d.\n", license.CommunityQuota)
}
cmd.Printf("Constellation license found: %s\n", licenseID)
cmd.Printf("Please keep your vCPU quota (%d) in mind.\n", quotaResp.Quota)
var sshUsers []*ssh.UserKey
for _, user := range config.SSHUsers {
sshUsers = append(sshUsers, &ssh.UserKey{
@ -401,3 +419,7 @@ func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]str
type grpcDialer interface {
Dial(ctx context.Context, target string) (*grpc.ClientConn, error)
}
type licenseClient interface {
CheckQuota(ctx context.Context, checkRequest license.CheckQuotaRequest) (license.CheckQuotaResponse, error)
}

View File

@ -23,6 +23,7 @@ import (
"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
"github.com/edgelesssys/constellation/internal/grpc/dialer"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/license"
"github.com/edgelesssys/constellation/internal/oid"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
@ -173,7 +174,7 @@ func TestInitialize(t *testing.T) {
defer cancel()
cmd.SetContext(ctx)
err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler, &tc.helmLoader)
err := initialize(cmd, newDialer, &tc.serviceAccountCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{})
if tc.wantErr {
assert.Error(err)
@ -452,7 +453,7 @@ func TestAttestation(t *testing.T) {
defer cancel()
cmd.SetContext(ctx)
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{})
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
assert.Error(err)
// make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed")
@ -536,3 +537,11 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, csp cloudprovider.Provi
config.RemoveProviderExcept(csp)
return config
}
type stubLicenseClient struct{}
func (c *stubLicenseClient) CheckQuota(ctx context.Context, checkRequest license.CheckQuotaRequest) (license.CheckQuotaResponse, error) {
return license.CheckQuotaResponse{
Quota: license.CommunityQuota,
}, nil
}

View File

@ -50,6 +50,7 @@ const (
StateFilename = "constellation-state.json"
ClusterIDsFileName = "constellation-id.json"
ConfigFilename = "constellation-conf.yaml"
LicenseFilename = "constellation.license"
DebugdConfigFilename = "cdbg-conf.yaml"
AdminConfFilename = "constellation-admin.conf"
MasterSecretFilename = "constellation-mastersecret.json"

28
internal/license/file.go Normal file
View File

@ -0,0 +1,28 @@
package license
import (
"encoding/base64"
"fmt"
"github.com/edgelesssys/constellation/internal/file"
)
func FromFile(fileHandler file.Handler, path string) (string, error) {
readBytes, err := fileHandler.Read(path)
if err != nil {
return "", fmt.Errorf("unable to read from '%s': %w", path, err)
}
maxSize := base64.StdEncoding.DecodedLen(len(readBytes))
decodedLicense := make([]byte, maxSize)
n, err := base64.StdEncoding.Decode(decodedLicense, readBytes)
if err != nil {
return "", fmt.Errorf("unable to base64 decode license file: %w", err)
}
if n != 36 { // length of UUID
return "", fmt.Errorf("license file corrupt: wrong length")
}
decodedLicense = decodedLicense[:n]
return string(decodedLicense), nil
}

View File

@ -0,0 +1,74 @@
package license
import (
"testing"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFromFile(t *testing.T) {
testCases := map[string]struct {
licenseFileBytes []byte
licenseFilePath string
dontCreate bool
wantLicense string
wantError bool
}{
"community license": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAw"),
licenseFilePath: constants.LicenseFilename,
wantLicense: "00000000-0000-0000-0000-000000000000",
},
"license file corrupt: too short": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDA="),
licenseFilePath: constants.LicenseFilename,
wantError: true,
},
"license file corrupt: too short by 1 character": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDA="),
licenseFilePath: constants.LicenseFilename,
wantError: true,
},
"license file corrupt: too long by 1 character": {
licenseFileBytes: []byte("MDAwMDAwMDAtMDAwMC0wMDAwLTAwMDAtMDAwMDAwMDAwMDAwMA=="),
licenseFilePath: constants.LicenseFilename,
wantError: true,
},
"license file corrupt: not base64": {
licenseFileBytes: []byte("I am a license file."),
licenseFilePath: constants.LicenseFilename,
wantError: true,
},
"license file missing": {
licenseFilePath: constants.LicenseFilename,
dontCreate: true,
wantError: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testFS := file.NewHandler(afero.NewMemMapFs())
if !tc.dontCreate {
err := testFS.Write(tc.licenseFilePath, tc.licenseFileBytes)
require.NoError(err)
}
license, err := FromFile(testFS, tc.licenseFilePath)
if tc.wantError {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantLicense, license)
})
}
}

View File

@ -0,0 +1,91 @@
package license
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
)
const (
// CommunityLicense is used by everyone who has not bought an enterprise license.
CommunityLicense = "00000000-0000-0000-0000-000000000000"
// CommunityQuota is the vCPU quota allowed for community installations of Constellation.
CommunityQuota = 8
apiHost = "license.confidential.cloud"
licensePath = "api/v1/license"
)
type Action string
const (
Init Action = "init"
test Action = "test"
)
// Client interacts with the ES license server.
type Client struct {
httpClient *http.Client
}
// NewClient creates a new client to interact with ES license server.
func NewClient() *Client {
return &Client{
httpClient: http.DefaultClient,
}
}
// CheckQuotaRequest is JSON request to license server to check quota for a given license and action.
type CheckQuotaRequest struct {
Action Action `json:"action"`
License string `json:"license"`
}
// CheckQuotaResponse is JSON response by license server.
type CheckQuotaResponse struct {
Quota int `json:"quota"`
}
// CheckQuota for a given license and action, passed via CheckQuotaRequest.
func (c *Client) CheckQuota(ctx context.Context, checkRequest CheckQuotaRequest) (CheckQuotaResponse, error) {
reqBody, err := json.Marshal(checkRequest)
if err != nil {
return CheckQuotaResponse{}, fmt.Errorf("unable to marshal input: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, licenseURL().String(), bytes.NewBuffer(reqBody))
if err != nil {
return CheckQuotaResponse{}, fmt.Errorf("unable to create request: %w", err)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return CheckQuotaResponse{}, fmt.Errorf("unable to do request: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return CheckQuotaResponse{}, fmt.Errorf("http error %d", resp.StatusCode)
}
responseContentType := resp.Header.Get("Content-Type")
if responseContentType != "application/json" {
return CheckQuotaResponse{}, fmt.Errorf("expected server JSON response but got '%s'", responseContentType)
}
var parsedResponse CheckQuotaResponse
err = json.NewDecoder(resp.Body).Decode(&parsedResponse)
if err != nil {
return CheckQuotaResponse{}, fmt.Errorf("unable to parse response: %w", err)
}
return parsedResponse, nil
}
func licenseURL() *url.URL {
return &url.URL{
Scheme: "https",
Host: apiHost,
Path: licensePath,
}
}

View File

@ -0,0 +1,64 @@
//go:build integration
package license
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCheckQuotaIntegration(t *testing.T) {
testCases := map[string]struct {
license string
action string
wantQuota int
wantError bool
}{
"ES license has quota 256": {
license: "***REMOVED***",
action: test,
wantQuota: 256,
},
"OSS license has quota 8": {
license: CommunityLicense,
action: test,
wantQuota: 8,
},
"OSS license missing action": {
license: CommunityLicense,
action: "",
wantQuota: 8,
},
"Empty license assumes community": {
license: "",
action: test,
wantQuota: 8,
},
"Empty request": {
license: "",
action: "",
wantQuota: 8,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := NewClient()
resp, err := client.CheckQuota(CheckQuotaRequest{
Action: tc.action,
License: tc.license,
})
if tc.wantError {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantQuota, resp.Quota)
})
}
}

View File

@ -0,0 +1,94 @@
package license
import (
"bytes"
"context"
"io"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
// roundTripFunc .
type roundTripFunc func(req *http.Request) *http.Response
// RoundTrip .
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req), nil
}
// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *Client {
return &Client{
httpClient: &http.Client{
Transport: fn,
},
}
}
func TestCheckQuota(t *testing.T) {
testCases := map[string]struct {
license string
serverResponse string
serverResponseCode int
serverResponseContent string
wantQuota int
wantError bool
}{
"success": {
license: "***REMOVED***",
serverResponse: "{\"quota\":256}",
serverResponseCode: http.StatusOK,
serverResponseContent: "application/json",
wantQuota: 256,
},
"404": {
serverResponseCode: http.StatusNotFound,
wantError: true,
},
"HTML not JSON": {
serverResponseCode: http.StatusOK,
serverResponseContent: "text/html",
wantError: true,
},
"promise JSON but actually HTML": {
serverResponseCode: http.StatusOK,
serverResponse: "<html><head></head></html>",
serverResponseContent: "application/json",
wantError: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := newTestClient(func(req *http.Request) *http.Response {
r := &http.Response{
StatusCode: tc.serverResponseCode,
Body: io.NopCloser(bytes.NewBufferString(tc.serverResponse)),
Header: make(http.Header),
}
r.Header.Set("Content-Type", tc.serverResponseContent)
return r
})
resp, err := client.CheckQuota(context.Background(), CheckQuotaRequest{
Action: test,
License: tc.license,
})
if tc.wantError {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantQuota, resp.Quota)
})
}
}
func Test_licenseURL(t *testing.T) {
assert.Equal(t, "https://license.confidential.cloud/api/v1/license", licenseURL().String())
}