Co-authored-by: Malte Poll <mp@edgeless.systems>
Co-authored-by: katexochen <katexochen@users.noreply.github.com>
Co-authored-by: Daniel Weiße <dw@edgeless.systems>
Co-authored-by: Thomas Tendyck <tt@edgeless.systems>
Co-authored-by: Benedict Schlueter <bs@edgeless.systems>
Co-authored-by: leongross <leon.gross@rub.de>
Co-authored-by: Moritz Eckert <m1gh7ym0@gmail.com>
This commit is contained in:
Leonard Cohnen 2022-03-22 16:03:15 +01:00
commit 2d8fcd9bf4
362 changed files with 50980 additions and 0 deletions

28
cli/README.md Normal file
View file

@ -0,0 +1,28 @@
# CLI to spawn a confidential kubernetes cluster
## Usage
0. (optional) replace the responsible in `cli/cmd/defaults.go` with yourself.
1. Build the CLI and authenticate with <AWS/Azure/GCP> according to the [README.md](https://github.com/edgelesssys/constellation-coordinator/blob/main/README.md#cloud-credentials).
2. Execute `constellation create <aws/azure/gcp> 2 <4xlarge|n2d-standard-2>`.
3. Execute `wg genkey | tee privatekey | wg pubkey > publickey` to generate a WireGuard keypair.
4. Execute `constellation init --publickey publickey`. Since the CLI waits for all nodes to be ready, this step can take up to 5 minutes.
5. Use the output from `constellation init` and the wireguard template below to create `/etc/wireguard/wg0.conf`, then execute `wg-quick up wg0`.
6. Execute `export KUBECONFIG=<path/to/admin.conf>`.
7. Use `kubectl get nodes` to inspect your cluster.
8. Execute `constellation terminate` to terminate your Constellation.
```bash
[Interface]
Address = <address from the init output>
PrivateKey = <your base64 encoded private key>
ListenPort = 51820
[Peer]
PublicKey = <public key from the init output>
AllowedIPs = 10.118.0.1/32 # IP set on the peer's wg interface
Endpoint = <public IPv4 address from the activated coordinator>:51820 # address where the peer listens on
PersistentKeepalive = 10
```
Note: Skip the manual configuration of WireGuard by executing Step 2 as root. Then, replace steps 4 and 5 with `sudo constellation init --privatekey <path/to/your/privatekey>`. This will automatically configure a new WireGuard interface named wg0 with the coordinator as peer.

View file

@ -0,0 +1,203 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"net/url"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/google/uuid"
)
const (
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
virtualMachineContributorRoleDefinitionID = "9980e02c-c2be-4d73-94e8-173b1dc7cf3c" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#virtual-machine-contributor
)
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
createAppRes, err := c.createADApplication(ctx)
if err != nil {
return "", err
}
c.adAppObjectID = createAppRes.ObjectID
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
if err != nil {
return "", err
}
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
return "", err
}
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
if err != nil {
return "", err
}
return ApplicationCredentials{
ClientID: createAppRes.AppID,
ClientSecret: clientSecret,
}.ConvertToCloudServiceAccountURI(), nil
}
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
return err
}
c.adAppObjectID = ""
return nil
}
// createADApplication creates a new azure AD app.
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
createParameters := graphrbac.ApplicationCreateParameters{
AvailableToOtherTenants: to.BoolPtr(false),
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
}
app, err := c.applicationsAPI.Create(ctx, createParameters)
if err != nil {
return createADApplicationOutput{}, err
}
if app.AppID == nil || app.ObjectID == nil {
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
}
return createADApplicationOutput{
AppID: *app.AppID,
ObjectID: *app.ObjectID,
}, nil
}
// createAppServicePrincipal creates a new service principal for an azure AD app.
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
createParameters := graphrbac.ServicePrincipalCreateParameters{
AppID: &appID,
AccountEnabled: to.BoolPtr(true),
}
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
if err != nil {
return "", err
}
if servicePrincipal.ObjectID == nil {
return "", errors.New("creating AD service principal did not result in a valid object id")
}
return *servicePrincipal.ObjectID, nil
}
// updateAppCredentials sets app client-secret for authentication.
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
keyID := uuid.New().String()
clientSecret, err := generateClientSecret()
if err != nil {
return "", fmt.Errorf("generating client secret failed: %w", err)
}
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
Value: &[]graphrbac.PasswordCredential{
{
StartDate: &date.Time{Time: time.Now()},
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
Value: to.StringPtr(clientSecret),
KeyID: to.StringPtr(keyID),
},
},
}
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
if err != nil {
return "", err
}
return clientSecret, nil
}
// assignResourceGroupRole assigns the service principal a role at resource group scope.
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
if err != nil || resourceGroup.ID == nil {
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
}
roleAssignmentID := uuid.New().String()
createParameters := authorization.RoleAssignmentCreateParameters{
Properties: &authorization.RoleAssignmentProperties{
PrincipalID: to.StringPtr(principalID),
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
},
}
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
// proper fix: use API version 2018-09-01-preview or later
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
var detailedErr autorest.DetailedError
var ok bool
if detailedErr, ok = err.(autorest.DetailedError); !ok {
return err
}
var requestErr *azure.RequestError
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
return err
}
if requestErr.ServiceError.Code != "PrincipalNotFound" {
return err
}
time.Sleep(c.adReplicationLagCheckInterval)
}
return err
}
// ApplicationCredentials is a set of Azure AD application credentials.
// It is the equivalent of a service account key in other cloud providers.
type ApplicationCredentials struct {
ClientID string
ClientSecret string
}
// ConvertToCloudServiceAccountURI converts the ApplicationCredentials into a cloud service account URI.
func (c ApplicationCredentials) ConvertToCloudServiceAccountURI() string {
query := url.Values{}
query.Add("client_id", c.ClientID)
query.Add("client_secret", c.ClientSecret)
uri := url.URL{
Scheme: "serviceaccount",
Host: "azure",
RawQuery: query.Encode(),
}
return uri.String()
}
type createADApplicationOutput struct {
AppID string
ObjectID string
}
func generateClientSecret() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 64
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
return string(pw), nil
}

View file

@ -0,0 +1,380 @@
package client
import (
"context"
"errors"
"net/url"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestCreateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
servicePrincipalsAPI servicePrincipalsAPI
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
errExpected bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
errExpected: true,
},
"failed service principal create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
errExpected: true,
},
"failed role assignment": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
"failed update creds": {
applicationsAPI: stubApplicationsAPI{
updateCredentialsErr: someErr,
},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
applicationsAPI: tc.applicationsAPI,
servicePrincipalsAPI: tc.servicePrincipalsAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
_, err := client.CreateServicePrincipal(ctx)
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestTerminateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
appObjectID string
applicationsAPI applicationsAPI
errExpected bool
}{
"successful terminate": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{},
},
"nothing to terminate": {
applicationsAPI: stubApplicationsAPI{},
},
"failed delete": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{
deleteErr: someErr,
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
adAppObjectID: tc.appObjectID,
applicationsAPI: tc.applicationsAPI,
}
err := client.TerminateServicePrincipal(ctx)
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateADApplication(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
errExpected bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
errExpected: true,
},
"app create returns invalid appid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
},
},
errExpected: true,
},
"app create returns invalid objectid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
applicationsAPI: tc.applicationsAPI,
}
appCredentials, err := client.createADApplication(ctx)
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(appCredentials)
})
}
}
func TestCreateAppServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
servicePrincipalsAPI servicePrincipalsAPI
errExpected bool
}{
"successful create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{},
},
"failed service principal create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
errExpected: true,
},
"service principal create returns invalid objectid": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createServicePrincipal: &graphrbac.ServicePrincipal{},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
servicePrincipalsAPI: tc.servicePrincipalsAPI,
}
_, err := client.createAppServicePrincipal(ctx, "app-id")
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestAssignOwnerOfResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
errExpected bool
}{
"successful assign": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"failed role assignment": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
"failed resource group get": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getErr: someErr,
},
errExpected: true,
},
"resource group get returns invalid id": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{},
},
errExpected: true,
},
"create returns PrincipalNotFound the first time": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "PrincipalNotFound",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"create does not return request error": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{autorest.DetailedError{Original: someErr}},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
"create service error code is unknown": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "some-unknown-error-code",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
if tc.errExpected {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestConvertToCloudServiceAccountURI(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
key := ApplicationCredentials{
ClientID: "client-id",
ClientSecret: "client-secret",
}
cloudServiceAccountURI := key.ConvertToCloudServiceAccountURI()
uri, err := url.Parse(cloudServiceAccountURI)
require.NoError(err)
query := uri.Query()
assert.Equal("serviceaccount", uri.Scheme)
assert.Equal("azure", uri.Host)
assert.Equal(url.Values{
"client_id": []string{"client-id"},
"client_secret": []string{"client-secret"},
}, query)
}

129
cli/azure/client/api.go Normal file
View file

@ -0,0 +1,129 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type virtualNetworksCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error)
}
type networksAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error)
}
type networkSecurityGroupsCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error)
}
type networkSecurityGroupsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error)
}
type virtualMachineScaleSetsCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error)
}
type scaleSetsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error)
}
type publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager interface {
NextPage(ctx context.Context) bool
PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse
}
// TODO: deprecate as soon as scale sets are available.
type publicIPAddressesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error)
}
type publicIPAddressesAPI interface {
ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager
// TODO: deprecate as soon as scale sets are available.
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error)
// TODO: deprecate as soon as scale sets are available.
Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error)
}
// TODO: deprecate as soon as scale sets are available.
type interfacesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.InterfacesClientCreateOrUpdateResponse, error)
}
type networkInterfacesAPI interface {
GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
// TODO: deprecate as soon as scale sets are available
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
interfacesClientCreateOrUpdatePollerResponse, error)
}
type resourceGroupsDeletePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armresources.ResourceGroupsClientDeleteResponse, error)
}
type resourceGroupAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error)
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
}
type applicationsAPI interface {
Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error)
Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error)
UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error)
}
type servicePrincipalsAPI interface {
Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error)
}
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
// TODO: switch to "armauthorization.RoleAssignmentsClient" if issue is resolved.
type roleAssignmentsAPI interface {
Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachinesClientCreateOrUpdateResponse, error)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions) (virtualMachinesClientCreateOrUpdatePollerResponse, error)
}

View file

@ -0,0 +1,388 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type stubNetworksAPI struct {
createErr error
stubResponse stubVirtualNetworksCreateOrUpdatePollerResponse
}
type stubVirtualNetworksCreateOrUpdatePollerResponse struct {
armnetwork.VirtualNetworksClientCreateOrUpdatePollerResponse
pollerErr error
}
func (r stubVirtualNetworksCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error) {
return armnetwork.VirtualNetworksClientCreateOrUpdateResponse{
VirtualNetworksClientCreateOrUpdateResult: armnetwork.VirtualNetworksClientCreateOrUpdateResult{
VirtualNetwork: armnetwork.VirtualNetwork{
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
Subnets: []*armnetwork.Subnet{
{
ID: to.StringPtr("virtual-network-subnet-id"),
},
},
},
},
},
}, r.pollerErr
}
func (a stubNetworksAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
type stubNetworkSecurityGroupsCreateOrUpdatePollerResponse struct {
armnetwork.SecurityGroupsClientCreateOrUpdatePollerResponse
pollerErr error
}
func (r stubNetworkSecurityGroupsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error) {
return armnetwork.SecurityGroupsClientCreateOrUpdateResponse{
SecurityGroupsClientCreateOrUpdateResult: armnetwork.SecurityGroupsClientCreateOrUpdateResult{
SecurityGroup: armnetwork.SecurityGroup{
ID: to.StringPtr("network-security-group-id"),
},
},
}, r.pollerErr
}
type stubNetworkSecurityGroupsAPI struct {
createErr error
stubPoller stubNetworkSecurityGroupsCreateOrUpdatePollerResponse
}
func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
) {
return a.stubPoller, a.createErr
}
type stubResourceGroupAPI struct {
terminateErr error
createErr error
getErr error
getResourceGroup armresources.ResourceGroup
stubResponse stubResourceGroupsDeletePollerResponse
}
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
) {
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
}
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
return armresources.ResourceGroupsClientGetResponse{
ResourceGroupsClientGetResult: armresources.ResourceGroupsClientGetResult{
ResourceGroup: a.getResourceGroup,
},
}, a.getErr
}
type stubResourceGroupsDeletePollerResponse struct {
armresources.ResourceGroupsClientDeletePollerResponse
pollerErr error
}
func (r stubResourceGroupsDeletePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armresources.ResourceGroupsClientDeleteResponse, error,
) {
return armresources.ResourceGroupsClientDeleteResponse{}, r.pollerErr
}
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error,
) {
return a.stubResponse, a.terminateErr
}
type stubScaleSetsAPI struct {
createErr error
stubResponse stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse
}
type stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse struct {
pollResponse armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse
pollErr error
}
func (r stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error,
) {
return r.pollResponse, r.pollErr
}
func (a stubScaleSetsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
type stubPublicIPAddressesAPI struct {
// TODO: deprecate as soon as scale sets are available.
createErr error
// TODO: deprecate as soon as scale sets are available.
getErr error
// TODO: deprecate as soon as scale sets are available.
stubCreateResponse stubPublicIPAddressesClientCreateOrUpdatePollerResponse
}
// TODO: deprecate as soon as scale sets are available.
type stubPublicIPAddressesClientCreateOrUpdatePollerResponse struct {
armnetwork.PublicIPAddressesClientCreateOrUpdatePollerResponse
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubPublicIPAddressesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error,
) {
return armnetwork.PublicIPAddressesClientCreateOrUpdateResponse{
PublicIPAddressesClientCreateOrUpdateResult: armnetwork.PublicIPAddressesClientCreateOrUpdateResult{
PublicIPAddress: armnetwork.PublicIPAddress{
ID: to.StringPtr("pubIP-id"),
},
},
}, r.pollErr
}
type stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager struct {
pagesCounter int
PagesMax int
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) NextPage(ctx context.Context) bool {
p.pagesCounter++
return p.pagesCounter <= p.PagesMax
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse {
return armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse{
PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult: armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult{
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
Value: []*armnetwork.PublicIPAddress{
{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
}
}
func (a stubPublicIPAddressesAPI) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
return &stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager{pagesCounter: 0, PagesMax: 1}
}
// TODO: deprecate as soon as scale sets are available.
func (a stubPublicIPAddressesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
) {
return a.stubCreateResponse, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubPublicIPAddressesAPI) Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error,
) {
return armnetwork.PublicIPAddressesClientGetResponse{
PublicIPAddressesClientGetResult: armnetwork.PublicIPAddressesClientGetResult{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
}, a.getErr
}
type stubNetworkInterfacesAPI struct {
getErr error
// TODO: deprecate as soon as scale sets are available
createErr error
// TODO: deprecate as soon as scale sets are available
stubResp stubInterfacesClientCreateOrUpdatePollerResponse
}
func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) {
if a.getErr != nil {
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{}, a.getErr
}
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{
InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult: armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult{
Interface: armnetwork.Interface{
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
}, nil
}
// TODO: deprecate as soon as scale sets are available.
type stubInterfacesClientCreateOrUpdatePollerResponse struct {
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubInterfacesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armnetwork.InterfacesClientCreateOrUpdateResponse, error,
) {
return armnetwork.InterfacesClientCreateOrUpdateResponse{
InterfacesClientCreateOrUpdateResult: armnetwork.InterfacesClientCreateOrUpdateResult{
Interface: armnetwork.Interface{
ID: to.StringPtr("interface-id"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
}, r.pollErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubNetworkInterfacesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
interfacesClientCreateOrUpdatePollerResponse, error,
) {
return a.stubResp, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
type stubVirtualMachinesAPI struct {
stubResponse stubVirtualMachinesClientCreateOrUpdatePollerResponse
createErr error
}
// TODO: deprecate as soon as scale sets are available.
type stubVirtualMachinesClientCreateOrUpdatePollerResponse struct {
pollResponse armcompute.VirtualMachinesClientCreateOrUpdateResponse
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubVirtualMachinesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armcompute.VirtualMachinesClientCreateOrUpdateResponse, error,
) {
return r.pollResponse, r.pollErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubVirtualMachinesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
return a.stubResponse, a.createErr
}
type stubApplicationsAPI struct {
createErr error
deleteErr error
updateCredentialsErr error
createApplication *graphrbac.Application
}
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
if a.createErr != nil {
return graphrbac.Application{}, a.createErr
}
if a.createApplication != nil {
return *a.createApplication, nil
}
return graphrbac.Application{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000001"),
}, nil
}
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
if a.deleteErr != nil {
return autorest.Response{}, a.deleteErr
}
return autorest.Response{}, nil
}
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
if a.updateCredentialsErr != nil {
return autorest.Response{}, a.updateCredentialsErr
}
return autorest.Response{}, nil
}
type stubServicePrincipalsAPI struct {
createErr error
createServicePrincipal *graphrbac.ServicePrincipal
}
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
if a.createErr != nil {
return graphrbac.ServicePrincipal{}, a.createErr
}
if a.createServicePrincipal != nil {
return *a.createServicePrincipal, nil
}
return graphrbac.ServicePrincipal{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000002"),
}, nil
}
type stubRoleAssignmentsAPI struct {
createCounter int
createErrors []error
}
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
a.createCounter += 1
if len(a.createErrors) == 0 {
return authorization.RoleAssignment{}, nil
}
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
}

View file

@ -0,0 +1,136 @@
package client
import (
"context"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type networksClient struct {
*armnetwork.VirtualNetworksClient
}
func (c *networksClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error,
) {
return c.VirtualNetworksClient.BeginCreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, parameters, options)
}
// TODO: deprecate as soon as scale sets are available.
type networkInterfacesClient struct {
*armnetwork.InterfacesClient
}
// TODO: deprecate as soon as scale sets are available.
func (c *networkInterfacesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions,
) (interfacesClientCreateOrUpdatePollerResponse, error) {
return c.InterfacesClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkInterfaceName, parameters, options)
}
type networkSecurityGroupsClient struct {
*armnetwork.SecurityGroupsClient
}
func (c *networkSecurityGroupsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
) {
return c.SecurityGroupsClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkSecurityGroupName, parameters, options)
}
type publicIPAddressesClient struct {
*armnetwork.PublicIPAddressesClient
}
func (c *publicIPAddressesClient) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
return c.PublicIPAddressesClient.ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName, virtualMachineScaleSetName,
virtualmachineIndex, networkInterfaceName, ipConfigurationName, options)
}
// TODO: deprecate as soon as scale sets are available.
func (c *publicIPAddressesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
) {
return c.PublicIPAddressesClient.BeginCreateOrUpdate(ctx, resourceGroupName, publicIPAddressName, parameters, options)
}
type virtualMachineScaleSetsClient struct {
*armcompute.VirtualMachineScaleSetsClient
}
func (c *virtualMachineScaleSetsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
) {
return c.VirtualMachineScaleSetsClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmScaleSetName, parameters, options)
}
type resourceGroupsClient struct {
*armresources.ResourceGroupsClient
}
func (c *resourceGroupsClient) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error,
) {
return c.ResourceGroupsClient.BeginDelete(ctx, resourceGroupName, options)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesClient struct {
*armcompute.VirtualMachinesClient
}
// TODO: deprecate as soon as scale sets are available.
func (c *virtualMachinesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
return c.VirtualMachinesClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmName, parameters, options)
}
type applicationsClient struct {
*graphrbac.ApplicationsClient
}
func (c *applicationsClient) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
return c.ApplicationsClient.Create(ctx, parameters)
}
func (c *applicationsClient) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
return c.ApplicationsClient.Delete(ctx, applicationObjectID)
}
func (c *applicationsClient) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
return c.ApplicationsClient.UpdatePasswordCredentials(ctx, objectID, parameters)
}
type servicePrincipalsClient struct {
*graphrbac.ServicePrincipalsClient
}
func (c *servicePrincipalsClient) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
return c.ServicePrincipalsClient.Create(ctx, parameters)
}
type roleAssignmentsClient struct {
*authorization.RoleAssignmentsClient
}
func (c *roleAssignmentsClient) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
return c.RoleAssignmentsClient.Create(ctx, scope, roleAssignmentName, parameters)
}

277
cli/azure/client/client.go Normal file
View file

@ -0,0 +1,277 @@
package client
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
)
const (
graphAPIResource = "https://graph.windows.net"
managementAPIResource = "https://management.azure.com"
)
// Client is a client for Azure.
type Client struct {
networksAPI
networkSecurityGroupsAPI
resourceGroupAPI
scaleSetsAPI
publicIPAddressesAPI
networkInterfacesAPI
virtualMachinesAPI
applicationsAPI
servicePrincipalsAPI
roleAssignmentsAPI
adReplicationLagCheckInterval time.Duration
adReplicationLagCheckMaxRetries int
nodes azure.Instances
coordinators azure.Instances
name string
uid string
resourceGroup string
location string
subscriptionID string
tenantID string
subnetID string
coordinatorsScaleSet string
nodesScaleSet string
networkSecurityGroup string
adAppObjectID string
}
// NewFromDefault creates a client with initialized clients.
func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
graphAuthorizer, err := getAuthorizer(graphAPIResource)
if err != nil {
return nil, err
}
managementAuthorizer, err := getAuthorizer(managementAPIResource)
if err != nil {
return nil, err
}
netAPI := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
netSecGrpAPI := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
resGroupAPI := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
scaleSetAPI := armcompute.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
publicIPAddressesAPI := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
networkInterfacesAPI := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
virtualMachinesAPI := armcompute.NewVirtualMachinesClient(subscriptionID, cred, nil)
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
servicePrincipalsAPI.Authorizer = graphAuthorizer
roleAssignmentsAPI := authorization.NewRoleAssignmentsClient(subscriptionID)
roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{
networksAPI: &networksClient{netAPI},
networkSecurityGroupsAPI: &networkSecurityGroupsClient{netSecGrpAPI},
resourceGroupAPI: &resourceGroupsClient{resGroupAPI},
scaleSetsAPI: &virtualMachineScaleSetsClient{scaleSetAPI},
publicIPAddressesAPI: &publicIPAddressesClient{publicIPAddressesAPI},
networkInterfacesAPI: &networkInterfacesClient{networkInterfacesAPI},
applicationsAPI: &applicationsClient{&applicationsAPI},
servicePrincipalsAPI: &servicePrincipalsClient{&servicePrincipalsAPI},
roleAssignmentsAPI: &roleAssignmentsClient{&roleAssignmentsAPI},
virtualMachinesAPI: &virtualMachinesClient{virtualMachinesAPI},
subscriptionID: subscriptionID,
tenantID: tenantID,
nodes: azure.Instances{},
coordinators: azure.Instances{},
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
}, nil
}
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil {
return nil, err
}
err = client.init(location, name)
return client, err
}
// init initializes the client.
func (c *Client) init(location, name string) error {
c.location = location
c.name = name
uid, err := c.generateUID()
if err != nil {
return err
}
c.uid = uid
return nil
}
// GetState returns the state of the client as ConstellationState.
func (c *Client) GetState() (state.ConstellationState, error) {
var stat state.ConstellationState
stat.CloudProvider = cloudprovider.Azure.String()
if len(c.resourceGroup) == 0 {
return state.ConstellationState{}, errors.New("client has no resource group")
}
stat.AzureResourceGroup = c.resourceGroup
if c.name == "" {
return state.ConstellationState{}, errors.New("client has no name")
}
stat.Name = c.name
if len(c.uid) == 0 {
return state.ConstellationState{}, errors.New("client has no uid")
}
stat.UID = c.uid
if len(c.location) == 0 {
return state.ConstellationState{}, errors.New("client has no location")
}
stat.AzureLocation = c.location
if len(c.subscriptionID) == 0 {
return state.ConstellationState{}, errors.New("client has no subscription")
}
stat.AzureSubscription = c.subscriptionID
if len(c.tenantID) == 0 {
return state.ConstellationState{}, errors.New("client has no tenant")
}
stat.AzureTenant = c.tenantID
if len(c.subnetID) == 0 {
return state.ConstellationState{}, errors.New("client has no subnet")
}
stat.AzureSubnet = c.subnetID
if len(c.networkSecurityGroup) == 0 {
return state.ConstellationState{}, errors.New("client has no network security group")
}
stat.AzureNetworkSecurityGroup = c.networkSecurityGroup
// TODO: un-deprecate as soon as scale sets are available
// if len(c.nodesScaleSet) == 0 {
// return state.ConstellationState{}, errors.New("client has no nodes scale set")
// }
// stat.AzureNodesScaleSet = c.nodesScaleSet
// if len(c.coordinatorsScaleSet) == 0 {
// return state.ConstellationState{}, errors.New("client has no coordinators scale set")
// }
// stat.AzureCoordinatorsScaleSet = c.coordinatorsScaleSet
if len(c.nodes) == 0 {
return state.ConstellationState{}, errors.New("client has no nodes")
}
stat.AzureNodes = c.nodes
if len(c.coordinators) == 0 {
return state.ConstellationState{}, errors.New("client has no coordinators")
}
stat.AzureCoordinators = c.coordinators
// AD App Object ID does not have to be set at all times
stat.AzureADAppObjectID = c.adAppObjectID
return stat, nil
}
// SetState sets the state of the client to the handed ConstellationState.
func (c *Client) SetState(stat state.ConstellationState) error {
if stat.CloudProvider != cloudprovider.Azure.String() {
return errors.New("state is not azure state")
}
if len(stat.AzureResourceGroup) == 0 {
return errors.New("state has no resource group")
}
c.resourceGroup = stat.AzureResourceGroup
if stat.Name == "" {
return errors.New("state has no name")
}
c.name = stat.Name
if len(stat.UID) == 0 {
return errors.New("state has no uuid")
}
c.uid = stat.UID
if len(stat.AzureLocation) == 0 {
return errors.New("state has no location")
}
c.location = stat.AzureLocation
if len(stat.AzureSubscription) == 0 {
return errors.New("state has no subscription")
}
c.subscriptionID = stat.AzureSubscription
if len(stat.AzureTenant) == 0 {
return errors.New("state has no tenant")
}
c.tenantID = stat.AzureTenant
if len(stat.AzureSubnet) == 0 {
return errors.New("state has no subnet")
}
c.subnetID = stat.AzureSubnet
if len(stat.AzureNetworkSecurityGroup) == 0 {
return errors.New("state has no subnet")
}
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
// TODO: un-deprecate as soon as scale sets are available
//if len(stat.AzureNodesScaleSet) == 0 {
// return errors.New("state has no nodes scale set")
//}
//c.nodesScaleSet = stat.AzureNodesScaleSet
//if len(stat.AzureCoordinatorsScaleSet) == 0 {
// return errors.New("state has no nodes scale set")
//}
//c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
if len(stat.AzureNodes) == 0 {
return errors.New("state has no coordinator scale set")
}
c.nodes = stat.AzureNodes
if len(stat.AzureCoordinators) == 0 {
return errors.New("state has no coordinators")
}
c.coordinators = stat.AzureCoordinators
// AD App Object ID does not have to be set at all times
c.adAppObjectID = stat.AzureADAppObjectID
return nil
}
func (c *Client) generateUID() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
const uidLen = 5
uid := make([]byte, uidLen)
for i := 0; i < uidLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
uid[i] = letters[n.Int64()]
}
return string(uid), nil
}
// getAuthorizer creates an autorest.Authorizer for different Azure AD APIs using either environment variables or azure cli credentials.
func getAuthorizer(resource string) (autorest.Authorizer, error) {
authorizer, cliErr := auth.NewAuthorizerFromCLIWithResource(resource)
if cliErr == nil {
return authorizer, nil
}
authorizer, envErr := auth.NewAuthorizerFromEnvironmentWithResource(resource)
if envErr == nil {
return authorizer, nil
}
return nil, fmt.Errorf("unable to create authorizer from env or cli: %v %v", envErr, cliErr)
}

View file

@ -0,0 +1,486 @@
package client
import (
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetGetState(t *testing.T) {
testCases := map[string]struct {
state state.ConstellationState
errExpected bool
}{
"valid state": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
// TODO: un-deprecate as soon as scale sets are available
// AzureNodesScaleSet: "node-scale-set",
// AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
},
"missing nodes": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing coordinator": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing name": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing uid": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing resource group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing location": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing subscription": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureTenant: "tenant",
AzureLocation: "location",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing tenant": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureSubscription: "subscription",
AzureLocation: "location",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing subnet": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
"missing network security group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
errExpected: true,
},
// TODO: un-deprecate as soon as scale sets are available
// "missing node scale set": {
// state: state.ConstellationState{
// CloudProvider: cloudprovider.Azure.String(),
// AzureNodes: azure.Instances{
// "0": {
// PublicIP: "ip1",
// PrivateIP: "ip2",
// },
// },
// AzureCoordinators: azure.Instances{
// "0": {
// PublicIP: "ip3",
// PrivateIP: "ip4",
// },
// },
// Name: "name",
// UID: "uid",
// AzureResourceGroup: "resource-group",
// AzureLocation: "location",
// AzureSubscription: "subscription",
// AzureTenant: "tenant",
// AzureSubnet: "azure-subnet",
// AzureNetworkSecurityGroup: "network-security-group",
// AzureCoordinatorsScaleSet: "coordinator-scale-set",
// },
// errExpected: true,
// },
// "missing coordinator scale set": {
// state: state.ConstellationState{
// CloudProvider: cloudprovider.Azure.String(),
// AzureNodes: azure.Instances{
// "0": {
// PublicIP: "ip1",
// PrivateIP: "ip2",
// },
// },
// AzureCoordinators: azure.Instances{
// "0": {
// PublicIP: "ip3",
// PrivateIP: "ip4",
// },
// },
// Name: "name",
// UID: "uid",
// AzureResourceGroup: "resource-group",
// AzureLocation: "location",
// AzureSubscription: "subscription",
// AzureTenant: "tenant",
// AzureSubnet: "azure-subnet",
// AzureNetworkSecurityGroup: "network-security-group",
// AzureNodesScaleSet: "node-scale-set",
// },
// errExpected: true,
// },
}
t.Run("SetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{}
if tc.errExpected {
assert.Error(client.SetState(tc.state))
} else {
assert.NoError(client.SetState(tc.state))
assert.Equal(tc.state.AzureNodes, client.nodes)
assert.Equal(tc.state.AzureCoordinators, client.coordinators)
assert.Equal(tc.state.Name, client.name)
assert.Equal(tc.state.UID, client.uid)
assert.Equal(tc.state.AzureResourceGroup, client.resourceGroup)
assert.Equal(tc.state.AzureLocation, client.location)
assert.Equal(tc.state.AzureSubscription, client.subscriptionID)
assert.Equal(tc.state.AzureTenant, client.tenantID)
assert.Equal(tc.state.AzureSubnet, client.subnetID)
assert.Equal(tc.state.AzureNetworkSecurityGroup, client.networkSecurityGroup)
assert.Equal(tc.state.AzureNodesScaleSet, client.nodesScaleSet)
assert.Equal(tc.state.AzureCoordinatorsScaleSet, client.coordinatorsScaleSet)
}
})
}
})
t.Run("GetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{
nodes: tc.state.AzureNodes,
coordinators: tc.state.AzureCoordinators,
name: tc.state.Name,
uid: tc.state.UID,
resourceGroup: tc.state.AzureResourceGroup,
location: tc.state.AzureLocation,
subscriptionID: tc.state.AzureSubscription,
tenantID: tc.state.AzureTenant,
subnetID: tc.state.AzureSubnet,
networkSecurityGroup: tc.state.AzureNetworkSecurityGroup,
nodesScaleSet: tc.state.AzureNodesScaleSet,
coordinatorsScaleSet: tc.state.AzureCoordinatorsScaleSet,
}
if tc.errExpected {
_, err := client.GetState()
assert.Error(err)
} else {
state, err := client.GetState()
assert.NoError(err)
assert.Equal(tc.state, state)
}
})
}
})
}
func TestSetStateCloudProvider(t *testing.T) {
assert := assert.New(t)
client := Client{}
stateMissingCloudProvider := state.ConstellationState{
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
}
assert.Error(client.SetState(stateMissingCloudProvider))
stateIncorrectCloudProvider := state.ConstellationState{
CloudProvider: "incorrect",
AzureNodes: azure.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: azure.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
}
assert.Error(client.SetState(stateIncorrectCloudProvider))
}
func TestInit(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{}
require.NoError(client.init("location", "name"))
assert.Equal("location", client.location)
assert.Equal("name", client.name)
assert.NotEmpty(client.uid)
}

281
cli/azure/client/compute.go Normal file
View file

@ -0,0 +1,281 @@
package client
import (
"context"
"errors"
"strconv"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/azure"
)
func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error {
// Create nodes scale set
createNodesInput := CreateScaleSetInput{
Name: "constellation-scale-set-nodes-" + c.uid,
NamePrefix: c.name + "-worker-" + c.uid + "-",
Count: input.Count - 1,
InstanceType: input.InstanceType,
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
}
if err := c.createScaleSet(ctx, createNodesInput); err != nil {
return err
}
c.nodesScaleSet = createNodesInput.Name
// Create coordinator scale set
createCoordinatorsInput := CreateScaleSetInput{
Name: "constellation-scale-set-coordinators-" + c.uid,
NamePrefix: c.name + "-control-plane-" + c.uid + "-",
Count: 1,
InstanceType: input.InstanceType,
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
}
if err := c.createScaleSet(ctx, createCoordinatorsInput); err != nil {
return err
}
// Get nodes IPs
instances, err := c.getInstanceIPs(ctx, createNodesInput.Name, createNodesInput.Count)
if err != nil {
return err
}
c.nodes = instances
// Get coordinators IPs
c.coordinatorsScaleSet = createCoordinatorsInput.Name
instances, err = c.getInstanceIPs(ctx, createCoordinatorsInput.Name, createCoordinatorsInput.Count)
if err != nil {
return err
}
c.coordinators = instances
return nil
}
// CreateInstancesInput is the input for a CreateInstances operation.
type CreateInstancesInput struct {
Count int
InstanceType string
Image string
UserAssingedIdentity string
}
// CreateInstancesVMs creates instances based on standalone VMs.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) CreateInstancesVMs(ctx context.Context, input CreateInstancesInput) error {
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
vm := azure.VMInstance{
Name: c.name + "-control-plane-" + c.uid,
Username: "constell",
Password: pw,
Location: c.location,
InstanceType: input.InstanceType,
Image: input.Image,
}
instance, err := c.createInstanceVM(ctx, vm)
if err != nil {
return err
}
c.coordinators = azure.Instances{"0": instance}
for i := 0; i < input.Count-1; i++ {
vm := azure.VMInstance{
Name: c.name + "-node-" + strconv.Itoa(i) + c.uid,
Username: "constell",
Password: pw,
Location: c.location,
InstanceType: input.InstanceType,
Image: input.Image,
}
instance, err := c.createInstanceVM(ctx, vm)
if err != nil {
return err
}
c.nodes[strconv.Itoa(i)] = instance
}
return nil
}
// createInstanceVM creates a single VM with a public IP address
// and a network interface.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createInstanceVM(ctx context.Context, input azure.VMInstance) (azure.Instance, error) {
pubIPName := input.Name + "-pubIP"
pubIPID, err := c.createPublicIPAddress(ctx, pubIPName)
if err != nil {
return azure.Instance{}, err
}
nicName := input.Name + "-NIC"
privIP, nicID, err := c.createNIC(ctx, nicName, pubIPID)
if err != nil {
return azure.Instance{}, err
}
input.NIC = nicID
poller, err := c.virtualMachinesAPI.BeginCreateOrUpdate(ctx, c.resourceGroup, input.Name, input.Azure(), nil)
if err != nil {
return azure.Instance{}, err
}
vm, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return azure.Instance{}, err
}
if vm.Identity == nil || vm.Identity.PrincipalID == nil {
return azure.Instance{}, errors.New("virtual machine was created without system managed identity")
}
if err := c.assignResourceGroupRole(ctx, *vm.Identity.PrincipalID, virtualMachineContributorRoleDefinitionID); err != nil {
return azure.Instance{}, err
}
res, err := c.publicIPAddressesAPI.Get(ctx, c.resourceGroup, pubIPName, nil)
if err != nil {
return azure.Instance{}, err
}
return azure.Instance{PublicIP: *res.PublicIPAddressesClientGetResult.PublicIPAddress.Properties.IPAddress, PrivateIP: privIP}, nil
}
func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) error {
// TODO: Generating a random password to be able
// to create the scale set. This is a temporary fix.
// We need to think about azure access at some point.
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
scaleSet := azure.ScaleSet{
Name: input.Name,
NamePrefix: input.NamePrefix,
Location: c.location,
InstanceType: input.InstanceType,
Count: int64(input.Count),
Username: "constellation",
SubnetID: c.subnetID,
NetworkSecurityGroup: c.networkSecurityGroup,
Image: input.Image,
Password: pw,
UserAssignedIdentity: input.UserAssingedIdentity,
}.Azure()
poller, err := c.scaleSetsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, input.Name,
scaleSet,
nil,
)
if err != nil {
return err
}
_, err = poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
return nil
}
func (c *Client) getInstanceIPs(ctx context.Context, scaleSet string, count int) (azure.Instances, error) {
instances := azure.Instances{}
for i := 0; i < count; i++ {
// get public ip address
var publicIPAddress string
pager := c.publicIPAddressesAPI.ListVirtualMachineScaleSetVMPublicIPAddresses(
c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, scaleSet, nil)
// We always need one pager.NextPage, since calling
// pager.PageResponse() directly return no result.
// We expect to get one page with one entry for each VM.
for pager.NextPage(ctx) {
for _, v := range pager.PageResponse().Value {
if v.Properties != nil && v.Properties.IPAddress != nil {
publicIPAddress = *v.Properties.IPAddress
break
}
}
}
// get private ip address
var privateIPAddress string
res, err := c.networkInterfacesAPI.GetVirtualMachineScaleSetNetworkInterface(
ctx, c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, nil)
if err != nil {
return nil, err
}
configs := res.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult.Interface.Properties.IPConfigurations
for _, config := range configs {
privateIPAddress = *config.Properties.PrivateIPAddress
break
}
instance := azure.Instance{
PrivateIP: privateIPAddress,
PublicIP: publicIPAddress,
}
instances[strconv.Itoa(i)] = instance
}
return instances, nil
}
// CreateScaleSetInput is the input for a CreateScaleSet operation.
type CreateScaleSetInput struct {
Name string
NamePrefix string
Count int
InstanceType string
Image string
UserAssingedIdentity string
}
// CreateResourceGroup creates a resource group.
func (c *Client) CreateResourceGroup(ctx context.Context) error {
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
armresources.ResourceGroup{
Location: &c.location,
}, nil)
if err != nil {
return err
}
c.resourceGroup = c.name + "-" + c.uid
return nil
}
// TerminateResourceGroup terminates a resource group.
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
if err != nil {
return err
}
if _, err = poller.PollUntilDone(ctx, 30*time.Second); err != nil {
return err
}
c.nodes = nil
c.coordinators = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.nodesScaleSet = ""
c.coordinatorsScaleSet = ""
return nil
}

View file

@ -0,0 +1,374 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
resourceGroupAPI resourceGroupAPI
errExpected bool
}{
"successful create": {
resourceGroupAPI: stubResourceGroupAPI{},
},
"failed create": {
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroupAPI: tc.resourceGroupAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateResourceGroup(ctx))
} else {
assert.NoError(client.CreateResourceGroup(ctx))
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
}
})
}
}
func TestTerminateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
clientWithResourceGroup := Client{
resourceGroup: "name",
location: "location",
name: "name",
uid: "uid",
subnetID: "subnet",
nodesScaleSet: "node-scale-set",
coordinatorsScaleSet: "coordinator-scale-set",
nodes: azure.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
coordinators: azure.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
}
testCases := map[string]struct {
resourceGroup string
resourceGroupAPI resourceGroupAPI
client Client
errExpected bool
}{
"successful terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: clientWithResourceGroup,
},
"no resource group to terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: Client{},
resourceGroup: "",
},
"failed terminate": {
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
client: clientWithResourceGroup,
errExpected: true,
},
"failed to poll terminate response": {
resourceGroupAPI: stubResourceGroupAPI{stubResponse: stubResourceGroupsDeletePollerResponse{pollerErr: someErr}},
client: clientWithResourceGroup,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.client.resourceGroupAPI = tc.resourceGroupAPI
ctx := context.Background()
if tc.errExpected {
assert.Error(tc.client.TerminateResourceGroup(ctx))
return
}
assert.NoError(tc.client.TerminateResourceGroup(ctx))
assert.Empty(tc.client.resourceGroup)
assert.Empty(tc.client.subnetID)
assert.Empty(tc.client.nodes)
assert.Empty(tc.client.coordinators)
assert.Empty(tc.client.nodesScaleSet)
assert.Empty(tc.client.coordinatorsScaleSet)
})
}
}
func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
errExpected bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{
stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{
pollResponse: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse{
VirtualMachineScaleSetsClientCreateOrUpdateResult: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResult{
VirtualMachineScaleSet: armcompute.VirtualMachineScaleSet{Identity: &armcompute.VirtualMachineScaleSetIdentity{PrincipalID: to.StringPtr("principal-id")}},
},
},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
errExpected: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{pollErr: someErr}},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
errExpected: true,
},
"error when retrieving private IPs": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateInstances(ctx, tc.createInstancesInput))
} else {
assert.NoError(client.CreateInstances(ctx, tc.createInstancesInput))
assert.Equal(1, len(client.coordinators))
assert.Equal(tc.createInstancesInput.Count-1, len(client.nodes))
assert.NotEmpty(client.nodes["0"].PrivateIP)
assert.NotEmpty(client.nodes["0"].PublicIP)
assert.NotEmpty(client.coordinators["0"].PrivateIP)
assert.NotEmpty(client.coordinators["0"].PublicIP)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreateInstancesVMs(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
virtualMachinesAPI virtualMachinesAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
errExpected bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{
stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{
pollResponse: armcompute.VirtualMachinesClientCreateOrUpdateResponse{VirtualMachinesClientCreateOrUpdateResult: armcompute.VirtualMachinesClientCreateOrUpdateResult{
VirtualMachine: armcompute.VirtualMachine{
Identity: &armcompute.VirtualMachineIdentity{PrincipalID: to.StringPtr("principal-id")},
},
}},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when creating NIC": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when creating public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
"error when retrieving public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{getErr: someErr},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
Count: 3,
InstanceType: "type",
Image: "image",
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachinesAPI: tc.virtualMachinesAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
return
}
require.NoError(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
assert.Equal(1, len(client.coordinators))
assert.Equal(tc.createInstancesInput.Count-1, len(client.nodes))
assert.NotEmpty(client.nodes["0"].PrivateIP)
assert.NotEmpty(client.nodes["0"].PublicIP)
assert.NotEmpty(client.coordinators["0"].PrivateIP)
assert.NotEmpty(client.coordinators["0"].PublicIP)
})
}
}
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
return &stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
}
}

166
cli/azure/client/network.go Normal file
View file

@ -0,0 +1,166 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
)
type createNetworkInput struct {
name string
location string
addressSpace string
}
// CreateVirtualNetwork creates a virtual network.
func (c *Client) CreateVirtualNetwork(ctx context.Context) error {
createNetworkInput := createNetworkInput{
name: "constellation-" + c.uid,
location: c.location,
addressSpace: "172.20.0.0/16",
}
poller, err := c.networksAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkInput.name,
armnetwork.VirtualNetwork{
Name: to.StringPtr(createNetworkInput.name), // this is supposed to be read-only
Location: to.StringPtr(createNetworkInput.location),
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
AddressSpace: &armnetwork.AddressSpace{
AddressPrefixes: []*string{
to.StringPtr(createNetworkInput.addressSpace),
},
},
Subnets: []*armnetwork.Subnet{
{
Name: to.StringPtr("default"),
Properties: &armnetwork.SubnetPropertiesFormat{
AddressPrefix: to.StringPtr(createNetworkInput.addressSpace),
},
},
},
},
},
nil,
)
if err != nil {
return err
}
resp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.subnetID = *resp.VirtualNetworksClientCreateOrUpdateResult.VirtualNetwork.Properties.Subnets[0].ID
return nil
}
type createNetworkSecurityGroupInput struct {
name string
location string
rules []*armnetwork.SecurityRule
}
// CreateSecurityGroup creates a security group containing firewall rules.
func (c *Client) CreateSecurityGroup(ctx context.Context, input NetworkSecurityGroupInput) error {
rules := input.Ingress.Azure()
createNetworkSecurityGroupInput := createNetworkSecurityGroupInput{
name: "constellation-security-group-" + c.uid,
location: c.location,
rules: rules,
}
poller, err := c.networkSecurityGroupsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkSecurityGroupInput.name,
armnetwork.SecurityGroup{
Name: to.StringPtr(createNetworkSecurityGroupInput.name),
Location: to.StringPtr(createNetworkSecurityGroupInput.location),
Properties: &armnetwork.SecurityGroupPropertiesFormat{
SecurityRules: createNetworkSecurityGroupInput.rules,
},
},
nil,
)
if err != nil {
return err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.networkSecurityGroup = *pollerResp.SecurityGroupsClientCreateOrUpdateResult.SecurityGroup.ID
return nil
}
// createNIC creates a network interface that references a public IP address.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createNIC(ctx context.Context, name, publicIPAddressID string) (ip string, id string, err error) {
poller, err := c.networkInterfacesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.Interface{
Location: to.StringPtr(c.location),
Properties: &armnetwork.InterfacePropertiesFormat{
NetworkSecurityGroup: &armnetwork.SecurityGroup{
ID: to.StringPtr(c.networkSecurityGroup),
},
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Name: to.StringPtr(name),
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
Subnet: &armnetwork.Subnet{
ID: to.StringPtr(c.subnetID),
},
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.StringPtr(publicIPAddressID),
},
},
},
},
},
},
nil,
)
if err != nil {
return "", "", err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return "", "", err
}
netInterface := pollerResp.InterfacesClientCreateOrUpdateResult.Interface
return *netInterface.Properties.IPConfigurations[0].Properties.PrivateIPAddress,
*netInterface.ID,
nil
}
// createPublicIPAddress creates a public IP address.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createPublicIPAddress(ctx context.Context, name string) (string, error) {
poller, err := c.publicIPAddressesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.PublicIPAddress{
Location: to.StringPtr(c.location),
},
nil,
)
if err != nil {
return "", err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return "", err
}
return *pollerResp.PublicIPAddressesClientCreateOrUpdateResult.PublicIPAddress.ID, nil
}
// NetworkSecurityGroupInput defines firewall rules to be set.
type NetworkSecurityGroupInput struct {
Ingress cloudtypes.Firewall
Egress cloudtypes.Firewall
}

View file

@ -0,0 +1,220 @@
package client
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateVirtualNetwork(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networksAPI networksAPI
errExpected bool
}{
"successful create": {
networksAPI: stubNetworksAPI{},
},
"failed to get response from successful create": {
networksAPI: stubNetworksAPI{stubResponse: stubVirtualNetworksCreateOrUpdatePollerResponse{pollerErr: someErr}},
errExpected: true,
},
"failed create": {
networksAPI: stubNetworksAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
networksAPI: tc.networksAPI,
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
}
if tc.errExpected {
assert.Error(client.CreateVirtualNetwork(ctx))
} else {
assert.NoError(client.CreateVirtualNetwork(ctx))
assert.NotEmpty(client.subnetID)
}
})
}
}
func TestCreateSecurityGroup(t *testing.T) {
someErr := errors.New("failed")
testNetworkSecurityGroupInput := NetworkSecurityGroupInput{
Ingress: cloudtypes.Firewall{
{
Name: "test-1",
Description: "test-1 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
Port: 9000,
},
{
Name: "test-2",
Description: "test-2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
Port: 51820,
},
},
Egress: cloudtypes.Firewall{},
}
testCases := map[string]struct {
networkSecurityGroupsAPI networkSecurityGroupsAPI
errExpected bool
}{
"successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{},
},
"failed to get response from successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{stubPoller: stubNetworkSecurityGroupsCreateOrUpdatePollerResponse{pollerErr: someErr}},
errExpected: true,
},
"failed create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
networkSecurityGroupsAPI: tc.networkSecurityGroupsAPI,
}
if tc.errExpected {
assert.Error(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
} else {
assert.NoError(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
assert.Equal("network-security-group-id", client.networkSecurityGroup)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreateNIC(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networkInterfacesAPI networkInterfacesAPI
name string
publicIPAddressID string
errExpected bool
}{
"successful create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{},
name: "nic-name",
publicIPAddressID: "pubIP-id",
},
"failed to get response from successful create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{stubResp: stubInterfacesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
errExpected: true,
},
"failed create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
networkInterfacesAPI: tc.networkInterfacesAPI,
}
ip, id, err := client.createNIC(ctx, tc.name, tc.publicIPAddressID)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(ip)
assert.NotEmpty(id)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreatePublicIPAddress(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
name string
errExpected bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
name: "nic-name",
},
"failed to get response from successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{stubCreateResponse: stubPublicIPAddressesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
errExpected: true,
},
"failed create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(azure.Instances),
coordinators: make(azure.Instances),
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
id, err := client.createPublicIPAddress(ctx, tc.name)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(id)
}
})
}
}

135
cli/azure/instances.go Normal file
View file

@ -0,0 +1,135 @@
package azure
// copy of ec2/instances.go
// TODO(katexochen): refactor into mulitcloud package.
import (
"errors"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
)
// Instance is a azure instance.
type Instance struct {
PublicIP string
PrivateIP string
}
// Instances is a map of azure Instances. The ID of an instance is used as key.
type Instances map[string]Instance
// IDs returns the IDs of all instances of the Constellation.
func (i Instances) IDs() []string {
var ids []string
for id := range i {
ids = append(ids, id)
}
return ids
}
// PublicIPs returns the public IPs of all the instances of the Constellation.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}
// GetOne return anyone instance out of the instances and its ID.
func (i Instances) GetOne() (string, Instance, error) {
for id, instance := range i {
return id, instance, nil
}
return "", Instance{}, errors.New("map is empty")
}
// GetOthers returns all instances but the one with the handed ID.
func (i Instances) GetOthers(id string) Instances {
others := make(Instances)
for key, instance := range i {
if key != id {
others[key] = instance
}
}
return others
}
// TODO: deprecate as soon as scale sets are available.
type VMInstance struct {
Name string
Location string
InstanceType string
Username string
Password string
NIC string
Image string
}
// TODO: deprecate as soon as scale sets are available.
func (i VMInstance) Azure() armcompute.VirtualMachine {
return armcompute.VirtualMachine{
Name: to.StringPtr(i.Name),
Location: to.StringPtr(i.Location),
Properties: &armcompute.VirtualMachineProperties{
HardwareProfile: &armcompute.HardwareProfile{
VMSize: (*armcompute.VirtualMachineSizeTypes)(to.StringPtr(i.InstanceType)),
},
OSProfile: &armcompute.OSProfile{
ComputerName: to.StringPtr(i.Name),
AdminPassword: to.StringPtr(i.Password),
AdminUsername: to.StringPtr(i.Username),
},
SecurityProfile: &armcompute.SecurityProfile{
UefiSettings: &armcompute.UefiSettings{
SecureBootEnabled: to.BoolPtr(true),
VTpmEnabled: to.BoolPtr(true),
},
SecurityType: armcompute.SecurityTypesConfidentialVM.ToPtr(),
},
NetworkProfile: &armcompute.NetworkProfile{
NetworkInterfaces: []*armcompute.NetworkInterfaceReference{
{
ID: to.StringPtr(i.NIC),
},
},
},
StorageProfile: &armcompute.StorageProfile{
OSDisk: &armcompute.OSDisk{
CreateOption: armcompute.DiskCreateOptionTypesFromImage.ToPtr(),
ManagedDisk: &armcompute.ManagedDiskParameters{
StorageAccountType: armcompute.StorageAccountTypesPremiumLRS.ToPtr(),
SecurityProfile: &armcompute.VMDiskSecurityProfile{
SecurityEncryptionType: armcompute.SecurityEncryptionTypesVMGuestStateOnly.ToPtr(),
},
},
},
ImageReference: &armcompute.ImageReference{
Publisher: to.StringPtr("0001-com-ubuntu-confidential-vm-focal"),
Offer: to.StringPtr("canonical"),
SKU: to.StringPtr("20_04-lts-gen2"),
Version: to.StringPtr("latest"),
},
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.BoolPtr(true),
},
},
},
Identity: &armcompute.VirtualMachineIdentity{
Type: armcompute.ResourceIdentityTypeSystemAssigned.ToPtr(),
},
}
}

View file

@ -0,0 +1,71 @@
package azure
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIDs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
assert.ElementsMatch(expectedIDs, testState.IDs())
}
func TestPublicIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
}
func TestPrivateIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
}
func TestGetOne(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
id, instance, err := testState.GetOne()
assert.NoError(err)
assert.Contains(testState, id)
assert.Equal(testState[id], instance)
}
func TestGetOthers(t *testing.T) {
assert := assert.New(t)
testCases := testInstances().IDs()
for _, id := range testCases {
others := testInstances().GetOthers(id)
assert.NotContains(others, id)
expectedInstances := testInstances()
delete(expectedInstances, id)
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
}
}
func testInstances() Instances {
return Instances{
"id-9": {
PublicIP: "192.0.2.1",
PrivateIP: "192.0.2.2",
},
"id-10": {
PublicIP: "192.0.2.3",
PrivateIP: "192.0.2.4",
},
"id-11": {
PublicIP: "192.0.2.5",
PrivateIP: "192.0.2.6",
},
"id-12": {
PublicIP: "192.0.2.7",
PrivateIP: "192.0.2.8",
},
}
}

View file

@ -0,0 +1,13 @@
package azure
import "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
// InstanceTypes are valid Azure instance types.
// Normally, this would be string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
// but currently needed instances are not in SDK.
var InstanceTypes = []string{
string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
"Standard_DC2as_v5",
"Standard_DC4as_v5",
"Standard_DC8as_v5",
}

123
cli/azure/scaleset.go Normal file
View file

@ -0,0 +1,123 @@
package azure
import (
"crypto/rand"
"math/big"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
)
// ScaleSet defines a Azure scale set.
type ScaleSet struct {
Name string
NamePrefix string
Location string
InstanceType string
Count int64
Username string
SubnetID string
NetworkSecurityGroup string
Password string
Image string
UserAssignedIdentity string
}
// Azure returns the Azure representation of ScaleSet.
func (s ScaleSet) Azure() armcompute.VirtualMachineScaleSet {
return armcompute.VirtualMachineScaleSet{
Name: to.StringPtr(s.Name),
Location: to.StringPtr(s.Location),
SKU: &armcompute.SKU{
Name: to.StringPtr(s.InstanceType),
Capacity: to.Int64Ptr(s.Count),
},
Properties: &armcompute.VirtualMachineScaleSetProperties{
Overprovision: to.BoolPtr(false),
UpgradePolicy: &armcompute.UpgradePolicy{
Mode: armcompute.UpgradeModeManual.ToPtr(),
AutomaticOSUpgradePolicy: &armcompute.AutomaticOSUpgradePolicy{
EnableAutomaticOSUpgrade: to.BoolPtr(false),
DisableAutomaticRollback: to.BoolPtr(false),
},
},
VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{
OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{
ComputerNamePrefix: to.StringPtr(s.NamePrefix),
AdminUsername: to.StringPtr(s.Username),
AdminPassword: to.StringPtr(s.Password),
LinuxConfiguration: &armcompute.LinuxConfiguration{},
},
StorageProfile: &armcompute.VirtualMachineScaleSetStorageProfile{
ImageReference: &armcompute.ImageReference{
ID: to.StringPtr(s.Image),
},
},
NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{
NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{
{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{
Primary: to.BoolPtr(true),
EnableIPForwarding: to.BoolPtr(true),
IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{
{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{
Subnet: &armcompute.APIEntityReference{
ID: to.StringPtr(s.SubnetID),
},
PublicIPAddressConfiguration: &armcompute.VirtualMachineScaleSetPublicIPAddressConfiguration{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetPublicIPAddressConfigurationProperties{
IdleTimeoutInMinutes: to.Int32Ptr(15), // default per https://docs.microsoft.com/en-us/azure/virtual-machine-scale-sets/virtual-machine-scale-sets-networking#creating-a-scale-set-with-public-ip-per-virtual-machine
},
},
},
},
},
NetworkSecurityGroup: &armcompute.SubResource{
ID: to.StringPtr(s.NetworkSecurityGroup),
},
},
},
},
},
SecurityProfile: &armcompute.SecurityProfile{
SecurityType: armcompute.SecurityTypesTrustedLaunch.ToPtr(),
UefiSettings: &armcompute.UefiSettings{VTpmEnabled: to.BoolPtr(true)},
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.BoolPtr(true),
},
},
},
},
Identity: &armcompute.VirtualMachineScaleSetIdentity{
Type: armcompute.ResourceIdentityTypeUserAssigned.ToPtr(),
UserAssignedIdentities: map[string]*armcompute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
s.UserAssignedIdentity: {},
},
},
}
}
// GeneratePassword is a helper function to generate a random password
// for Azure's scale set.
func GeneratePassword() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 16
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
// bypass password rules
pw = append(pw, []byte("Aa1!")...)
return string(pw), nil
}

111
cli/azure/scaleset_test.go Normal file
View file

@ -0,0 +1,111 @@
package azure
import (
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFirewallPermissions(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
scaleSet := ScaleSet{
Name: "name",
NamePrefix: "constellation-",
Location: "UK South",
InstanceType: "Standard_D2s_v3",
Count: 3,
Username: "constellation",
SubnetID: "subnet-id",
NetworkSecurityGroup: "network-security-group",
Password: "password",
Image: "image",
UserAssignedIdentity: "user-identity",
}
scaleSetAzure := scaleSet.Azure()
require.NotNil(scaleSetAzure.Name)
assert.Equal(scaleSet.Name, *scaleSetAzure.Name)
require.NotNil(scaleSetAzure.Location)
assert.Equal(scaleSet.Location, *scaleSetAzure.Location)
require.NotNil(scaleSetAzure.SKU)
require.NotNil(scaleSetAzure.SKU.Name)
assert.Equal(scaleSet.InstanceType, *scaleSetAzure.SKU.Name)
require.NotNil(scaleSetAzure.SKU.Capacity)
assert.Equal(scaleSet.Count, *scaleSetAzure.SKU.Capacity)
require.NotNil(scaleSetAzure.Properties)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
assert.Equal(scaleSet.NamePrefix, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
assert.Equal(scaleSet.Username, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
assert.Equal(scaleSet.Password, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
// Verify image
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
assert.Equal(scaleSet.Image, *scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
// Verify network
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile)
require.Len(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations, 1)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0])
networkConfig := scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0]
require.NotNil(networkConfig.Name)
assert.Equal(scaleSet.Name, *networkConfig.Name)
require.NotNil(networkConfig.Properties)
require.Len(networkConfig.Properties.IPConfigurations, 1)
require.NotNil(networkConfig.Properties.IPConfigurations[0])
ipConfig := networkConfig.Properties.IPConfigurations[0]
require.NotNil(ipConfig.Name)
assert.Equal(scaleSet.Name, *ipConfig.Name)
require.NotNil(ipConfig.Properties)
require.NotNil(ipConfig.Properties.Subnet)
require.NotNil(ipConfig.Properties.Subnet.ID)
assert.Equal(scaleSet.SubnetID, *ipConfig.Properties.Subnet.ID)
require.NotNil(networkConfig.Properties.NetworkSecurityGroup)
assert.Equal(scaleSet.NetworkSecurityGroup, *networkConfig.Properties.NetworkSecurityGroup.ID)
// Verify vTPM
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
assert.Equal(armcompute.SecurityTypesTrustedLaunch, *scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
assert.True(*scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
// Verify UserAssignedIdentity
require.NotNil(scaleSetAzure.Identity)
require.NotNil(scaleSetAzure.Identity.Type)
assert.Equal(armcompute.ResourceIdentityTypeUserAssigned, *scaleSetAzure.Identity.Type)
require.Len(scaleSetAzure.Identity.UserAssignedIdentities, 1)
assert.Contains(scaleSetAzure.Identity.UserAssignedIdentities, scaleSet.UserAssignedIdentity)
}
func TestGeneratePassword(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
pw, err := GeneratePassword()
require.NoError(err)
assert.Len(pw, 20)
}

View file

@ -0,0 +1,88 @@
package cloudtypes
import (
"fmt"
"strconv"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
"google.golang.org/protobuf/proto"
)
type FirewallRule struct {
Name string
Description string
Protocol string
IPRange string
Port int
}
type Firewall []FirewallRule
func (f Firewall) GCP() []*computepb.Firewall {
var fw []*computepb.Firewall
for _, rule := range f {
var destRange []string = nil
if rule.IPRange != "" {
destRange = append(destRange, rule.IPRange)
}
fw = append(fw, &computepb.Firewall{
Allowed: []*computepb.Allowed{
{
IPProtocol: proto.String(rule.Protocol),
Ports: []string{fmt.Sprint(rule.Port)},
},
},
Description: proto.String(rule.Description),
DestinationRanges: destRange,
Name: proto.String(rule.Name),
})
}
return fw
}
func (f Firewall) Azure() []*armnetwork.SecurityRule {
var fw []*armnetwork.SecurityRule
for i, rule := range f {
// format string according to armnetwork.SecurityRuleProtocol specification
protocol := strings.Title(strings.ToLower(rule.Protocol))
fw = append(fw, &armnetwork.SecurityRule{
Name: proto.String(rule.Name),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String(rule.Description),
Protocol: (*armnetwork.SecurityRuleProtocol)(proto.String(protocol)),
SourceAddressPrefix: proto.String(rule.IPRange),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String(rule.IPRange),
DestinationPortRange: proto.String(strconv.Itoa(rule.Port)),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
// Each security role needs a unique priority
Priority: proto.Int32(int32(100 * (i + 1))),
},
})
}
return fw
}
func (f Firewall) AWS() []ec2types.IpPermission {
var fw []ec2types.IpPermission
for _, rule := range f {
fw = append(fw, ec2types.IpPermission{
FromPort: proto.Int32(int32(rule.Port)),
ToPort: proto.Int32(int32(rule.Port)),
IpProtocol: proto.String(rule.Protocol),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String(rule.IPRange),
Description: proto.String(rule.Description),
},
},
})
}
return fw
}

View file

@ -0,0 +1,188 @@
package cloudtypes
import (
"strconv"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestFirewallGCP(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testFw := Firewall{
{
Name: "test-1",
Description: "This is the Test-1 Permission",
Protocol: "tcp",
IPRange: "",
Port: 9000,
},
{
Name: "test-2",
Description: "This is the Test-2 Permission",
Protocol: "udp",
IPRange: "",
Port: 51820,
},
}
firewalls := testFw.GCP()
assert.Equal(2, len(firewalls))
// Check permissions
for i := 0; i < len(testFw); i++ {
firewall1 := firewalls[i]
actualPermission1 := firewall1.Allowed[0]
actualPort, err := strconv.Atoi(actualPermission1.GetPorts()[0])
require.NoError(err)
assert.Equal(testFw[i].Port, actualPort)
assert.Equal(testFw[i].Protocol, actualPermission1.GetIPProtocol())
assert.Equal(testFw[i].Name, firewall1.GetName())
assert.Equal(testFw[i].Description, firewall1.GetDescription())
}
}
func TestFirewallAzure(t *testing.T) {
assert := assert.New(t)
input := Firewall{
{
Name: "perm1",
Description: "perm1 description",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 22,
},
{
Name: "perm2",
Description: "perm2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
Port: 4433,
},
{
Name: "perm3",
Description: "perm3 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
Port: 4433,
},
}
expectedOutput := []*armnetwork.SecurityRule{
{
Name: proto.String("perm1"),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String("perm1 description"),
Protocol: armnetwork.SecurityRuleProtocolTCP.ToPtr(),
SourceAddressPrefix: proto.String("192.0.2.0/24"),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
DestinationPortRange: proto.String("22"),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
Priority: proto.Int32(100),
},
},
{
Name: proto.String("perm2"),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String("perm2 description"),
Protocol: armnetwork.SecurityRuleProtocolUDP.ToPtr(),
SourceAddressPrefix: proto.String("192.0.2.0/24"),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
DestinationPortRange: proto.String("4433"),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
Priority: proto.Int32(200),
},
},
{
Name: proto.String("perm3"),
Properties: &armnetwork.SecurityRulePropertiesFormat{
Description: proto.String("perm3 description"),
Protocol: armnetwork.SecurityRuleProtocolTCP.ToPtr(),
SourceAddressPrefix: proto.String("192.0.2.0/24"),
SourcePortRange: proto.String("*"),
DestinationAddressPrefix: proto.String("192.0.2.0/24"),
DestinationPortRange: proto.String("4433"),
Access: armnetwork.SecurityRuleAccessAllow.ToPtr(),
Direction: armnetwork.SecurityRuleDirectionInbound.ToPtr(),
Priority: proto.Int32(300),
},
},
}
out := input.Azure()
assert.Equal(expectedOutput, out)
}
func TestIPPermissonsToAWS(t *testing.T) {
assert := assert.New(t)
input := Firewall{
{
Description: "perm1",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 22,
},
{
Description: "perm2",
Protocol: "UDP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
{
Description: "perm3",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
}
expectedOutput := []ec2types.IpPermission{
{
FromPort: proto.Int32(int32(22)),
ToPort: proto.Int32(int32(22)),
IpProtocol: proto.String("TCP"),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String("192.0.2.0/24"),
Description: proto.String("perm1"),
},
},
},
{
FromPort: proto.Int32(int32(4433)),
ToPort: proto.Int32(int32(4433)),
IpProtocol: proto.String("UDP"),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String("192.0.2.0/24"),
Description: proto.String("perm2"),
},
},
},
{
FromPort: proto.Int32(int32(4433)),
ToPort: proto.Int32(int32(4433)),
IpProtocol: proto.String("TCP"),
IpRanges: []ec2types.IpRange{
{
CidrIp: proto.String("192.0.2.0/24"),
Description: proto.String("perm3"),
},
},
},
}
out := input.AWS()
assert.Equal(expectedOutput, out)
}

View file

@ -0,0 +1,13 @@
package cloudprovider
//go:generate stringer -type=CloudProvider
// CloudProvider is cloud provider used by the CLI.
type CloudProvider uint32
const (
Unknown CloudProvider = iota
AWS
Azure
GCP
)

View file

@ -0,0 +1,26 @@
// Code generated by "stringer -type=CloudProvider"; DO NOT EDIT.
package cloudprovider
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[AWS-1]
_ = x[Azure-2]
_ = x[GCP-3]
}
const _CloudProvider_name = "UnknownAWSAzureGCP"
var _CloudProvider_index = [...]uint8{0, 7, 10, 15, 18}
func (i CloudProvider) String() string {
if i >= CloudProvider(len(_CloudProvider_index)-1) {
return "CloudProvider(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _CloudProvider_name[_CloudProvider_index[i]:_CloudProvider_index[i+1]]
}

22
cli/cmd/azureclient.go Normal file
View file

@ -0,0 +1,22 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/internal/state"
)
type azureclient interface {
GetState() (state.ConstellationState, error)
SetState(state.ConstellationState) error
CreateResourceGroup(ctx context.Context) error
CreateVirtualNetwork(ctx context.Context) error
CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error
CreateInstances(ctx context.Context, input client.CreateInstancesInput) error
// TODO: deprecate as soon as scale sets are available
CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error
CreateServicePrincipal(ctx context.Context) (string, error)
TerminateResourceGroup(ctx context.Context) error
TerminateServicePrincipal(ctx context.Context) error
}

194
cli/cmd/azureclient_test.go Normal file
View file

@ -0,0 +1,194 @@
package cmd
import (
"context"
"strconv"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
)
type fakeAzureClient struct {
nodes azure.Instances
coordinators azure.Instances
resourceGroup string
name string
uid string
location string
subscriptionID string
tenantID string
subnetID string
coordinatorsScaleSet string
nodesScaleSet string
networkSecurityGroup string
adAppObjectID string
}
func (c *fakeAzureClient) GetState() (state.ConstellationState, error) {
stat := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: c.nodes,
AzureCoordinators: c.coordinators,
Name: c.name,
UID: c.uid,
AzureResourceGroup: c.resourceGroup,
AzureLocation: c.location,
AzureSubscription: c.subscriptionID,
AzureTenant: c.tenantID,
AzureSubnet: c.subnetID,
AzureNetworkSecurityGroup: c.networkSecurityGroup,
AzureNodesScaleSet: c.nodesScaleSet,
AzureCoordinatorsScaleSet: c.coordinatorsScaleSet,
AzureADAppObjectID: c.adAppObjectID,
}
return stat, nil
}
func (c *fakeAzureClient) SetState(stat state.ConstellationState) error {
c.nodes = stat.AzureNodes
c.coordinators = stat.AzureCoordinators
c.name = stat.Name
c.uid = stat.UID
c.resourceGroup = stat.AzureResourceGroup
c.location = stat.AzureLocation
c.subscriptionID = stat.AzureSubscription
c.tenantID = stat.AzureTenant
c.subnetID = stat.AzureSubnet
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
c.nodesScaleSet = stat.AzureNodesScaleSet
c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
c.adAppObjectID = stat.AzureADAppObjectID
return nil
}
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
c.resourceGroup = "resource-group"
return nil
}
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
c.subnetID = "subnet"
return nil
}
func (c *fakeAzureClient) CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error {
c.networkSecurityGroup = "network-security-group"
return nil
}
func (c *fakeAzureClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
c.coordinatorsScaleSet = "coordinators-scale-set"
c.nodesScaleSet = "nodes-scale-set"
c.nodes = make(azure.Instances)
for i := 0; i < input.Count-1; i++ {
id := strconv.Itoa(i)
c.nodes[id] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
}
c.coordinators = make(azure.Instances)
c.coordinators["0"] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
return nil
}
// TODO: deprecate as soon as scale sets are available.
func (c *fakeAzureClient) CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error {
c.nodes = make(azure.Instances)
for i := 0; i < input.Count-1; i++ {
id := strconv.Itoa(i)
c.nodes[id] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
}
c.coordinators = make(azure.Instances)
c.coordinators["0"] = azure.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
return nil
}
func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, error) {
c.adAppObjectID = "00000000-0000-0000-0000-000000000001"
return client.ApplicationCredentials{
ClientID: "client-id",
ClientSecret: "client-secret",
}.ConvertToCloudServiceAccountURI(), nil
}
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
c.nodes = nil
c.coordinators = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.nodesScaleSet = ""
c.coordinatorsScaleSet = ""
return nil
}
func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
c.adAppObjectID = ""
return nil
}
type stubAzureClient struct {
terminateResourceGroupCalled bool
getStateErr error
setStateErr error
createResourceGroupErr error
createVirtualNetworkErr error
createSecurityGroupErr error
createInstancesErr error
createServicePrincipalErr error
terminateResourceGroupErr error
terminateServicePrincipalErr error
}
func (c *stubAzureClient) GetState() (state.ConstellationState, error) {
return state.ConstellationState{}, c.getStateErr
}
func (c *stubAzureClient) SetState(state.ConstellationState) error {
return c.setStateErr
}
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
return c.createResourceGroupErr
}
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
return c.createVirtualNetworkErr
}
func (c *stubAzureClient) CreateSecurityGroup(ctx context.Context, input client.NetworkSecurityGroupInput) error {
return c.createSecurityGroupErr
}
func (c *stubAzureClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
return c.createInstancesErr
}
// TODO: deprecate as soon as scale sets are available.
func (c *stubAzureClient) CreateInstancesVMs(ctx context.Context, input client.CreateInstancesInput) error {
return c.createInstancesErr
}
func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, error) {
return client.ApplicationCredentials{
ClientID: "00000000-0000-0000-0000-000000000000",
ClientSecret: "secret",
}.ConvertToCloudServiceAccountURI(), c.createServicePrincipalErr
}
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error {
c.terminateResourceGroupCalled = true
return c.terminateResourceGroupErr
}
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {
return c.terminateServicePrincipalErr
}

13
cli/cmd/constants.go Normal file
View file

@ -0,0 +1,13 @@
package cmd
// wireguardKeyLength is the length of a WireGuard key in byte.
const wireguardKeyLength = 32
// masterSecretLengthDefault is the default length in bytes for CLI generated master secrets.
const masterSecretLengthDefault = 32
// masterSecretLengthMin is the minimal length in bytes for user provided master secrets.
const masterSecretLengthMin = 16
// constellationNameLength is the maximum length of a Constellation's name.
const constellationNameLength = 37

41
cli/cmd/create.go Normal file
View file

@ -0,0 +1,41 @@
package cmd
import (
"errors"
"fmt"
"io/fs"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/cobra"
)
func newCreateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "create aws|gcp|azure",
Short: "Create instances on a cloud platform for your Constellation.",
Long: "Create instances on a cloud platform for your Constellation.",
}
cmd.PersistentFlags().String("name", "constell", "Set this flag to create the Constellation with the specified name.")
cmd.PersistentFlags().BoolP("yes", "y", false, "Set this flag to create the Constellation without further confirmation.")
cmd.AddCommand(newCreateAWSCmd())
cmd.AddCommand(newCreateGCPCmd())
cmd.AddCommand(newCreateAzureCmd())
return cmd
}
// checkDirClean checks if files of a previous Constellation are left in the current working dir.
func checkDirClean(fileHandler file.Handler, config *config.Config) error {
if _, err := fileHandler.Stat(*config.StatePath); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", *config.StatePath)
}
if _, err := fileHandler.Stat(*config.AdminConfPath); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", *config.AdminConfPath)
}
if _, err := fileHandler.Stat(*config.MasterSecretPath); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, clean it up first", *config.MasterSecretPath)
}
return nil
}

138
cli/cmd/create_aws.go Normal file
View file

@ -0,0 +1,138 @@
package cmd
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
)
func newCreateAWSCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "aws NUMBER SIZE",
Short: "Create a Constellation of NUMBER nodes of SIZE on AWS.",
Long: "Create a Constellation of NUMBER nodes of SIZE on AWS.",
Example: "aws 4 2xlarge",
Args: cobra.MatchAll(
cobra.ExactArgs(2),
isIntGreaterArg(0, 1),
isEC2InstanceType(1),
),
ValidArgsFunction: createAWSCompletion,
RunE: runCreateAWS,
}
return cmd
}
// runCreateAWS runs the create command.
func runCreateAWS(cmd *cobra.Command, args []string) error {
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
size := strings.ToLower(args[1])
name, err := cmd.Flags().GetString("name")
if err != nil {
return err
}
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
client, err := client.NewFromDefault(cmd.Context())
if err != nil {
return err
}
return createAWS(cmd, client, fileHandler, config, size, name, count)
}
// createAWS uses the given client to create 'count' instances of 'size'.
// After the instances are running, they are tagged with the default tags.
// On success, the state of the client is saved to the state file.
func createAWS(cmd *cobra.Command, cl ec2client, fileHandler file.Handler, config *config.Config, size, name string, count int) (retErr error) {
if err := checkDirClean(fileHandler, config); err != nil {
return err
}
const maxLength = 255
if len(name) > maxLength {
return fmt.Errorf("name for constellation too long, maximum length is %d: %s", maxLength, name)
}
ec2Tags := append([]ec2.Tag{}, *config.Provider.EC2.Tags...)
ec2Tags = append(ec2Tags, ec2.Tag{Key: "Name", Value: name})
ok, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
if !ok {
// Ask user to confirm action.
cmd.Printf("The following Constellation will be created:\n")
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the Constellation was aborted.")
return nil
}
}
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerAWS{client: cl})
if err := cl.CreateSecurityGroup(cmd.Context(), *config.Provider.EC2.SecurityGroupInput); err != nil {
return err
}
createInput := client.CreateInput{
ImageId: *config.Provider.EC2.Image,
InstanceType: size,
Count: count,
Tags: ec2Tags,
}
if err := cl.CreateInstances(cmd.Context(), createInput); err != nil {
return err
}
stat, err := cl.GetState()
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
return err
}
cmd.Println("Your Constellation was created successfully.")
return nil
}
// createAWSCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func createAWSCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{}, cobra.ShellCompDirectiveNoFileComp
case 1:
return []string{
"4xlarge",
"8xlarge",
"12xlarge",
"16xlarge",
"24xlarge",
}, cobra.ShellCompDirectiveDefault
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

205
cli/cmd/create_aws_test.go Normal file
View file

@ -0,0 +1,205 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateAWSCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid size 4XL": {[]string{"5", "4xlarge"}, false},
"valid size 8XL": {[]string{"4", "8xlarge"}, false},
"valid size 12XL": {[]string{"3", "12xlarge"}, false},
"valid size 16XL": {[]string{"2", "16xlarge"}, false},
"valid size 24XL": {[]string{"2", "24xlarge"}, false},
"valid short 12XL": {[]string{"4", "12xl"}, false},
"valid short 24XL": {[]string{"2", "24xl"}, false},
"valid capitalized": {[]string{"3", "24XlARge"}, false},
"valid short capitalized": {[]string{"4", "16XL"}, false},
"invalid to many arguments": {[]string{"2", "4xl", "2xl"}, true},
"invalid to many arguments 2": {[]string{"2", "4xl", "2"}, true},
"invalidOnlyOneInstance": {[]string{"1", "4xl"}, true},
"invalid first is no int": {[]string{"xl", "4xl"}, true},
"invalid second is no size": {[]string{"2", "2"}, true},
"invalid wrong order": {[]string{"4xl", "2"}, true},
}
cmd := newCreateAWSCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateAWS(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
someErr := errors.New("failed")
config := config.Default()
testCases := map[string]struct {
existingState *state.ConstellationState
client ec2client
interactive bool
interactiveStdin string
stateExpected state.ConstellationState
errExpected bool
}{
"create some instances": {
client: &fakeEc2Client{},
stateExpected: testState,
errExpected: false,
},
"state already exists": {
existingState: &testState,
client: &fakeEc2Client{},
errExpected: true,
},
"create some instances interactive": {
client: &fakeEc2Client{},
interactive: true,
interactiveStdin: "y\n",
stateExpected: testState,
errExpected: false,
},
"fail CreateSecurityGroup": {
client: &stubEc2Client{createSecurityGroupErr: someErr},
errExpected: true,
},
"fail CreateInstances": {
client: &stubEc2Client{createInstancesErr: someErr},
errExpected: true,
},
"fail GetState": {
client: &stubEc2Client{getStateErr: someErr},
errExpected: true,
},
"error on rollback": {
client: &stubEc2Client{createInstancesErr: someErr, deleteSecurityGroupErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateAWSCmd()
cmd.Flags().BoolP("yes", "y", false, "")
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.interactiveStdin)
cmd.SetIn(in)
if !tc.interactive {
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
}
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
if tc.existingState != nil {
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
}
err := createAWS(cmd, tc.client, fileHandler, config, "xlarge", "name", 3)
if tc.errExpected {
assert.Error(err)
if stubClient, ok := tc.client.(*stubEc2Client); ok {
// Should have made a rollback on error.
assert.True(stubClient.terminateInstancesCalled)
assert.True(stubClient.deleteSecurityGroupCalled)
}
} else {
assert.NoError(err)
var stat state.ConstellationState
err := fileHandler.ReadJSON(*config.StatePath, &stat)
assert.NoError(err)
assert.Equal(tc.stateExpected, stat)
}
})
}
}
func TestCreateAWSCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "21",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"23"},
toComplete: "4xl",
resultExpected: []string{
"4xlarge",
"8xlarge",
"12xlarge",
"16xlarge",
"24xlarge",
},
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"third arg": {
args: []string{"23", "4xlarge"},
toComplete: "xl",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createAWSCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

137
cli/cmd/create_azure.go Normal file
View file

@ -0,0 +1,137 @@
package cmd
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newCreateAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure",
Short: "Create a Constellation of NUMBER nodes of SIZE on Azure.",
Long: "Create a Constellation of NUMBER nodes of SIZE on Azure.",
Args: cobra.MatchAll(
cobra.ExactArgs(2),
isIntGreaterArg(0, 1),
isAzureInstanceType(1),
),
ValidArgsFunction: createAzureCompletion,
RunE: runCreateAzure,
}
return cmd
}
// runCreateAzure runs the create command.
func runCreateAzure(cmd *cobra.Command, args []string) error {
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
size := strings.ToLower(args[1])
subscriptionID := "0d202bbb-4fa7-4af8-8125-58c269a05435" // TODO: This will be user input
tenantID := "adb650a8-5da3-4b15-b4b0-3daf65ff7626" // TODO: This will be user input
location := "North Europe" // TODO: This will be user input
name, err := cmd.Flags().GetString("name")
if err != nil {
return err
}
if len(name) > constellationNameLength {
return fmt.Errorf("name for constellation too long, maximum length is %d got %d: %s", constellationNameLength, len(name), name)
}
client, err := client.NewInitialized(
subscriptionID,
tenantID,
name,
location,
)
if err != nil {
return err
}
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
return createAzure(cmd, client, fileHandler, config, size, count)
}
func createAzure(cmd *cobra.Command, cl azureclient, fileHandler file.Handler, config *config.Config, size string, count int) (retErr error) {
if err := checkDirClean(fileHandler, config); err != nil {
return err
}
ok, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
if !ok {
// Ask user to confirm action.
cmd.Printf("The following Constellation will be created:\n")
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the Constellation was aborted.")
return nil
}
}
// Create all azure resources
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerAzure{client: cl})
if err := cl.CreateResourceGroup(cmd.Context()); err != nil {
return err
}
if err := cl.CreateVirtualNetwork(cmd.Context()); err != nil {
return err
}
if err := cl.CreateSecurityGroup(cmd.Context(), *config.Provider.Azure.NetworkSecurityGroupInput); err != nil {
return err
}
if err := cl.CreateInstances(cmd.Context(), client.CreateInstancesInput{
Count: count,
InstanceType: size,
Image: *config.Provider.Azure.Image,
UserAssingedIdentity: *config.Provider.Azure.UserAssignedIdentity,
}); err != nil {
return err
}
stat, err := cl.GetState()
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
return err
}
cmd.Println("Your Constellation was created successfully.")
return nil
}
// createAzureCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func createAzureCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{}, cobra.ShellCompDirectiveNoFileComp
case 1:
return azure.InstanceTypes, cobra.ShellCompDirectiveDefault
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

View file

@ -0,0 +1,206 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateAzureCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid create 1": {[]string{"3", "Standard_DC2as_v5"}, false},
"valid create 2": {[]string{"7", "Standard_DC4as_v5"}, false},
"valid create 3": {[]string{"2", "Standard_DC8as_v5"}, false},
"invalid to many arguments": {[]string{"2", "Standard_DC2as_v5", "Standard_DC2as_v5"}, true},
"invalid to many arguments 2": {[]string{"2", "Standard_DC2as_v5", "2"}, true},
"invalidOnlyOneInstance": {[]string{"1", "Standard_DC2as_v5"}, true},
"invalid first is no int": {[]string{"Standard_DC2as_v5", "Standard_DC2as_v5"}, true},
"invalid second is no size": {[]string{"2", "2"}, true},
"invalid wrong order": {[]string{"Standard_DC2as_v5", "2"}, true},
}
cmd := newCreateAzureCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateAzure(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "resource-group",
AzureSubnet: "subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "nodes-scale-set",
AzureCoordinatorsScaleSet: "coordinators-scale-set",
}
someErr := errors.New("failed")
config := config.Default()
testCases := map[string]struct {
existingState *state.ConstellationState
client azureclient
interactive bool
interactiveStdin string
stateExpected state.ConstellationState
errExpected bool
}{
"create some instances": {
client: &fakeAzureClient{},
stateExpected: testState,
},
"state already exists": {
existingState: &testState,
client: &fakeAzureClient{},
errExpected: true,
},
"create some instances interactive": {
client: &fakeAzureClient{},
interactive: true,
interactiveStdin: "y\n",
stateExpected: testState,
errExpected: false,
},
"fail getState": {
client: &stubAzureClient{getStateErr: someErr},
errExpected: true,
},
"fail createVirtualNetwork": {
client: &stubAzureClient{createVirtualNetworkErr: someErr},
errExpected: true,
},
"fail createSecurityGroup": {
client: &stubAzureClient{createSecurityGroupErr: someErr},
errExpected: true,
},
"fail createInstances": {
client: &stubAzureClient{createInstancesErr: someErr},
errExpected: true,
},
"fail createResourceGroup": {
client: &stubAzureClient{createResourceGroupErr: someErr},
errExpected: true,
},
"error on rollback": {
client: &stubAzureClient{createInstancesErr: someErr, terminateResourceGroupErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateAzureCmd()
cmd.Flags().BoolP("yes", "y", false, "")
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.interactiveStdin)
cmd.SetIn(in)
if !tc.interactive {
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
}
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
if tc.existingState != nil {
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
}
err := createAzure(cmd, tc.client, fileHandler, config, "Standard_D2s_v3", 3)
if tc.errExpected {
assert.Error(err)
if stubClient, ok := tc.client.(*stubAzureClient); ok {
// Should have made a rollback on error.
assert.True(stubClient.terminateResourceGroupCalled)
}
} else {
assert.NoError(err)
var state state.ConstellationState
err := fileHandler.ReadJSON(*config.StatePath, &state)
assert.NoError(err)
assert.Equal(tc.stateExpected, state)
}
})
}
}
func TestCreateAzureCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "21",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"23"},
toComplete: "Standard_D",
resultExpected: azure.InstanceTypes,
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"third arg": {
args: []string{"23", "Standard_D2s_v3"},
toComplete: "Standard_D",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createAzureCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

132
cli/cmd/create_gcp.go Normal file
View file

@ -0,0 +1,132 @@
package cmd
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newCreateGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp",
Short: "Create a Constellation of NUMBER nodes of SIZE on Google Cloud Platform.",
Long: "Create a Constellation of NUMBER nodes of SIZE on Google Cloud Platform.",
Args: cobra.MatchAll(
cobra.ExactArgs(2),
isIntGreaterArg(0, 1),
isGCPInstanceType(1),
),
ValidArgsFunction: createGCPCompletion,
RunE: runCreateGCP,
}
return cmd
}
// runCreateGCP runs the create command.
func runCreateGCP(cmd *cobra.Command, args []string) error {
count, _ := strconv.Atoi(args[0]) // err already checked in args validation
size := strings.ToLower(args[1])
project := "constellation-331613" // TODO: This will be user input
zone := "us-central1-c" // TODO: This will be user input
region := "us-central1" // TODO: This will be user input
name, err := cmd.Flags().GetString("name")
if err != nil {
return err
}
if len(name) > constellationNameLength {
return fmt.Errorf("name for constellation too long, maximum length is %d got %d: %s", constellationNameLength, len(name), name)
}
client, err := client.NewInitialized(cmd.Context(), project, zone, region, name)
if err != nil {
return err
}
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
fileHandler := file.NewHandler(afero.NewOsFs())
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
return createGCP(cmd, client, fileHandler, config, size, count)
}
func createGCP(cmd *cobra.Command, cl gcpclient, fileHandler file.Handler, config *config.Config, size string, count int) (retErr error) {
if err := checkDirClean(fileHandler, config); err != nil {
return err
}
createInput := client.CreateInstancesInput{
Count: count,
ImageId: *config.Provider.GCP.Image,
InstanceType: size,
KubeEnv: gcp.KubeEnv,
DisableCVM: *config.Provider.GCP.DisableCVM,
}
ok, err := cmd.Flags().GetBool("yes")
if err != nil {
return err
}
if !ok {
// Ask user to confirm action.
cmd.Printf("The following Constellation will be created:\n")
cmd.Printf("%d nodes of size %s will be created.\n", count, size)
ok, err := askToConfirm(cmd, "Do you want to create this Constellation?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the Constellation was aborted.")
return nil
}
}
// Create all gcp resources
defer rollbackOnError(context.Background(), cmd.OutOrStdout(), &retErr, &rollbackerGCP{client: cl})
if err := cl.CreateVPCs(cmd.Context(), *config.Provider.GCP.VPCsInput); err != nil {
return err
}
if err := cl.CreateFirewall(cmd.Context(), *config.Provider.GCP.FirewallInput); err != nil {
return err
}
if err := cl.CreateInstances(cmd.Context(), createInput); err != nil {
return err
}
stat, err := cl.GetState()
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, false); err != nil {
return err
}
cmd.Println("Your Constellation was created successfully.")
return nil
}
// createGCPCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func createGCPCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{}, cobra.ShellCompDirectiveNoFileComp
case 1:
return gcp.InstanceTypes, cobra.ShellCompDirectiveDefault
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

206
cli/cmd/create_gcp_test.go Normal file
View file

@ -0,0 +1,206 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateGCPCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid create 1": {[]string{"3", "n2d-standard-2"}, false},
"valid create 2": {[]string{"7", "n2d-standard-16"}, false},
"valid create 3": {[]string{"2", "n2d-standard-96"}, false},
"invalid to many arguments": {[]string{"2", "n2d-standard-2", "n2d-standard-2"}, true},
"invalid to many arguments 2": {[]string{"2", "n2d-standard-2", "2"}, true},
"invalidOnlyOneInstance": {[]string{"1", "n2d-standard-2"}, true},
"invalid first is no int": {[]string{"n2d-standard-2", "n2d-standard-2"}, true},
"invalid second is no size": {[]string{"2", "2"}, true},
"invalid wrong order": {[]string{"n2d-standard-2", "2"}, true},
}
cmd := newCreateGCPCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateGCP(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "node-template",
GCPCoordinatorInstanceTemplate: "coordinator-template",
GCPNetwork: "network",
GCPSubnetwork: "subnetwork",
GCPFirewalls: []string{"coordinator", "wireguard", "ssh"},
}
someErr := errors.New("failed")
config := config.Default()
testCases := map[string]struct {
existingState *state.ConstellationState
client gcpclient
interactive bool
interactiveStdin string
stateExpected state.ConstellationState
errExpected bool
}{
"create some instances": {
client: &fakeGcpClient{},
stateExpected: testState,
},
"state already exists": {
existingState: &testState,
client: &fakeGcpClient{},
errExpected: true,
},
"create some instances interactive": {
client: &fakeGcpClient{},
interactive: true,
interactiveStdin: "y\n",
stateExpected: testState,
errExpected: false,
},
"fail getState": {
client: &stubGcpClient{getStateErr: someErr},
errExpected: true,
},
"fail createVPCs": {
client: &stubGcpClient{createVPCsErr: someErr},
errExpected: true,
},
"fail createFirewall": {
client: &stubGcpClient{createFirewallErr: someErr},
errExpected: true,
},
"fail createInstances": {
client: &stubGcpClient{createInstancesErr: someErr},
errExpected: true,
},
"error on rollback": {
client: &stubGcpClient{createInstancesErr: someErr, terminateVPCsErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateGCPCmd()
cmd.Flags().BoolP("yes", "y", false, "")
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.interactiveStdin)
cmd.SetIn(in)
if !tc.interactive {
require.NoError(cmd.Flags().Set("yes", "true")) // disable interactivity
}
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
if tc.existingState != nil {
require.NoError(fileHandler.WriteJSON(*config.StatePath, *tc.existingState, false))
}
err := createGCP(cmd, tc.client, fileHandler, config, "n2d-standard-2", 3)
if tc.errExpected {
assert.Error(err)
if stubClient, ok := tc.client.(*stubGcpClient); ok {
// Should have made a rollback on error.
assert.True(stubClient.terminateFirewallCalled)
assert.True(stubClient.terminateInstancesCalled)
assert.True(stubClient.terminateVPCsCalled)
}
} else {
assert.NoError(err)
var stat state.ConstellationState
err := fileHandler.ReadJSON(*config.StatePath, &stat)
assert.NoError(err)
assert.Equal(tc.stateExpected, stat)
}
})
}
}
func TestCreateGCPCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "21",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"23"},
toComplete: "n2d-stan",
resultExpected: gcp.InstanceTypes,
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"third arg": {
args: []string{"23", "n2d-standard-2"},
toComplete: "n2d-stan",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createGCPCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

64
cli/cmd/create_test.go Normal file
View file

@ -0,0 +1,64 @@
package cmd
import (
"testing"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCheckDirClean(t *testing.T) {
config := config.Default()
testCases := map[string]struct {
fileHandler file.Handler
existingFiles []string
wantErr bool
}{
"no file exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
},
"adminconf exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.AdminConfPath},
wantErr: true,
},
"master secret exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.MasterSecretPath},
wantErr: true,
},
"state file exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.StatePath},
wantErr: true,
},
"multiple exist": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{*config.AdminConfPath, *config.MasterSecretPath, *config.StatePath},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
for _, f := range tc.existingFiles {
require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, false))
}
err := checkDirClean(tc.fileHandler, config)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

17
cli/cmd/ec2client.go Normal file
View file

@ -0,0 +1,17 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/internal/state"
)
type ec2client interface {
GetState() (state.ConstellationState, error)
SetState(stat state.ConstellationState) error
CreateInstances(ctx context.Context, input client.CreateInput) error
TerminateInstances(ctx context.Context) error
CreateSecurityGroup(ctx context.Context, input client.SecurityGroupInput) error
DeleteSecurityGroup(ctx context.Context) error
}

139
cli/cmd/ec2client_test.go Normal file
View file

@ -0,0 +1,139 @@
package cmd
import (
"context"
"errors"
"strconv"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/internal/state"
)
type fakeEc2Client struct {
instances ec2.Instances
securityGroup string
ec2state []fakeEc2Instance
}
func (c *fakeEc2Client) GetState() (state.ConstellationState, error) {
if len(c.instances) == 0 {
return state.ConstellationState{}, errors.New("client has no instances")
}
stat := state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: c.instances,
EC2SecurityGroup: c.securityGroup,
}
for id, instance := range c.instances {
instance.PrivateIP = "192.0.2.1"
instance.PublicIP = "192.0.2.2"
c.instances[id] = instance
}
return stat, nil
}
func (c *fakeEc2Client) SetState(stat state.ConstellationState) error {
if len(stat.EC2Instances) == 0 {
return errors.New("state has no instances")
}
c.instances = stat.EC2Instances
c.securityGroup = stat.EC2SecurityGroup
return nil
}
func (c *fakeEc2Client) CreateInstances(_ context.Context, input client.CreateInput) error {
if c.securityGroup == "" {
return errors.New("client has no security group")
}
if c.instances == nil {
c.instances = make(ec2.Instances)
}
for i := 0; i < input.Count; i++ {
id := "id-" + strconv.Itoa(len(c.ec2state))
c.ec2state = append(c.ec2state, fakeEc2Instance{
state: running,
instanceID: id,
securityGroup: c.securityGroup,
tags: input.Tags,
})
c.instances[id] = ec2.Instance{}
}
return nil
}
func (c *fakeEc2Client) TerminateInstances(_ context.Context) error {
if len(c.instances) == 0 {
return nil
}
for _, instance := range c.ec2state {
instance.state = terminated
}
return nil
}
func (c *fakeEc2Client) CreateSecurityGroup(_ context.Context, input client.SecurityGroupInput) error {
if c.securityGroup != "" {
return errors.New("client already has a security group")
}
c.securityGroup = "sg-test"
return nil
}
func (c *fakeEc2Client) DeleteSecurityGroup(_ context.Context) error {
c.securityGroup = ""
return nil
}
type ec2InstanceState int
const (
running = iota
terminated
)
type fakeEc2Instance struct {
state ec2InstanceState
instanceID string
tags ec2.Tags
securityGroup string
}
type stubEc2Client struct {
terminateInstancesCalled bool
deleteSecurityGroupCalled bool
getStateErr error
setStateErr error
createInstancesErr error
terminateInstancesErr error
createSecurityGroupErr error
deleteSecurityGroupErr error
}
func (c *stubEc2Client) GetState() (state.ConstellationState, error) {
return state.ConstellationState{}, c.getStateErr
}
func (c *stubEc2Client) SetState(stat state.ConstellationState) error {
return c.setStateErr
}
func (c *stubEc2Client) CreateInstances(_ context.Context, input client.CreateInput) error {
return c.createInstancesErr
}
func (c *stubEc2Client) TerminateInstances(_ context.Context) error {
c.terminateInstancesCalled = true
return c.terminateInstancesErr
}
func (c *stubEc2Client) CreateSecurityGroup(_ context.Context, input client.SecurityGroupInput) error {
return c.createSecurityGroupErr
}
func (c *stubEc2Client) DeleteSecurityGroup(_ context.Context) error {
c.deleteSecurityGroupCalled = true
return c.deleteSecurityGroupErr
}

22
cli/cmd/gcpclient.go Normal file
View file

@ -0,0 +1,22 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/state"
)
type gcpclient interface {
GetState() (state.ConstellationState, error)
SetState(state.ConstellationState) error
CreateVPCs(ctx context.Context, input client.VPCsInput) error
CreateFirewall(ctx context.Context, input client.FirewallInput) error
CreateInstances(ctx context.Context, input client.CreateInstancesInput) error
CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error)
TerminateFirewall(ctx context.Context) error
TerminateVPCs(context.Context) error
TerminateInstances(context.Context) error
TerminateServiceAccount(ctx context.Context) error
Close() error
}

219
cli/cmd/gcpclient_test.go Normal file
View file

@ -0,0 +1,219 @@
package cmd
import (
"context"
"errors"
"strconv"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/state"
)
type fakeGcpClient struct {
nodes gcp.Instances
coordinators gcp.Instances
nodesInstanceGroup string
coordinatorInstanceGroup string
coordinatorTemplate string
nodeTemplate string
network string
subnetwork string
firewalls []string
project string
uid string
name string
zone string
serviceAccount string
}
func (c *fakeGcpClient) GetState() (state.ConstellationState, error) {
stat := state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: c.nodes,
GCPCoordinators: c.coordinators,
GCPNodeInstanceGroup: c.nodesInstanceGroup,
GCPCoordinatorInstanceGroup: c.coordinatorInstanceGroup,
GCPNodeInstanceTemplate: c.nodeTemplate,
GCPCoordinatorInstanceTemplate: c.coordinatorTemplate,
GCPNetwork: c.network,
GCPSubnetwork: c.subnetwork,
GCPFirewalls: c.firewalls,
GCPProject: c.project,
Name: c.name,
UID: c.uid,
GCPZone: c.zone,
GCPServiceAccount: c.serviceAccount,
}
return stat, nil
}
func (c *fakeGcpClient) SetState(stat state.ConstellationState) error {
c.nodes = stat.GCPNodes
c.coordinators = stat.GCPCoordinators
c.nodesInstanceGroup = stat.GCPNodeInstanceGroup
c.coordinatorInstanceGroup = stat.GCPCoordinatorInstanceGroup
c.nodeTemplate = stat.GCPNodeInstanceTemplate
c.coordinatorTemplate = stat.GCPCoordinatorInstanceTemplate
c.network = stat.GCPNetwork
c.subnetwork = stat.GCPSubnetwork
c.firewalls = stat.GCPFirewalls
c.project = stat.GCPProject
c.name = stat.Name
c.uid = stat.UID
c.zone = stat.GCPZone
c.serviceAccount = stat.GCPServiceAccount
return nil
}
func (c *fakeGcpClient) CreateVPCs(ctx context.Context, input client.VPCsInput) error {
c.network = "network"
c.subnetwork = "subnetwork"
return nil
}
func (c *fakeGcpClient) CreateFirewall(ctx context.Context, input client.FirewallInput) error {
if c.network == "" {
return errors.New("client has not network")
}
var firewalls []string
for _, rule := range input.Ingress {
firewalls = append(firewalls, rule.Name)
}
c.firewalls = firewalls
return nil
}
func (c *fakeGcpClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
c.coordinatorInstanceGroup = "coordinator-group"
c.nodesInstanceGroup = "nodes-group"
c.nodeTemplate = "node-template"
c.coordinatorTemplate = "coordinator-template"
c.nodes = make(gcp.Instances)
for i := 0; i < input.Count-1; i++ {
id := "id-" + strconv.Itoa(len(c.nodes))
c.nodes[id] = gcp.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
}
c.coordinators = make(gcp.Instances)
c.coordinators["id-c"] = gcp.Instance{PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1"}
return nil
}
func (c *fakeGcpClient) CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error) {
c.serviceAccount = "service-account@" + c.project + ".iam.gserviceaccount.com"
return client.ServiceAccountKey{
Type: "service_account",
ProjectID: c.project,
PrivateKeyID: "key-id",
PrivateKey: "-----BEGIN PRIVATE KEY-----\nprivate-key\n-----END PRIVATE KEY-----\n",
ClientEmail: c.serviceAccount,
ClientID: "client-id",
AuthURI: "https://accounts.google.com/o/oauth2/auth",
TokenURI: "https://accounts.google.com/o/oauth2/token",
AuthProviderX509CertURL: "https://www.googleapis.com/oauth2/v1/certs",
ClientX509CertURL: "https://www.googleapis.com/robot/v1/metadata/x509/service-account-email",
}.ConvertToCloudServiceAccountURI(), nil
}
func (c *fakeGcpClient) TerminateFirewall(ctx context.Context) error {
if len(c.firewalls) == 0 {
return nil
}
c.firewalls = nil
return nil
}
func (c *fakeGcpClient) TerminateVPCs(context.Context) error {
if len(c.firewalls) != 0 {
return errors.New("client has firewalls, which must be deleted first")
}
c.network = ""
c.subnetwork = ""
return nil
}
func (c *fakeGcpClient) TerminateInstances(context.Context) error {
c.nodeTemplate = ""
c.coordinatorTemplate = ""
c.nodesInstanceGroup = ""
c.coordinatorInstanceGroup = ""
c.nodes = nil
c.coordinators = nil
return nil
}
func (c *fakeGcpClient) TerminateServiceAccount(context.Context) error {
c.serviceAccount = ""
return nil
}
func (c *fakeGcpClient) Close() error {
return nil
}
type stubGcpClient struct {
terminateFirewallCalled bool
terminateInstancesCalled bool
terminateVPCsCalled bool
getStateErr error
setStateErr error
createVPCsErr error
createFirewallErr error
createInstancesErr error
createServiceAccountErr error
terminateFirewallErr error
terminateVPCsErr error
terminateInstancesErr error
terminateServiceAccountErr error
closeErr error
}
func (c *stubGcpClient) GetState() (state.ConstellationState, error) {
return state.ConstellationState{}, c.getStateErr
}
func (c *stubGcpClient) SetState(state.ConstellationState) error {
return c.setStateErr
}
func (c *stubGcpClient) CreateVPCs(ctx context.Context, input client.VPCsInput) error {
return c.createVPCsErr
}
func (c *stubGcpClient) CreateFirewall(ctx context.Context, input client.FirewallInput) error {
return c.createFirewallErr
}
func (c *stubGcpClient) CreateInstances(ctx context.Context, input client.CreateInstancesInput) error {
return c.createInstancesErr
}
func (c *stubGcpClient) CreateServiceAccount(ctx context.Context, input client.ServiceAccountInput) (string, error) {
return client.ServiceAccountKey{}.ConvertToCloudServiceAccountURI(), c.createServiceAccountErr
}
func (c *stubGcpClient) TerminateFirewall(ctx context.Context) error {
c.terminateFirewallCalled = true
return c.terminateFirewallErr
}
func (c *stubGcpClient) TerminateVPCs(context.Context) error {
c.terminateVPCsCalled = true
return c.terminateVPCsErr
}
func (c *stubGcpClient) TerminateInstances(context.Context) error {
c.terminateInstancesCalled = true
return c.terminateInstancesErr
}
func (c *stubGcpClient) TerminateServiceAccount(context.Context) error {
return c.terminateServiceAccountErr
}
func (c *stubGcpClient) Close() error {
return c.closeErr
}

484
cli/cmd/init.go Normal file
View file

@ -0,0 +1,484 @@
package cmd
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"io/fs"
"net"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/proto"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/cli/vpn"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
func newInitCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "init",
Short: "Initialize the Constellation. Start your confidential Kubernetes cluster.",
Long: "Initialize the Constellation. Start your confidential Kubernetes cluster.",
ValidArgsFunction: initCompletion,
Args: cobra.ExactArgs(0),
RunE: runInitialize,
}
cmd.Flags().String("privatekey", "", "path to your private key.")
cmd.Flags().String("publickey", "", "path to your public key.")
cmd.Flags().String("master-secret", "", "path to base64 encoded master secret.")
cmd.Flags().Bool("autoscale", false, "enable kubernetes cluster-autoscaler")
return cmd
}
// runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
defer protoClient.Close()
vpnClient, err := vpn.NewWithDefaults()
if err != nil {
return err
}
// We have to parse the context separately, since cmd.Context()
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, vpnClient, serviceAccountClient{}, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs))
}
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
// themself activate the other peers as nodes.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, vpnCl vpnConfigurer, serviceAccountCr serviceAccountCreator,
fileHandler file.Handler, config *config.Config, waiter statusWaiter,
) error {
flagArgs, err := evalFlagArgs(cmd, fileHandler, config)
if err != nil {
return err
}
var stat state.ConstellationState
err = fileHandler.ReadJSON(*config.StatePath, &stat)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("nothing to initialize: %w", err)
} else if err != nil {
return err
}
switch stat.CloudProvider {
case "GCP":
if err := warnAboutPCRs(cmd, *config.Provider.GCP.PCRs, true); err != nil {
return err
}
case "Azure":
if err := warnAboutPCRs(cmd, *config.Provider.Azure.PCRs, true); err != nil {
return err
}
}
serviceAccount, stat, err := serviceAccountCr.createServiceAccount(ctx, stat, config)
if err != nil {
return err
}
if err := fileHandler.WriteJSON(*config.StatePath, stat, true); err != nil {
return err
}
coordinators, nodes, err := getScalingGroupsFromConfig(stat, config)
if err != nil {
return err
}
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
if err := waiter.WaitForAll(ctx, coordinatorstate.AcceptingInit, endpoints); err != nil {
return fmt.Errorf("failed to wait for peer status: %w", err)
}
var autoscalingNodeGroups []string
if flagArgs.autoscale {
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
}
input := activationInput{
coordinatorPubIP: coordinators.PublicIPs()[0],
pubKey: flagArgs.userPubKey,
masterSecret: flagArgs.masterSecret,
nodePrivIPs: nodes.PrivateIPs(),
autoscalingNodeGroups: autoscalingNodeGroups,
cloudServiceAccountURI: serviceAccount,
}
result, err := activate(ctx, cmd, protCl, input, config)
if err != nil {
return err
}
err = result.writeOutput(cmd.OutOrStdout(), fileHandler, config)
if err != nil {
return err
}
if flagArgs.autoconfigureWG {
if err := configureVpn(vpnCl, result.clientVpnIP, result.coordinatorPubKey, result.coordinatorPubIP, flagArgs.userPrivKey); err != nil {
return err
}
}
return nil
}
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
return activationResult{}, err
}
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, ipsToEndpoints(input.nodePrivIPs, *config.CoordinatorPort), input.autoscalingNodeGroups, input.cloudServiceAccountURI)
if err != nil {
return activationResult{}, err
}
if err := respCl.WriteLogStream(cmd.OutOrStdout()); err != nil {
return activationResult{}, err
}
clientVpnIp, err := respCl.GetClientVpnIp()
if err != nil {
return activationResult{}, err
}
coordinatorPubKey, err := respCl.GetCoordinatorVpnKey()
if err != nil {
return activationResult{}, err
}
kubeconfig, err := respCl.GetKubeconfig()
if err != nil {
return activationResult{}, err
}
ownerID, err := respCl.GetOwnerID()
if err != nil {
return activationResult{}, err
}
clusterID, err := respCl.GetClusterID()
if err != nil {
return activationResult{}, err
}
return activationResult{
clientVpnIP: clientVpnIp,
coordinatorPubKey: coordinatorPubKey,
coordinatorPubIP: input.coordinatorPubIP,
kubeconfig: kubeconfig,
ownerID: ownerID,
clusterID: clusterID,
}, nil
}
type activationInput struct {
coordinatorPubIP string
pubKey []byte
masterSecret []byte
nodePrivIPs []string
autoscalingNodeGroups []string
cloudServiceAccountURI string
}
type activationResult struct {
clientVpnIP string
coordinatorPubKey string
coordinatorPubIP string
kubeconfig string
ownerID string
clusterID string
}
func (res activationResult) writeOutput(w io.Writer, fileHandler file.Handler, config *config.Config) error {
fmt.Fprintln(w, "Your Constellation was successfully initialized.")
fmt.Fprintf(w, "Your WireGuard IP is %s\n", res.clientVpnIP)
fmt.Fprintf(w, "The Coordinator's public IP is %s\n", res.coordinatorPubIP)
fmt.Fprintf(w, "The Coordinator's public key is %s\n", res.coordinatorPubKey)
fmt.Fprintf(w, "The Constellation's owner identifier is %s\n", res.ownerID)
fmt.Fprintf(w, "The Constellation's unique identifier is %s\n", res.clusterID)
if err := fileHandler.Write(*config.AdminConfPath, []byte(res.kubeconfig), false); err != nil {
return err
}
fmt.Fprintf(w, "Your Constellation Kubernetes configuration was successfully written to %s\n", *config.AdminConfPath)
return nil
}
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler, config *config.Config) (flagArgs, error) {
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
if err != nil {
return flagArgs{}, err
}
userPublicKeyPath, err := cmd.Flags().GetString("publickey")
if err != nil {
return flagArgs{}, err
}
userPrivKey, userPubKey, err := readVpnKey(fileHandler, userPrivKeyPath, userPublicKeyPath)
if err != nil {
return flagArgs{}, err
}
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return flagArgs{}, err
}
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath, config)
if err != nil {
return flagArgs{}, err
}
autoscale, err := cmd.Flags().GetBool("autoscale")
if err != nil {
return flagArgs{}, err
}
return flagArgs{
userPrivKey: userPrivKey,
userPubKey: userPubKey,
autoconfigureWG: userPrivKeyPath != "",
autoscale: autoscale,
masterSecret: masterSecret,
}, nil
}
// flagArgs are the resulting values of flag preprocessing.
type flagArgs struct {
userPrivKey []byte
userPubKey []byte
masterSecret []byte
autoconfigureWG bool
autoscale bool
}
func readVpnKey(fileHandler file.Handler, privKeyPath, publicKeyPath string) (privKey, pubKey []byte, err error) {
if privKeyPath != "" {
privKey, err = fileHandler.Read(privKeyPath)
if err != nil {
return nil, nil, err
}
privKeyParsed, err := wgtypes.ParseKey(string(privKey))
if err != nil {
return nil, nil, err
}
pubKey = []byte(privKeyParsed.PublicKey().String())
} else if publicKeyPath != "" {
pubKey, err = fileHandler.Read(publicKeyPath)
if err != nil {
return nil, nil, err
}
if err := checkBase64WGKey(pubKey); err != nil {
return nil, nil, fmt.Errorf("wireguard public key is invalid: %w", err)
}
} else {
return nil, nil, errors.New("neither path to public nor private key provided")
}
return privKey, pubKey, nil
}
func configureVpn(vpnCl vpnConfigurer, clientVpnIp, coordinatorPubKey, coordinatorPublicIp string, privKey []byte) error {
err := vpnCl.Configure(clientVpnIp, coordinatorPubKey, coordinatorPublicIp, string(privKey))
if err != nil {
return fmt.Errorf("could not configure WireGuard automatically: %w", err)
}
return nil
}
func ipsToEndpoints(ips []string, port string) []string {
var endpoints []string
for _, ip := range ips {
endpoints = append(endpoints, net.JoinHostPort(ip, port))
}
return endpoints
}
func checkBase64WGKey(b []byte) error {
keyStr, err := base64.StdEncoding.DecodeString(string(b))
if err != nil {
return errors.New("key can't be decoded")
}
if length := len(keyStr); length != wireguardKeyLength {
return fmt.Errorf("key has invalid length %d", length)
}
return nil
}
// readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string, config *config.Config) ([]byte, error) {
if filename != "" {
// Try to read the base64 secret from file
encodedSecret, err := fileHandler.Read(filename)
if err != nil {
return nil, err
}
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
if err != nil {
return nil, err
}
if len(decoded) < masterSecretLengthMin {
return nil, errors.New("provided master secret is smaller than the required minimum of 16 Bytes")
}
return decoded, nil
}
// No file given, generate a new secret, and save it to disk
masterSecret, err := util.GenerateRandomBytes(masterSecretLengthDefault)
if err != nil {
return nil, err
}
if err := fileHandler.Write(*config.MasterSecretPath, []byte(base64.StdEncoding.EncodeToString(masterSecret)), false); err != nil {
return nil, err
}
fmt.Fprintf(w, "Your Constellation master secret was successfully written to ./%s\n", *config.MasterSecretPath)
return masterSecret, nil
}
func getScalingGroupsFromConfig(stat state.ConstellationState, config *config.Config) (coordinators, nodes ScalingGroup, err error) {
switch {
case len(stat.EC2Instances) != 0:
return getAWSInstances(stat)
case len(stat.GCPCoordinators) != 0:
return getGCPInstances(stat, config)
case len(stat.AzureCoordinators) != 0:
return getAzureInstances(stat)
default:
return ScalingGroup{}, ScalingGroup{}, errors.New("no instances to init")
}
}
func getAWSInstances(stat state.ConstellationState) (coordinators, nodes ScalingGroup, err error) {
coordinatorID, coordinator, err := stat.EC2Instances.GetOne()
if err != nil {
return
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
nodeMap := stat.EC2Instances.GetOthers(coordinatorID)
if len(nodeMap) == 0 {
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
var nodeInstances Instances
for _, node := range nodeMap {
nodeInstances = append(nodeInstances, Instance(node))
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
// TODO: GroupID of nodes is empty, since they currently do not scale.
nodes = ScalingGroup{Instances: nodeInstances, GroupID: ""}
return
}
func getGCPInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes ScalingGroup, err error) {
_, coordinator, err := stat.GCPCoordinators.GetOne()
if err != nil {
return
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
nodeMap := stat.GCPNodes
if len(nodeMap) == 0 {
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
var nodeInstances Instances
for _, node := range nodeMap {
nodeInstances = append(nodeInstances, Instance(node))
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
nodes = ScalingGroup{
Instances: nodeInstances,
GroupID: gcp.AutoscalingNodeGroup(stat.GCPProject, stat.GCPZone, stat.GCPNodeInstanceGroup, *config.AutoscalingNodeGroupsMin, *config.AutoscalingNodeGroupsMax),
}
return
}
func getAzureInstances(stat state.ConstellationState) (coordinators, nodes ScalingGroup, err error) {
_, coordinator, err := stat.AzureCoordinators.GetOne()
if err != nil {
return
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = ScalingGroup{Instances: Instances{Instance(coordinator)}, GroupID: ""}
nodeMap := stat.AzureNodes
if len(nodeMap) == 0 {
return ScalingGroup{}, ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
var nodeInstances Instances
for _, node := range nodeMap {
nodeInstances = append(nodeInstances, Instance(node))
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
nodes = ScalingGroup{
Instances: nodeInstances,
GroupID: "",
}
return
}
// initCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) != 0 {
return []string{}, cobra.ShellCompDirectiveError
}
return []string{}, cobra.ShellCompDirectiveDefault
}
//
// TODO: Code below is target of multicloud refactoring.
//
// Instance is a cloud instance.
type Instance struct {
PublicIP string
PrivateIP string
}
type Instances []Instance
type ScalingGroup struct {
Instances
GroupID string
}
// PublicIPs returns the public IPs of all the instances.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}

686
cli/cmd/init_test.go Normal file
View file

@ -0,0 +1,686 @@
package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"strconv"
"strings"
"testing"
"time"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitArgumentValidation(t *testing.T) {
assert := assert.New(t)
cmd := newInitCmd()
assert.NoError(cmd.ValidateArgs(nil))
assert.Error(cmd.ValidateArgs([]string{"something"}))
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
}
func TestInitialize(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
}
testAzureState := state.ConstellationState{
CloudProvider: "Azure",
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "test",
}
testActivationResps := []fakeActivationRespMessage{
{log: "testlog1"},
{log: "testlog2"},
{
kubeconfig: "kubeconfig",
clientVpnIp: "vpnIp",
coordinatorVpnKey: "coordKey",
ownerID: "ownerID",
clusterID: "clusterID",
},
{log: "testlog3"},
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client protoClient
serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter
pubKey string
errExpected bool
}{
"initialize some ec2 instances": {
existingState: testEc2State,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some gcp instances": {
existingState: testGcpState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some azure instances": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"no state exists": {
existingState: state.ConstellationState{},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"no instances to pick one": {
existingState: state.ConstellationState{
EC2Instances: ec2.Instances{},
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"only one instance": {
existingState: state.ConstellationState{
EC2Instances: ec2.Instances{"id-1": {}},
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"public key to short": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
errExpected: true,
},
"public key to long": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
errExpected: true,
},
"public key not base64": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
pubKey: "this is not base64 encoded",
errExpected: true,
},
"fail Connect": {
existingState: testEc2State,
client: &stubProtoClient{connectErr: someErr},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail Activate": {
existingState: testEc2State,
client: &stubProtoClient{activateErr: someErr},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient WriteLogStream": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getKubeconfig": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getCoordinatorVpnKey": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getClientVpnIp": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getOwnerID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail respClient getClusterID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
"fail to wait for required status": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: stubStatusWaiter{waitForAllErr: someErr},
pubKey: testKey,
errExpected: true,
},
"fail to create service account": {
existingState: testGcpState,
client: &stubProtoClient{},
serviceAccountCreator: stubServiceAccountCreator{
createErr: someErr,
},
waiter: stubStatusWaiter{},
pubKey: testKey,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
// Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600))
require.NoError(cmd.Flags().Set("publickey", "pubKPath"))
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
err := initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter)
if tc.errExpected {
assert.Error(err)
} else {
require.NoError(err)
assert.Contains(out.String(), "vpnIp")
assert.Contains(out.String(), "coordKey")
assert.Contains(out.String(), "ownerID")
assert.Contains(out.String(), "clusterID")
}
})
}
}
func TestConfigureVPN(t *testing.T) {
assert := assert.New(t)
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
ip := "192.0.2.1"
someErr := errors.New("failed")
configurer := stubVPNConfigurer{}
assert.NoError(configureVpn(&configurer, ip, string(key), ip, key))
assert.True(configurer.configured)
configurer = stubVPNConfigurer{configureErr: someErr}
assert.Error(configureVpn(&configurer, ip, string(key), ip, key))
}
func TestWriteOutput(t *testing.T) {
assert := assert.New(t)
result := activationResult{
clientVpnIP: "foo-qq",
coordinatorPubKey: "bar-qq",
coordinatorPubIP: "baz-qq",
kubeconfig: "foo-bar-baz-qq",
}
var out bytes.Buffer
testFs := afero.NewMemMapFs()
fileHandler := file.NewHandler(testFs)
config := config.Default()
err := result.writeOutput(&out, fileHandler, config)
assert.NoError(err)
assert.Contains(out.String(), result.clientVpnIP)
assert.Contains(out.String(), result.coordinatorPubIP)
assert.Contains(out.String(), result.coordinatorPubKey)
afs := afero.Afero{Fs: testFs}
adminConf, err := afs.ReadFile(*config.AdminConfPath)
assert.NoError(err)
assert.Equal(result.kubeconfig, string(adminConf))
}
func TestIpsToEndpoints(t *testing.T) {
assert := assert.New(t)
ips := []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}
port := "8080"
endpoints := ipsToEndpoints(ips, port)
assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints)
}
func TestCheckBase64WGKEy(t *testing.T) {
assert := assert.New(t)
key := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
assert.NoError(checkBase64WGKey(key))
key = []byte(base64.StdEncoding.EncodeToString([]byte("shortKey")))
assert.Error(checkBase64WGKey(key))
key = []byte(base64.StdEncoding.EncodeToString([]byte("looooooooooongKeyWithMoreThan32Bytes")))
assert.Error(checkBase64WGKey(key))
key = []byte("noBase 64")
assert.Error(checkBase64WGKey(key))
}
func TestInitCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "hello",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveDefault,
},
"secnod arg": {
args: []string{"23"},
toComplete: "/test/h",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
"third arg": {
args: []string{"./file", "sth"},
toComplete: "./file",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := initCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}
func TestReadVpnKey(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testKeyA := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
testKeyB := []byte(base64.StdEncoding.EncodeToString([]byte("anotherWireGuardKeyForTheTesting")))
fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write("testKeyA", testKeyA, false))
require.NoError(fileHandler.Write("testKeyB", testKeyB, false))
// provide privK
privK, _, err := readVpnKey(fileHandler, "testKeyA", "")
assert.NoError(err)
assert.Equal(testKeyA, privK)
// provide pubK
_, pubK, err := readVpnKey(fileHandler, "", "testKeyA")
assert.NoError(err)
assert.Equal(testKeyA, pubK)
// provide both, privK should be used, pubK ignored
privK, pubK, err = readVpnKey(fileHandler, "testKeyA", "testKeyB")
assert.NoError(err)
assert.Equal(testKeyA, privK)
assert.NotEqual(testKeyB, pubK)
// no path provided
_, _, err = readVpnKey(fileHandler, "", "")
assert.Error(err)
}
func TestReadOrGeneratedMasterSecret(t *testing.T) {
testCases := map[string]struct {
filename string
filecontent string
createFile bool
fs func() afero.Fs
errExpected bool
}{
"file with secret exists": {
filename: "someSecret",
filecontent: base64.StdEncoding.EncodeToString([]byte("ConstellationSecret")),
createFile: true,
fs: afero.NewMemMapFs,
errExpected: false,
},
"no file given": {
filename: "",
filecontent: "",
fs: afero.NewMemMapFs,
errExpected: false,
},
"file does not exist": {
filename: "nonExistingSecret",
filecontent: "",
createFile: false,
fs: afero.NewMemMapFs,
errExpected: true,
},
"file is empty": {
filename: "emptySecret",
filecontent: "",
createFile: true,
fs: afero.NewMemMapFs,
errExpected: true,
},
"secret too short": {
filename: "shortSecret",
filecontent: base64.StdEncoding.EncodeToString([]byte("short")),
createFile: true,
fs: afero.NewMemMapFs,
errExpected: true,
},
"secret not encoded": {
filename: "unencodedSecret",
filecontent: "Constellation",
createFile: true,
fs: afero.NewMemMapFs,
errExpected: true,
},
"file not writeable": {
filename: "",
filecontent: "",
createFile: false,
fs: func() afero.Fs { return afero.NewReadOnlyFs(afero.NewMemMapFs()) },
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(tc.fs())
config := config.Default()
if tc.createFile {
require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), false))
}
var out bytes.Buffer
secret, err := readOrGeneratedMasterSecret(&out, fileHandler, tc.filename, config)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
if tc.filename == "" {
require.Contains(out.String(), *config.MasterSecretPath)
filename := strings.Split(out.String(), "./")
tc.filename = strings.Trim(filename[1], "\n")
}
content, err := fileHandler.Read(tc.filename)
require.NoError(err)
assert.Equal(content, []byte(base64.StdEncoding.EncodeToString(secret)))
}
})
}
}
func TestAutoscaleFlag(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
}
testAzureState := state.ConstellationState{
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "test",
}
testActivationResps := []fakeActivationRespMessage{
{log: "testlog1"},
{log: "testlog2"},
{
kubeconfig: "kubeconfig",
clientVpnIp: "vpnIp",
coordinatorVpnKey: "coordKey",
ownerID: "ownerID",
clusterID: "clusterID",
},
{log: "testlog3"},
}
testCases := map[string]struct {
autoscaleFlag bool
existingState state.ConstellationState
client *stubProtoClient
serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter
pubKey string
}{
"initialize some ec2 instances without autoscale flag": {
autoscaleFlag: false,
existingState: testEc2State,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some gcp instances without autoscale flag": {
autoscaleFlag: false,
existingState: testGcpState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some azure instances without autoscale flag": {
autoscaleFlag: false,
existingState: testAzureState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some ec2 instances with autoscale flag": {
autoscaleFlag: true,
existingState: testEc2State,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some gcp instances with autoscale flag": {
autoscaleFlag: true,
existingState: testGcpState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
"initialize some azure instances with autoscale flag": {
autoscaleFlag: true,
existingState: testAzureState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
pubKey: testKey,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(*config.StatePath, tc.existingState, false))
// Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("pubKPath", []byte(tc.pubKey), 0o600))
require.NoError(cmd.Flags().Set("publickey", "pubKPath"))
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &dummyVPNConfigurer{}, &tc.serviceAccountCreator, fileHandler, config, tc.waiter))
if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else {
assert.Len(tc.client.activateAutoscalingNodeGroups, 0)
}
})
}
}

13
cli/cmd/protoclient.go Normal file
View file

@ -0,0 +1,13 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/proto"
)
type protoClient interface {
Connect(ip string, port string) error
Close() error
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
}

188
cli/cmd/protoclient_test.go Normal file
View file

@ -0,0 +1,188 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/cli/proto"
)
type stubProtoClient struct {
conn bool
respClient proto.ActivationResponseClient
connectErr error
closeErr error
activateErr error
activateUserPublicKey []byte
activateMasterSecret []byte
activateEndpoints []string
activateAutoscalingNodeGroups []string
cloudServiceAccountURI string
}
func (c *stubProtoClient) Connect(ip string, port string) error {
c.conn = true
return c.connectErr
}
func (c *stubProtoClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints []string, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) {
c.activateUserPublicKey = userPublicKey
c.activateMasterSecret = masterSecret
c.activateEndpoints = endpoints
c.activateAutoscalingNodeGroups = autoscalingNodeGroups
c.cloudServiceAccountURI = cloudServiceAccountURI
return c.respClient, c.activateErr
}
type stubActivationRespClient struct {
nextLogErr *error
getKubeconfigErr error
getCoordinatorVpnKeyErr error
getClientVpnIpErr error
getOwnerIDErr error
getClusterIDErr error
writeLogStreamErr error
}
func (s *stubActivationRespClient) NextLog() (string, error) {
if s.nextLogErr == nil {
return "", io.EOF
}
return "", *s.nextLogErr
}
func (s *stubActivationRespClient) WriteLogStream(io.Writer) error {
return s.writeLogStreamErr
}
func (s *stubActivationRespClient) GetKubeconfig() (string, error) {
return "", s.getKubeconfigErr
}
func (s *stubActivationRespClient) GetCoordinatorVpnKey() (string, error) {
return "", s.getCoordinatorVpnKeyErr
}
func (s *stubActivationRespClient) GetClientVpnIp() (string, error) {
return "", s.getClientVpnIpErr
}
func (s *stubActivationRespClient) GetOwnerID() (string, error) {
return "", s.getOwnerIDErr
}
func (s *stubActivationRespClient) GetClusterID() (string, error) {
return "", s.getClusterIDErr
}
type fakeProtoClient struct {
conn bool
respClient proto.ActivationResponseClient
}
func (c *fakeProtoClient) Connect(ip string, port string) error {
c.conn = true
return nil
}
func (c *fakeProtoClient) Close() error {
c.conn = false
return nil
}
func (c *fakeProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints []string, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) {
if !c.conn {
return nil, errors.New("client is not connected")
}
return c.respClient, nil
}
type fakeActivationRespClient struct {
responses []fakeActivationRespMessage
kubeconfig string
coordinatorVpnKey string
clientVpnIp string
ownerID string
clusterID string
}
func (c *fakeActivationRespClient) NextLog() (string, error) {
for len(c.responses) > 0 {
resp := c.responses[0]
c.responses = c.responses[1:]
if len(resp.log) > 0 {
return resp.log, nil
}
c.kubeconfig = resp.kubeconfig
c.coordinatorVpnKey = resp.coordinatorVpnKey
c.clientVpnIp = resp.clientVpnIp
c.ownerID = resp.ownerID
c.clusterID = resp.clusterID
}
return "", io.EOF
}
func (c *fakeActivationRespClient) WriteLogStream(w io.Writer) error {
log, err := c.NextLog()
for err == nil {
fmt.Fprint(w, log)
log, err = c.NextLog()
}
if !errors.Is(err, io.EOF) {
return err
}
return nil
}
func (c *fakeActivationRespClient) GetKubeconfig() (string, error) {
if c.kubeconfig == "" {
return "", errors.New("kubeconfig is empty")
}
return c.kubeconfig, nil
}
func (c *fakeActivationRespClient) GetCoordinatorVpnKey() (string, error) {
if c.coordinatorVpnKey == "" {
return "", errors.New("coordinator public VPN key is empty")
}
return c.coordinatorVpnKey, nil
}
func (c *fakeActivationRespClient) GetClientVpnIp() (string, error) {
if c.clientVpnIp == "" {
return "", errors.New("client VPN IP is empty")
}
return c.clientVpnIp, nil
}
func (c *fakeActivationRespClient) GetOwnerID() (string, error) {
if c.ownerID == "" {
return "", errors.New("init secret is empty")
}
return c.ownerID, nil
}
func (c *fakeActivationRespClient) GetClusterID() (string, error) {
if c.clusterID == "" {
return "", errors.New("cluster identifier is empty")
}
return c.clusterID, nil
}
type fakeActivationRespMessage struct {
log string
kubeconfig string
coordinatorVpnKey string
clientVpnIp string
ownerID string
clusterID string
}

60
cli/cmd/rollback.go Normal file
View file

@ -0,0 +1,60 @@
package cmd
import (
"context"
"fmt"
"io"
"go.uber.org/multierr"
)
// rollbacker does a rollback.
type rollbacker interface {
rollback(ctx context.Context) error
}
// rollbackOnError calls rollback on the rollbacker if the handed error is not nil,
// and writes logs to the writer w.
func rollbackOnError(ctx context.Context, w io.Writer, onErr *error, roll rollbacker) {
if *onErr == nil {
return
}
fmt.Fprintf(w, "An error occurred: %s\n", *onErr)
fmt.Fprintln(w, "Attempting to roll back.")
if err := roll.rollback(ctx); err != nil {
*onErr = multierr.Append(*onErr, fmt.Errorf("on rollback: %w", err)) // TODO: print the error, or retrun it?
return
}
fmt.Fprintln(w, "Rollback succeeded.")
}
type rollbackerGCP struct {
client gcpclient
}
func (r *rollbackerGCP) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.TerminateInstances(ctx))
err = multierr.Append(err, r.client.TerminateFirewall(ctx))
err = multierr.Append(err, r.client.TerminateVPCs(ctx))
return err
}
type rollbackerAzure struct {
client azureclient
}
func (r *rollbackerAzure) rollback(ctx context.Context) error {
return r.client.TerminateResourceGroup(ctx)
}
type rollbackerAWS struct {
client ec2client
}
func (r *rollbackerAWS) rollback(ctx context.Context) error {
var err error
err = multierr.Append(err, r.client.TerminateInstances(ctx))
err = multierr.Append(err, r.client.DeleteSecurityGroup(ctx))
return err
}

64
cli/cmd/root.go Normal file
View file

@ -0,0 +1,64 @@
package cmd
import (
"context"
"fmt"
"os"
"os/signal"
"github.com/spf13/cobra"
)
var rootCmd = &cobra.Command{
Use: "constellation",
Short: "Set up your Constellation cluster.",
Long: "Set up your Constellation cluster.",
SilenceUsage: true,
}
// Execute starts the CLI.
func Execute() error {
ctx, cancel := signalContext(context.Background(), os.Interrupt)
defer cancel()
return rootCmd.ExecuteContext(ctx)
}
// signalContext returns a context that is canceled on the handed signal.
// The signal isn't watched after its first occurrence. Call the cancel
// function to ensure the internal goroutine is stopped and the signal isn't
// watched any longer.
func signalContext(ctx context.Context, sig os.Signal) (context.Context, context.CancelFunc) {
sigCtx, stop := signal.NotifyContext(ctx, sig)
done := make(chan struct{}, 1)
stopDone := make(chan struct{}, 1)
go func() {
defer func() { stopDone <- struct{}{} }()
defer stop()
select {
case <-sigCtx.Done():
fmt.Println(" Signal caught. Press ctrl+c again to terminate the program immediately.")
case <-done:
}
}()
cancelFunc := func() {
done <- struct{}{}
<-stopDone
}
return sigCtx, cancelFunc
}
func init() {
// Set output of cmd.Print to stdout.
rootCmd.SetOut(os.Stdout)
// Disable --no-description flag for completion command.
rootCmd.CompletionOptions.DisableNoDescFlag = true
rootCmd.PersistentFlags().String("dev-config", "", "Set this flag to create the Constellation using settings from a development config.")
rootCmd.AddCommand(newVersionCmd())
rootCmd.AddCommand(newCreateCmd())
rootCmd.AddCommand(newInitCmd())
rootCmd.AddCommand(newTerminateCmd())
rootCmd.AddCommand(newVerifyCmd())
}

View file

@ -0,0 +1,95 @@
package cmd
import (
"context"
"fmt"
azurecl "github.com/edgelesssys/constellation/cli/azure/client"
"github.com/edgelesssys/constellation/cli/cloudprovider"
ec2cl "github.com/edgelesssys/constellation/cli/ec2/client"
gcpcl "github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
type serviceAccountCreator interface {
createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error)
}
type serviceAccountClient struct{}
// createServiceAccount creates a new cloud provider service account with access to the created resources.
func (c serviceAccountClient) createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
switch stat.CloudProvider {
case cloudprovider.AWS.String():
// TODO: implement
ec2client, err := ec2cl.NewFromDefault(ctx)
if err != nil {
return "", state.ConstellationState{}, err
}
return c.createServiceAccountEC2(ctx, ec2client, stat, config)
case cloudprovider.GCP.String():
gcpclient, err := gcpcl.NewFromDefault(ctx)
if err != nil {
return "", state.ConstellationState{}, err
}
serviceAccount, stat, err := c.createServiceAccountGCP(ctx, gcpclient, stat, config)
if err != nil {
return "", state.ConstellationState{}, err
}
return serviceAccount, stat, gcpclient.Close()
case cloudprovider.Azure.String():
azureclient, err := azurecl.NewFromDefault(stat.AzureSubscription, stat.AzureTenant)
if err != nil {
return "", state.ConstellationState{}, err
}
return c.createServiceAccountAzure(ctx, azureclient, stat)
}
return "", state.ConstellationState{}, fmt.Errorf("unknown cloud provider %v", stat.CloudProvider)
}
func (c serviceAccountClient) createServiceAccountAzure(ctx context.Context, cl azureclient, stat state.ConstellationState) (string, state.ConstellationState, error) {
if err := cl.SetState(stat); err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
}
serviceAccount, err := cl.CreateServicePrincipal(ctx)
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to create service account: %w", err)
}
stat, err = cl.GetState()
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to get state after creating service account: %w", err)
}
return serviceAccount, stat, nil
}
func (c serviceAccountClient) createServiceAccountGCP(ctx context.Context, cl gcpclient, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
if err := cl.SetState(stat); err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
}
input := gcpcl.ServiceAccountInput{
Roles: *config.Provider.GCP.ServiceAccountRoles,
}
serviceAccount, err := cl.CreateServiceAccount(ctx, input)
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to create service account: %w", err)
}
stat, err = cl.GetState()
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to get state after creating service account: %w", err)
}
return serviceAccount, stat, nil
}
//nolint:unparam
func (c serviceAccountClient) createServiceAccountEC2(ctx context.Context, cl ec2client, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
// TODO: implement
if err := cl.SetState(stat); err != nil {
return "", state.ConstellationState{}, fmt.Errorf("failed to set state while creating service account: %w", err)
}
return "", stat, nil
}

View file

@ -0,0 +1,136 @@
package cmd
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestCreateServiceAccountAzure(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client azureclient
errExpected bool
}{
"create service account works": {
existingState: testState,
client: &fakeAzureClient{},
},
"fail setState": {
existingState: testState,
client: &stubAzureClient{setStateErr: someErr},
errExpected: true,
},
"fail create": {
existingState: testState,
client: &stubAzureClient{createServicePrincipalErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := serviceAccountClient{}
serviceAccount, _, err := client.createServiceAccountAzure(context.Background(), tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotNil(serviceAccount)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
}, stat)
}
})
}
}
func TestCreateServiceAccountGCP(t *testing.T) {
testState := state.ConstellationState{
GCPProject: "project",
GCPNodes: gcp.Instances{},
GCPCoordinators: gcp.Instances{},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "template",
GCPCoordinatorInstanceTemplate: "template",
GCPNetwork: "network",
GCPFirewalls: []string{},
}
config := config.Default()
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client gcpclient
errExpected bool
}{
"create service account works": {
existingState: testState,
client: &fakeGcpClient{},
},
"fail setState": {
existingState: testState,
client: &stubGcpClient{setStateErr: someErr},
errExpected: true,
},
"fail create": {
existingState: testState,
client: &stubGcpClient{createServiceAccountErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := serviceAccountClient{}
serviceAccount, _, err := client.createServiceAccountGCP(context.Background(), tc.client, tc.existingState, config)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotNil(serviceAccount)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPProject: "project",
GCPNodes: gcp.Instances{},
GCPCoordinators: gcp.Instances{},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "template",
GCPCoordinatorInstanceTemplate: "template",
GCPNetwork: "network",
GCPFirewalls: []string{},
GCPServiceAccount: "service-account@project.iam.gserviceaccount.com",
}, stat)
}
})
}
}
type stubServiceAccountCreator struct {
cloudServiceAccountURI string
createErr error
}
func (c *stubServiceAccountCreator) createServiceAccount(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
return c.cloudServiceAccountURI, stat, c.createErr
}

11
cli/cmd/statuswaiter.go Normal file
View file

@ -0,0 +1,11 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/coordinator/state"
)
type statusWaiter interface {
WaitForAll(ctx context.Context, status state.State, endpoints []string) error
}

View file

@ -0,0 +1,15 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/coordinator/state"
)
type stubStatusWaiter struct {
waitForAllErr error
}
func (w stubStatusWaiter) WaitForAll(ctx context.Context, status state.State, endpoints []string) error {
return w.waitForAllErr
}

131
cli/cmd/terminate.go Normal file
View file

@ -0,0 +1,131 @@
package cmd
import (
"errors"
"fmt"
"io/fs"
"github.com/spf13/afero"
"github.com/spf13/cobra"
azure "github.com/edgelesssys/constellation/cli/azure/client"
ec2 "github.com/edgelesssys/constellation/cli/ec2/client"
"github.com/edgelesssys/constellation/cli/file"
gcp "github.com/edgelesssys/constellation/cli/gcp/client"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
func newTerminateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "terminate",
Short: "Terminate an existing Constellation.",
Long: "Terminate an existing Constellation. The Constellation can't be started again, and all persistent storage will be lost.",
Args: cobra.NoArgs,
RunE: runTerminate,
}
return cmd
}
// runTerminate runs the terminate command.
func runTerminate(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
return terminate(cmd, fileHandler, config)
}
func terminate(cmd *cobra.Command, fileHandler file.Handler, config *config.Config) error {
var stat state.ConstellationState
if err := fileHandler.ReadJSON(*config.StatePath, &stat); err != nil {
return err
}
if len(stat.EC2Instances) != 0 || stat.EC2SecurityGroup != "" {
ec2client, err := ec2.NewFromDefault(cmd.Context())
if err != nil {
return err
}
if err := terminateEC2(cmd, ec2client, stat); err != nil {
return err
}
}
// TODO: improve check, also look for other resources that might need to be terminated
if len(stat.GCPNodes) != 0 {
gcpclient, err := gcp.NewFromDefault(cmd.Context())
if err != nil {
return err
}
if err := terminateGCP(cmd, gcpclient, stat); err != nil {
return err
}
}
if len(stat.AzureResourceGroup) != 0 {
azureclient, err := azure.NewFromDefault(stat.AzureSubscription, stat.AzureTenant)
if err != nil {
return err
}
if err := terminateAzure(cmd, azureclient, stat); err != nil {
return err
}
}
cmd.Println("Your Constellation was terminated successfully.")
if err := fileHandler.Remove(*config.StatePath); err != nil {
return fmt.Errorf("failed to remove file '%s', please remove manually", *config.StatePath)
}
if err := fileHandler.Remove(*config.AdminConfPath); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to remove file '%s', please remove manually", *config.AdminConfPath)
}
return nil
}
func terminateAzure(cmd *cobra.Command, cl azureclient, stat state.ConstellationState) error {
if err := cl.SetState(stat); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
if err := cl.TerminateServicePrincipal(cmd.Context()); err != nil {
return err
}
return cl.TerminateResourceGroup(cmd.Context())
}
func terminateGCP(cmd *cobra.Command, cl gcpclient, stat state.ConstellationState) error {
if err := cl.SetState(stat); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
if err := cl.TerminateInstances(cmd.Context()); err != nil {
return err
}
if err := cl.TerminateFirewall(cmd.Context()); err != nil {
return err
}
if err := cl.TerminateVPCs(cmd.Context()); err != nil {
return err
}
return cl.TerminateServiceAccount(cmd.Context())
}
// terminateEC2 and remove the existing Constellation form the state file.
func terminateEC2(cmd *cobra.Command, cl ec2client, stat state.ConstellationState) error {
if err := cl.SetState(stat); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
if err := cl.TerminateInstances(cmd.Context()); err != nil {
return fmt.Errorf("failed to terminate the Constellation: %w", err)
}
return cl.DeleteSecurityGroup(cmd.Context())
}

288
cli/cmd/terminate_test.go Normal file
View file

@ -0,0 +1,288 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestTerminateCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
expectErr bool
}{
"no args": {[]string{}, false},
"some args": {[]string{"hello", "test"}, true},
"some other args": {[]string{"12", "2"}, true},
}
cmd := newTerminateCmd()
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := cmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTerminateEC2(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-3": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
EC2SecurityGroup: "sg-test",
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client ec2client
errExpected bool
}{
"terminate existing instances": {
existingState: testState,
client: &fakeEc2Client{},
errExpected: false,
},
"state without instances": {
existingState: state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: ec2.Instances{},
},
client: &fakeEc2Client{},
errExpected: true,
},
"fail TerminateInstances": {
existingState: testState,
client: &stubEc2Client{terminateInstancesErr: someErr},
errExpected: true,
},
"fail DeleteSecurityGroup": {
existingState: testState,
client: &stubEc2Client{deleteSecurityGroupErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
err := terminateEC2(cmd, tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTerminateGCP(t *testing.T) {
testState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
GCPNodeInstanceGroup: "nodes-group",
GCPCoordinatorInstanceGroup: "coordinator-group",
GCPNodeInstanceTemplate: "template",
GCPCoordinatorInstanceTemplate: "template",
GCPNetwork: "network",
GCPFirewalls: []string{"coordinator", "wireguard", "ssh"},
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client gcpclient
errExpected bool
}{
"terminate existing instances": {
existingState: testState,
client: &fakeGcpClient{},
},
"state without instances": {
existingState: state.ConstellationState{EC2Instances: ec2.Instances{}},
client: &fakeGcpClient{},
},
"state not found": {
existingState: testState,
client: &fakeGcpClient{},
},
"fail setState": {
existingState: testState,
client: &stubGcpClient{setStateErr: someErr},
errExpected: true,
},
"fail terminateFirewall": {
existingState: testState,
client: &stubGcpClient{terminateFirewallErr: someErr},
errExpected: true,
},
"fail terminateVPC": {
existingState: testState,
client: &stubGcpClient{terminateVPCsErr: someErr},
errExpected: true,
},
"fail terminateInstances": {
existingState: testState,
client: &stubGcpClient{terminateInstancesErr: someErr},
errExpected: true,
},
"fail terminateServiceAccount": {
existingState: testState,
client: &stubGcpClient{terminateServiceAccountErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
err := terminateGCP(cmd, tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
}, stat)
}
})
}
}
func TestTerminateAzure(t *testing.T) {
testState := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
},
AzureResourceGroup: "test",
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client azureclient
errExpected bool
}{
"terminate existing instances": {
existingState: testState,
client: &fakeAzureClient{},
},
"state resource group": {
existingState: state.ConstellationState{AzureResourceGroup: ""},
client: &fakeAzureClient{},
},
"state not found": {
existingState: testState,
client: &fakeAzureClient{},
},
"fail setState": {
existingState: testState,
client: &stubAzureClient{setStateErr: someErr},
errExpected: true,
},
"fail resource group termination": {
existingState: testState,
client: &stubAzureClient{terminateResourceGroupErr: someErr},
errExpected: true,
},
"fail service principal termination": {
existingState: testState,
client: &stubAzureClient{terminateServicePrincipalErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
err := terminateAzure(cmd, tc.client, tc.existingState)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
stat, err := tc.client.GetState()
assert.NoError(err)
assert.Equal(state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
}, stat)
}
})
}
}

View file

@ -0,0 +1,80 @@
package cmd
import (
"bufio"
"errors"
"fmt"
"strings"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/spf13/cobra"
)
var (
// ErrInvalidInput is an error where user entered invalid input.
ErrInvalidInput = errors.New("user made invalid input")
warningStr = "Warning: not verifying the Constellation's %s measurements\n"
)
// askToConfirm asks user to confirm an action.
// The user will be asked the handed question and can answer with
// yes or no.
func askToConfirm(cmd *cobra.Command, question string) (bool, error) {
reader := bufio.NewReader(cmd.InOrStdin())
cmd.Printf("%s [y/n]: ", question)
for i := 0; i < 3; i++ {
resp, err := reader.ReadString('\n')
if err != nil {
return false, err
}
resp = strings.ToLower(strings.TrimSpace(resp))
if resp == "n" || resp == "no" {
return false, nil
}
if resp == "y" || resp == "yes" {
return true, nil
}
cmd.Printf("Type 'y' or 'yes' to confirm, or abort action with 'n' or 'no': ")
}
return false, ErrInvalidInput
}
// warnAboutPCRs displays warnings if specifc PCR values are not verified.
//
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func warnAboutPCRs(cmd *cobra.Command, pcrs map[uint32][]byte, checkInit bool) error {
for k, v := range pcrs {
if len(v) != 32 {
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
}
}
if pcrs[0] == nil || pcrs[1] == nil {
cmd.PrintErrf(warningStr, "BIOS")
}
if pcrs[2] == nil || pcrs[3] == nil {
cmd.PrintErrf(warningStr, "OPROM")
}
if pcrs[4] == nil || pcrs[5] == nil {
cmd.PrintErrf(warningStr, "MBR")
}
// GRUB measures kernel command line and initrd into pcrs 8 and 9
if pcrs[8] == nil {
cmd.PrintErrf(warningStr, "kernel command line")
}
if pcrs[9] == nil {
cmd.PrintErrf(warningStr, "initrd")
}
// Only warn about initialization PCRs if necessary
if checkInit {
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
cmd.PrintErrf(warningStr, "initialization status")
}
}
return nil
}

View file

@ -0,0 +1,249 @@
package cmd
import (
"bytes"
"errors"
"io"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestAskToConfirm(t *testing.T) {
// errAborted is an error where the user aborted the action.
errAborted := errors.New("user aborted")
cmd := &cobra.Command{
Use: "test",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ok, err := askToConfirm(cmd, "777")
if err != nil {
return err
}
if !ok {
return errAborted
}
return nil
},
}
testCases := map[string]struct {
input string
expectedErr error
}{
"user confirms": {"y\n", nil},
"user confirms long": {"yes\n", nil},
"user disagrees": {"n\n", errAborted},
"user disagrees long": {"no\n", errAborted},
"user is first unsure, but agrees": {"what?\ny\n", nil},
"user is first unsure, but disagrees": {"wait.\nn\n", errAborted},
"repeated invalid input": {"h\nb\nq\n", ErrInvalidInput},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
out := bytes.NewBufferString("")
cmd.SetOut(out)
errOut := bytes.NewBufferString("")
cmd.SetErr(errOut)
in := bytes.NewBufferString(tc.input)
cmd.SetIn(in)
err := cmd.Execute()
assert.ErrorIs(err, tc.expectedErr)
output, err := io.ReadAll(out)
assert.NoError(err)
assert.Contains(string(output), "777")
})
}
}
func TestWarnAboutPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
testCases := map[string]struct {
pcrs map[uint32][]byte
dontWarnInit bool
expectedWarnings []string
errExpected bool
}{
"no warnings": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
},
},
"no warnings for missing non critical values": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
},
"warn for BIOS": {
pcrs: map[uint32][]byte{
0: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"BIOS"},
},
"warn for OPROM": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"OPROM"},
},
"warn for MBR": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"MBR"},
},
"warn for kernel": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
9: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"kernel"},
},
"warn for initrd": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
11: zero,
12: zero,
},
expectedWarnings: []string{"initrd"},
},
"warn for initialization": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
},
dontWarnInit: false,
expectedWarnings: []string{"initialization"},
},
"don't warn for initialization": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
},
dontWarnInit: true,
},
"multi warning": {
pcrs: map[uint32][]byte{},
expectedWarnings: []string{
"BIOS",
"OPROM",
"MBR",
"initialization",
"initrd",
"kernel",
},
},
"bad config": {
pcrs: map[uint32][]byte{
0: []byte("000"),
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
err := warnAboutPCRs(cmd, tc.pcrs, !tc.dontWarnInit)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
if len(tc.expectedWarnings) == 0 {
assert.Empty(errOut.String())
} else {
for _, warning := range tc.expectedWarnings {
assert.Contains(errOut.String(), warning)
}
}
}
})
}
}

70
cli/cmd/validargs.go Normal file
View file

@ -0,0 +1,70 @@
package cmd
import (
"fmt"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/spf13/cobra"
)
// isIntArg checks if argument at position arg is an integer.
func isIntArg(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if _, err := strconv.Atoi(args[arg]); err != nil {
return fmt.Errorf("argument %d must be an integer", arg)
}
return nil
}
}
// isIntGreaterArg checks if argument at position arg is and integer and greater i.
func isIntGreaterArg(arg int, i int) cobra.PositionalArgs {
return cobra.MatchAll(isIntArg(arg), func(cmd *cobra.Command, args []string) error {
if v, _ := strconv.Atoi(args[arg]); v <= i {
return fmt.Errorf("argument %d must be greater %d, but it's %d", arg, i, v)
}
return nil
})
}
// isIntGreaterZeroArg checks if argument at position arg is a positive non zero integer.
func isIntGreaterZeroArg(arg int) cobra.PositionalArgs {
return cobra.MatchAll(isIntGreaterArg(arg, 0))
}
// isEC2InstanceType checks if argument at position arg is a key in m.
// The argument will always be converted to lower case letters.
func isEC2InstanceType(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if _, ok := ec2.InstanceTypes[strings.ToLower(args[arg])]; !ok {
return fmt.Errorf("'%s' isn't an AWS EC2 instance type", args[arg])
}
return nil
}
}
func isGCPInstanceType(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
for _, instanceType := range gcp.InstanceTypes {
if args[arg] == instanceType {
return nil
}
}
return fmt.Errorf("argument %s isn't a valid GCP instance type", args[arg])
}
}
func isAzureInstanceType(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
for _, instanceType := range azure.InstanceTypes {
if args[arg] == instanceType {
return nil
}
}
return fmt.Errorf("argument %s isn't a valid Azure instance type", args[arg])
}
}

197
cli/cmd/validargs_test.go Normal file
View file

@ -0,0 +1,197 @@
package cmd
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestIsIntArg(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isIntArg(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid int 1": {[]string{"1"}, false},
"valid int 2": {[]string{"42"}, false},
"valid int 3": {[]string{"987987498"}, false},
"valid int and other args": {[]string{"3", "hello"}, false},
"valid int and other args 2": {[]string{"3", "4"}, false},
"invalid 1": {[]string{"hello world"}, true},
"invalid 2": {[]string{"98798d749f8"}, true},
"invalid 3": {[]string{"three"}, true},
"invalid 4": {[]string{"0.3"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsIntGreaterArg(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isIntGreaterArg(0, 12),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid int 1": {[]string{"13"}, false},
"valid int 2": {[]string{"42"}, false},
"valid int 3": {[]string{"987987498"}, false},
"invalid int 1": {[]string{"1"}, true},
"invalid int and other args": {[]string{"3", "hello"}, true},
"invalid int and other args 2": {[]string{"-14", "4"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsIntGreaterZeroArg(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isIntGreaterZeroArg(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"valid int 1": {[]string{"13"}, false},
"valid int 2": {[]string{"42"}, false},
"valid int 3": {[]string{"987987498"}, false},
"invalid": {[]string{"0"}, true},
"invalid int 1": {[]string{"-42", "hello"}, true},
"invalid int and other args": {[]string{"-9487239847", "4"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsEC2InstanceType(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isEC2InstanceType(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"is instance type 1": {[]string{"4xl"}, false},
"is instance type 2": {[]string{"12xlarge", "something else"}, false},
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
"isn't instance type 2": {[]string{"Hello World!"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsGCPInstanceType(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isGCPInstanceType(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"is instance type 1": {[]string{"n2d-standard-4"}, false},
"is instance type 2": {[]string{"n2d-standard-16", "something else"}, false},
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
"isn't instance type 2": {[]string{"Hello World!"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestIsAzureInstanceType(t *testing.T) {
testCmd := &cobra.Command{
Use: "test",
Args: isAzureInstanceType(0),
Run: func(cmd *cobra.Command, args []string) {},
}
testCases := map[string]struct {
args []string
expectErr bool
}{
"is instance type 1": {[]string{"Standard_DC2as_v5"}, false},
"is instance type 2": {[]string{"Standard_DC8as_v5", "something else"}, false},
"isn't instance type 1": {[]string{"notanInstanceType"}, true},
"isn't instance type 2": {[]string{"Hello World!"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := testCmd.ValidateArgs(tc.args)
if tc.expectErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

141
cli/cmd/verify.go Normal file
View file

@ -0,0 +1,141 @@
package cmd
import (
"context"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
rpcStatus "google.golang.org/grpc/status"
)
func newVerifyCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "verify azure|gcp",
Short: "Verify the confidential properties of your Constellation.",
Long: "Verify the confidential properties of your Constellation.",
}
cmd.PersistentFlags().String("owner-id", "", "verify the Constellation using the owner identity derived from the master secret.")
cmd.PersistentFlags().String("unique-id", "", "verify the Constellation using the unique cluster identity.")
cmd.AddCommand(newVerifyGCPCmd())
cmd.AddCommand(newVerifyAzureCmd())
cmd.AddCommand(newVerifyGCPNonCVMCmd())
return cmd
}
func runVerify(cmd *cobra.Command, args []string, pcrs map[uint32][]byte, validator atls.Validator) error {
if err := warnAboutPCRs(cmd, pcrs, false); err != nil {
return err
}
verifier := verifier{
newConn: newVerifiedConn,
newClient: pubproto.NewAPIClient,
}
return verify(cmd.Context(), cmd.OutOrStdout(), net.JoinHostPort(args[0], args[1]), []atls.Validator{validator}, verifier)
}
func verify(ctx context.Context, w io.Writer, target string, validators []atls.Validator, verifier verifier) error {
conn, err := verifier.newConn(ctx, target, validators)
if err != nil {
return err
}
defer conn.Close()
client := verifier.newClient(conn)
if _, err := client.GetState(ctx, &pubproto.GetStateRequest{}); err != nil {
if err, ok := rpcStatus.FromError(err); ok {
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
}
return err
}
fmt.Fprintln(w, "OK")
return nil
}
// prepareValidator parses parameters and updates the PCR map.
func prepareValidator(cmd *cobra.Command, pcrs map[uint32][]byte) error {
ownerID, err := cmd.Flags().GetString("owner-id")
if err != nil {
return err
}
clusterID, err := cmd.Flags().GetString("unique-id")
if err != nil {
return err
}
if ownerID == "" && clusterID == "" {
return errors.New("neither owner identity nor unique identity provided to verify the Constellation")
}
return updatePCRMap(pcrs, ownerID, clusterID)
}
func updatePCRMap(pcrs map[uint32][]byte, ownerID, clusterID string) error {
if err := addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexOwnerID), ownerID); err != nil {
return err
}
return addOrSkipPCR(pcrs, uint32(vtpm.PCRIndexClusterID), clusterID)
}
// addOrSkipPCR adds a new entry to the map, or removes the key if the input is an empty string.
//
// When adding, the input is first decoded from base64.
// We then calculate the expected PCR by hashing the input using SHA256,
// appending expected PCR for initialization, and then hashing once more.
func addOrSkipPCR(toAdd map[uint32][]byte, pcrIndex uint32, encoded string) error {
if encoded == "" {
delete(toAdd, pcrIndex)
return nil
}
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return fmt.Errorf("input [%s] is not base64 encoded: %w", encoded, err)
}
// new_pcr_value := hash(old_pcr_value || data_to_extend)
// Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input
hashedInput := sha256.Sum256(decoded)
expectedPcr := sha256.Sum256(append(toAdd[pcrIndex], hashedInput[:]...))
toAdd[pcrIndex] = expectedPcr[:]
return nil
}
type verifier struct {
newConn func(context.Context, string, []atls.Validator) (status.ClientConn, error)
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
}
// newVerifiedConn creates a grpc over aTLS connection to the target, using the provided PCR values to verify the server.
func newVerifiedConn(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil {
return nil, err
}
return grpc.DialContext(
ctx, target, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
}
// verifyCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0, 1:
return []string{}, cobra.ShellCompDirectiveNoFileComp
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

51
cli/cmd/verify_azure.go Normal file
View file

@ -0,0 +1,51 @@
package cmd
import (
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newVerifyAzureCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "azure IP PORT",
Short: "Verify the confidential properties of your Constellation on Azure.",
Long: "Verify the confidential properties of your Constellation on Azure.",
Args: cobra.ExactArgs(2),
ValidArgsFunction: verifyCompletion,
RunE: runVerifyAzure,
}
return cmd
}
func runVerifyAzure(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
validators, err := getAzureValidator(cmd, *config.Provider.GCP.PCRs)
if err != nil {
return err
}
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
}
// getAzureValidator returns an Azure validator.
func getAzureValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
if err := prepareValidator(cmd, pcrs); err != nil {
return nil, err
}
return azure.NewValidator(pcrs), nil
}

View file

@ -0,0 +1,66 @@
package cmd
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetAzureValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyAzureCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
_, err := getAzureValidator(cmd, map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

51
cli/cmd/verify_gcp.go Normal file
View file

@ -0,0 +1,51 @@
package cmd
import (
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newVerifyGCPCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp IP PORT",
Short: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
Long: "Verify the confidential properties of your Constellation on Google Cloud Platform.",
Args: cobra.ExactArgs(2),
ValidArgsFunction: verifyCompletion,
RunE: runVerifyGCP,
}
return cmd
}
func runVerifyGCP(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
validators, err := getGCPValidator(cmd, *config.Provider.GCP.PCRs)
if err != nil {
return err
}
return runVerify(cmd, args, *config.Provider.GCP.PCRs, validators)
}
// getValidators returns a GCP validator.
func getGCPValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
if err := prepareValidator(cmd, pcrs); err != nil {
return nil, err
}
return gcp.NewValidator(pcrs), nil
}

View file

@ -0,0 +1,40 @@
package cmd
import (
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/spf13/cobra"
)
// TODO: Remove this command once we no longer use non cvms.
func newVerifyGCPNonCVMCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "gcp-non-cvm IP PORT",
Short: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
Long: "Verify the TPM attestation of your shielded VM Constellation on Google Cloud Platform.",
Args: cobra.ExactArgs(2),
ValidArgsFunction: verifyCompletion,
RunE: runVerifyGCPNonCVM,
}
return cmd
}
func runVerifyGCPNonCVM(cmd *cobra.Command, args []string) error {
pcrs := map[uint32][]byte{}
validator, err := getGCPNonCVMValidator(cmd, pcrs)
if err != nil {
return err
}
return runVerify(cmd, args, pcrs, validator)
}
// getGCPNonCVMValidator returns a GCP validator for regular shielded VMs.
func getGCPNonCVMValidator(cmd *cobra.Command, pcrs map[uint32][]byte) (atls.Validator, error) {
if err := prepareValidator(cmd, pcrs); err != nil {
return nil, err
}
return gcp.NewNonCVMValidator(pcrs), nil
}

View file

@ -0,0 +1,63 @@
package cmd
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetGCPNonCVMValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyGCPNonCVMCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
_, err := getGCPNonCVMValidator(cmd, map[uint32][]byte{})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

View file

@ -0,0 +1,66 @@
package cmd
import (
"bytes"
"encoding/base64"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetGCPValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyGCPCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
_, err := getGCPValidator(cmd, map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
})
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

325
cli/cmd/verify_test.go Normal file
View file

@ -0,0 +1,325 @@
package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
rpcStatus "google.golang.org/grpc/status"
)
func TestVerify(t *testing.T) {
testCases := map[string]struct {
connErr error
checkErr error
state state.State
errExpected bool
}{
"connection error": {
connErr: errors.New("connection error"),
checkErr: nil,
state: 0,
errExpected: true,
},
"check error": {
connErr: nil,
checkErr: errors.New("check error"),
state: 0,
errExpected: true,
},
"check error, rpc status": {
connErr: nil,
checkErr: rpcStatus.Error(codes.Unavailable, "check error"),
state: 0,
errExpected: true,
},
"verify on worker node": {
connErr: nil,
checkErr: nil,
state: state.IsNode,
errExpected: false,
},
"verify on master node": {
connErr: nil,
checkErr: nil,
state: state.ActivatingNodes,
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
var out bytes.Buffer
verifier := verifier{
newConn: stubNewConnFunc(tc.connErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{
state: tc.state,
checkErr: tc.checkErr,
}),
}
pcrs := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
err := verify(ctx, &out, "", []atls.Validator{gcp.NewValidator(pcrs)}, verifier)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Contains(out.String(), "OK")
}
})
}
}
func stubNewConnFunc(errStub error) func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
return func(ctx context.Context, target string, validators []atls.Validator) (status.ClientConn, error) {
return &stubClientConn{}, errStub
}
}
type stubClientConn struct{}
func (c *stubClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
return nil
}
func (c *stubClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return nil, nil
}
func (c *stubClientConn) Close() error {
return nil
}
func stubNewClientFunc(stubClient pubproto.APIClient) func(cc grpc.ClientConnInterface) pubproto.APIClient {
return func(cc grpc.ClientConnInterface) pubproto.APIClient {
return stubClient
}
}
type stubPeerStatusClient struct {
state state.State
checkErr error
pubproto.APIClient
}
func (c *stubPeerStatusClient) GetState(ctx context.Context, in *pubproto.GetStateRequest, opts ...grpc.CallOption) (*pubproto.GetStateResponse, error) {
resp := &pubproto.GetStateResponse{State: uint32(c.state)}
return resp, c.checkErr
}
func TestPrepareValidator(t *testing.T) {
testCases := map[string]struct {
ownerID string
clusterID string
errExpected bool
}{
"no input": {
ownerID: "",
clusterID: "",
errExpected: true,
},
"unencoded secret ID": {
ownerID: "owner-id",
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: true,
},
"unencoded cluster ID": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: "unique-id",
errExpected: true,
},
"correct input": {
ownerID: base64.StdEncoding.EncodeToString([]byte("owner-id")),
clusterID: base64.StdEncoding.EncodeToString([]byte("unique-id")),
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyCmd()
cmd.Flags().String("owner-id", "", "")
cmd.Flags().String("unique-id", "", "")
require.NoError(cmd.Flags().Set("owner-id", tc.ownerID))
require.NoError(cmd.Flags().Set("unique-id", tc.clusterID))
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
pcrs := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
err := prepareValidator(cmd, pcrs)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
if tc.clusterID != "" {
assert.Len(pcrs[uint32(vtpm.PCRIndexClusterID)], 32)
} else {
assert.Nil(pcrs[uint32(vtpm.PCRIndexClusterID)])
}
if tc.ownerID != "" {
assert.Len(pcrs[uint32(vtpm.PCRIndexOwnerID)], 32)
} else {
assert.Nil(pcrs[uint32(vtpm.PCRIndexOwnerID)])
}
}
})
}
}
func TestAddOrSkipPcr(t *testing.T) {
emptyMap := map[uint32][]byte{}
defaultMap := map[uint32][]byte{
0: []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"),
1: []byte("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"),
}
testCases := map[string]struct {
pcrMap map[uint32][]byte
pcrIndex uint32
encoded string
expectedEntries int
errExpected bool
}{
"empty input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: "",
expectedEntries: 0,
errExpected: false,
},
"empty input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: "",
expectedEntries: len(defaultMap),
errExpected: false,
},
"correct input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
expectedEntries: 1,
errExpected: false,
},
"correct input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: base64.StdEncoding.EncodeToString([]byte("Constellation")),
expectedEntries: len(defaultMap) + 1,
errExpected: false,
},
"unencoded input, empty map": {
pcrMap: emptyMap,
pcrIndex: 10,
encoded: "Constellation",
expectedEntries: 0,
errExpected: true,
},
"unencoded input, default map": {
pcrMap: defaultMap,
pcrIndex: 10,
encoded: "Constellation",
expectedEntries: len(defaultMap),
errExpected: true,
},
"empty input at occupied index": {
pcrMap: defaultMap,
pcrIndex: 0,
encoded: "",
expectedEntries: len(defaultMap) - 1,
errExpected: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
res := make(map[uint32][]byte)
for k, v := range tc.pcrMap {
res[k] = v
}
err := addOrSkipPCR(res, tc.pcrIndex, tc.encoded)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
assert.Len(res, tc.expectedEntries)
for _, v := range res {
assert.Len(v, 32)
}
})
}
}
func TestVerifyCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
resultExpected []string
shellCDExpected cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "192.0.2.1",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"192.0.2.1"},
toComplete: "443",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveNoFileComp,
},
"third arg": {
args: []string{"192.0.2.1", "443"},
toComplete: "./file",
resultExpected: []string{},
shellCDExpected: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := verifyCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.resultExpected, result)
assert.Equal(tc.shellCDExpected, shellCD)
})
}
}

19
cli/cmd/version.go Normal file
View file

@ -0,0 +1,19 @@
package cmd
import (
"github.com/edgelesssys/constellation/internal/config"
"github.com/spf13/cobra"
)
func newVersionCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "version",
Short: "Display version of this CLI",
Long: `Display version of this CLI`,
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
cmd.Printf("CLI Version: v%s \n", config.Version)
},
}
return cmd
}

25
cli/cmd/version_test.go Normal file
View file

@ -0,0 +1,25 @@
package cmd
import (
"bytes"
"io"
"testing"
"github.com/edgelesssys/constellation/internal/config"
"github.com/stretchr/testify/assert"
)
func TestVersionCmd(t *testing.T) {
assert := assert.New(t)
cmd := newVersionCmd()
b := bytes.NewBufferString("")
cmd.SetOut(b)
err := cmd.Execute()
assert.NoError(err)
s, err := io.ReadAll(b)
assert.NoError(err)
assert.Contains(string(s), config.Version)
}

5
cli/cmd/vpnconfigurer.go Normal file
View file

@ -0,0 +1,5 @@
package cmd
type vpnConfigurer interface {
Configure(clientVpnIp string, coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string) error
}

View file

@ -0,0 +1,17 @@
package cmd
type stubVPNConfigurer struct {
configured bool
configureErr error
}
func (c *stubVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error {
c.configured = true
return c.configureErr
}
type dummyVPNConfigurer struct{}
func (c *dummyVPNConfigurer) Configure(clientVpnIp, coordinatorPubKey, coordinatorPubIP, clientPrivKey string) error {
panic("dummy doesn't implement this function")
}

42
cli/ec2/client/api.go Normal file
View file

@ -0,0 +1,42 @@
package client
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/ec2"
)
// api collects used functions of AWS' ec2.Client as interfaces to enable testing.
type api interface {
ec2.DescribeInstancesAPIClient
// Instances
RunInstances(ctx context.Context,
params *ec2.RunInstancesInput,
optFns ...func(*ec2.Options)) (*ec2.RunInstancesOutput, error)
TerminateInstances(ctx context.Context,
params *ec2.TerminateInstancesInput,
optFns ...func(*ec2.Options)) (*ec2.TerminateInstancesOutput, error)
CreateTags(ctx context.Context,
params *ec2.CreateTagsInput,
optFns ...func(*ec2.Options)) (*ec2.CreateTagsOutput, error)
// SecurityGroup
CreateSecurityGroup(ctx context.Context,
params *ec2.CreateSecurityGroupInput,
optFns ...func(*ec2.Options)) (*ec2.CreateSecurityGroupOutput, error)
DeleteSecurityGroup(ctx context.Context,
params *ec2.DeleteSecurityGroupInput,
optFns ...func(*ec2.Options)) (*ec2.DeleteSecurityGroupOutput, error)
AuthorizeSecurityGroupIngress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupIngressInput,
optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupIngressOutput, error)
AuthorizeSecurityGroupEgress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupEgressInput,
optFns ...func(*ec2.Options)) (*ec2.AuthorizeSecurityGroupEgressOutput, error)
}

137
cli/ec2/client/api_test.go Normal file
View file

@ -0,0 +1,137 @@
package client
import (
"context"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/aws/smithy-go"
)
// stubAPI is a stub ec2 api for testing.
type stubAPI struct {
instances []types.Instance
securityGroup types.SecurityGroup
describeInstancesErr error
runInstancesErr error
runInstancesDryRunErr *error
terminateInstancesErr error
terminateInstancesDryRunErr *error
createTagsErr error
createSecurityGroupErr error
createSecurityGroupDryRunErr *error
deleteSecurityGroupErr error
deleteSecurityGroupDryRunErr *error
authorizeSecurityGroupIngressErr error
authorizeSecurityGroupIngressDryRunErr *error
authorizeSecurityGroupEgressErr error
authorizeSecurityGroupEgressDryRunErr *error
}
func (a stubAPI) DescribeInstances(ctx context.Context,
params *ec2.DescribeInstancesInput,
optFns ...func(*ec2.Options),
) (*ec2.DescribeInstancesOutput, error) {
return &ec2.DescribeInstancesOutput{
Reservations: []types.Reservation{
{Instances: a.instances},
},
}, a.describeInstancesErr
}
func (a stubAPI) RunInstances(ctx context.Context,
params *ec2.RunInstancesInput,
optFns ...func(*ec2.Options),
) (*ec2.RunInstancesOutput, error) {
if err := getDryRunErr(params.DryRun, a.runInstancesDryRunErr); err != nil {
return nil, err
}
return &ec2.RunInstancesOutput{Instances: a.instances}, a.runInstancesErr
}
func (a stubAPI) CreateTags(ctx context.Context,
params *ec2.CreateTagsInput,
optFns ...func(*ec2.Options),
) (*ec2.CreateTagsOutput, error) {
return nil, a.createTagsErr
}
func (a stubAPI) TerminateInstances(ctx context.Context,
params *ec2.TerminateInstancesInput,
optFns ...func(*ec2.Options),
) (*ec2.TerminateInstancesOutput, error) {
if err := getDryRunErr(params.DryRun, a.terminateInstancesDryRunErr); err != nil {
return nil, err
}
return nil, a.terminateInstancesErr
}
func (a stubAPI) CreateSecurityGroup(ctx context.Context,
params *ec2.CreateSecurityGroupInput,
optFns ...func(*ec2.Options),
) (*ec2.CreateSecurityGroupOutput, error) {
if err := getDryRunErr(params.DryRun, a.createSecurityGroupDryRunErr); err != nil {
return nil, err
}
return &ec2.CreateSecurityGroupOutput{
GroupId: a.securityGroup.GroupId,
}, a.createSecurityGroupErr
}
func (a stubAPI) DeleteSecurityGroup(ctx context.Context,
params *ec2.DeleteSecurityGroupInput,
optFns ...func(*ec2.Options),
) (*ec2.DeleteSecurityGroupOutput, error) {
if err := getDryRunErr(params.DryRun, a.deleteSecurityGroupDryRunErr); err != nil {
return nil, err
}
return nil, a.deleteSecurityGroupErr
}
func (a stubAPI) AuthorizeSecurityGroupIngress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupIngressInput,
optFns ...func(*ec2.Options),
) (*ec2.AuthorizeSecurityGroupIngressOutput, error) {
if err := getDryRunErr(params.DryRun, a.authorizeSecurityGroupIngressDryRunErr); err != nil {
return nil, err
}
return nil, a.authorizeSecurityGroupIngressErr
}
func (a stubAPI) AuthorizeSecurityGroupEgress(ctx context.Context,
params *ec2.AuthorizeSecurityGroupEgressInput,
optFns ...func(*ec2.Options),
) (*ec2.AuthorizeSecurityGroupEgressOutput, error) {
if err := getDryRunErr(params.DryRun, a.authorizeSecurityGroupEgressDryRunErr); err != nil {
return nil, err
}
return nil, a.authorizeSecurityGroupEgressErr
}
func getDryRunErr(dryRun *bool, stubErr *error) error {
if dryRun == nil || !*dryRun {
return nil
}
if stubErr != nil {
return *stubErr
}
return &smithy.GenericAPIError{Code: "DryRunOperation"}
}
var stateRunning = types.InstanceState{
Code: aws.Int32(int32(16)),
Name: types.InstanceStateNameRunning,
}
var stateTerminated = types.InstanceState{
Code: aws.Int32(48),
Name: types.InstanceStateNameTerminated,
}

71
cli/ec2/client/client.go Normal file
View file

@ -0,0 +1,71 @@
package client
import (
"context"
"errors"
"time"
awsconfig "github.com/aws/aws-sdk-go-v2/config"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/internal/state"
)
// Client for the AWS EC2 API.
type Client struct {
api api
instances ec2.Instances
securityGroup string
timeout time.Duration
}
func newClient(api api) (*Client, error) {
return &Client{
api: api,
instances: make(map[string]ec2.Instance),
timeout: 2 * time.Minute,
}, nil
}
// NewFromDefault creates a Client from the default config.
func NewFromDefault(ctx context.Context) (*Client, error) {
cfg, err := awsconfig.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
return newClient(awsec2.NewFromConfig(cfg))
}
// GetState returns the current configuration of the Constellation,
// which can be stored and used through later CLI commands.
func (c *Client) GetState() (state.ConstellationState, error) {
if len(c.instances) == 0 {
return state.ConstellationState{}, errors.New("client has no instances")
}
if c.securityGroup == "" {
return state.ConstellationState{}, errors.New("client has no security group")
}
return state.ConstellationState{
CloudProvider: cloudprovider.AWS.String(),
EC2Instances: c.instances,
EC2SecurityGroup: c.securityGroup,
}, nil
}
// SetState sets a Client to an existing configuration.
func (c *Client) SetState(stat state.ConstellationState) error {
if stat.CloudProvider != cloudprovider.AWS.String() {
return errors.New("state is not aws state")
}
if len(stat.EC2Instances) == 0 {
return errors.New("state has no instances")
}
if stat.EC2SecurityGroup == "" {
return errors.New("state has no security group")
}
c.instances = stat.EC2Instances
c.securityGroup = stat.EC2SecurityGroup
return nil
}

View file

@ -0,0 +1,120 @@
package client
import (
"testing"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestGetState(t *testing.T) {
testCases := map[string]struct {
client Client
wantState state.ConstellationState
wantErr bool
}{
"successful get": {
client: Client{
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
securityGroup: "sg",
},
wantState: state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
EC2SecurityGroup: "sg",
},
},
"client without security group": {
client: Client{
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
"client without instances": {
client: Client{
securityGroup: "sg",
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
stat, err := tc.client.GetState()
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantState, stat)
}
})
}
}
func TestSetState(t *testing.T) {
testCases := map[string]struct {
state state.ConstellationState
wantInstances ec2.Instances
wantSecurityGroup string
wantErr bool
}{
"successful set": {
state: state.ConstellationState{
CloudProvider: "AWS",
EC2SecurityGroup: "sg-test",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantInstances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
wantSecurityGroup: "sg-test",
},
"state without cloudprovider": {
state: state.ConstellationState{
EC2SecurityGroup: "sg-test",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
"state with incorrect cloudprovider": {
state: state.ConstellationState{
CloudProvider: "incorrect",
EC2SecurityGroup: "sg-test",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
"state without instances": {
state: state.ConstellationState{
CloudProvider: "AWS",
EC2SecurityGroup: "sg-test",
},
wantErr: true,
},
"state without security group": {
state: state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{}
err := client.SetState(tc.state)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantInstances, client.instances)
assert.Equal(tc.wantSecurityGroup, client.securityGroup)
}
})
}
}

199
cli/ec2/client/instances.go Normal file
View file

@ -0,0 +1,199 @@
package client
import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/ec2"
)
// CreateInstances creates the instances defined in input.
//
// An existing security group is needed to create instances.
func (c *Client) CreateInstances(ctx context.Context, input CreateInput) error {
if c.securityGroup == "" {
return errors.New("no security group set")
}
input.securityGroupIds = []string{c.securityGroup}
if err := c.createDryRun(ctx, input); err != nil {
return err
}
resp, err := c.api.RunInstances(ctx, input.AWS())
if err != nil {
return fmt.Errorf("failed to create instances: %w", err)
}
for _, instance := range resp.Instances {
id := instance.InstanceId
if id == nil {
return errors.New("instanceId is nil pointer")
}
c.instances[*id] = ec2.Instance{}
}
if err := c.waitStateRunning(ctx); err != nil {
return err
}
if err := c.tagInstances(ctx, input.Tags); err != nil {
return err
}
if err := c.getInstanceIPs(ctx); err != nil {
return err
}
return nil
}
// TerminateInstances terminates all instances of a Client.
func (c *Client) TerminateInstances(ctx context.Context) error {
if len(c.instances) == 0 {
return nil
}
input := &awsec2.TerminateInstancesInput{
InstanceIds: c.instances.IDs(),
}
if err := c.terminateDryRun(ctx, *input); err != nil {
return err
}
if _, err := c.api.TerminateInstances(ctx, input); err != nil {
return err
}
if err := c.waitStateTerminated(ctx); err != nil {
return err
}
c.instances = ec2.Instances{}
return nil
}
// waitStateRunning waits until all the client's instances reached the running state.
//
// A set of instances is also considered to be running if at least one of the
// instances' state is 'running' and the other instances have a nil state.
func (c *Client) waitStateRunning(ctx context.Context) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
waiter := awsec2.NewInstanceRunningWaiter(c.api)
return waiter.Wait(ctx, describeInput, c.timeout)
}
// waitStateTerminated waits until all the client's instances reached the terminated state.
//
// A set of instances is also considered to be terminated if at least one of the
// instances' state is 'terminated' and the other instances have a nil state.
func (c *Client) waitStateTerminated(ctx context.Context) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
waiter := awsec2.NewInstanceTerminatedWaiter(c.api)
return waiter.Wait(ctx, describeInput, c.timeout)
}
// tagInstances tags all instances of a client with a given set of tags.
func (c *Client) tagInstances(ctx context.Context, tags ec2.Tags) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
tagInput := &awsec2.CreateTagsInput{
Resources: c.instances.IDs(),
Tags: tags.AWS(),
}
if _, err := c.api.CreateTags(ctx, tagInput); err != nil {
return fmt.Errorf("failed to tag instances: %w", err)
}
return nil
}
// createDryRun checks if user has the privilege to create the instances
// which were defined in input.
func (c *Client) createDryRun(ctx context.Context, input CreateInput) error {
runInput := input.AWS()
runInput.DryRun = aws.Bool(true)
_, err := c.api.RunInstances(ctx, runInput)
return checkDryRunError(err)
}
// terminateDryRun checks if user has the privilege to terminate the instances
// which were defined in input.
func (c *Client) terminateDryRun(ctx context.Context, input awsec2.TerminateInstancesInput) error {
input.DryRun = aws.Bool(true)
_, err := c.api.TerminateInstances(ctx, &input)
return checkDryRunError(err)
}
// getInstanceIPs queries the private and public IP addresses
// and adds the information to each instance.
//
// The instances must be in 'running' state.
func (c *Client) getInstanceIPs(ctx context.Context) error {
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
paginator := awsec2.NewDescribeInstancesPaginator(c.api, describeInput)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
return err
}
for _, reservation := range output.Reservations {
for _, instanceDescription := range reservation.Instances {
if instanceDescription.InstanceId == nil {
return errors.New("instanceId is nil pointer")
}
if instanceDescription.PublicIpAddress == nil {
return errors.New("publicIpAddress is nil pointer")
}
if instanceDescription.PrivateIpAddress == nil {
return errors.New("privateIpAddress is nil pointer")
}
instance, ok := c.instances[*instanceDescription.InstanceId]
if !ok {
return errors.New("got an instance description to an unknown instanceId")
}
instance.PublicIP = *instanceDescription.PublicIpAddress
instance.PrivateIP = *instanceDescription.PrivateIpAddress
c.instances[*instanceDescription.InstanceId] = instance
}
}
}
return nil
}
// CreateInput defines the propertis of the instances to create.
type CreateInput struct {
ImageId string
InstanceType string
Count int
Tags ec2.Tags
securityGroupIds []string
}
// AWS creates a AWS ec2.RunInstancesInput from an CreateInput.
func (ci *CreateInput) AWS() *awsec2.RunInstancesInput {
return &awsec2.RunInstancesInput{
ImageId: aws.String(ci.ImageId),
InstanceType: ec2.InstanceTypes[ci.InstanceType],
MaxCount: aws.Int32(int32(ci.Count)),
MinCount: aws.Int32(int32(ci.Count)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: ci.securityGroupIds,
}
}

View file

@ -0,0 +1,493 @@
package client
import (
"context"
"errors"
"testing"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/stretchr/testify/assert"
)
func TestCreateInstances(t *testing.T) {
testInstances := []types.Instance{
{
InstanceId: aws.String("id-1"),
PublicIpAddress: aws.String("192.0.2.1"),
PrivateIpAddress: aws.String("192.0.2.2"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-2"),
PublicIpAddress: aws.String("192.0.2.3"),
PrivateIpAddress: aws.String("192.0.2.4"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-3"),
PublicIpAddress: aws.String("192.0.2.5"),
PrivateIpAddress: aws.String("192.0.2.6"),
State: &stateRunning,
},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
instances ec2.Instances
securityGroup string
errExpected bool
wantInstances ec2.Instances
}{
"create": {
api: stubAPI{instances: testInstances},
securityGroup: "sg-test",
wantInstances: ec2.Instances{
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
},
},
"client already has instances": {
api: stubAPI{instances: testInstances},
instances: ec2.Instances{"id-4": {}, "id-5": {}},
securityGroup: "sg-test",
wantInstances: ec2.Instances{
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
"id-4": {},
"id-5": {},
},
},
"client already has same instance id": {
api: stubAPI{instances: testInstances},
instances: ec2.Instances{"id-1": {}, "id-4": {}, "id-5": {}},
securityGroup: "sg-test",
errExpected: false,
wantInstances: ec2.Instances{
"id-1": {PublicIP: "192.0.2.1", PrivateIP: "192.0.2.2"},
"id-2": {PublicIP: "192.0.2.3", PrivateIP: "192.0.2.4"},
"id-3": {PublicIP: "192.0.2.5", PrivateIP: "192.0.2.6"},
"id-4": {},
"id-5": {},
},
},
"client has no security group": {
api: stubAPI{},
errExpected: true,
},
"run API error": {
api: stubAPI{runInstancesErr: someErr},
securityGroup: "sg-test",
errExpected: true,
},
"runDryRun API error": {
api: stubAPI{runInstancesDryRunErr: &someErr},
securityGroup: "sg-test",
errExpected: true,
},
"runDryRun missing expected API error": {
api: stubAPI{runInstancesDryRunErr: &noErr},
securityGroup: "sg-test",
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
securityGroup: tc.securityGroup,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
input := CreateInput{
ImageId: "test-image",
InstanceType: "",
Count: 13,
}
err := client.CreateInstances(context.Background(), input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.ElementsMatch(tc.wantInstances.IDs(), client.instances.IDs())
assert.ElementsMatch(tc.wantInstances.PublicIPs(), client.instances.PublicIPs())
assert.ElementsMatch(tc.wantInstances.PrivateIPs(), client.instances.PrivateIPs())
}
})
}
}
func TestTerminateInstances(t *testing.T) {
testAWSInstances := []types.Instance{
{InstanceId: aws.String("id-1"), State: &stateTerminated},
{InstanceId: aws.String("id-2"), State: &stateTerminated},
{InstanceId: aws.String("id-3"), State: &stateTerminated},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
instances ec2.Instances
errExpected bool
}{
"client with instances": {
api: stubAPI{instances: testAWSInstances},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"client no instances set": {
api: stubAPI{},
},
"terminate API error": {
api: stubAPI{terminateInstancesErr: someErr},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"terminateDryRun API error": {
api: stubAPI{terminateInstancesDryRunErr: &someErr},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"terminateDryRun miss expected API error": {
api: stubAPI{terminateInstancesDryRunErr: &noErr},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.TerminateInstances(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Empty(client.instances)
}
})
}
}
func TestWaitStateRunning(t *testing.T) {
testCases := map[string]struct {
api api
instances ec2.Instances
errExpected bool
}{
"instances are running": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-2"),
State: &stateRunning,
},
{
InstanceId: aws.String("id-3"),
State: &stateRunning,
},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance running, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateRunning,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance terminated, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"instances with different state": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-2"),
State: &stateRunning,
},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"all instances have nil state": {
api: stubAPI{instances: []types.Instance{
{InstanceId: aws.String("id-1")},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"client has no instances": {
api: &stubAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.waitStateRunning(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestWaitStateTerminated(t *testing.T) {
testCases := map[string]struct {
api api
instances ec2.Instances
errExpected bool
}{
"instances are terminated": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-2"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-3"),
State: &stateTerminated,
},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance terminated, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"one instance running, rest nil": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateRunning,
},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"instances with different state": {
api: stubAPI{instances: []types.Instance{
{
InstanceId: aws.String("id-1"),
State: &stateTerminated,
},
{
InstanceId: aws.String("id-2"),
State: &stateRunning,
},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"all instances have nil state": {
api: stubAPI{instances: []types.Instance{
{InstanceId: aws.String("id-1")},
{InstanceId: aws.String("id-2")},
{InstanceId: aws.String("id-3")},
}},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
"client has no instances": {
api: &stubAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.waitStateTerminated(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTagInstances(t *testing.T) {
testTags := ec2.Tags{
{Key: "Name", Value: "Test"},
{Key: "Foo", Value: "Bar"},
}
testCases := map[string]struct {
api stubAPI
instances ec2.Instances
errExpected bool
}{
"tag": {
api: stubAPI{},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: false,
},
"client without instances": {
api: stubAPI{createTagsErr: errors.New("failed")},
errExpected: true,
},
"tag API error": {
api: stubAPI{createTagsErr: errors.New("failed")},
instances: ec2.Instances{"id-1": {}, "id-2": {}, "id-3": {}},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
api: tc.api,
instances: tc.instances,
timeout: time.Millisecond,
}
if client.instances == nil {
client.instances = make(map[string]ec2.Instance)
}
err := client.tagInstances(context.Background(), testTags)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestEc2RunInstanceInput(t *testing.T) {
assert := assert.New(t)
testCases := []struct {
in CreateInput
outExpected awsec2.RunInstancesInput
}{
{
in: CreateInput{
ImageId: "test-image",
InstanceType: "4xlarge",
Count: 13,
securityGroupIds: []string{"test-sec-group"},
},
outExpected: awsec2.RunInstancesInput{
ImageId: aws.String("test-image"),
InstanceType: types.InstanceTypeC5a4xlarge,
MinCount: aws.Int32(int32(13)),
MaxCount: aws.Int32(int32(13)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: []string{"test-sec-group"},
},
},
{
in: CreateInput{
ImageId: "test-image-2",
InstanceType: "12xlarge",
Count: 2,
securityGroupIds: []string{"test-sec-group-2"},
},
outExpected: awsec2.RunInstancesInput{
ImageId: aws.String("test-image-2"),
InstanceType: types.InstanceTypeC5a12xlarge,
MinCount: aws.Int32(int32(2)),
MaxCount: aws.Int32(int32(2)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: []string{"test-sec-group-2"},
},
},
}
for _, tc := range testCases {
out := tc.in.AWS()
assert.Equal(tc.outExpected, *out)
}
}

View file

@ -0,0 +1,136 @@
package client
import (
"context"
"errors"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/google/uuid"
)
// CreateSecurityGroup creates a AWS security group with the handed properties.
func (c *Client) CreateSecurityGroup(ctx context.Context, input SecurityGroupInput) error {
if c.securityGroup != "" {
return errors.New("client already has a security group")
}
id := uuid.New()
createInput := &awsec2.CreateSecurityGroupInput{
Description: aws.String("Security group of Constellation. This group was generated through the Constellation CLI."),
GroupName: aws.String("Constellation-" + id.String()),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.CreateSecurityGroup(ctx, createInput)
if err := checkDryRunError(err); err != nil {
return err
}
createInput.DryRun = aws.Bool(false)
// Create
out, err := c.api.CreateSecurityGroup(ctx, createInput)
if err != nil {
return err
}
if out.GroupId == nil {
return errors.New("security group creation didn't return an id")
}
c.securityGroup = *out.GroupId
// Authorize.
return c.authorizeSecurityGroup(ctx, input)
}
// DeleteSecurityGroup deletes the security group of the client.
// The deletion is idempotent, no error is returned if the client has
// an empty securityGroupID.
func (c *Client) DeleteSecurityGroup(ctx context.Context) error {
if c.securityGroup == "" {
return nil
}
input := &awsec2.DeleteSecurityGroupInput{
GroupId: aws.String(c.securityGroup),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.DeleteSecurityGroup(ctx, input)
if err := checkDryRunError(err); err != nil {
return err
}
input.DryRun = aws.Bool(false)
// Delete
if _, err := c.api.DeleteSecurityGroup(ctx, input); err != nil {
return err
}
c.securityGroup = ""
return nil
}
func (c *Client) authorizeSecurityGroup(ctx context.Context, input SecurityGroupInput) error {
if c.securityGroup == "" {
return errors.New("client hasn't got a security group id")
}
if err := c.authorizeSecurityGroupIngress(ctx, input.Inbound); err != nil {
return err
}
return c.authorizeSecurityGroupEgress(ctx, input.Outbound)
}
func (c *Client) authorizeSecurityGroupIngress(ctx context.Context, perms cloudtypes.Firewall) error {
if len(perms) == 0 {
return nil
}
authInput := &awsec2.AuthorizeSecurityGroupIngressInput{
GroupId: aws.String(c.securityGroup),
IpPermissions: perms.AWS(),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.AuthorizeSecurityGroupIngress(ctx, authInput)
if err := checkDryRunError(err); err != nil {
return err
}
authInput.DryRun = aws.Bool(false)
// Authorize
_, err = c.api.AuthorizeSecurityGroupIngress(ctx, authInput)
return err
}
func (c *Client) authorizeSecurityGroupEgress(ctx context.Context, perms cloudtypes.Firewall) error {
if len(perms) == 0 {
return nil
}
authInput := &awsec2.AuthorizeSecurityGroupEgressInput{
GroupId: aws.String(c.securityGroup),
IpPermissions: perms.AWS(),
DryRun: aws.Bool(true),
}
// DryRun
_, err := c.api.AuthorizeSecurityGroupEgress(ctx, authInput)
if err := checkDryRunError(err); err != nil {
return err
}
authInput.DryRun = aws.Bool(false)
// Authorize
_, err = c.api.AuthorizeSecurityGroupEgress(ctx, authInput)
return err
}
type SecurityGroupInput struct {
Inbound cloudtypes.Firewall
Outbound cloudtypes.Firewall
}

View file

@ -0,0 +1,269 @@
package client
import (
"context"
"errors"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateSecurityGroup(t *testing.T) {
testInput := SecurityGroupInput{
Inbound: cloudtypes.Firewall{
{
Description: "perm1",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 22,
},
{
Description: "perm2",
Protocol: "UDP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
},
Outbound: cloudtypes.Firewall{
{
Description: "perm3",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 4040,
},
},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
securityGroup string
input SecurityGroupInput
errExpected bool
securityGroupExpected string
}{
"create security group": {
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")}},
input: testInput,
securityGroupExpected: "sg-test",
},
"create security group without permissions": {
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")}},
input: SecurityGroupInput{},
securityGroupExpected: "sg-test",
},
"client already has security group": {
api: stubAPI{},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"create returns nil security group ID": {
api: stubAPI{securityGroup: types.SecurityGroup{GroupId: nil}},
input: testInput,
errExpected: true,
},
"create API error": {
api: stubAPI{createSecurityGroupErr: someErr},
input: testInput,
errExpected: true,
},
"create DryRun API error": {
api: stubAPI{createSecurityGroupDryRunErr: &someErr},
input: testInput,
errExpected: true,
},
"create DryRun missing expected error": {
api: stubAPI{createSecurityGroupDryRunErr: &noErr},
input: testInput,
errExpected: true,
},
"authorize error": {
api: stubAPI{
securityGroup: types.SecurityGroup{GroupId: aws.String("sg-test")},
authorizeSecurityGroupIngressErr: someErr,
},
input: testInput,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client, err := newClient(tc.api)
require.NoError(err)
client.securityGroup = tc.securityGroup
err = client.CreateSecurityGroup(context.Background(), tc.input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.securityGroupExpected, client.securityGroup)
}
})
}
}
func TestDeleteSecurityGroup(t *testing.T) {
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
securityGroup string
errExpected bool
}{
"delete security group": {
api: stubAPI{},
securityGroup: "sg-test",
},
"client without security group": {
api: stubAPI{},
},
"delete API error": {
api: stubAPI{deleteSecurityGroupErr: someErr},
securityGroup: "sg-test",
errExpected: true,
},
"delete DryRun API error": {
api: stubAPI{deleteSecurityGroupDryRunErr: &someErr},
securityGroup: "sg-test",
errExpected: true,
},
"delete DryRun missing expected error": {
api: stubAPI{deleteSecurityGroupDryRunErr: &noErr},
securityGroup: "sg-test",
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client, err := newClient(tc.api)
require.NoError(err)
client.securityGroup = tc.securityGroup
err = client.DeleteSecurityGroup(context.Background())
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Empty(client.securityGroup)
}
})
}
}
func TestAuthorizeSecurityGroup(t *testing.T) {
testInput := SecurityGroupInput{
Inbound: cloudtypes.Firewall{
{
Description: "perm1",
Protocol: "TCP",
IPRange: " 192.0.2.0/24",
Port: 22,
},
{
Description: "perm2",
Protocol: "UDP",
IPRange: "192.0.2.0/24",
Port: 4433,
},
},
Outbound: cloudtypes.Firewall{
{
Description: "perm3",
Protocol: "TCP",
IPRange: "192.0.2.0/24",
Port: 4040,
},
},
}
someErr := errors.New("failed")
var noErr error
testCases := map[string]struct {
api stubAPI
securityGroup string
input SecurityGroupInput
errExpected bool
}{
"authorize": {
api: stubAPI{},
securityGroup: "sg-test",
input: testInput,
errExpected: false,
},
"client without security group": {
api: stubAPI{},
input: testInput,
errExpected: true,
},
"authorizeIngress API error": {
api: stubAPI{authorizeSecurityGroupIngressErr: someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeIngress DryRun API error": {
api: stubAPI{authorizeSecurityGroupIngressDryRunErr: &someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeIngress DryRun missing expected error": {
api: stubAPI{authorizeSecurityGroupIngressDryRunErr: &noErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeEgress API error": {
api: stubAPI{authorizeSecurityGroupEgressErr: someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeEgress DryRun API error": {
api: stubAPI{authorizeSecurityGroupEgressDryRunErr: &someErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
"authorizeEgress DryRun missing expected error": {
api: stubAPI{authorizeSecurityGroupEgressDryRunErr: &noErr},
securityGroup: "sg-test",
input: testInput,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client, err := newClient(tc.api)
require.NoError(err)
client.securityGroup = tc.securityGroup
err = client.authorizeSecurityGroup(context.Background(), tc.input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

21
cli/ec2/client/util.go Normal file
View file

@ -0,0 +1,21 @@
package client
import (
"errors"
"github.com/aws/smithy-go"
)
// checkDryRunError error checks if an error is a DryRun error.
// If the error is nil, an error is returned, since a DryRun error
// is the expected result of a DryRun operation.
func checkDryRunError(err error) error {
var apiErr smithy.APIError
if errors.As(err, &apiErr) && apiErr.ErrorCode() == "DryRunOperation" {
return nil
}
if err != nil {
return err
}
return errors.New("expected APIError: DryRunOperation, but got no error at all")
}

View file

@ -0,0 +1,22 @@
package client
import (
"errors"
"testing"
"github.com/aws/smithy-go"
"github.com/stretchr/testify/assert"
)
func TestCheckDryRunError(t *testing.T) {
assert := assert.New(t)
someErr := errors.New("failed")
assert.ErrorIs(checkDryRunError(someErr), someErr)
dryRunErr := smithy.GenericAPIError{Code: "DryRunOperation"}
assert.NoError(checkDryRunError(&dryRunErr))
var nilErr error
assert.Error(checkDryRunError(nilErr))
}

58
cli/ec2/instances.go Normal file
View file

@ -0,0 +1,58 @@
package ec2
import "errors"
// Instance is an ec2 instance.
type Instance struct {
PublicIP string
PrivateIP string
}
// Instances is a map of ec2 Instances. The ID of an instance is used as key.
type Instances map[string]Instance
// IDs returns the IDs of all instances of the Constellation.
func (i Instances) IDs() []string {
var ids []string
for id := range i {
ids = append(ids, id)
}
return ids
}
// PublicIPs returns the public IPs of all the instances of the Constellation.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}
// GetOne return anyone instance out of the instances and its ID.
func (i Instances) GetOne() (string, Instance, error) {
for id, instance := range i {
return id, instance, nil
}
return "", Instance{}, errors.New("map is empty")
}
// GetOthers returns all instances but the one with the handed ID.
func (i Instances) GetOthers(id string) Instances {
others := make(Instances)
for key, instance := range i {
if key != id {
others[key] = instance
}
}
return others
}

71
cli/ec2/instances_test.go Normal file
View file

@ -0,0 +1,71 @@
package ec2
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIDs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
assert.ElementsMatch(expectedIDs, testState.IDs())
}
func TestPublicIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
}
func TestPrivateIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
}
func TestGetOne(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
id, instance, err := testState.GetOne()
assert.NoError(err)
assert.Contains(testState, id)
assert.Equal(testState[id], instance)
}
func TestGetOthers(t *testing.T) {
assert := assert.New(t)
testCases := testInstances().IDs()
for _, id := range testCases {
others := testInstances().GetOthers(id)
assert.NotContains(others, id)
expectedInstances := testInstances()
delete(expectedInstances, id)
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
}
}
func testInstances() Instances {
return Instances{
"id-9": {
PublicIP: "192.0.2.1",
PrivateIP: "192.0.2.2",
},
"id-10": {
PublicIP: "192.0.2.3",
PrivateIP: "192.0.2.4",
},
"id-11": {
PublicIP: "192.0.2.5",
PrivateIP: "192.0.2.6",
},
"id-12": {
PublicIP: "192.0.2.7",
PrivateIP: "192.0.2.8",
},
}
}

18
cli/ec2/instancetypes.go Normal file
View file

@ -0,0 +1,18 @@
package ec2
import "github.com/aws/aws-sdk-go-v2/service/ec2/types"
// InstanceTypes defines possible values for the SIZE positional argument.
var InstanceTypes = map[string]types.InstanceType{
"4xlarge": types.InstanceTypeC5a4xlarge,
"8xlarge": types.InstanceTypeC5a8xlarge,
"12xlarge": types.InstanceTypeC5a12xlarge,
"16xlarge": types.InstanceTypeC5a16xlarge,
"24xlarge": types.InstanceTypeC5a24xlarge,
// shorthands
"4xl": types.InstanceTypeC5a4xlarge,
"8xl": types.InstanceTypeC5a8xlarge,
"12xl": types.InstanceTypeC5a12xlarge,
"16xl": types.InstanceTypeC5a16xlarge,
"24xl": types.InstanceTypeC5a24xlarge,
}

27
cli/ec2/tags.go Normal file
View file

@ -0,0 +1,27 @@
package ec2
import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
)
// Tag is a ec2 tag. It consits of a key and a value.
type Tag struct {
Key string
Value string
}
// Tags is a set of Tags.
type Tags []Tag
// AWS returns a AWS representation of tags.
func (t Tags) AWS() []types.Tag {
var awsTags []types.Tag
for _, tag := range t {
awsTags = append(awsTags, types.Tag{
Key: aws.String(tag.Key),
Value: aws.String(tag.Value),
})
}
return awsTags
}

37
cli/ec2/tags_test.go Normal file
View file

@ -0,0 +1,37 @@
package ec2
import (
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
)
func TestTagsAws(t *testing.T) {
assert := assert.New(t)
testTags := Tags{
{
Key: "Name",
Value: "Test",
},
{
Key: "Foo",
Value: "Bar",
},
}
expected := []types.Tag{
{
Key: aws.String("Name"),
Value: aws.String("Test"),
},
{
Key: aws.String("Foo"),
Value: aws.String("Bar"),
},
}
awsTags := testTags.AWS()
assert.Equal(expected, awsTags)
}

86
cli/file/file.go Normal file
View file

@ -0,0 +1,86 @@
/*
Package file provides functions that combine file handling, JSON marshaling
and file system abstraction.
*/
package file
import (
"encoding/json"
"io"
"io/fs"
"os"
"github.com/spf13/afero"
)
// Handler handles file interaction.
type Handler struct {
fs *afero.Afero
}
// NewHandler returns a new file handler.
func NewHandler(fs afero.Fs) Handler {
afs := &afero.Afero{Fs: fs}
return Handler{fs: afs}
}
// Read reads the file given name and returns the bytes read.
func (h *Handler) Read(name string) ([]byte, error) {
file, err := h.fs.OpenFile(name, os.O_RDONLY, 0o644)
if err != nil {
return nil, err
}
defer file.Close()
return io.ReadAll(file)
}
// Write writes the data bytes into the file with the given name.
// If a file already exists at path and overwrite is true, the file will be
// overwritten. Otherwise, an error is returned.
func (h *Handler) Write(name string, data []byte, overwrite bool) error {
flags := os.O_WRONLY | os.O_CREATE | os.O_EXCL
if overwrite {
flags = os.O_WRONLY | os.O_CREATE | os.O_TRUNC
}
file, err := h.fs.OpenFile(name, flags, 0o644)
if err != nil {
return err
}
_, err = file.Write(data)
if errTmp := file.Close(); errTmp != nil && err == nil {
err = errTmp
}
return err
}
// ReadJSON reads a JSON file from name and unmarshals it into the content interface.
// The interface content must be a pointer to a JSON marchalable object.
func (h *Handler) ReadJSON(name string, content interface{}) error {
data, err := h.Read(name)
if err != nil {
return err
}
return json.Unmarshal(data, content)
}
// WriteJSON marshals the content interface to JSON and writes it to the path with the given name.
// If a file already exists and overwrite is true, the file will be
// overwritten. Otherwise, an error is returned.
func (h *Handler) WriteJSON(name string, content interface{}, overwrite bool) error {
jsonData, err := json.MarshalIndent(content, "", "\t")
if err != nil {
return err
}
return h.Write(name, jsonData, overwrite)
}
// Remove deletes the file with the given name.
func (h *Handler) Remove(name string) error {
return h.fs.Remove(name)
}
// Stat returns a FileInfo describing the named file, or an error, if any
// happens.
func (h *Handler) Stat(name string) (fs.FileInfo, error) {
return h.fs.Stat(name)
}

168
cli/file/file_test.go Normal file
View file

@ -0,0 +1,168 @@
package file
import (
"encoding/json"
"testing"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadJSON(t *testing.T) {
type testContent struct {
First string
Second int
}
someContent := testContent{
First: "first",
Second: 2,
}
jsonContent, err := json.MarshalIndent(someContent, "", "\t")
require.NoError(t, err)
testCases := map[string]struct {
fs afero.Fs
setupFs func(fs *afero.Afero) error
name string
wantContent interface{}
wantErr bool
}{
"successful read": {
fs: afero.NewMemMapFs(),
name: "test/statefile",
setupFs: func(fs *afero.Afero) error { return fs.WriteFile("test/statefile", jsonContent, 0o755) },
wantContent: someContent,
},
"file not existent": {
fs: afero.NewMemMapFs(),
name: "test/statefile",
wantErr: true,
},
"file not json": {
fs: afero.NewMemMapFs(),
name: "test/statefile",
setupFs: func(fs *afero.Afero) error { return fs.WriteFile("test/statefile", []byte{0x1}, 0o755) },
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := NewHandler(tc.fs)
if tc.setupFs != nil {
require.NoError(tc.setupFs(handler.fs))
}
resultContent := &testContent{}
if tc.wantErr {
assert.Error(handler.ReadJSON(tc.name, resultContent))
} else {
assert.NoError(handler.ReadJSON(tc.name, resultContent))
assert.Equal(tc.wantContent, *resultContent)
}
})
}
}
func TestWriteJSON(t *testing.T) {
type testContent struct {
First string
Second int
}
someContent := testContent{
First: "first",
Second: 2,
}
notMarshalableContent := struct{ Foo chan int }{Foo: make(chan int)}
testCases := map[string]struct {
fs afero.Fs
setupFs func(af afero.Afero) error
name string
content interface{}
overwrite bool
wantErr bool
}{
"successful write": {
fs: afero.NewMemMapFs(),
name: "test/statefile",
content: someContent,
},
"successful overwrite": {
fs: afero.NewMemMapFs(),
setupFs: func(af afero.Afero) error { return af.WriteFile("test/statefile", []byte{}, 0o644) },
name: "test/statefile",
content: someContent,
overwrite: true,
},
"read only fs": {
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
name: "test/statefile",
content: someContent,
wantErr: true,
},
"file already exists": {
fs: afero.NewMemMapFs(),
setupFs: func(af afero.Afero) error { return af.WriteFile("test/statefile", []byte{}, 0o644) },
name: "test/statefile",
content: someContent,
wantErr: true,
},
"marshal error": {
fs: afero.NewMemMapFs(),
name: "test/statefile",
content: notMarshalableContent,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
handler := NewHandler(tc.fs)
if tc.setupFs != nil {
require.NoError(tc.setupFs(afero.Afero{Fs: tc.fs}))
}
if tc.wantErr {
assert.Error(handler.WriteJSON(tc.name, tc.content, tc.overwrite))
} else {
assert.NoError(handler.WriteJSON(tc.name, tc.content, tc.overwrite))
resultContent := &testContent{}
assert.NoError(handler.ReadJSON(tc.name, resultContent))
assert.Equal(tc.content, *resultContent)
}
})
}
}
func TestRemove(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
handler := NewHandler(fs)
aferoHelper := afero.Afero{Fs: fs}
require.NoError(aferoHelper.WriteFile("a", []byte{0xa}, 0o644))
require.NoError(aferoHelper.WriteFile("b", []byte{0xb}, 0o644))
require.NoError(aferoHelper.WriteFile("c", []byte{0xc}, 0o644))
assert.NoError(handler.Remove("a"))
assert.NoError(handler.Remove("b"))
assert.NoError(handler.Remove("c"))
_, err := handler.fs.Stat("a")
assert.ErrorIs(err, afero.ErrFileNotFound)
_, err = handler.fs.Stat("b")
assert.ErrorIs(err, afero.ErrFileNotFound)
_, err = handler.fs.Stat("c")
assert.ErrorIs(err, afero.ErrFileNotFound)
assert.Error(handler.Remove("d"))
}

View file

@ -0,0 +1,7 @@
package gcp
import "fmt"
func AutoscalingNodeGroup(project string, zone string, nodeInstanceGroup string, min int, max int) string {
return fmt.Sprintf("%d:%d:https://www.googleapis.com/compute/v1/projects/%s/zones/%s/instanceGroups/%s", min, max, project, zone, nodeInstanceGroup)
}

View file

@ -0,0 +1,14 @@
package gcp
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAutoscalingNodeGroup(t *testing.T) {
assert := assert.New(t)
nodeGroups := AutoscalingNodeGroup("some-project", "some-zone", "some-group", 0, 100)
expectedNodeGroups := "0:100:https://www.googleapis.com/compute/v1/projects/some-project/zones/some-zone/instanceGroups/some-group"
assert.Equal(expectedNodeGroups, nodeGroups)
}

101
cli/gcp/client/api.go Normal file
View file

@ -0,0 +1,101 @@
package client
import (
"context"
"github.com/googleapis/gax-go/v2"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
adminpb "google.golang.org/genproto/googleapis/iam/admin/v1"
iampb "google.golang.org/genproto/googleapis/iam/v1"
)
type instanceAPI interface {
Close() error
List(ctx context.Context, req *computepb.ListInstancesRequest,
opts ...gax.CallOption) InstanceIterator
}
type operationRegionAPI interface {
Close() error
Wait(ctx context.Context, req *computepb.WaitRegionOperationRequest,
opts ...gax.CallOption) (*computepb.Operation, error)
}
type operationZoneAPI interface {
Close() error
Wait(ctx context.Context, req *computepb.WaitZoneOperationRequest,
opts ...gax.CallOption) (*computepb.Operation, error)
}
type operationGlobalAPI interface {
Close() error
Wait(ctx context.Context, req *computepb.WaitGlobalOperationRequest,
opts ...gax.CallOption) (*computepb.Operation, error)
}
type firewallsAPI interface {
Close() error
Delete(ctx context.Context, req *computepb.DeleteFirewallRequest,
opts ...gax.CallOption) (Operation, error)
Insert(ctx context.Context, req *computepb.InsertFirewallRequest,
opts ...gax.CallOption) (Operation, error)
}
type networksAPI interface {
Close() error
Delete(ctx context.Context, req *computepb.DeleteNetworkRequest,
opts ...gax.CallOption) (Operation, error)
Insert(ctx context.Context, req *computepb.InsertNetworkRequest,
opts ...gax.CallOption) (Operation, error)
}
type subnetworksAPI interface {
Close() error
Delete(ctx context.Context, req *computepb.DeleteSubnetworkRequest,
opts ...gax.CallOption) (Operation, error)
Insert(ctx context.Context, req *computepb.InsertSubnetworkRequest,
opts ...gax.CallOption) (Operation, error)
}
type instanceTemplateAPI interface {
Close() error
Delete(ctx context.Context, req *computepb.DeleteInstanceTemplateRequest,
opts ...gax.CallOption) (Operation, error)
Insert(ctx context.Context, req *computepb.InsertInstanceTemplateRequest,
opts ...gax.CallOption) (Operation, error)
}
type instanceGroupManagersAPI interface {
Close() error
Delete(ctx context.Context, req *computepb.DeleteInstanceGroupManagerRequest,
opts ...gax.CallOption) (Operation, error)
Insert(ctx context.Context, req *computepb.InsertInstanceGroupManagerRequest,
opts ...gax.CallOption) (Operation, error)
ListManagedInstances(ctx context.Context, req *computepb.ListManagedInstancesInstanceGroupManagersRequest,
opts ...gax.CallOption) ManagedInstanceIterator
}
type iamAPI interface {
Close() error
CreateServiceAccount(ctx context.Context, req *adminpb.CreateServiceAccountRequest, opts ...gax.CallOption) (*adminpb.ServiceAccount, error)
CreateServiceAccountKey(ctx context.Context, req *adminpb.CreateServiceAccountKeyRequest, opts ...gax.CallOption) (*adminpb.ServiceAccountKey, error)
DeleteServiceAccount(ctx context.Context, req *adminpb.DeleteServiceAccountRequest, opts ...gax.CallOption) error
}
type projectsAPI interface {
Close() error
GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error)
SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error)
}
type Operation interface {
Proto() *computepb.Operation
}
type ManagedInstanceIterator interface {
Next() (*computepb.ManagedInstance, error)
}
type InstanceIterator interface {
Next() (*computepb.Instance, error)
}

413
cli/gcp/client/api_test.go Normal file
View file

@ -0,0 +1,413 @@
package client
import (
"context"
"time"
"github.com/googleapis/gax-go/v2"
"google.golang.org/api/iterator"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
adminpb "google.golang.org/genproto/googleapis/iam/admin/v1"
iampb "google.golang.org/genproto/googleapis/iam/v1"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
type stubOperation struct {
*computepb.Operation
}
func (o *stubOperation) Proto() *computepb.Operation {
return o.Operation
}
type stubInstanceAPI struct {
listIterator *stubInstanceIterator
}
func (a stubInstanceAPI) Close() error {
return nil
}
func (a stubInstanceAPI) List(ctx context.Context, req *computepb.ListInstancesRequest,
opts ...gax.CallOption,
) InstanceIterator {
return a.listIterator
}
type stubInstanceIterator struct {
instances []*computepb.Instance
nextErr error
internalCounter int
}
func (i *stubInstanceIterator) Next() (*computepb.Instance, error) {
if i.nextErr != nil {
return nil, i.nextErr
}
if i.internalCounter >= len(i.instances) {
i.internalCounter = 0
return nil, iterator.Done
}
resp := i.instances[i.internalCounter]
i.internalCounter++
return resp, nil
}
type stubOperationZoneAPI struct {
waitErr error
}
func (a stubOperationZoneAPI) Close() error {
return nil
}
func (a stubOperationZoneAPI) Wait(ctx context.Context, req *computepb.WaitZoneOperationRequest,
opts ...gax.CallOption,
) (*computepb.Operation, error) {
if a.waitErr != nil {
return nil, a.waitErr
}
return &computepb.Operation{
Status: computepb.Operation_DONE.Enum(),
}, nil
}
type stubOperationRegionAPI struct {
waitErr error
}
func (a stubOperationRegionAPI) Close() error {
return nil
}
func (a stubOperationRegionAPI) Wait(ctx context.Context, req *computepb.WaitRegionOperationRequest,
opts ...gax.CallOption,
) (*computepb.Operation, error) {
if a.waitErr != nil {
return nil, a.waitErr
}
return &computepb.Operation{
Status: computepb.Operation_DONE.Enum(),
}, nil
}
type stubOperationGlobalAPI struct {
waitErr error
}
func (a stubOperationGlobalAPI) Close() error {
return nil
}
func (a stubOperationGlobalAPI) Wait(ctx context.Context, req *computepb.WaitGlobalOperationRequest,
opts ...gax.CallOption,
) (*computepb.Operation, error) {
if a.waitErr != nil {
return nil, a.waitErr
}
return &computepb.Operation{
Status: computepb.Operation_DONE.Enum(),
}, nil
}
type stubFirewallsAPI struct {
deleteErr error
insertErr error
}
func (a stubFirewallsAPI) Close() error {
return nil
}
func (a stubFirewallsAPI) Delete(ctx context.Context, req *computepb.DeleteFirewallRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.deleteErr != nil {
return nil, a.deleteErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
},
}, nil
}
func (a stubFirewallsAPI) Insert(ctx context.Context, req *computepb.InsertFirewallRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.insertErr != nil {
return nil, a.insertErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
},
}, nil
}
type stubNetworksAPI struct {
insertErr error
deleteErr error
}
func (a stubNetworksAPI) Close() error {
return nil
}
func (a stubNetworksAPI) Insert(ctx context.Context, req *computepb.InsertNetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.insertErr != nil {
return nil, a.insertErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
},
}, nil
}
func (a stubNetworksAPI) Delete(ctx context.Context, req *computepb.DeleteNetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.deleteErr != nil {
return nil, a.deleteErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
},
}, nil
}
type stubSubnetworksAPI struct {
insertErr error
deleteErr error
}
func (a stubSubnetworksAPI) Close() error {
return nil
}
func (a stubSubnetworksAPI) Insert(ctx context.Context, req *computepb.InsertSubnetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.insertErr != nil {
return nil, a.insertErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
Region: proto.String("region"),
},
}, nil
}
func (a stubSubnetworksAPI) Delete(ctx context.Context, req *computepb.DeleteSubnetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.deleteErr != nil {
return nil, a.deleteErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
Region: proto.String("region"),
},
}, nil
}
type stubInstanceTemplateAPI struct {
deleteErr error
insertErr error
}
func (a stubInstanceTemplateAPI) Close() error {
return nil
}
func (a stubInstanceTemplateAPI) Delete(ctx context.Context, req *computepb.DeleteInstanceTemplateRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.deleteErr != nil {
return nil, a.deleteErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
},
}, nil
}
func (a stubInstanceTemplateAPI) Insert(ctx context.Context, req *computepb.InsertInstanceTemplateRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.insertErr != nil {
return nil, a.insertErr
}
return &stubOperation{
&computepb.Operation{
Name: proto.String("name"),
},
}, nil
}
type stubInstanceGroupManagersAPI struct {
listIterator *stubManagedInstanceIterator
deleteErr error
insertErr error
}
func (a stubInstanceGroupManagersAPI) Close() error {
return nil
}
func (a stubInstanceGroupManagersAPI) Delete(ctx context.Context, req *computepb.DeleteInstanceGroupManagerRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.deleteErr != nil {
return nil, a.deleteErr
}
return &stubOperation{
&computepb.Operation{
Zone: proto.String("zone"),
Name: proto.String("name"),
},
}, nil
}
func (a stubInstanceGroupManagersAPI) Insert(ctx context.Context, req *computepb.InsertInstanceGroupManagerRequest,
opts ...gax.CallOption,
) (Operation, error) {
if a.insertErr != nil {
return nil, a.insertErr
}
return &stubOperation{
&computepb.Operation{
Zone: proto.String("zone"),
Name: proto.String("name"),
},
}, nil
}
func (a stubInstanceGroupManagersAPI) ListManagedInstances(ctx context.Context, req *computepb.ListManagedInstancesInstanceGroupManagersRequest,
opts ...gax.CallOption,
) ManagedInstanceIterator {
return a.listIterator
}
type stubIAMAPI struct {
serviceAccountKeyData []byte
createErr error
createKeyErr error
deleteServiceAccountErr error
}
func (a stubIAMAPI) Close() error {
return nil
}
func (a stubIAMAPI) CreateServiceAccount(ctx context.Context, req *adminpb.CreateServiceAccountRequest, opts ...gax.CallOption) (*adminpb.ServiceAccount, error) {
if a.createErr != nil {
return nil, a.createErr
}
return &adminpb.ServiceAccount{
Name: "name",
ProjectId: "project-id",
UniqueId: "unique-id",
Email: "email",
DisplayName: "display-name",
Description: "description",
Oauth2ClientId: "oauth2-client-id",
Disabled: false,
}, nil
}
func (a stubIAMAPI) CreateServiceAccountKey(ctx context.Context, req *adminpb.CreateServiceAccountKeyRequest, opts ...gax.CallOption) (*adminpb.ServiceAccountKey, error) {
if a.createKeyErr != nil {
return nil, a.createKeyErr
}
return &adminpb.ServiceAccountKey{
Name: "name",
PrivateKeyType: adminpb.ServiceAccountPrivateKeyType_TYPE_GOOGLE_CREDENTIALS_FILE,
KeyAlgorithm: adminpb.ServiceAccountKeyAlgorithm_KEY_ALG_RSA_2048,
PrivateKeyData: a.serviceAccountKeyData,
PublicKeyData: []byte("public-key-data"),
ValidAfterTime: timestamppb.New(time.Time{}),
ValidBeforeTime: timestamppb.New(time.Time{}),
KeyOrigin: adminpb.ServiceAccountKeyOrigin_GOOGLE_PROVIDED,
KeyType: adminpb.ListServiceAccountKeysRequest_USER_MANAGED,
}, nil
}
func (a stubIAMAPI) DeleteServiceAccount(ctx context.Context, req *adminpb.DeleteServiceAccountRequest, opts ...gax.CallOption) error {
return a.deleteServiceAccountErr
}
type stubProjectsAPI struct {
getPolicyErr error
setPolicyErr error
}
func (a stubProjectsAPI) Close() error {
return nil
}
func (a stubProjectsAPI) GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) {
if a.getPolicyErr != nil {
return nil, a.getPolicyErr
}
return &iampb.Policy{
Version: 3,
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{
"member",
},
},
},
Etag: []byte("etag"),
}, nil
}
func (a stubProjectsAPI) SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) {
if a.setPolicyErr != nil {
return nil, a.setPolicyErr
}
return &iampb.Policy{
Version: 3,
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{
"member",
},
},
},
Etag: []byte("etag"),
}, nil
}
type stubManagedInstanceIterator struct {
instances []*computepb.ManagedInstance
nextErr error
internalCounter int
}
func (i *stubManagedInstanceIterator) Next() (*computepb.ManagedInstance, error) {
if i.nextErr != nil {
return nil, i.nextErr
}
if i.internalCounter >= len(i.instances) {
i.internalCounter = 0
return nil, iterator.Done
}
resp := i.instances[i.internalCounter]
i.internalCounter++
return resp, nil
}

384
cli/gcp/client/client.go Normal file
View file

@ -0,0 +1,384 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"strings"
compute "cloud.google.com/go/compute/apiv1"
admin "cloud.google.com/go/iam/admin/apiv1"
resourcemanager "cloud.google.com/go/resourcemanager/apiv3"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/state"
)
// Client is a client for the Google Compute Engine.
type Client struct {
instanceAPI
operationRegionAPI
operationZoneAPI
operationGlobalAPI
networksAPI
subnetworksAPI
firewallsAPI
instanceTemplateAPI
instanceGroupManagersAPI
iamAPI
projectsAPI
nodes gcp.Instances
coordinators gcp.Instances
nodesInstanceGroup string
coordinatorInstanceGroup string
coordinatorTemplate string
nodeTemplate string
network string
subnetwork string
secondarySubnetworkRange string
firewalls []string
name string
project string
uid string
zone string
region string
serviceAccount string
}
// NewFromDefault creates an uninitialized client.
func NewFromDefault(ctx context.Context) (*Client, error) {
var closers []closer
insAPI, err := compute.NewInstancesRESTClient(ctx)
if err != nil {
return nil, err
}
closers = append(closers, insAPI)
opZoneAPI, err := compute.NewZoneOperationsRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, opZoneAPI)
opRegionAPI, err := compute.NewRegionOperationsRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, opRegionAPI)
opGlobalAPI, err := compute.NewGlobalOperationsRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, opGlobalAPI)
netAPI, err := compute.NewNetworksRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, netAPI)
subnetAPI, err := compute.NewSubnetworksRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, subnetAPI)
fwAPI, err := compute.NewFirewallsRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, fwAPI)
templAPI, err := compute.NewInstanceTemplatesRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, templAPI)
groupAPI, err := compute.NewInstanceGroupManagersRESTClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, groupAPI)
iamAPI, err := admin.NewIamClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
closers = append(closers, iamAPI)
projectsAPI, err := resourcemanager.NewProjectsClient(ctx)
if err != nil {
_ = closeAll(closers)
return nil, err
}
return &Client{
instanceAPI: &instanceClient{insAPI},
operationRegionAPI: opRegionAPI,
operationZoneAPI: opZoneAPI,
operationGlobalAPI: opGlobalAPI,
networksAPI: &networksClient{netAPI},
subnetworksAPI: &subnetworksClient{subnetAPI},
firewallsAPI: &firewallsClient{fwAPI},
instanceTemplateAPI: &instanceTemplateClient{templAPI},
instanceGroupManagersAPI: &instanceGroupManagersClient{groupAPI},
iamAPI: &iamClient{iamAPI},
projectsAPI: &projectsClient{projectsAPI},
nodes: make(gcp.Instances),
coordinators: make(gcp.Instances),
}, nil
}
// NewInitialized creates an initialized client.
func NewInitialized(ctx context.Context, project, zone, region, name string) (*Client, error) {
client, err := NewFromDefault(ctx)
if err != nil {
return nil, err
}
err = client.init(project, zone, region, name)
return client, err
}
// Close closes the client's connection.
func (c *Client) Close() error {
closers := []closer{
c.instanceAPI,
c.operationZoneAPI,
c.operationGlobalAPI,
c.networksAPI,
c.firewallsAPI,
c.instanceTemplateAPI,
c.instanceGroupManagersAPI,
}
return closeAll(closers)
}
// init initializes the client.
func (c *Client) init(project, zone, region, name string) error {
c.project = project
c.zone = zone
c.name = name
c.region = region
uid, err := c.generateUID()
if err != nil {
return err
}
c.uid = uid
return nil
}
// GetState returns the state of the client as ConstellationState.
func (c *Client) GetState() (state.ConstellationState, error) {
var stat state.ConstellationState
stat.CloudProvider = cloudprovider.GCP.String()
if len(c.nodes) == 0 {
return state.ConstellationState{}, errors.New("client has no nodes")
}
stat.GCPNodes = c.nodes
if len(c.coordinators) == 0 {
return state.ConstellationState{}, errors.New("client has no coordinators")
}
stat.GCPCoordinators = c.coordinators
if c.nodesInstanceGroup == "" {
return state.ConstellationState{}, errors.New("client has no nodeInstanceGroup")
}
stat.GCPNodeInstanceGroup = c.nodesInstanceGroup
if c.coordinatorInstanceGroup == "" {
return state.ConstellationState{}, errors.New("client has no coordinatorInstanceGroup")
}
stat.GCPCoordinatorInstanceGroup = c.coordinatorInstanceGroup
if c.project == "" {
return state.ConstellationState{}, errors.New("client has no project")
}
stat.GCPProject = c.project
if c.zone == "" {
return state.ConstellationState{}, errors.New("client has no zone")
}
stat.GCPZone = c.zone
if c.region == "" {
return state.ConstellationState{}, errors.New("client has no region")
}
stat.GCPRegion = c.region
if c.name == "" {
return state.ConstellationState{}, errors.New("client has no name")
}
stat.Name = c.name
if c.uid == "" {
return state.ConstellationState{}, errors.New("client has no uid")
}
stat.UID = c.uid
if len(c.firewalls) == 0 {
return state.ConstellationState{}, errors.New("client has no firewalls")
}
stat.GCPFirewalls = c.firewalls
if c.network == "" {
return state.ConstellationState{}, errors.New("client has no network")
}
stat.GCPNetwork = c.network
if c.subnetwork == "" {
return state.ConstellationState{}, errors.New("client has no subnetwork")
}
stat.GCPSubnetwork = c.subnetwork
if c.nodeTemplate == "" {
return state.ConstellationState{}, errors.New("client has no node instance template")
}
stat.GCPNodeInstanceTemplate = c.nodeTemplate
if c.coordinatorTemplate == "" {
return state.ConstellationState{}, errors.New("client has no coordinator instance template")
}
stat.GCPCoordinatorInstanceTemplate = c.coordinatorTemplate
// service account does not have to be set at all times
stat.GCPServiceAccount = c.serviceAccount
return stat, nil
}
// SetState sets the state of the client to the handed ConstellationState.
func (c *Client) SetState(stat state.ConstellationState) error {
if stat.CloudProvider != cloudprovider.GCP.String() {
return errors.New("state is not gcp state")
}
if len(stat.GCPNodes) == 0 {
return errors.New("state has no nodes")
}
c.nodes = stat.GCPNodes
if len(stat.GCPCoordinators) == 0 {
return errors.New("state has no coordinator")
}
c.coordinators = stat.GCPCoordinators
if stat.GCPNodeInstanceGroup == "" {
return errors.New("state has no nodeInstanceGroup")
}
c.nodesInstanceGroup = stat.GCPNodeInstanceGroup
if stat.GCPCoordinatorInstanceGroup == "" {
return errors.New("state has no coordinatorInstanceGroup")
}
c.coordinatorInstanceGroup = stat.GCPCoordinatorInstanceGroup
if stat.GCPProject == "" {
return errors.New("state has no project")
}
c.project = stat.GCPProject
if stat.GCPZone == "" {
return errors.New("state has no zone")
}
c.zone = stat.GCPZone
if stat.GCPRegion == "" {
return errors.New("state has no region")
}
c.region = stat.GCPRegion
if stat.Name == "" {
return errors.New("state has no name")
}
c.name = stat.Name
if stat.UID == "" {
return errors.New("state has no uid")
}
c.uid = stat.UID
if len(stat.GCPFirewalls) == 0 {
return errors.New("state has no firewalls")
}
c.firewalls = stat.GCPFirewalls
if stat.GCPNetwork == "" {
return errors.New("state has no network")
}
c.network = stat.GCPNetwork
if stat.GCPSubnetwork == "" {
return errors.New("state has no subnetwork")
}
c.subnetwork = stat.GCPSubnetwork
if stat.GCPNodeInstanceTemplate == "" {
return errors.New("state has no node instance template")
}
c.nodeTemplate = stat.GCPNodeInstanceTemplate
if stat.GCPCoordinatorInstanceTemplate == "" {
return errors.New("state has no coordinator instance template")
}
c.coordinatorTemplate = stat.GCPCoordinatorInstanceTemplate
// service account does not have to be set at all times
c.serviceAccount = stat.GCPServiceAccount
return nil
}
func (c *Client) generateUID() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyz0123456789")
const uidLen = 5
uid := make([]byte, uidLen)
for i := 0; i < uidLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
uid[i] = letters[n.Int64()]
}
return string(uid), nil
}
type closer interface {
Close() error
}
// closeAll closes all closers, even if an error occurs.
//
// Errors are collected and a composed error is returned.
func closeAll(closers []closer) error {
// Since this function is intended to be deferred, it will always call all
// close operations, even if a previous operation failed. The if multiple
// errors occur, the returned error will be composed of the error messages
// of those errors.
var errs []error
for _, closer := range closers {
errs = append(errs, closer.Close())
}
return composeErr(errs)
}
// composeErr composes a list of errors to a single error.
//
// If all errs are nil, the returned error is also nil.
func composeErr(errs []error) error {
var composed strings.Builder
for i, err := range errs {
if err != nil {
composed.WriteString(fmt.Sprintf("%d: %s", i, err.Error()))
}
}
if composed.Len() != 0 {
return errors.New(composed.String())
}
return nil
}

View file

@ -0,0 +1,684 @@
package client
import (
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetGetState(t *testing.T) {
testCases := map[string]struct {
state state.ConstellationState
errExpected bool
}{
"valid state": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
GCPServiceAccount: "service-account",
},
},
"missing nodes": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing coordinator": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing node group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing coordinator group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing project id": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing zone": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing region": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing name": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
UID: "uid",
GCPRegion: "region-id",
GCPNetwork: "net-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing uid": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
Name: "name",
GCPRegion: "region-id",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing firewalls": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing network": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing external network": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing subnetwork": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing external subnetwork": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing node template": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPCoordinatorInstanceTemplate: "temp-id",
},
errExpected: true,
},
"missing coordinator template": {
state: state.ConstellationState{
CloudProvider: cloudprovider.GCP.String(),
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
},
errExpected: true,
},
}
t.Run("SetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{}
if tc.errExpected {
assert.Error(client.SetState(tc.state))
} else {
assert.NoError(client.SetState(tc.state))
assert.Equal(tc.state.GCPNodes, client.nodes)
assert.Equal(tc.state.GCPCoordinators, client.coordinators)
assert.Equal(tc.state.GCPNodeInstanceGroup, client.nodesInstanceGroup)
assert.Equal(tc.state.GCPCoordinatorInstanceGroup, client.coordinatorInstanceGroup)
assert.Equal(tc.state.GCPProject, client.project)
assert.Equal(tc.state.GCPZone, client.zone)
assert.Equal(tc.state.Name, client.name)
assert.Equal(tc.state.UID, client.uid)
assert.Equal(tc.state.GCPNetwork, client.network)
assert.Equal(tc.state.GCPFirewalls, client.firewalls)
assert.Equal(tc.state.GCPCoordinatorInstanceTemplate, client.coordinatorTemplate)
assert.Equal(tc.state.GCPNodeInstanceTemplate, client.nodeTemplate)
assert.Equal(tc.state.GCPServiceAccount, client.serviceAccount)
}
})
}
})
t.Run("GetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{
nodes: tc.state.GCPNodes,
coordinators: tc.state.GCPCoordinators,
nodesInstanceGroup: tc.state.GCPNodeInstanceGroup,
coordinatorInstanceGroup: tc.state.GCPCoordinatorInstanceGroup,
project: tc.state.GCPProject,
zone: tc.state.GCPZone,
region: tc.state.GCPRegion,
name: tc.state.Name,
uid: tc.state.UID,
network: tc.state.GCPNetwork,
subnetwork: tc.state.GCPSubnetwork,
firewalls: tc.state.GCPFirewalls,
nodeTemplate: tc.state.GCPNodeInstanceTemplate,
coordinatorTemplate: tc.state.GCPCoordinatorInstanceTemplate,
serviceAccount: tc.state.GCPServiceAccount,
}
if tc.errExpected {
_, err := client.GetState()
assert.Error(err)
} else {
stat, err := client.GetState()
assert.NoError(err)
assert.Equal(tc.state, stat)
}
})
}
})
}
func TestSetStateCloudProvider(t *testing.T) {
assert := assert.New(t)
client := Client{}
stateMissingCloudProvider := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
}
assert.Error(client.SetState(stateMissingCloudProvider))
stateIncorrectCloudProvider := state.ConstellationState{
CloudProvider: "incorrect",
GCPNodes: gcp.Instances{
"id-1": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
GCPCoordinators: gcp.Instances{
"id-1": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
GCPNodeInstanceGroup: "group-id",
GCPCoordinatorInstanceGroup: "group-id",
GCPProject: "proj-id",
GCPZone: "zone-id",
GCPRegion: "region-id",
Name: "name",
UID: "uid",
GCPNetwork: "net-id",
GCPSubnetwork: "subnet-id",
GCPFirewalls: []string{"fw-1", "fw-2"},
GCPNodeInstanceTemplate: "temp-id",
GCPCoordinatorInstanceTemplate: "temp-id",
}
assert.Error(client.SetState(stateIncorrectCloudProvider))
}
func TestInit(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{}
require.NoError(client.init("project", "zone", "region", "name"))
assert.Equal("project", client.project)
assert.Equal("zone", client.zone)
assert.Equal("region", client.region)
assert.Equal("name", client.name)
}
func TestCloseAll(t *testing.T) {
assert := assert.New(t)
closers := []closer{&someCloser{}, &someCloser{}, &someCloser{}}
assert.NoError(closeAll(closers))
for _, c := range closers {
assert.True(c.(*someCloser).closed)
}
someErr := errors.New("failed")
closers = []closer{&someCloser{}, &someCloser{closeErr: someErr}, &someCloser{}}
assert.Error(closeAll(closers))
for _, c := range closers {
assert.True(c.(*someCloser).closed)
}
}
type someCloser struct {
closeErr error
closed bool
}
func (c *someCloser) Close() error {
c.closed = true
return c.closeErr
}
func TestComposedErr(t *testing.T) {
assert := assert.New(t)
noErrs := []error{nil, nil, nil}
assert.NoError(composeErr(noErrs))
someErrs := []error{
errors.New("failed 4"),
errors.New("failed 7"),
nil,
nil,
errors.New("failed 9"),
}
err := composeErr(someErrs)
assert.Error(err)
assert.Contains(err.Error(), "4")
assert.Contains(err.Error(), "7")
assert.Contains(err.Error(), "9")
}

View file

@ -0,0 +1,169 @@
package client
import (
"context"
compute "cloud.google.com/go/compute/apiv1"
admin "cloud.google.com/go/iam/admin/apiv1"
resourcemanager "cloud.google.com/go/resourcemanager/apiv3"
"github.com/googleapis/gax-go/v2"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
adminpb "google.golang.org/genproto/googleapis/iam/admin/v1"
iampb "google.golang.org/genproto/googleapis/iam/v1"
)
type instanceClient struct {
*compute.InstancesClient
}
func (c *instanceClient) Close() error {
return c.InstancesClient.Close()
}
func (c *instanceClient) List(ctx context.Context, req *computepb.ListInstancesRequest,
opts ...gax.CallOption,
) InstanceIterator {
return c.InstancesClient.List(ctx, req)
}
type firewallsClient struct {
*compute.FirewallsClient
}
func (c *firewallsClient) Close() error {
return c.FirewallsClient.Close()
}
func (c *firewallsClient) Delete(ctx context.Context, req *computepb.DeleteFirewallRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.FirewallsClient.Delete(ctx, req)
}
func (c *firewallsClient) Insert(ctx context.Context, req *computepb.InsertFirewallRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.FirewallsClient.Insert(ctx, req)
}
type networksClient struct {
*compute.NetworksClient
}
func (c *networksClient) Close() error {
return c.NetworksClient.Close()
}
func (c *networksClient) Insert(ctx context.Context, req *computepb.InsertNetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.NetworksClient.Insert(ctx, req)
}
func (c *networksClient) Delete(ctx context.Context, req *computepb.DeleteNetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.NetworksClient.Delete(ctx, req)
}
type subnetworksClient struct {
*compute.SubnetworksClient
}
func (c *subnetworksClient) Close() error {
return c.SubnetworksClient.Close()
}
func (c *subnetworksClient) Insert(ctx context.Context, req *computepb.InsertSubnetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.SubnetworksClient.Insert(ctx, req)
}
func (c *subnetworksClient) Delete(ctx context.Context, req *computepb.DeleteSubnetworkRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.SubnetworksClient.Delete(ctx, req)
}
type instanceTemplateClient struct {
*compute.InstanceTemplatesClient
}
func (c *instanceTemplateClient) Close() error {
return c.InstanceTemplatesClient.Close()
}
func (c *instanceTemplateClient) Delete(ctx context.Context, req *computepb.DeleteInstanceTemplateRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.InstanceTemplatesClient.Delete(ctx, req)
}
func (c *instanceTemplateClient) Insert(ctx context.Context, req *computepb.InsertInstanceTemplateRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.InstanceTemplatesClient.Insert(ctx, req)
}
type instanceGroupManagersClient struct {
*compute.InstanceGroupManagersClient
}
func (c *instanceGroupManagersClient) Close() error {
return c.InstanceGroupManagersClient.Close()
}
func (c *instanceGroupManagersClient) Delete(ctx context.Context, req *computepb.DeleteInstanceGroupManagerRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.InstanceGroupManagersClient.Delete(ctx, req)
}
func (c *instanceGroupManagersClient) Insert(ctx context.Context, req *computepb.InsertInstanceGroupManagerRequest,
opts ...gax.CallOption,
) (Operation, error) {
return c.InstanceGroupManagersClient.Insert(ctx, req)
}
func (c *instanceGroupManagersClient) ListManagedInstances(ctx context.Context, req *computepb.ListManagedInstancesInstanceGroupManagersRequest,
opts ...gax.CallOption,
) ManagedInstanceIterator {
return c.InstanceGroupManagersClient.ListManagedInstances(ctx, req)
}
type iamClient struct {
*admin.IamClient
}
func (c *iamClient) Close() error {
return c.IamClient.Close()
}
func (c *iamClient) CreateServiceAccount(ctx context.Context, req *adminpb.CreateServiceAccountRequest, opts ...gax.CallOption) (*adminpb.ServiceAccount, error) {
return c.IamClient.CreateServiceAccount(ctx, req)
}
func (c *iamClient) CreateServiceAccountKey(ctx context.Context, req *adminpb.CreateServiceAccountKeyRequest, opts ...gax.CallOption) (*adminpb.ServiceAccountKey, error) {
return c.IamClient.CreateServiceAccountKey(ctx, req)
}
func (c *iamClient) DeleteServiceAccount(ctx context.Context, req *adminpb.DeleteServiceAccountRequest, opts ...gax.CallOption) error {
return c.IamClient.DeleteServiceAccount(ctx, req)
}
type projectsClient struct {
*resourcemanager.ProjectsClient
}
func (c *projectsClient) Close() error {
return c.ProjectsClient.Close()
}
func (c *projectsClient) GetIamPolicy(ctx context.Context, req *iampb.GetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) {
return c.ProjectsClient.GetIamPolicy(ctx, req)
}
func (c *projectsClient) SetIamPolicy(ctx context.Context, req *iampb.SetIamPolicyRequest, opts ...gax.CallOption) (*iampb.Policy, error) {
return c.ProjectsClient.SetIamPolicy(ctx, req)
}

393
cli/gcp/client/instances.go Normal file
View file

@ -0,0 +1,393 @@
package client
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/edgelesssys/constellation/cli/gcp"
"google.golang.org/api/iterator"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
"google.golang.org/protobuf/proto"
)
// CreateInstances creates instances (virtual machines) on Google Compute Engine.
//
// A separate managed instance group is created for coordinators and nodes, the function
// waits until the instances are up and stores the public and private IPs of the instances
// in the client. If the client's network must be set before instances can be created.
func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error {
if c.network == "" {
return errors.New("client has no network")
}
ops := []Operation{}
nodeTemplateInput := insertInstanceTemplateInput{
Name: c.name + "-worker-" + c.uid,
Network: c.network,
SecondarySubnetworkRangeName: c.secondarySubnetworkRange,
Subnetwork: c.subnetwork,
ImageId: input.ImageId,
InstanceType: input.InstanceType,
KubeEnv: input.KubeEnv,
Project: c.project,
Zone: c.zone,
Region: c.region,
UID: c.uid,
DisableCVM: input.DisableCVM,
}
op, err := c.insertInstanceTemplate(ctx, nodeTemplateInput)
if err != nil {
return fmt.Errorf("inserting instanceTemplate failed: %w", err)
}
ops = append(ops, op)
c.nodeTemplate = nodeTemplateInput.Name
coordinatorTemplateInput := insertInstanceTemplateInput{
Name: c.name + "-control-plane-" + c.uid,
Network: c.network,
Subnetwork: c.subnetwork,
ImageId: input.ImageId,
InstanceType: input.InstanceType,
KubeEnv: input.KubeEnv,
Project: c.project,
Zone: c.zone,
Region: c.region,
UID: c.uid,
DisableCVM: input.DisableCVM,
}
op, err = c.insertInstanceTemplate(ctx, coordinatorTemplateInput)
if err != nil {
return fmt.Errorf("inserting instanceTemplate failed: %w", err)
}
ops = append(ops, op)
c.coordinatorTemplate = coordinatorTemplateInput.Name
if err := c.waitForOperations(ctx, ops); err != nil {
return err
}
ops = []Operation{}
nodeGroupInput := instanceGroupManagerInput{
Count: input.Count - 1,
Name: strings.Join([]string{c.name, "worker", c.uid}, "-"),
Template: c.nodeTemplate,
UID: c.uid,
Project: c.project,
Zone: c.zone,
}
op, err = c.insertInstanceGroupManger(ctx, nodeGroupInput)
if err != nil {
return fmt.Errorf("inserting instanceGroupManager failed: %w", err)
}
ops = append(ops, op)
c.nodesInstanceGroup = nodeGroupInput.Name
coordinatorGroupInput := instanceGroupManagerInput{
Count: 1,
Name: strings.Join([]string{c.name, "control-plane", c.uid}, "-"),
Template: c.coordinatorTemplate,
UID: c.uid,
Project: c.project,
Zone: c.zone,
}
op, err = c.insertInstanceGroupManger(ctx, coordinatorGroupInput)
if err != nil {
return fmt.Errorf("inserting instanceGroupManager failed: %w", err)
}
ops = append(ops, op)
c.coordinatorInstanceGroup = coordinatorGroupInput.Name
if err := c.waitForOperations(ctx, ops); err != nil {
return err
}
if err := c.waitForInstanceGroupScaling(ctx, c.nodesInstanceGroup); err != nil {
return fmt.Errorf("waiting for instanceGroupScaling failed: %w", err)
}
if err := c.waitForInstanceGroupScaling(ctx, c.coordinatorInstanceGroup); err != nil {
return fmt.Errorf("waiting for instanceGroupScaling failed: %w", err)
}
if err := c.getInstanceIPs(ctx, c.nodesInstanceGroup, c.nodes); err != nil {
return fmt.Errorf("failed to get instanceIPs: %w", err)
}
if err := c.getInstanceIPs(ctx, c.coordinatorInstanceGroup, c.coordinators); err != nil {
return fmt.Errorf("failed to get instanceIPs: %w", err)
}
return nil
}
// TerminateInstances terminates the clients instances.
func (c *Client) TerminateInstances(ctx context.Context) error {
ops := []Operation{}
if c.nodesInstanceGroup != "" {
op, err := c.deleteInstanceGroupManager(ctx, c.nodesInstanceGroup)
if err != nil {
return fmt.Errorf("deleting instanceGroupManager '%s' failed: %w", c.nodesInstanceGroup, err)
}
ops = append(ops, op)
c.nodesInstanceGroup = ""
c.nodes = make(gcp.Instances)
}
if c.coordinatorInstanceGroup != "" {
op, err := c.deleteInstanceGroupManager(ctx, c.coordinatorInstanceGroup)
if err != nil {
return fmt.Errorf("deleting instanceGroupManager '%s' failed: %w", c.coordinatorInstanceGroup, err)
}
ops = append(ops, op)
c.coordinatorInstanceGroup = ""
c.coordinators = make(gcp.Instances)
}
if err := c.waitForOperations(ctx, ops); err != nil {
return err
}
ops = []Operation{}
if c.nodeTemplate != "" {
op, err := c.deleteInstanceTemplate(ctx, c.nodeTemplate)
if err != nil {
return fmt.Errorf("deleting instanceTemplate failed: %w", err)
}
ops = append(ops, op)
c.nodeTemplate = ""
}
if c.coordinatorTemplate != "" {
op, err := c.deleteInstanceTemplate(ctx, c.coordinatorTemplate)
if err != nil {
return fmt.Errorf("deleting instanceTemplate failed: %w", err)
}
ops = append(ops, op)
c.coordinatorTemplate = ""
}
return c.waitForOperations(ctx, ops)
}
func (c *Client) insertInstanceTemplate(ctx context.Context, input insertInstanceTemplateInput) (Operation, error) {
req := input.insertInstanceTemplateRequest()
return c.instanceTemplateAPI.Insert(ctx, req)
}
func (c *Client) deleteInstanceTemplate(ctx context.Context, name string) (Operation, error) {
req := &computepb.DeleteInstanceTemplateRequest{
InstanceTemplate: name,
Project: c.project,
}
return c.instanceTemplateAPI.Delete(ctx, req)
}
func (c *Client) insertInstanceGroupManger(ctx context.Context, input instanceGroupManagerInput) (Operation, error) {
req := input.InsertInstanceGroupManagerRequest()
return c.instanceGroupManagersAPI.Insert(ctx, &req)
}
func (c *Client) deleteInstanceGroupManager(ctx context.Context, instanceGroupManagerName string) (Operation, error) {
req := &computepb.DeleteInstanceGroupManagerRequest{
InstanceGroupManager: instanceGroupManagerName,
Project: c.project,
Zone: c.zone,
}
return c.instanceGroupManagersAPI.Delete(ctx, req)
}
func (c *Client) waitForInstanceGroupScaling(ctx context.Context, groupId string) error {
for {
if err := ctx.Err(); err != nil {
return err
}
listReq := &computepb.ListManagedInstancesInstanceGroupManagersRequest{
InstanceGroupManager: groupId,
Project: c.project,
Zone: c.zone,
}
it := c.instanceGroupManagersAPI.ListManagedInstances(ctx, listReq)
for {
resp, err := it.Next()
if errors.Is(err, iterator.Done) {
return nil
}
if err != nil {
return err
}
if resp.CurrentAction == nil {
return errors.New("currentAction is nil")
}
if *resp.CurrentAction != computepb.ManagedInstance_NONE.String() {
time.Sleep(5 * time.Second)
break
}
}
}
}
// getInstanceIPs requests the IPs of the client's instances.
func (c *Client) getInstanceIPs(ctx context.Context, groupId string, list gcp.Instances) error {
req := &computepb.ListInstancesRequest{
Filter: proto.String("name=" + groupId + "*"),
Project: c.project,
Zone: c.zone,
}
it := c.instanceAPI.List(ctx, req)
for {
resp, err := it.Next()
if errors.Is(err, iterator.Done) {
return nil
}
if err != nil {
return err
}
if resp.Name == nil {
return errors.New("instance name is nil pointer")
}
if len(resp.NetworkInterfaces) == 0 {
return errors.New("network interface is empty")
}
if resp.NetworkInterfaces[0].NetworkIP == nil {
return errors.New("networkIP is nil")
}
if len(resp.NetworkInterfaces[0].AccessConfigs) == 0 {
return errors.New("access configs is empty")
}
if resp.NetworkInterfaces[0].AccessConfigs[0].NatIP == nil {
return errors.New("natIP is nil")
}
instance := gcp.Instance{
PrivateIP: *resp.NetworkInterfaces[0].NetworkIP,
PublicIP: *resp.NetworkInterfaces[0].AccessConfigs[0].NatIP,
}
list[*resp.Name] = instance
}
}
type instanceGroupManagerInput struct {
Count int
Name string
Template string
Project string
Zone string
UID string
}
func (i *instanceGroupManagerInput) InsertInstanceGroupManagerRequest() computepb.InsertInstanceGroupManagerRequest {
return computepb.InsertInstanceGroupManagerRequest{
InstanceGroupManagerResource: &computepb.InstanceGroupManager{
BaseInstanceName: proto.String(i.Name),
InstanceTemplate: proto.String("projects/" + i.Project + "/global/instanceTemplates/" + i.Template),
Name: proto.String(i.Name),
TargetSize: proto.Int32(int32(i.Count)),
},
Project: i.Project,
Zone: i.Zone,
}
}
// CreateInstancesInput is the input for a CreatInstances operation.
type CreateInstancesInput struct {
Count int
ImageId string
InstanceType string
KubeEnv string
DisableCVM bool
}
type insertInstanceTemplateInput struct {
Name string
Network string
Subnetwork string
SecondarySubnetworkRangeName string
ImageId string
InstanceType string
KubeEnv string
Project string
Zone string
Region string
UID string
DisableCVM bool
}
func (i insertInstanceTemplateInput) insertInstanceTemplateRequest() *computepb.InsertInstanceTemplateRequest {
req := computepb.InsertInstanceTemplateRequest{
InstanceTemplateResource: &computepb.InstanceTemplate{
Description: proto.String("This instance belongs to a Constellation."),
Name: proto.String(i.Name),
Properties: &computepb.InstanceProperties{
ConfidentialInstanceConfig: &computepb.ConfidentialInstanceConfig{
EnableConfidentialCompute: proto.Bool(!i.DisableCVM),
},
Description: proto.String("This instance belongs to a Constellation."),
Disks: []*computepb.AttachedDisk{
{
InitializeParams: &computepb.AttachedDiskInitializeParams{
DiskSizeGb: proto.Int64(10),
SourceImage: proto.String("projects/" + i.Project + "/global/images/" + i.ImageId),
},
AutoDelete: proto.Bool(true),
Boot: proto.Bool(true),
Mode: proto.String(computepb.AttachedDisk_READ_WRITE.String()),
},
},
MachineType: proto.String(i.InstanceType),
Metadata: &computepb.Metadata{
Items: []*computepb.Items{
{
Key: proto.String("kube-env"),
Value: proto.String(i.KubeEnv),
},
{
Key: proto.String("constellation-uid"),
Value: proto.String(i.UID),
},
},
},
NetworkInterfaces: []*computepb.NetworkInterface{
{
Network: proto.String("projects/" + i.Project + "/global/networks/" + i.Network),
Subnetwork: proto.String("regions/" + i.Region + "/subnetworks/" + i.Subnetwork),
AccessConfigs: []*computepb.AccessConfig{
{Type: proto.String(computepb.AccessConfig_ONE_TO_ONE_NAT.String())},
},
},
},
Scheduling: &computepb.Scheduling{
OnHostMaintenance: proto.String(computepb.Scheduling_TERMINATE.String()),
},
ServiceAccounts: []*computepb.ServiceAccount{
{
Scopes: []string{
"https://www.googleapis.com/auth/compute",
"https://www.googleapis.com/auth/servicecontrol",
"https://www.googleapis.com/auth/service.management.readonly",
"https://www.googleapis.com/auth/devstorage.read_only",
"https://www.googleapis.com/auth/logging.write",
"https://www.googleapis.com/auth/monitoring.write",
"https://www.googleapis.com/auth/trace.append",
},
},
},
ShieldedInstanceConfig: &computepb.ShieldedInstanceConfig{
EnableIntegrityMonitoring: proto.Bool(true),
EnableSecureBoot: proto.Bool(true),
EnableVtpm: proto.Bool(true),
},
Tags: &computepb.Tags{
Items: []string{"constellation"},
},
},
},
Project: i.Project,
}
// if there is an secondary IP range defined, we use it as an alias IP range
if i.SecondarySubnetworkRangeName != "" {
req.InstanceTemplateResource.Properties.NetworkInterfaces[0].AliasIpRanges = []*computepb.AliasIpRange{
{
IpCidrRange: proto.String("/24"),
SubnetworkRangeName: proto.String(i.SecondarySubnetworkRangeName),
},
}
}
return &req
}

View file

@ -0,0 +1,262 @@
package client
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/stretchr/testify/assert"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
"google.golang.org/protobuf/proto"
)
func TestCreateInstances(t *testing.T) {
testInstances := []*computepb.Instance{
{
Name: proto.String("instance-name-1"),
NetworkInterfaces: []*computepb.NetworkInterface{
{
AccessConfigs: []*computepb.AccessConfig{
{NatIP: proto.String("public-ip")},
},
NetworkIP: proto.String("private-ip"),
},
},
},
{
Name: proto.String("instance-name-2"),
NetworkInterfaces: []*computepb.NetworkInterface{
{
AccessConfigs: []*computepb.AccessConfig{
{NatIP: proto.String("public-ip")},
},
NetworkIP: proto.String("private-ip"),
},
},
},
}
testManagedInstances := []*computepb.ManagedInstance{
{CurrentAction: proto.String(computepb.ManagedInstance_NONE.String())},
{CurrentAction: proto.String(computepb.ManagedInstance_NONE.String())},
}
testInput := CreateInstancesInput{
Count: 3,
ImageId: "img",
InstanceType: "n2d-standard-2",
KubeEnv: "kube-env",
}
someErr := errors.New("failed")
testCases := map[string]struct {
instanceAPI instanceAPI
operationZoneAPI operationZoneAPI
operationGlobalAPI operationGlobalAPI
instanceTemplateAPI instanceTemplateAPI
instanceGroupManagersAPI instanceGroupManagersAPI
input CreateInstancesInput
network string
errExpected bool
}{
"successful create": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{instances: testManagedInstances}},
network: "network",
input: testInput,
},
"failed no network": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{waitErr: someErr},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{instances: testManagedInstances}},
input: testInput,
errExpected: true,
},
"failed wait zonal op": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{waitErr: someErr},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{instances: testManagedInstances}},
network: "network",
input: testInput,
errExpected: true,
},
"failed wait global op": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{waitErr: someErr},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{instances: testManagedInstances}},
network: "network",
input: testInput,
errExpected: true,
},
"failed insert template": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{insertErr: someErr},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{instances: testManagedInstances}},
input: testInput,
network: "network",
errExpected: true,
},
"failed insert instanceGroupManager": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{insertErr: someErr},
network: "network",
input: testInput,
errExpected: true,
},
"failed instanceGroupManager iterator": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{instances: testInstances}},
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{nextErr: someErr}},
network: "network",
input: testInput,
errExpected: true,
},
"failed instance iterator": {
instanceAPI: stubInstanceAPI{listIterator: &stubInstanceIterator{nextErr: someErr}},
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{listIterator: &stubManagedInstanceIterator{instances: testManagedInstances}},
network: "network",
input: testInput,
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
network: tc.network,
subnetwork: "subnetwork",
secondarySubnetworkRange: "secondary-range",
instanceAPI: tc.instanceAPI,
operationZoneAPI: tc.operationZoneAPI,
operationGlobalAPI: tc.operationGlobalAPI,
instanceTemplateAPI: tc.instanceTemplateAPI,
instanceGroupManagersAPI: tc.instanceGroupManagersAPI,
nodes: make(gcp.Instances),
coordinators: make(gcp.Instances),
}
if tc.errExpected {
assert.Error(client.CreateInstances(ctx, tc.input))
} else {
assert.NoError(client.CreateInstances(ctx, tc.input))
assert.Equal([]string{"public-ip", "public-ip"}, client.nodes.PublicIPs())
assert.Equal([]string{"private-ip", "private-ip"}, client.nodes.PrivateIPs())
assert.Equal([]string{"public-ip", "public-ip"}, client.coordinators.PublicIPs())
assert.Equal([]string{"private-ip", "private-ip"}, client.coordinators.PrivateIPs())
assert.NotNil(client.nodesInstanceGroup)
assert.NotNil(client.coordinatorInstanceGroup)
assert.NotNil(client.coordinatorTemplate)
assert.NotNil(client.nodeTemplate)
}
})
}
}
func TestTerminateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
operationZoneAPI operationZoneAPI
operationGlobalAPI operationGlobalAPI
instanceTemplateAPI instanceTemplateAPI
instanceGroupManagersAPI instanceGroupManagersAPI
missingNodeInstanceGroup bool
errExpected bool
}{
"successful terminate": {
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{},
},
"successful terminate with missing node instance group": {
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{},
missingNodeInstanceGroup: true,
},
"fail delete instanceGroupManager": {
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{deleteErr: someErr},
errExpected: true,
},
"fail delete instanceTemplate": {
operationZoneAPI: stubOperationZoneAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
instanceTemplateAPI: stubInstanceTemplateAPI{deleteErr: someErr},
instanceGroupManagersAPI: stubInstanceGroupManagersAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
operationZoneAPI: tc.operationZoneAPI,
operationGlobalAPI: tc.operationGlobalAPI,
instanceTemplateAPI: tc.instanceTemplateAPI,
instanceGroupManagersAPI: tc.instanceGroupManagersAPI,
nodes: gcp.Instances{"node-id-1": gcp.Instance{}, "node-id-2": gcp.Instance{}},
coordinators: gcp.Instances{"coordinator-id-1": gcp.Instance{}},
firewalls: []string{"firewall-1", "firewall-2"},
network: "network-id-1",
nodesInstanceGroup: "nodeInstanceGroup-id-1",
coordinatorInstanceGroup: "coordinatorInstanceGroup-id-1",
nodeTemplate: "template-id-1",
coordinatorTemplate: "template-id-1",
}
if tc.missingNodeInstanceGroup {
client.nodesInstanceGroup = ""
client.nodes = gcp.Instances{}
}
if tc.errExpected {
assert.Error(client.TerminateInstances(ctx))
} else {
assert.NoError(client.TerminateInstances(ctx))
assert.Nil(client.nodes.PublicIPs())
assert.Nil(client.nodes.PrivateIPs())
assert.Nil(client.coordinators.PublicIPs())
assert.Nil(client.coordinators.PrivateIPs())
assert.Empty(client.nodesInstanceGroup)
assert.Empty(client.coordinatorInstanceGroup)
assert.Empty(client.coordinatorTemplate)
assert.Empty(client.nodeTemplate)
}
})
}
}

202
cli/gcp/client/network.go Normal file
View file

@ -0,0 +1,202 @@
package client
import (
"context"
"errors"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
"google.golang.org/protobuf/proto"
)
// CreateFirewall creates a set of firewall rules for the client's network.
//
// The client must have a VPC network to set firewall rules.
func (c *Client) CreateFirewall(ctx context.Context, input FirewallInput) error {
if c.network == "" {
return errors.New("client has not network")
}
firewallRules := input.Ingress.GCP()
var ops []Operation
for _, rule := range firewallRules {
c.firewalls = append(c.firewalls, rule.GetName())
rule.Network = proto.String("global/networks/" + c.network)
rule.Name = proto.String(rule.GetName() + "-" + c.uid)
req := &computepb.InsertFirewallRequest{
FirewallResource: rule,
Project: c.project,
}
resp, err := c.firewallsAPI.Insert(ctx, req)
if err != nil {
return err
}
if resp.Proto().Name == nil {
return errors.New("operation name is nil")
}
ops = append(ops, resp)
}
return c.waitForOperations(ctx, ops)
}
// TerminateFirewall deletes firewall rules from the client's network.
//
// The client must have a VPC network to set firewall rules.
func (c *Client) TerminateFirewall(ctx context.Context) error {
if len(c.firewalls) == 0 {
return nil
}
var ops []Operation
for _, name := range c.firewalls {
ruleName := name + "-" + c.uid
req := &computepb.DeleteFirewallRequest{
Firewall: ruleName,
Project: c.project,
}
resp, err := c.firewallsAPI.Delete(ctx, req)
if err != nil {
return err
}
if resp.Proto().Name == nil {
return errors.New("operation name is nil")
}
ops = append(ops, resp)
}
if err := c.waitForOperations(ctx, ops); err != nil {
return err
}
c.firewalls = []string{}
return nil
}
// FirewallInput defines firewall rules to be set.
type FirewallInput struct {
Ingress cloudtypes.Firewall
Egress cloudtypes.Firewall
}
// VPCsInput defines the VPC configuration.
type VPCsInput struct {
SubnetCIDR string
SubnetExtCIDR string
}
// CreateVPCs creates all necessary VPC networks.
func (c *Client) CreateVPCs(ctx context.Context, input VPCsInput) error {
c.network = c.name + "-" + c.uid
op, err := c.createVPC(ctx, c.network)
if err != nil {
return err
}
if err := c.waitForOperations(ctx, []Operation{op}); err != nil {
return err
}
if err := c.createSubnets(ctx, input.SubnetCIDR); err != nil {
return err
}
return nil
}
// createVPC creates a VPC network.
func (c *Client) createVPC(ctx context.Context, name string) (Operation, error) {
req := &computepb.InsertNetworkRequest{
NetworkResource: &computepb.Network{
AutoCreateSubnetworks: proto.Bool(false),
Description: proto.String("Constellation VPC"),
Name: proto.String(name),
},
Project: c.project,
}
return c.networksAPI.Insert(ctx, req)
}
// TerminateVPCs terminates all VPC networks.
//
// If the any network has firewall rules, these must be terminated first.
func (c *Client) TerminateVPCs(ctx context.Context) error {
if len(c.firewalls) != 0 {
return errors.New("client has firewalls, which must be deleted first")
}
if err := c.terminateSubnets(ctx); err != nil {
return err
}
var op Operation
var err error
if c.network != "" {
op, err = c.terminateVPC(ctx, c.network)
if err != nil {
return err
}
c.network = ""
}
return c.waitForOperations(ctx, []Operation{op})
}
// terminateVPC terminates a VPC network.
//
// If the network has firewall rules, these must be terminated first.
func (c *Client) terminateVPC(ctx context.Context, network string) (Operation, error) {
req := &computepb.DeleteNetworkRequest{
Project: c.project,
Network: network,
}
return c.networksAPI.Delete(ctx, req)
}
func (c *Client) createSubnets(ctx context.Context, subnetCIDR string) error {
c.subnetwork = "node-net-" + c.uid
c.secondarySubnetworkRange = "net-ext" + c.uid
op, err := c.createSubnet(ctx, c.subnetwork, subnetCIDR, c.network, c.secondarySubnetworkRange)
if err != nil {
return err
}
return c.waitForOperations(ctx, []Operation{op})
}
func (c *Client) createSubnet(ctx context.Context, name, cidr, network, secondaryRangeName string) (Operation, error) {
req := &computepb.InsertSubnetworkRequest{
Project: c.project,
Region: c.region,
SubnetworkResource: &computepb.Subnetwork{
IpCidrRange: proto.String(cidr),
Name: proto.String(name),
Network: proto.String("projects/" + c.project + "/global/networks/" + network),
SecondaryIpRanges: []*computepb.SubnetworkSecondaryRange{
{
RangeName: proto.String(secondaryRangeName),
IpCidrRange: proto.String("10.10.0.0/16"),
},
},
},
}
return c.subnetworksAPI.Insert(ctx, req)
}
func (c *Client) terminateSubnets(ctx context.Context) error {
var op Operation
var err error
if c.subnetwork != "" {
op, err = c.terminateSubnet(ctx, c.subnetwork)
if err != nil {
return err
}
}
return c.waitForOperations(ctx, []Operation{op})
}
func (c *Client) terminateSubnet(ctx context.Context, name string) (Operation, error) {
req := &computepb.DeleteSubnetworkRequest{
Project: c.project,
Region: c.region,
Subnetwork: name,
}
return c.subnetworksAPI.Delete(ctx, req)
}

View file

@ -0,0 +1,302 @@
package client
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/stretchr/testify/assert"
)
func TestCreateVPCs(t *testing.T) {
someErr := errors.New("failed")
testInput := VPCsInput{
SubnetCIDR: "192.0.2.0/24",
SubnetExtCIDR: "198.51.100.0/24",
}
testCases := map[string]struct {
operationGlobalAPI operationGlobalAPI
operationRegionAPI operationRegionAPI
networksAPI networksAPI
subnetworksAPI subnetworksAPI
errExpected bool
}{
"successful create": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{},
},
"failed wait global op": {
operationGlobalAPI: stubOperationGlobalAPI{waitErr: someErr},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{},
errExpected: true,
},
"failed wait region op": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{waitErr: someErr},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{},
errExpected: true,
},
"failed insert networks": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{insertErr: someErr},
subnetworksAPI: stubSubnetworksAPI{},
errExpected: true,
},
"failed insert subnetworks": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{insertErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
operationGlobalAPI: tc.operationGlobalAPI,
operationRegionAPI: tc.operationRegionAPI,
networksAPI: tc.networksAPI,
subnetworksAPI: tc.subnetworksAPI,
nodes: make(gcp.Instances),
coordinators: make(gcp.Instances),
}
if tc.errExpected {
assert.Error(client.CreateVPCs(ctx, testInput))
} else {
assert.NoError(client.CreateVPCs(ctx, testInput))
assert.NotNil(client.network)
}
})
}
}
func TestTerminateVPCs(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
operationGlobalAPI operationGlobalAPI
operationRegionAPI operationRegionAPI
networksAPI networksAPI
subnetworksAPI subnetworksAPI
firewalls []string
errExpected bool
}{
"successful terminate": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{},
},
"failed wait global op": {
operationGlobalAPI: stubOperationGlobalAPI{waitErr: someErr},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{},
errExpected: true,
},
"failed delete networks": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{deleteErr: someErr},
subnetworksAPI: stubSubnetworksAPI{},
errExpected: true,
},
"failed delete subnetworks": {
operationGlobalAPI: stubOperationGlobalAPI{},
operationRegionAPI: stubOperationRegionAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{deleteErr: someErr},
errExpected: true,
},
"must delete firewalls first": {
firewalls: []string{"firewall-1", "firewall-2"},
operationRegionAPI: stubOperationRegionAPI{},
operationGlobalAPI: stubOperationGlobalAPI{},
networksAPI: stubNetworksAPI{},
subnetworksAPI: stubSubnetworksAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
operationGlobalAPI: tc.operationGlobalAPI,
operationRegionAPI: tc.operationRegionAPI,
networksAPI: tc.networksAPI,
subnetworksAPI: tc.subnetworksAPI,
firewalls: tc.firewalls,
network: "network-id-1",
subnetwork: "subnetwork-id-1",
}
if tc.errExpected {
assert.Error(client.TerminateVPCs(ctx))
} else {
assert.NoError(client.TerminateVPCs(ctx))
assert.Empty(client.network)
}
})
}
}
func TestCreateFirewall(t *testing.T) {
someErr := errors.New("failed")
testFirewallInput := FirewallInput{
Ingress: cloudtypes.Firewall{
cloudtypes.FirewallRule{
Name: "test-1",
Description: "test-1 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
Port: 9000,
},
cloudtypes.FirewallRule{
Name: "test-2",
Description: "test-2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
Port: 51820,
},
},
Egress: cloudtypes.Firewall{},
}
testCases := map[string]struct {
network string
operationGlobalAPI operationGlobalAPI
firewallsAPI firewallsAPI
firewallInput FirewallInput
errExpected bool
}{
"successful create": {
network: "network",
operationGlobalAPI: stubOperationGlobalAPI{},
firewallsAPI: stubFirewallsAPI{},
},
"failed wait global op": {
network: "network",
operationGlobalAPI: stubOperationGlobalAPI{waitErr: someErr},
firewallsAPI: stubFirewallsAPI{},
errExpected: true,
},
"failed insert networks": {
network: "network",
operationGlobalAPI: stubOperationGlobalAPI{},
firewallsAPI: stubFirewallsAPI{insertErr: someErr},
errExpected: true,
},
"no network set": {
operationGlobalAPI: stubOperationGlobalAPI{},
firewallsAPI: stubFirewallsAPI{},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
network: tc.network,
operationGlobalAPI: tc.operationGlobalAPI,
firewallsAPI: tc.firewallsAPI,
}
if tc.errExpected {
assert.Error(client.CreateFirewall(ctx, testFirewallInput))
} else {
assert.NoError(client.CreateFirewall(ctx, testFirewallInput))
assert.ElementsMatch([]string{"test-1", "test-2"}, client.firewalls)
}
})
}
}
func TestTerminateFirewall(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
operationGlobalAPI operationGlobalAPI
firewallsAPI firewallsAPI
firewalls []string
errExpected bool
}{
"successful terminate": {
operationGlobalAPI: stubOperationGlobalAPI{},
firewallsAPI: stubFirewallsAPI{},
firewalls: []string{"firewall-1", "firewall-2"},
},
"successful terminate when no firewall exists": {
operationGlobalAPI: stubOperationGlobalAPI{},
firewallsAPI: stubFirewallsAPI{},
firewalls: []string{},
},
"failed to wait on global operation": {
operationGlobalAPI: stubOperationGlobalAPI{waitErr: someErr},
firewallsAPI: stubFirewallsAPI{},
firewalls: []string{"firewall-1", "firewall-2"},
errExpected: true,
},
"failed to delete firewalls": {
operationGlobalAPI: stubOperationGlobalAPI{},
firewallsAPI: stubFirewallsAPI{deleteErr: someErr},
firewalls: []string{"firewall-1", "firewall-2"},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
firewalls: tc.firewalls,
operationGlobalAPI: tc.operationGlobalAPI,
firewallsAPI: tc.firewallsAPI,
}
if tc.errExpected {
assert.Error(client.TerminateFirewall(ctx))
} else {
assert.NoError(client.TerminateFirewall(ctx))
assert.Empty(client.firewalls)
}
})
}
}

View file

@ -0,0 +1,98 @@
package client
import (
"context"
"fmt"
computepb "google.golang.org/genproto/googleapis/cloud/compute/v1"
)
// waitForOperations waits until every operation in the opIDs slice is
// done or returns the first occurring error.
func (c *Client) waitForOperations(ctx context.Context, ops []Operation) error {
for _, op := range ops {
switch {
case op.Proto().Zone != nil:
if err := c.waitForZoneOperation(ctx, op); err != nil {
return err
}
case op.Proto().Region != nil:
if err := c.waitForRegionOperation(ctx, op); err != nil {
return err
}
default:
if err := c.waitForGlobalOperation(ctx, op); err != nil {
return err
}
}
}
return nil
}
func (c *Client) waitForGlobalOperation(ctx context.Context, op Operation) error {
for {
if err := ctx.Err(); err != nil {
return err
}
waitReq := &computepb.WaitGlobalOperationRequest{
Operation: *op.Proto().Name,
Project: c.project,
}
zoneOp, err := c.operationGlobalAPI.Wait(ctx, waitReq)
if err != nil {
return fmt.Errorf("unable to wait for the operation: %w", err)
}
if *zoneOp.Status.Enum() == computepb.Operation_DONE {
if opErr := zoneOp.Error; opErr != nil {
return fmt.Errorf("operation failed: %s", opErr.String())
}
return nil
}
}
}
func (c *Client) waitForZoneOperation(ctx context.Context, op Operation) error {
for {
if err := ctx.Err(); err != nil {
return err
}
waitReq := &computepb.WaitZoneOperationRequest{
Operation: *op.Proto().Name,
Project: c.project,
Zone: c.zone,
}
zoneOp, err := c.operationZoneAPI.Wait(ctx, waitReq)
if err != nil {
return fmt.Errorf("unable to wait for the operation: %w", err)
}
if *zoneOp.Status.Enum() == computepb.Operation_DONE {
if opErr := zoneOp.Error; opErr != nil {
return fmt.Errorf("operation failed: %s", opErr.String())
}
return nil
}
}
}
func (c *Client) waitForRegionOperation(ctx context.Context, op Operation) error {
for {
if err := ctx.Err(); err != nil {
return err
}
waitReq := &computepb.WaitRegionOperationRequest{
Operation: *op.Proto().Name,
Project: c.project,
Region: c.region,
}
regionOp, err := c.operationRegionAPI.Wait(ctx, waitReq)
if err != nil {
return fmt.Errorf("unable to wait for the operation: %w", err)
}
if *regionOp.Status.Enum() == computepb.Operation_DONE {
if opErr := regionOp.Error; opErr != nil {
return fmt.Errorf("operation failed: %s", opErr.String())
}
return nil
}
}
}

71
cli/gcp/client/project.go Normal file
View file

@ -0,0 +1,71 @@
package client
import (
"context"
"fmt"
iampb "google.golang.org/genproto/googleapis/iam/v1"
)
// addIAMPolicyBindings adds a GCP service account to roles specified in the input.
func (c *Client) addIAMPolicyBindings(ctx context.Context, input AddIAMPolicyBindingInput) error {
getReq := &iampb.GetIamPolicyRequest{
Resource: "projects/" + c.project,
}
policy, err := c.projectsAPI.GetIamPolicy(ctx, getReq)
if err != nil {
return fmt.Errorf("retrieving current iam policy failed: %w", err)
}
for _, binding := range input.Bindings {
addIAMPolicy(policy, binding)
}
setReq := &iampb.SetIamPolicyRequest{
Resource: "projects/" + c.project,
Policy: policy,
}
if _, err := c.projectsAPI.SetIamPolicy(ctx, setReq); err != nil {
return fmt.Errorf("setting new iam policy failed: %w", err)
}
return nil
}
// PolicyBinding is a GCP IAM policy binding.
type PolicyBinding struct {
ServiceAccount string
Role string
}
// addIAMPolicy inserts policy binding for service account and role to an existing iam policy.
func addIAMPolicy(policy *iampb.Policy, policyBinding PolicyBinding) {
var binding *iampb.Binding
for _, existingBinding := range policy.Bindings {
if existingBinding.Role == policyBinding.Role && existingBinding.Condition == nil {
binding = existingBinding
break
}
}
if binding == nil {
binding = &iampb.Binding{
Role: policyBinding.Role,
}
policy.Bindings = append(policy.Bindings, binding)
}
// add service account to role, if not already a member
member := "serviceAccount:" + policyBinding.ServiceAccount
var alreadyMember bool
for _, existingMember := range binding.Members {
if member == existingMember {
alreadyMember = true
break
}
}
if !alreadyMember {
binding.Members = append(binding.Members, member)
}
}
// AddIAMPolicyBindingInput is the input for an AddIAMPolicyBinding operation.
type AddIAMPolicyBindingInput struct {
Bindings []PolicyBinding
}

View file

@ -0,0 +1,177 @@
package client
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
iampb "google.golang.org/genproto/googleapis/iam/v1"
"google.golang.org/protobuf/proto"
)
func TestAddIAMPolicyBindings(t *testing.T) {
someErr := errors.New("someErr")
testCases := map[string]struct {
projectsAPI stubProjectsAPI
input AddIAMPolicyBindingInput
errExpected bool
}{
"successful set without new bindings": {
input: AddIAMPolicyBindingInput{
Bindings: []PolicyBinding{},
},
},
"successful set with bindings": {
input: AddIAMPolicyBindingInput{
Bindings: []PolicyBinding{
{
ServiceAccount: "service-account",
Role: "role",
},
},
},
},
"retrieving iam policy fails": {
projectsAPI: stubProjectsAPI{
getPolicyErr: someErr,
},
errExpected: true,
},
"setting iam policy fails": {
projectsAPI: stubProjectsAPI{
setPolicyErr: someErr,
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
projectsAPI: tc.projectsAPI,
}
err := client.addIAMPolicyBindings(ctx, tc.input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestAddIAMPolicy(t *testing.T) {
testCases := map[string]struct {
binding PolicyBinding
policy *iampb.Policy
errExpected bool
policyExpected *iampb.Policy
}{
"successful on empty policy": {
binding: PolicyBinding{
ServiceAccount: "service-account",
Role: "role",
},
policy: &iampb.Policy{
Bindings: []*iampb.Binding{},
},
policyExpected: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{"serviceAccount:service-account"},
},
},
},
},
"successful on existing policy with different role": {
binding: PolicyBinding{
ServiceAccount: "service-account",
Role: "role",
},
policy: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "other-role",
Members: []string{"other-member"},
},
},
},
policyExpected: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "other-role",
Members: []string{"other-member"},
},
{
Role: "role",
Members: []string{"serviceAccount:service-account"},
},
},
},
},
"successful on existing policy with existing role": {
binding: PolicyBinding{
ServiceAccount: "service-account",
Role: "role",
},
policy: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{"other-member"},
},
},
},
policyExpected: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{"other-member", "serviceAccount:service-account"},
},
},
},
},
"already a member": {
binding: PolicyBinding{
ServiceAccount: "service-account",
Role: "role",
},
policy: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{"serviceAccount:service-account"},
},
},
},
policyExpected: &iampb.Policy{
Bindings: []*iampb.Binding{
{
Role: "role",
Members: []string{"serviceAccount:service-account"},
},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
addIAMPolicy(tc.policy, tc.binding)
assert.True(proto.Equal(tc.policyExpected, tc.policy))
})
}
}

View file

@ -0,0 +1,152 @@
package client
import (
"context"
"encoding/json"
"fmt"
"net/url"
adminpb "google.golang.org/genproto/googleapis/iam/admin/v1"
)
// CreateServiceAccount creates a new GCP service account and returns an account key as service account URI.
func (c *Client) CreateServiceAccount(ctx context.Context, input ServiceAccountInput) (string, error) {
insertInput := insertServiceAccountInput{
Project: c.project,
AccountID: "constellation-app-" + c.uid,
DisplayName: "constellation-app-" + c.uid,
Description: "This service account belongs to a Constellation.",
}
email, err := c.insertServiceAccount(ctx, insertInput)
if err != nil {
return "", err
}
c.serviceAccount = email
iamInput := input.addIAMPolicyBindingInput(c.serviceAccount)
if err := c.addIAMPolicyBindings(ctx, iamInput); err != nil {
return "", err
}
key, err := c.createServiceAccountKey(ctx, email)
if err != nil {
return "", err
}
return key.ConvertToCloudServiceAccountURI(), nil
}
func (c *Client) TerminateServiceAccount(ctx context.Context) error {
if c.serviceAccount != "" {
req := &adminpb.DeleteServiceAccountRequest{
Name: "projects/-/serviceAccounts/" + c.serviceAccount,
}
if err := c.iamAPI.DeleteServiceAccount(ctx, req); err != nil {
return fmt.Errorf("deleting service account failed: %w", err)
}
c.serviceAccount = ""
}
return nil
}
type ServiceAccountInput struct {
Roles []string
}
func (i ServiceAccountInput) addIAMPolicyBindingInput(serviceAccount string) AddIAMPolicyBindingInput {
iamPolicyBindingInput := AddIAMPolicyBindingInput{
Bindings: make([]PolicyBinding, len(i.Roles)),
}
for i, role := range i.Roles {
iamPolicyBindingInput.Bindings[i] = PolicyBinding{
ServiceAccount: serviceAccount,
Role: role,
}
}
return iamPolicyBindingInput
}
// ServiceAccountKey is a GCP service account key.
type ServiceAccountKey struct {
Type string `json:"type"`
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
AuthURI string `json:"auth_uri"`
TokenURI string `json:"token_uri"`
AuthProviderX509CertURL string `json:"auth_provider_x509_cert_url"`
ClientX509CertURL string `json:"client_x509_cert_url"`
}
// ConvertToCloudServiceAccountURI converts the ServiceAccountKey into a cloud service account URI.
func (k ServiceAccountKey) ConvertToCloudServiceAccountURI() string {
query := url.Values{}
query.Add("type", k.Type)
query.Add("project_id", k.ProjectID)
query.Add("private_key_id", k.PrivateKeyID)
query.Add("private_key", k.PrivateKey)
query.Add("client_email", k.ClientEmail)
query.Add("client_id", k.ClientID)
query.Add("auth_uri", k.AuthURI)
query.Add("token_uri", k.TokenURI)
query.Add("auth_provider_x509_cert_url", k.AuthProviderX509CertURL)
query.Add("client_x509_cert_url", k.ClientX509CertURL)
uri := url.URL{
Scheme: "serviceaccount",
Host: "gcp",
RawQuery: query.Encode(),
}
return uri.String()
}
func (c *Client) insertServiceAccount(ctx context.Context, input insertServiceAccountInput) (string, error) {
req := input.createServiceAccountRequest()
account, err := c.iamAPI.CreateServiceAccount(ctx, req)
if err != nil {
return "", err
}
return account.Email, nil
}
func (c *Client) createServiceAccountKey(ctx context.Context, email string) (ServiceAccountKey, error) {
req := createServiceAccountKeyRequest(email)
key, err := c.iamAPI.CreateServiceAccountKey(ctx, req)
if err != nil {
return ServiceAccountKey{}, fmt.Errorf("creating service account key failed: %w", err)
}
var serviceAccountKey ServiceAccountKey
if err := json.Unmarshal(key.PrivateKeyData, &serviceAccountKey); err != nil {
return ServiceAccountKey{}, fmt.Errorf("decoding service account key JSON failed: %w", err)
}
return serviceAccountKey, nil
}
// insertServiceAccountInput is the input for a createServiceAccount operation.
type insertServiceAccountInput struct {
Project string
AccountID string
DisplayName string
Description string
}
func (c insertServiceAccountInput) createServiceAccountRequest() *adminpb.CreateServiceAccountRequest {
return &adminpb.CreateServiceAccountRequest{
Name: "projects/" + c.Project,
AccountId: c.AccountID,
ServiceAccount: &adminpb.ServiceAccount{
DisplayName: c.DisplayName,
Description: c.Description,
},
}
}
func createServiceAccountKeyRequest(email string) *adminpb.CreateServiceAccountKeyRequest {
return &adminpb.CreateServiceAccountKeyRequest{
Name: "projects/-/serviceAccounts/" + email,
}
}

View file

@ -0,0 +1,174 @@
package client
import (
"context"
"encoding/json"
"errors"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateServiceAccount(t *testing.T) {
require := require.New(t)
someErr := errors.New("someErr")
key := ServiceAccountKey{
Type: "type",
ProjectID: "project-id",
PrivateKeyID: "private-key-id",
PrivateKey: "private-key",
ClientEmail: "client-email",
ClientID: "client-id",
AuthURI: "auth-uri",
TokenURI: "token-uri",
AuthProviderX509CertURL: "auth-provider-x509-cert-url",
ClientX509CertURL: "client-x509-cert-url",
}
keyData, err := json.Marshal(key)
require.NoError(err)
testCases := map[string]struct {
iamAPI iamAPI
projectsAPI stubProjectsAPI
input ServiceAccountInput
errExpected bool
}{
"successful create": {
iamAPI: stubIAMAPI{serviceAccountKeyData: keyData},
input: ServiceAccountInput{
Roles: []string{"someRole"},
},
},
"successful create with roles": {
iamAPI: stubIAMAPI{serviceAccountKeyData: keyData},
},
"creating account fails": {
iamAPI: stubIAMAPI{createErr: someErr},
errExpected: true,
},
"creating account key fails": {
iamAPI: stubIAMAPI{createKeyErr: someErr},
errExpected: true,
},
"key data missing": {
iamAPI: stubIAMAPI{},
errExpected: true,
},
"key data corrupt": {
iamAPI: stubIAMAPI{serviceAccountKeyData: []byte("invalid key data")},
errExpected: true,
},
"retrieving iam policy bindings fails": {
iamAPI: stubIAMAPI{},
projectsAPI: stubProjectsAPI{getPolicyErr: someErr},
errExpected: true,
},
"setting iam policy bindings fails": {
iamAPI: stubIAMAPI{},
projectsAPI: stubProjectsAPI{setPolicyErr: someErr},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
iamAPI: tc.iamAPI,
projectsAPI: tc.projectsAPI,
}
serviceAccountKey, err := client.CreateServiceAccount(ctx, tc.input)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(key.ConvertToCloudServiceAccountURI(), serviceAccountKey)
assert.Equal("email", client.serviceAccount)
}
})
}
}
func TestTerminateServiceAccount(t *testing.T) {
testCases := map[string]struct {
iamAPI iamAPI
errExpected bool
}{
"delete works": {
iamAPI: stubIAMAPI{},
},
"delete fails": {
iamAPI: stubIAMAPI{
deleteServiceAccountErr: errors.New("someErr"),
},
errExpected: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
project: "project",
zone: "zone",
name: "name",
uid: "uid",
serviceAccount: "service-account",
iamAPI: tc.iamAPI,
}
err := client.TerminateServiceAccount(ctx)
if tc.errExpected {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestConvertToCloudServiceAccountURI(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
key := ServiceAccountKey{
Type: "type",
ProjectID: "project-id",
PrivateKeyID: "private-key-id",
PrivateKey: "private-key",
ClientEmail: "client-email",
ClientID: "client-id",
AuthURI: "auth-uri",
TokenURI: "token-uri",
AuthProviderX509CertURL: "auth-provider-x509-cert-url",
ClientX509CertURL: "client-x509-cert-url",
}
cloudServiceAccountURI := key.ConvertToCloudServiceAccountURI()
uri, err := url.Parse(cloudServiceAccountURI)
require.NoError(err)
query := uri.Query()
assert.Equal("serviceaccount", uri.Scheme)
assert.Equal("gcp", uri.Host)
assert.Equal(url.Values{
"type": []string{"type"},
"project_id": []string{"project-id"},
"private_key_id": []string{"private-key-id"},
"private_key": []string{"private-key"},
"client_email": []string{"client-email"},
"client_id": []string{"client-id"},
"auth_uri": []string{"auth-uri"},
"token_uri": []string{"token-uri"},
"auth_provider_x509_cert_url": []string{"auth-provider-x509-cert-url"},
"client_x509_cert_url": []string{"client-x509-cert-url"},
}, query)
}

62
cli/gcp/instances.go Normal file
View file

@ -0,0 +1,62 @@
package gcp
// copy of ec2/instances.go
// TODO(katexochen): refactor into mulitcloud package.
import "errors"
// Instance is a gcp instance.
type Instance struct {
PublicIP string
PrivateIP string
}
// Instances is a map of gcp Instances. The ID of an instance is used as key.
type Instances map[string]Instance
// IDs returns the IDs of all instances of the Constellation.
func (i Instances) IDs() []string {
var ids []string
for id := range i {
ids = append(ids, id)
}
return ids
}
// PublicIPs returns the public IPs of all the instances of the Constellation.
func (i Instances) PublicIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PublicIP)
}
return ips
}
// PrivateIPs returns the private IPs of all the instances of the Constellation.
func (i Instances) PrivateIPs() []string {
var ips []string
for _, instance := range i {
ips = append(ips, instance.PrivateIP)
}
return ips
}
// GetOne return anyone instance out of the instances and its ID.
func (i Instances) GetOne() (string, Instance, error) {
for id, instance := range i {
return id, instance, nil
}
return "", Instance{}, errors.New("map is empty")
}
// GetOthers returns all instances but the one with the handed ID.
func (i Instances) GetOthers(id string) Instances {
others := make(Instances)
for key, instance := range i {
if key != id {
others[key] = instance
}
}
return others
}

71
cli/gcp/instances_test.go Normal file
View file

@ -0,0 +1,71 @@
package gcp
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIDs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIDs := []string{"id-9", "id-10", "id-11", "id-12"}
assert.ElementsMatch(expectedIDs, testState.IDs())
}
func TestPublicIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.1", "192.0.2.3", "192.0.2.5", "192.0.2.7"}
assert.ElementsMatch(expectedIPs, testState.PublicIPs())
}
func TestPrivateIPs(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
expectedIPs := []string{"192.0.2.2", "192.0.2.4", "192.0.2.6", "192.0.2.8"}
assert.ElementsMatch(expectedIPs, testState.PrivateIPs())
}
func TestGetOne(t *testing.T) {
assert := assert.New(t)
testState := testInstances()
id, instance, err := testState.GetOne()
assert.NoError(err)
assert.Contains(testState, id)
assert.Equal(testState[id], instance)
}
func TestGetOthers(t *testing.T) {
assert := assert.New(t)
testCases := testInstances().IDs()
for _, id := range testCases {
others := testInstances().GetOthers(id)
assert.NotContains(others, id)
expectedInstances := testInstances()
delete(expectedInstances, id)
assert.ElementsMatch(others.IDs(), expectedInstances.IDs())
}
}
func testInstances() Instances {
return Instances{
"id-9": {
PublicIP: "192.0.2.1",
PrivateIP: "192.0.2.2",
},
"id-10": {
PublicIP: "192.0.2.3",
PrivateIP: "192.0.2.4",
},
"id-11": {
PublicIP: "192.0.2.5",
PrivateIP: "192.0.2.6",
},
"id-12": {
PublicIP: "192.0.2.7",
PrivateIP: "192.0.2.8",
},
}
}

14
cli/gcp/instancetypes.go Normal file
View file

@ -0,0 +1,14 @@
package gcp
// InstanceTypes are valid GCP instance types.
var InstanceTypes = []string{
"n2d-standard-2",
"n2d-standard-4",
"n2d-standard-8",
"n2d-standard-16",
"n2d-standard-32",
"n2d-standard-48",
"n2d-standard-64",
"n2d-standard-80",
"n2d-standard-96",
}

Some files were not shown because too many files have changed in this diff Show more