Manually manage resource group on Azure

This commit is contained in:
katexochen 2022-08-25 15:12:08 +02:00 committed by Paul Meyer
parent e6ae54a25a
commit f15605cb45
25 changed files with 403 additions and 1162 deletions

View file

@ -1,184 +0,0 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/edgelesssys/constellation/internal/azureshared"
"github.com/google/uuid"
)
const (
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
)
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
createAppRes, err := c.createADApplication(ctx)
if err != nil {
return "", err
}
c.adAppObjectID = createAppRes.ObjectID
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
if err != nil {
return "", err
}
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
return "", err
}
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
if err != nil {
return "", err
}
return azureshared.ApplicationCredentials{
TenantID: c.tenantID,
ClientID: createAppRes.AppID,
ClientSecret: clientSecret,
Location: c.location,
}.ToCloudServiceAccountURI(), nil
}
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
return err
}
c.adAppObjectID = ""
return nil
}
// createADApplication creates a new azure AD app.
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
createParameters := graphrbac.ApplicationCreateParameters{
AvailableToOtherTenants: to.BoolPtr(false),
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
}
app, err := c.applicationsAPI.Create(ctx, createParameters)
if err != nil {
return createADApplicationOutput{}, err
}
if app.AppID == nil || app.ObjectID == nil {
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
}
return createADApplicationOutput{
AppID: *app.AppID,
ObjectID: *app.ObjectID,
}, nil
}
// createAppServicePrincipal creates a new service principal for an azure AD app.
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
createParameters := graphrbac.ServicePrincipalCreateParameters{
AppID: &appID,
AccountEnabled: to.BoolPtr(true),
}
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
if err != nil {
return "", err
}
if servicePrincipal.ObjectID == nil {
return "", errors.New("creating AD service principal did not result in a valid object id")
}
return *servicePrincipal.ObjectID, nil
}
// updateAppCredentials sets app client-secret for authentication.
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
keyID := uuid.New().String()
clientSecret, err := generateClientSecret()
if err != nil {
return "", fmt.Errorf("generating client secret: %w", err)
}
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
Value: &[]graphrbac.PasswordCredential{
{
StartDate: &date.Time{Time: time.Now()},
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
Value: to.StringPtr(clientSecret),
KeyID: to.StringPtr(keyID),
},
},
}
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
if err != nil {
return "", err
}
return clientSecret, nil
}
// assignResourceGroupRole assigns the service principal a role at resource group scope.
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
if err != nil || resourceGroup.ID == nil {
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
}
roleAssignmentID := uuid.New().String()
createParameters := authorization.RoleAssignmentCreateParameters{
Properties: &authorization.RoleAssignmentProperties{
PrincipalID: to.StringPtr(principalID),
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
},
}
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
// proper fix: use API version 2018-09-01-preview or later
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
var detailedErr autorest.DetailedError
var ok bool
if detailedErr, ok = err.(autorest.DetailedError); !ok {
return err
}
var requestErr *azure.RequestError
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
return err
}
if requestErr.ServiceError.Code != "PrincipalNotFound" {
return err
}
time.Sleep(c.adReplicationLagCheckInterval)
}
return err
}
type createADApplicationOutput struct {
AppID string
ObjectID string
}
func generateClientSecret() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 64
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
return string(pw), nil
}

View file

@ -1,358 +0,0 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestCreateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
servicePrincipalsAPI servicePrincipalsAPI
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed service principal create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed role assignment": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"failed update creds": {
applicationsAPI: stubApplicationsAPI{
updateCredentialsErr: someErr,
},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
applicationsAPI: tc.applicationsAPI,
servicePrincipalsAPI: tc.servicePrincipalsAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
_, err := client.CreateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestTerminateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
appObjectID string
applicationsAPI applicationsAPI
wantErr bool
}{
"successful terminate": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{},
},
"nothing to terminate": {
applicationsAPI: stubApplicationsAPI{},
},
"failed delete": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{
deleteErr: someErr,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
adAppObjectID: tc.appObjectID,
applicationsAPI: tc.applicationsAPI,
}
err := client.TerminateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateADApplication(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"app create returns invalid appid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
},
},
wantErr: true,
},
"app create returns invalid objectid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
applicationsAPI: tc.applicationsAPI,
}
appCredentials, err := client.createADApplication(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(appCredentials)
})
}
}
func TestCreateAppServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
servicePrincipalsAPI servicePrincipalsAPI
wantErr bool
}{
"successful create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{},
},
"failed service principal create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"service principal create returns invalid objectid": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createServicePrincipal: &graphrbac.ServicePrincipal{},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
servicePrincipalsAPI: tc.servicePrincipalsAPI,
}
_, err := client.createAppServicePrincipal(ctx, "app-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestAssignOwnerOfResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful assign": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"failed role assignment": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"failed resource group get": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getErr: someErr,
},
wantErr: true,
},
"resource group get returns invalid id": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{},
},
wantErr: true,
},
"create returns PrincipalNotFound the first time": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "PrincipalNotFound",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"create does not return request error": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{autorest.DetailedError{Original: someErr}},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"create service error code is unknown": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "some-unknown-error-code",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}

View file

@ -64,15 +64,13 @@ type networkInterfacesAPI interface {
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) ) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
} }
type resourceGroupAPI interface { type resourceAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string, NewListByResourceGroupPager(resourceGroupName string,
parameters armresources.ResourceGroup, options *armresources.ClientListByResourceGroupOptions,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) ( ) *runtime.Pager[armresources.ClientListByResourceGroupResponse]
armresources.ResourceGroupsClientCreateOrUpdateResponse, error) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
BeginDelete(ctx context.Context, resourceGroupName string, options *armresources.ClientBeginDeleteByIDOptions,
options *armresources.ResourceGroupsClientBeginDeleteOptions) ( ) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error)
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error)
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
} }
type applicationsAPI interface { type applicationsAPI interface {

View file

@ -4,15 +4,11 @@ import (
"context" "context"
"net/http" "net/http"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2" armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
) )
type stubNetworksAPI struct { type stubNetworksAPI struct {
@ -94,44 +90,6 @@ func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, r
return poller, a.createErr return poller, a.createErr
} }
type stubResourceGroupAPI struct {
terminateErr error
createErr error
getErr error
getResourceGroup armresources.ResourceGroup
pollErr error
}
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
) {
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
}
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
return armresources.ResourceGroupsClientGetResponse{
ResourceGroup: a.getResourceGroup,
}, a.getErr
}
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ResourceGroupsClientDeleteResponse]{
Handler: &stubPoller[armresources.ResourceGroupsClientDeleteResponse]{
result: armresources.ResourceGroupsClientDeleteResponse{},
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.terminateErr
}
type stubScaleSetsAPI struct { type stubScaleSetsAPI struct {
createErr error createErr error
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
@ -282,71 +240,6 @@ func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx
}, nil }, nil
} }
type stubApplicationsAPI struct {
createErr error
deleteErr error
updateCredentialsErr error
createApplication *graphrbac.Application
}
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
if a.createErr != nil {
return graphrbac.Application{}, a.createErr
}
if a.createApplication != nil {
return *a.createApplication, nil
}
return graphrbac.Application{
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000001"),
}, nil
}
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
if a.deleteErr != nil {
return autorest.Response{}, a.deleteErr
}
return autorest.Response{}, nil
}
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
if a.updateCredentialsErr != nil {
return autorest.Response{}, a.updateCredentialsErr
}
return autorest.Response{}, nil
}
type stubServicePrincipalsAPI struct {
createErr error
createServicePrincipal *graphrbac.ServicePrincipal
}
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
if a.createErr != nil {
return graphrbac.ServicePrincipal{}, a.createErr
}
if a.createServicePrincipal != nil {
return *a.createServicePrincipal, nil
}
return graphrbac.ServicePrincipal{
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000002"),
}, nil
}
type stubRoleAssignmentsAPI struct {
createCounter int
createErrors []error
}
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
a.createCounter++
if len(a.createErrors) == 0 {
return authorization.RoleAssignment{}, nil
}
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
}
type stubApplicationInsightsAPI struct { type stubApplicationInsightsAPI struct {
err error err error
} }

View file

@ -29,7 +29,7 @@ const (
type Client struct { type Client struct {
networksAPI networksAPI
networkSecurityGroupsAPI networkSecurityGroupsAPI
resourceGroupAPI resourceAPI
scaleSetsAPI scaleSetsAPI
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
@ -39,9 +39,7 @@ type Client struct {
roleAssignmentsAPI roleAssignmentsAPI
applicationInsightsAPI applicationInsightsAPI
pollFrequency time.Duration pollFrequency time.Duration
adReplicationLagCheckInterval time.Duration
adReplicationLagCheckMaxRetries int
workers cloudtypes.Instances workers cloudtypes.Instances
controlPlanes cloudtypes.Instances controlPlanes cloudtypes.Instances
@ -83,10 +81,6 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
resGroupAPI, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil) scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
if err != nil { if err != nil {
return nil, err return nil, err
@ -107,6 +101,10 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
applicationsAPI := graphrbac.NewApplicationsClient(tenantID) applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID) servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
@ -115,42 +113,41 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
roleAssignmentsAPI.Authorizer = managementAuthorizer roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{ return &Client{
networksAPI: netAPI, networksAPI: netAPI,
networkSecurityGroupsAPI: netSecGrpAPI, networkSecurityGroupsAPI: netSecGrpAPI,
resourceGroupAPI: resGroupAPI, resourceAPI: resourceAPI,
scaleSetsAPI: scaleSetAPI, scaleSetsAPI: scaleSetAPI,
publicIPAddressesAPI: publicIPAddressesAPI, publicIPAddressesAPI: publicIPAddressesAPI,
networkInterfacesAPI: networkInterfacesAPI, networkInterfacesAPI: networkInterfacesAPI,
loadBalancersAPI: loadBalancersAPI, loadBalancersAPI: loadBalancersAPI,
applicationsAPI: applicationsAPI, applicationsAPI: applicationsAPI,
servicePrincipalsAPI: servicePrincipalsAPI, servicePrincipalsAPI: servicePrincipalsAPI,
roleAssignmentsAPI: roleAssignmentsAPI, roleAssignmentsAPI: roleAssignmentsAPI,
applicationInsightsAPI: applicationInsightsAPI, applicationInsightsAPI: applicationInsightsAPI,
subscriptionID: subscriptionID, subscriptionID: subscriptionID,
tenantID: tenantID, tenantID: tenantID,
workers: cloudtypes.Instances{}, workers: cloudtypes.Instances{},
controlPlanes: cloudtypes.Instances{}, controlPlanes: cloudtypes.Instances{},
pollFrequency: time.Second * 5, pollFrequency: time.Second * 5,
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
}, nil }, nil
} }
// NewInitialized creates and initializes client by setting the subscriptionID, location and name // NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation. // of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) { func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID) client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = client.init(location, name) err = client.init(location, name, resourceGroup)
return client, err return client, err
} }
// init initializes the client. // init initializes the client.
func (c *Client) init(location, name string) error { func (c *Client) init(location, name, resourceGroup string) error {
c.location = location c.location = location
c.name = name c.name = name
c.resourceGroup = resourceGroup
uid, err := c.generateUID() uid, err := c.generateUID()
if err != nil { if err != nil {
return err return err

View file

@ -84,8 +84,9 @@ func TestInit(t *testing.T) {
require := require.New(t) require := require.New(t)
client := Client{} client := Client{}
require.NoError(client.init("location", "name")) require.NoError(client.init("location", "name", "rGroup"))
assert.Equal("location", client.location) assert.Equal("location", client.location)
assert.Equal("name", client.name) assert.Equal("name", client.name)
assert.Equal("rGroup", client.resourceGroup)
assert.NotEmpty(client.uid) assert.NotEmpty(client.uid)
} }

View file

@ -10,8 +10,6 @@ import (
"time" "time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/internal/azure" "github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/azure/internal/poller" "github.com/edgelesssys/constellation/cli/internal/azure/internal/poller"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes" "github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
@ -213,45 +211,6 @@ type CreateScaleSetInput struct {
ConfidentialVM bool ConfidentialVM bool
} }
// CreateResourceGroup creates a resource group.
func (c *Client) CreateResourceGroup(ctx context.Context) error {
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
armresources.ResourceGroup{
Location: &c.location,
}, nil)
if err != nil {
return err
}
c.resourceGroup = c.name + "-" + c.uid
return nil
}
// TerminateResourceGroup terminates a resource group.
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
if err != nil {
return err
}
if _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
}); err != nil {
return err
}
c.workers = nil
c.controlPlanes = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.workerScaleSet = ""
c.controlPlaneScaleSet = ""
return nil
}
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully. // scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
type scaleSetCreationPollingHandler struct { type scaleSetCreationPollingHandler struct {
done bool done bool

View file

@ -7,126 +7,16 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2" armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes" "github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestCreateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
resourceGroupAPI: stubResourceGroupAPI{},
},
"failed create": {
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroupAPI: tc.resourceGroupAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateResourceGroup(ctx))
} else {
assert.NoError(client.CreateResourceGroup(ctx))
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
}
})
}
}
func TestTerminateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
clientWithResourceGroup := Client{
resourceGroup: "name",
location: "location",
name: "name",
uid: "uid",
subnetID: "subnet",
workerScaleSet: "node-scale-set",
controlPlaneScaleSet: "controlplane-scale-set",
workers: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
controlPlanes: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
}
testCases := map[string]struct {
resourceGroup string
resourceGroupAPI resourceGroupAPI
client Client
wantErr bool
}{
"successful terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: clientWithResourceGroup,
},
"no resource group to terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: Client{},
resourceGroup: "",
},
"failed terminate": {
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
"failed to poll terminate response": {
resourceGroupAPI: stubResourceGroupAPI{pollErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.client.resourceGroupAPI = tc.resourceGroupAPI
ctx := context.Background()
if tc.wantErr {
assert.Error(tc.client.TerminateResourceGroup(ctx))
return
}
assert.NoError(tc.client.TerminateResourceGroup(ctx))
assert.Empty(tc.client.resourceGroup)
assert.Empty(tc.client.subnetID)
assert.Empty(tc.client.workers)
assert.Empty(tc.client.controlPlanes)
assert.Empty(tc.client.workerScaleSet)
assert.Empty(tc.client.controlPlaneScaleSet)
})
}
}
func TestCreateInstances(t *testing.T) { func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed") someErr := errors.New("failed")
testCases := map[string]struct { testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI scaleSetsAPI scaleSetsAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput createInstancesInput CreateInstancesInput
wantErr bool wantErr bool
}{ }{
@ -138,8 +28,6 @@ func TestCreateInstances(t *testing.T) {
VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}}, VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}},
}, },
}, },
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{ createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3, CountControlPlanes: 3,
CountWorkers: 3, CountWorkers: 3,
@ -153,8 +41,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{}, publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{}, networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr}, scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{ createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3, CountControlPlanes: 3,
CountWorkers: 3, CountWorkers: 3,
@ -169,8 +55,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{}, publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{}, networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr}, scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{ createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3, CountControlPlanes: 3,
CountWorkers: 3, CountWorkers: 3,
@ -185,8 +69,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{}, publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr}, networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{}, scaleSetsAPI: stubScaleSetsAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{ createInstancesInput: CreateInstancesInput{
CountWorkers: 3, CountWorkers: 3,
InstanceType: "type", InstanceType: "type",
@ -211,8 +93,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: tc.publicIPAddressesAPI, publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI, networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI, scaleSetsAPI: tc.scaleSetsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
workers: make(cloudtypes.Instances), workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances), controlPlanes: make(cloudtypes.Instances),
loadBalancerPubIP: "lbip", loadBalancerPubIP: "lbip",
@ -232,11 +112,3 @@ func TestCreateInstances(t *testing.T) {
}) })
} }
} }
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
return &stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
}
}

View file

@ -0,0 +1,133 @@
package client
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
)
// TerminateResourceGroupResources deletes all resources from the resource group.
func (c *Client) TerminateResourceGroupResources(ctx context.Context) error {
const timeOut = 10 * time.Minute
ctx, cancel := context.WithTimeout(ctx, timeOut)
defer cancel()
pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20)
delete := make(chan struct{}, 1)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() { // This routine lists resources and starts their deletion, where possible.
defer wg.Done()
defer func() {
close(pollers)
for range delete { // drain channel
}
}()
for {
ids, err := c.getResourceIDList(ctx)
if err != nil {
time.Sleep(3 * time.Second)
continue
}
if len(ids) == 0 {
return
}
for _, id := range ids {
poller, err := c.deleteResourceByID(ctx, id)
if err != nil {
continue
}
pollers <- poller
}
select {
case <-ctx.Done():
return
case _, ok := <-delete:
if !ok { // channel was closed
return
}
}
}
}()
go func() { // This routine polls for for the deletions to complete.
defer wg.Done()
defer close(delete)
for poller := range pollers {
_, err := poller.PollUntilDone(ctx, nil)
if err != nil {
continue
}
select {
case delete <- struct{}{}:
default:
}
}
}()
wg.Wait()
return nil
}
func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) {
var ids []string
pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err)
}
for _, resource := range page.Value {
if resource.ID == nil {
return nil, fmt.Errorf("resource %v has no ID", resource)
}
ids = append(ids, *resource.ID)
}
}
return ids, nil
}
func (c *Client) deleteResourceByID(ctx context.Context, id string,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
apiVersion := "2020-02-02"
// First try, API version unknown, will fail.
poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
if isVersionWrongErr(err) {
// bad hack, but easiest way to get the right API version
apiVersion = parseAPIVersionFromErr(err)
poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
}
return poller, err
}
func isVersionWrongErr(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "NoRegisteredProviderFound") &&
strings.Contains(err.Error(), "The supported api-versions are")
}
var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`)
func parseAPIVersionFromErr(err error) string {
if err == nil {
return ""
}
matches := apiVersionRegex.FindStringSubmatch(err.Error())
return matches[1]
}

View file

@ -0,0 +1,139 @@
package client
import (
"context"
"errors"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestTerminateResourceGroupResources(t *testing.T) {
someErr := errors.New("failed")
apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'")
testCases := map[string]struct {
resourceAPI resourceAPI
}{
"no resources": {
resourceAPI: &fakeResourceAPI{},
},
"some resources": {
resourceAPI: &fakeResourceAPI{
resources: map[string]fakeResource{
"id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr},
"id-1": {beginDeleteByIDErr: apiVersionErr},
"id-2": {},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
resourceAPI: tc.resourceAPI,
}
ctx := context.Background()
err := client.TerminateResourceGroupResources(ctx)
assert.NoError(err)
})
}
}
type fakeResourceAPI struct {
resources map[string]fakeResource
fetchErr error
}
type fakeResource struct {
beginDeleteByIDErr error
pollErr error
}
func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse] {
pager := &stubClientListByResourceGroupResponsePager{
resources: a.resources,
fetchErr: a.fetchErr,
}
return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{
More: pager.moreFunc(),
Fetcher: pager.fetcherFunc(),
})
}
func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
res := a.resources[resourceID]
pollErr := res.pollErr
if pollErr != nil {
res.pollErr = nil
}
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{
Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{
result: armresources.ClientDeleteByIDResponse{},
resultErr: pollErr,
},
})
if err != nil {
panic(err)
}
beginDeleteByIDErr := res.beginDeleteByIDErr
if beginDeleteByIDErr != nil {
res.beginDeleteByIDErr = nil
}
if res.beginDeleteByIDErr == nil && res.pollErr == nil {
delete(a.resources, resourceID)
fmt.Printf("fake delete %s\n", resourceID)
} else {
a.resources[resourceID] = res
}
return poller, beginDeleteByIDErr
}
type stubClientListByResourceGroupResponsePager struct {
resources map[string]fakeResource
fetchErr error
more bool
}
func (p *stubClientListByResourceGroupResponsePager) moreFunc() func(
armresources.ClientListByResourceGroupResponse) bool {
return func(armresources.ClientListByResourceGroupResponse) bool {
return p.more
}
}
func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func(
context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error) {
return func(context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error,
) {
var resources []*armresources.GenericResourceExpanded
for id := range p.resources {
resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)})
}
p.more = false
return armresources.ClientListByResourceGroupResponse{
ResourceListResult: armresources.ResourceListResult{
Value: resources,
},
}, p.fetchErr
}
}

View file

@ -26,12 +26,9 @@ type azureclient interface {
GetState() state.ConstellationState GetState() state.ConstellationState
SetState(state.ConstellationState) SetState(state.ConstellationState)
CreateApplicationInsight(ctx context.Context) error CreateApplicationInsight(ctx context.Context) error
CreateResourceGroup(ctx context.Context) error
CreateExternalLoadBalancer(ctx context.Context) error CreateExternalLoadBalancer(ctx context.Context) error
CreateVirtualNetwork(ctx context.Context) error CreateVirtualNetwork(ctx context.Context) error
CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error
CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error
CreateServicePrincipal(ctx context.Context) (string, error) TerminateResourceGroupResources(ctx context.Context) error
TerminateResourceGroup(ctx context.Context) error
TerminateServicePrincipal(ctx context.Context) error
} }

View file

@ -79,11 +79,6 @@ func (c *fakeAzureClient) CreateApplicationInsight(ctx context.Context) error {
return nil return nil
} }
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
c.resourceGroup = "resource-group"
return nil
}
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error { func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
c.subnetID = "subnet" c.subnetID = "subnet"
return nil return nil
@ -123,17 +118,8 @@ func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
}.ToCloudServiceAccountURI(), nil }.ToCloudServiceAccountURI(), nil
} }
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error { func (c *fakeAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
if c.resourceGroup == "" { // TODO(katexochen)
return nil
}
c.workers = nil
c.controlPlanes = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.workerScaleSet = ""
c.controlPlaneScaleSet = ""
return nil return nil
} }
@ -146,18 +132,17 @@ func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
} }
type stubAzureClient struct { type stubAzureClient struct {
terminateResourceGroupCalled bool terminateResourceGroupResourcesCalled bool
terminateServicePrincipalCalled bool terminateServicePrincipalCalled bool
createApplicationInsightErr error createApplicationInsightErr error
createResourceGroupErr error createVirtualNetworkErr error
createVirtualNetworkErr error createSecurityGroupErr error
createSecurityGroupErr error createLoadBalancerErr error
createLoadBalancerErr error createInstancesErr error
createInstancesErr error createServicePrincipalErr error
createServicePrincipalErr error terminateResourceGroupResourcesErr error
terminateResourceGroupErr error terminateServicePrincipalErr error
terminateServicePrincipalErr error
} }
func (c *stubAzureClient) GetState() state.ConstellationState { func (c *stubAzureClient) GetState() state.ConstellationState {
@ -175,10 +160,6 @@ func (c *stubAzureClient) CreateApplicationInsight(ctx context.Context) error {
return c.createApplicationInsightErr return c.createApplicationInsightErr
} }
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
return c.createResourceGroupErr
}
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error { func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
return c.createVirtualNetworkErr return c.createVirtualNetworkErr
} }
@ -198,9 +179,9 @@ func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
}.ToCloudServiceAccountURI(), c.createServicePrincipalErr }.ToCloudServiceAccountURI(), c.createServicePrincipalErr
} }
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error { func (c *stubAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
c.terminateResourceGroupCalled = true c.terminateResourceGroupResourcesCalled = true
return c.terminateResourceGroupErr return c.terminateResourceGroupResourcesErr
} }
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error { func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {

View file

@ -18,7 +18,7 @@ import (
type Creator struct { type Creator struct {
out io.Writer out io.Writer
newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error) newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error)
newAzureClient func(subscriptionID, tenantID, name, location string) (azureclient, error) newAzureClient func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error)
} }
// NewCreator creates a new creator. // NewCreator creates a new creator.
@ -28,8 +28,8 @@ func NewCreator(out io.Writer) *Creator {
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) { newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
return gcpcl.NewInitialized(ctx, project, zone, region, name) return gcpcl.NewInitialized(ctx, project, zone, region, name)
}, },
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) { newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
return azurecl.NewInitialized(subscriptionID, tenantID, name, location) return azurecl.NewInitialized(subscriptionID, tenantID, name, location, resourceGroup)
}, },
} }
} }
@ -57,6 +57,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
config.Provider.Azure.TenantID, config.Provider.Azure.TenantID,
name, name,
config.Provider.Azure.Location, config.Provider.Azure.Location,
config.Provider.Azure.ResourceGroup,
) )
if err != nil { if err != nil {
return state.ConstellationState{}, err return state.ConstellationState{}, err
@ -144,9 +145,6 @@ func (c *Creator) createAzure(ctx context.Context, cl azureclient, config *confi
) (stat state.ConstellationState, retErr error) { ) (stat state.ConstellationState, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl}) defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl})
if err := cl.CreateResourceGroup(ctx); err != nil {
return state.ConstellationState{}, err
}
if err := cl.CreateApplicationInsight(ctx); err != nil { if err := cl.CreateApplicationInsight(ctx); err != nil {
return state.ConstellationState{}, err return state.ConstellationState{}, err
} }

View file

@ -51,7 +51,6 @@ func TestCreator(t *testing.T) {
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
}, },
AzureResourceGroup: "resource-group",
AzureSubnet: "subnet", AzureSubnet: "subnet",
AzureNetworkSecurityGroup: "network-security-group", AzureNetworkSecurityGroup: "network-security-group",
AzureWorkerScaleSet: "workers-scale-set", AzureWorkerScaleSet: "workers-scale-set",
@ -123,13 +122,6 @@ func TestCreator(t *testing.T) {
config: config.Default(), config: config.Default(),
wantErr: true, wantErr: true,
}, },
"azure CreateResourceGroup error": {
azureclient: &stubAzureClient{createResourceGroupErr: someErr},
provider: cloudprovider.Azure,
config: config.Default(),
wantErr: true,
wantRollback: true,
},
"azure CreateVirtualNetwork error": { "azure CreateVirtualNetwork error": {
azureclient: &stubAzureClient{createVirtualNetworkErr: someErr}, azureclient: &stubAzureClient{createVirtualNetworkErr: someErr},
provider: cloudprovider.Azure, provider: cloudprovider.Azure,
@ -167,7 +159,7 @@ func TestCreator(t *testing.T) {
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) { newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
return tc.gcpclient, tc.newGCPClientErr return tc.gcpclient, tc.newGCPClientErr
}, },
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) { newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
return tc.azureclient, tc.newAzureClientErr return tc.azureclient, tc.newAzureClientErr
}, },
} }
@ -186,7 +178,7 @@ func TestCreator(t *testing.T) {
assert.True(cl.closeCalled) assert.True(cl.closeCalled)
case cloudprovider.Azure: case cloudprovider.Azure:
cl := tc.azureclient.(*stubAzureClient) cl := tc.azureclient.(*stubAzureClient)
assert.True(cl.terminateResourceGroupCalled) assert.True(cl.terminateResourceGroupResourcesCalled)
} }
} }
} else { } else {

View file

@ -46,5 +46,5 @@ type rollbackerAzure struct {
} }
func (r *rollbackerAzure) rollback(ctx context.Context) error { func (r *rollbackerAzure) rollback(ctx context.Context) error {
return r.client.TerminateResourceGroup(ctx) return r.client.TerminateResourceGroupResources(ctx)
} }

View file

@ -1,68 +0,0 @@
package cloudcmd
import (
"context"
"fmt"
azurecl "github.com/edgelesssys/constellation/cli/internal/azure/client"
gcpcl "github.com/edgelesssys/constellation/cli/internal/gcp/client"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
// ServiceAccountCreator creates service accounts.
type ServiceAccountCreator struct {
newGCPClient func(ctx context.Context) (gcpclient, error)
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
}
func NewServiceAccountCreator() *ServiceAccountCreator {
return &ServiceAccountCreator{
newGCPClient: func(ctx context.Context) (gcpclient, error) {
return gcpcl.NewFromDefault(ctx)
},
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return azurecl.NewFromDefault(subscriptionID, tenantID)
},
}
}
// Create creates a new cloud provider service account with access to the created resources.
func (c *ServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
) (string, state.ConstellationState, error) {
provider := cloudprovider.FromString(stat.CloudProvider)
switch provider {
case cloudprovider.GCP:
return "", state.ConstellationState{}, fmt.Errorf("creating service account not supported for GCP")
case cloudprovider.Azure:
cl, err := c.newAzureClient(stat.AzureSubscription, stat.AzureTenant)
if err != nil {
return "", state.ConstellationState{}, err
}
serviceAccount, stat, err := c.createServiceAccountAzure(ctx, cl, stat, config)
if err != nil {
return "", state.ConstellationState{}, err
}
return serviceAccount, stat, err
case cloudprovider.QEMU:
return "unsupported://qemu", stat, nil
default:
return "", state.ConstellationState{}, fmt.Errorf("unsupported provider: %s", provider)
}
}
func (c *ServiceAccountCreator) createServiceAccountAzure(ctx context.Context, cl azureclient,
stat state.ConstellationState, _ *config.Config,
) (string, state.ConstellationState, error) {
cl.SetState(stat)
serviceAccount, err := cl.CreateServicePrincipal(ctx)
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("creating service account: %w", err)
}
return serviceAccount, cl.GetState(), nil
}

View file

@ -1,89 +0,0 @@
package cloudcmd
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestServiceAccountCreator(t *testing.T) {
someAzureState := func() state.ConstellationState {
return state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
}
}
someErr := errors.New("failed")
testCases := map[string]struct {
newGCPClient func(ctx context.Context) (gcpclient, error)
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
state state.ConstellationState
config *config.Config
wantErr bool
wantStateMutator func(*state.ConstellationState)
}{
"azure": {
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return &fakeAzureClient{}, nil
},
state: someAzureState(),
config: config.Default(),
wantStateMutator: func(stat *state.ConstellationState) {
stat.AzureADAppObjectID = "00000000-0000-0000-0000-000000000001"
},
},
"azure newAzureClient error": {
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return nil, someErr
},
state: someAzureState(),
config: config.Default(),
wantErr: true,
},
"azure client createServiceAccount error": {
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return &stubAzureClient{createServicePrincipalErr: someErr}, nil
},
state: someAzureState(),
config: config.Default(),
wantErr: true,
},
"qemu": {
state: state.ConstellationState{CloudProvider: "qemu"},
wantStateMutator: func(cs *state.ConstellationState) {},
config: config.Default(),
},
"unknown cloud provider": {
state: state.ConstellationState{},
config: config.Default(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
creator := &ServiceAccountCreator{
newGCPClient: tc.newGCPClient,
newAzureClient: tc.newAzureClient,
}
serviceAccount, state, err := creator.Create(context.Background(), tc.state, tc.config)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(serviceAccount)
tc.wantStateMutator(&tc.state)
assert.Equal(tc.state, state)
}
})
}
}

View file

@ -72,8 +72,5 @@ func (t *Terminator) terminateGCP(ctx context.Context, cl gcpclient, state state
func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error { func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error {
cl.SetState(state) cl.SetState(state)
if err := cl.TerminateServicePrincipal(ctx); err != nil { return cl.TerminateResourceGroupResources(ctx)
return err
}
return cl.TerminateResourceGroup(ctx)
} }

View file

@ -41,7 +41,6 @@ func TestTerminator(t *testing.T) {
AzureControlPlaneInstances: cloudtypes.Instances{ AzureControlPlaneInstances: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
}, },
AzureResourceGroup: "group",
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001", AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
} }
} }
@ -88,13 +87,8 @@ func TestTerminator(t *testing.T) {
state: someAzureState(), state: someAzureState(),
wantErr: true, wantErr: true,
}, },
"azure terminateServicePrincipal error": { "azure terminateResourceGroupResources error": {
azureclient: &stubAzureClient{terminateServicePrincipalErr: someErr}, azureclient: &stubAzureClient{terminateResourceGroupResourcesErr: someErr},
state: someAzureState(),
wantErr: true,
},
"azure terminateResourceGroup error": {
azureclient: &stubAzureClient{terminateResourceGroupErr: someErr},
state: someAzureState(), state: someAzureState(),
wantErr: true, wantErr: true,
}, },
@ -132,8 +126,7 @@ func TestTerminator(t *testing.T) {
assert.True(cl.closeCalled) assert.True(cl.closeCalled)
case cloudprovider.Azure: case cloudprovider.Azure:
cl := tc.azureclient.(*stubAzureClient) cl := tc.azureclient.(*stubAzureClient)
assert.True(cl.terminateResourceGroupCalled) assert.True(cl.terminateResourceGroupResourcesCalled)
assert.True(cl.terminateServicePrincipalCalled)
} }
} }
}) })

View file

@ -21,8 +21,3 @@ type cloudCreator interface {
type cloudTerminator interface { type cloudTerminator interface {
Terminate(context.Context, state.ConstellationState) error Terminate(context.Context, state.ConstellationState) error
} }
type serviceAccountCreator interface {
Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
) (string, state.ConstellationState, error)
}

View file

@ -47,12 +47,3 @@ func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationStat
func (c *stubCloudTerminator) Called() bool { func (c *stubCloudTerminator) Called() bool {
return c.called return c.called
} }
type stubServiceAccountCreator struct {
cloudServiceAccountURI string
createErr error
}
func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
return c.cloudServiceAccountURI, stat, c.createErr
}

View file

@ -55,17 +55,16 @@ func NewInitCmd() *cobra.Command {
// runInitialize runs the initialize command. // runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, args []string) error { func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer { newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
return dialer.New(nil, validator.V(cmd), &net.Dialer{}) return dialer.New(nil, validator.V(cmd), &net.Dialer{})
} }
helmLoader := &helm.ChartLoader{} helmLoader := &helm.ChartLoader{}
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient()) return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient())
} }
// initialize initializes a Constellation. // initialize initializes a Constellation.
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer, func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
) error { ) error {
flags, err := evalFlagArgs(cmd, fileHandler) flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil { if err != nil {
@ -105,22 +104,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return err return err
} }
var serviceAccURI string serviceAccURI, err := getMarschaledServiceAccountURI(provider, config, fileHandler)
// Temporary legacy flow for Azure. if err != nil {
if provider == cloudprovider.Azure { return err
cmd.Println("Creating service account ...")
serviceAccURI, stat, err = serviceAccCreator.Create(cmd.Context(), stat, config)
if err != nil {
return err
}
if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil {
return err
}
} else {
serviceAccURI, err = getMarschaledServiceAccountURI(provider, config, fileHandler)
if err != nil {
return err
}
} }
workers, err := getScalingGroupsFromState(stat, config) workers, err := getScalingGroupsFromState(stat, config)

View file

@ -69,49 +69,54 @@ func TestInitialize(t *testing.T) {
someErr := errors.New("failed") someErr := errors.New("failed")
testCases := map[string]struct { testCases := map[string]struct {
state *state.ConstellationState state *state.ConstellationState
existingIDFile *clusterIDsFile idFile *clusterIDsFile
serviceAccCreator serviceAccountCreator configMutator func(*config.Config)
configMutator func(*config.Config) serviceAccKey *gcpshared.ServiceAccountKey
serviceAccKey *gcpshared.ServiceAccountKey helmLoader stubHelmLoader
helmLoader stubHelmLoader initServerAPI *stubInitServer
initServerAPI *stubInitServer endpointFlag string
endpointFlag string setAutoscaleFlag bool
setAutoscaleFlag bool wantErr bool
wantErr bool
}{ }{
"initialize some gcp instances": { "initialize some gcp instances": {
state: testGcpState, state: testGcpState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey, serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp}, initServerAPI: &stubInitServer{initResp: testInitResp},
}, },
"initialize some azure instances": { "initialize some azure instances": {
state: testAzureState, state: testAzureState,
serviceAccCreator: &stubServiceAccountCreator{}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, configMutator: func(c *config.Config) {
initServerAPI: &stubInitServer{initResp: testInitResp}, c.Provider.Azure.ResourceGroup = "resourceGroup"
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
},
initServerAPI: &stubInitServer{initResp: testInitResp},
}, },
"initialize some qemu instances": { "initialize some qemu instances": {
state: testQemuState, state: testQemuState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initResp: testInitResp}, initServerAPI: &stubInitServer{initResp: testInitResp},
}, },
"initialize gcp with autoscaling": { "initialize gcp with autoscaling": {
state: testGcpState, state: testGcpState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey, serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp}, initServerAPI: &stubInitServer{initResp: testInitResp},
setAutoscaleFlag: true, setAutoscaleFlag: true,
}, },
"initialize azure with autoscaling": { "initialize azure with autoscaling": {
state: testAzureState, state: testAzureState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
serviceAccCreator: &stubServiceAccountCreator{}, configMutator: func(c *config.Config) {
initServerAPI: &stubInitServer{initResp: testInitResp}, c.Provider.Azure.ResourceGroup = "resourceGroup"
setAutoscaleFlag: true, c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
},
initServerAPI: &stubInitServer{initResp: testInitResp},
setAutoscaleFlag: true,
}, },
"initialize with endpoint flag": { "initialize with endpoint flag": {
state: testGcpState, state: testGcpState,
@ -121,27 +126,30 @@ func TestInitialize(t *testing.T) {
endpointFlag: "192.0.2.1", endpointFlag: "192.0.2.1",
}, },
"empty state": { "empty state": {
state: &state.ConstellationState{}, state: &state.ConstellationState{},
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{}, initServerAPI: &stubInitServer{},
wantErr: true, wantErr: true,
}, },
"neither endpoint flag nor id file": { "neither endpoint flag nor id file": {
state: &state.ConstellationState{}, state: &state.ConstellationState{},
wantErr: true, wantErr: true,
}, },
"init call fails": { "init call fails": {
state: testGcpState, state: testGcpState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initErr: someErr}, initServerAPI: &stubInitServer{initErr: someErr},
wantErr: true, wantErr: true,
}, },
"fail to create service account": { "fail to create service account": {
state: testAzureState, state: testAzureState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{}, configMutator: func(c *config.Config) {
serviceAccCreator: &stubServiceAccountCreator{createErr: someErr}, c.Provider.Azure.ResourceGroup = "resourceGroup"
wantErr: true, c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
},
initServerAPI: &stubInitServer{},
wantErr: true,
}, },
"fail to load helm charts": { "fail to load helm charts": {
state: testGcpState, state: testGcpState,
@ -194,8 +202,8 @@ func TestInitialize(t *testing.T) {
if tc.state != nil { if tc.state != nil {
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone)) require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone))
} }
if tc.existingIDFile != nil { if tc.idFile != nil {
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.existingIDFile, file.OptNone)) require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.idFile, file.OptNone))
} }
if tc.serviceAccKey != nil { if tc.serviceAccKey != nil {
require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone)) require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone))
@ -206,7 +214,7 @@ func TestInitialize(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
err := initialize(cmd, newDialer, tc.serviceAccCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{}) err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{})
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -477,7 +485,7 @@ func TestAttestation(t *testing.T) {
defer cancel() defer cancel()
cmd.SetContext(ctx) cmd.SetContext(ctx)
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}) err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
assert.Error(err) assert.Error(err)
// make sure the error is actually a TLS handshake error // make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed") assert.Contains(err.Error(), "transport: authentication handshake failed")
@ -548,6 +556,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.Location = "test-location" conf.Provider.Azure.Location = "test-location"
conf.Provider.Azure.UserAssignedIdentity = "test-identity" conf.Provider.Azure.UserAssignedIdentity = "test-identity"
conf.Provider.Azure.Image = "some/image/location" conf.Provider.Azure.Image = "some/image/location"
conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000") conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111") conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
case cloudprovider.GCP: case cloudprovider.GCP:

View file

@ -158,6 +158,9 @@ type AzureConfig struct {
// Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure // Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure
UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"` UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"`
// description: | // description: |
// Resource group to use.
ResourceGroup string `yaml:"resourceGroup" validate:"required"`
// description: |
// Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview // Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview
ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"` ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"`
} }
@ -244,6 +247,7 @@ func Default() *Config {
TenantID: "", TenantID: "",
Location: "", Location: "",
UserAssignedIdentity: "", UserAssignedIdentity: "",
ResourceGroup: "",
Image: DefaultImageAzure, Image: DefaultImageAzure,
StateDiskType: "Premium_LRS", StateDiskType: "Premium_LRS",
Measurements: copyPCRMap(azurePCRs), Measurements: copyPCRMap(azurePCRs),

View file

@ -199,7 +199,7 @@ func init() {
FieldName: "azure", FieldName: "azure",
}, },
} }
AzureConfigDoc.Fields = make([]encoder.Doc, 9) AzureConfigDoc.Fields = make([]encoder.Doc, 10)
AzureConfigDoc.Fields[0].Name = "subscription" AzureConfigDoc.Fields[0].Name = "subscription"
AzureConfigDoc.Fields[0].Type = "string" AzureConfigDoc.Fields[0].Type = "string"
AzureConfigDoc.Fields[0].Note = "" AzureConfigDoc.Fields[0].Note = ""
@ -240,6 +240,11 @@ func init() {
AzureConfigDoc.Fields[7].Note = "" AzureConfigDoc.Fields[7].Note = ""
AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure" AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure" AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
AzureConfigDoc.Fields[8].Name = "resourceGroup"
AzureConfigDoc.Fields[8].Type = "string"
AzureConfigDoc.Fields[8].Note = ""
AzureConfigDoc.Fields[8].Description = "Resource group to use."
AzureConfigDoc.Fields[8].Comments[encoder.LineComment] = "Resource group to use."
AzureConfigDoc.Fields[8].Name = "confidentialVM" AzureConfigDoc.Fields[8].Name = "confidentialVM"
AzureConfigDoc.Fields[8].Type = "bool" AzureConfigDoc.Fields[8].Type = "bool"
AzureConfigDoc.Fields[8].Note = "" AzureConfigDoc.Fields[8].Note = ""