Manually manage resource group on Azure

This commit is contained in:
katexochen 2022-08-25 15:12:08 +02:00 committed by Paul Meyer
parent e6ae54a25a
commit f15605cb45
25 changed files with 403 additions and 1162 deletions

View File

@ -1,184 +0,0 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/edgelesssys/constellation/internal/azureshared"
"github.com/google/uuid"
)
const (
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
)
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
createAppRes, err := c.createADApplication(ctx)
if err != nil {
return "", err
}
c.adAppObjectID = createAppRes.ObjectID
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
if err != nil {
return "", err
}
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
return "", err
}
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
if err != nil {
return "", err
}
return azureshared.ApplicationCredentials{
TenantID: c.tenantID,
ClientID: createAppRes.AppID,
ClientSecret: clientSecret,
Location: c.location,
}.ToCloudServiceAccountURI(), nil
}
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
return err
}
c.adAppObjectID = ""
return nil
}
// createADApplication creates a new azure AD app.
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
createParameters := graphrbac.ApplicationCreateParameters{
AvailableToOtherTenants: to.BoolPtr(false),
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
}
app, err := c.applicationsAPI.Create(ctx, createParameters)
if err != nil {
return createADApplicationOutput{}, err
}
if app.AppID == nil || app.ObjectID == nil {
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
}
return createADApplicationOutput{
AppID: *app.AppID,
ObjectID: *app.ObjectID,
}, nil
}
// createAppServicePrincipal creates a new service principal for an azure AD app.
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
createParameters := graphrbac.ServicePrincipalCreateParameters{
AppID: &appID,
AccountEnabled: to.BoolPtr(true),
}
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
if err != nil {
return "", err
}
if servicePrincipal.ObjectID == nil {
return "", errors.New("creating AD service principal did not result in a valid object id")
}
return *servicePrincipal.ObjectID, nil
}
// updateAppCredentials sets app client-secret for authentication.
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
keyID := uuid.New().String()
clientSecret, err := generateClientSecret()
if err != nil {
return "", fmt.Errorf("generating client secret: %w", err)
}
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
Value: &[]graphrbac.PasswordCredential{
{
StartDate: &date.Time{Time: time.Now()},
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
Value: to.StringPtr(clientSecret),
KeyID: to.StringPtr(keyID),
},
},
}
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
if err != nil {
return "", err
}
return clientSecret, nil
}
// assignResourceGroupRole assigns the service principal a role at resource group scope.
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
if err != nil || resourceGroup.ID == nil {
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
}
roleAssignmentID := uuid.New().String()
createParameters := authorization.RoleAssignmentCreateParameters{
Properties: &authorization.RoleAssignmentProperties{
PrincipalID: to.StringPtr(principalID),
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
},
}
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
// proper fix: use API version 2018-09-01-preview or later
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
var detailedErr autorest.DetailedError
var ok bool
if detailedErr, ok = err.(autorest.DetailedError); !ok {
return err
}
var requestErr *azure.RequestError
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
return err
}
if requestErr.ServiceError.Code != "PrincipalNotFound" {
return err
}
time.Sleep(c.adReplicationLagCheckInterval)
}
return err
}
type createADApplicationOutput struct {
AppID string
ObjectID string
}
func generateClientSecret() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 64
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
return string(pw), nil
}

View File

@ -1,358 +0,0 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestCreateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
servicePrincipalsAPI servicePrincipalsAPI
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed service principal create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed role assignment": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"failed update creds": {
applicationsAPI: stubApplicationsAPI{
updateCredentialsErr: someErr,
},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
applicationsAPI: tc.applicationsAPI,
servicePrincipalsAPI: tc.servicePrincipalsAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
_, err := client.CreateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestTerminateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
appObjectID string
applicationsAPI applicationsAPI
wantErr bool
}{
"successful terminate": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{},
},
"nothing to terminate": {
applicationsAPI: stubApplicationsAPI{},
},
"failed delete": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{
deleteErr: someErr,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
adAppObjectID: tc.appObjectID,
applicationsAPI: tc.applicationsAPI,
}
err := client.TerminateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateADApplication(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"app create returns invalid appid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
},
},
wantErr: true,
},
"app create returns invalid objectid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
applicationsAPI: tc.applicationsAPI,
}
appCredentials, err := client.createADApplication(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(appCredentials)
})
}
}
func TestCreateAppServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
servicePrincipalsAPI servicePrincipalsAPI
wantErr bool
}{
"successful create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{},
},
"failed service principal create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"service principal create returns invalid objectid": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createServicePrincipal: &graphrbac.ServicePrincipal{},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
servicePrincipalsAPI: tc.servicePrincipalsAPI,
}
_, err := client.createAppServicePrincipal(ctx, "app-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestAssignOwnerOfResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful assign": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"failed role assignment": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"failed resource group get": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getErr: someErr,
},
wantErr: true,
},
"resource group get returns invalid id": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{},
},
wantErr: true,
},
"create returns PrincipalNotFound the first time": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "PrincipalNotFound",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"create does not return request error": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{autorest.DetailedError{Original: someErr}},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"create service error code is unknown": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "some-unknown-error-code",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}

View File

@ -64,15 +64,13 @@ type networkInterfacesAPI interface {
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
}
type resourceGroupAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error)
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
type resourceAPI interface {
NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse]
BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error)
}
type applicationsAPI interface {

View File

@ -4,15 +4,11 @@ import (
"context"
"net/http"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type stubNetworksAPI struct {
@ -94,44 +90,6 @@ func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, r
return poller, a.createErr
}
type stubResourceGroupAPI struct {
terminateErr error
createErr error
getErr error
getResourceGroup armresources.ResourceGroup
pollErr error
}
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
) {
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
}
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
return armresources.ResourceGroupsClientGetResponse{
ResourceGroup: a.getResourceGroup,
}, a.getErr
}
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ResourceGroupsClientDeleteResponse]{
Handler: &stubPoller[armresources.ResourceGroupsClientDeleteResponse]{
result: armresources.ResourceGroupsClientDeleteResponse{},
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.terminateErr
}
type stubScaleSetsAPI struct {
createErr error
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
@ -282,71 +240,6 @@ func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx
}, nil
}
type stubApplicationsAPI struct {
createErr error
deleteErr error
updateCredentialsErr error
createApplication *graphrbac.Application
}
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
if a.createErr != nil {
return graphrbac.Application{}, a.createErr
}
if a.createApplication != nil {
return *a.createApplication, nil
}
return graphrbac.Application{
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000001"),
}, nil
}
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
if a.deleteErr != nil {
return autorest.Response{}, a.deleteErr
}
return autorest.Response{}, nil
}
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
if a.updateCredentialsErr != nil {
return autorest.Response{}, a.updateCredentialsErr
}
return autorest.Response{}, nil
}
type stubServicePrincipalsAPI struct {
createErr error
createServicePrincipal *graphrbac.ServicePrincipal
}
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
if a.createErr != nil {
return graphrbac.ServicePrincipal{}, a.createErr
}
if a.createServicePrincipal != nil {
return *a.createServicePrincipal, nil
}
return graphrbac.ServicePrincipal{
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000002"),
}, nil
}
type stubRoleAssignmentsAPI struct {
createCounter int
createErrors []error
}
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
a.createCounter++
if len(a.createErrors) == 0 {
return authorization.RoleAssignment{}, nil
}
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
}
type stubApplicationInsightsAPI struct {
err error
}

View File

@ -29,7 +29,7 @@ const (
type Client struct {
networksAPI
networkSecurityGroupsAPI
resourceGroupAPI
resourceAPI
scaleSetsAPI
publicIPAddressesAPI
networkInterfacesAPI
@ -39,9 +39,7 @@ type Client struct {
roleAssignmentsAPI
applicationInsightsAPI
pollFrequency time.Duration
adReplicationLagCheckInterval time.Duration
adReplicationLagCheckMaxRetries int
pollFrequency time.Duration
workers cloudtypes.Instances
controlPlanes cloudtypes.Instances
@ -83,10 +81,6 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
if err != nil {
return nil, err
}
resGroupAPI, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
@ -107,6 +101,10 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
if err != nil {
return nil, err
}
resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
@ -115,42 +113,41 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{
networksAPI: netAPI,
networkSecurityGroupsAPI: netSecGrpAPI,
resourceGroupAPI: resGroupAPI,
scaleSetsAPI: scaleSetAPI,
publicIPAddressesAPI: publicIPAddressesAPI,
networkInterfacesAPI: networkInterfacesAPI,
loadBalancersAPI: loadBalancersAPI,
applicationsAPI: applicationsAPI,
servicePrincipalsAPI: servicePrincipalsAPI,
roleAssignmentsAPI: roleAssignmentsAPI,
applicationInsightsAPI: applicationInsightsAPI,
subscriptionID: subscriptionID,
tenantID: tenantID,
workers: cloudtypes.Instances{},
controlPlanes: cloudtypes.Instances{},
pollFrequency: time.Second * 5,
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
networksAPI: netAPI,
networkSecurityGroupsAPI: netSecGrpAPI,
resourceAPI: resourceAPI,
scaleSetsAPI: scaleSetAPI,
publicIPAddressesAPI: publicIPAddressesAPI,
networkInterfacesAPI: networkInterfacesAPI,
loadBalancersAPI: loadBalancersAPI,
applicationsAPI: applicationsAPI,
servicePrincipalsAPI: servicePrincipalsAPI,
roleAssignmentsAPI: roleAssignmentsAPI,
applicationInsightsAPI: applicationInsightsAPI,
subscriptionID: subscriptionID,
tenantID: tenantID,
workers: cloudtypes.Instances{},
controlPlanes: cloudtypes.Instances{},
pollFrequency: time.Second * 5,
}, nil
}
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil {
return nil, err
}
err = client.init(location, name)
err = client.init(location, name, resourceGroup)
return client, err
}
// init initializes the client.
func (c *Client) init(location, name string) error {
func (c *Client) init(location, name, resourceGroup string) error {
c.location = location
c.name = name
c.resourceGroup = resourceGroup
uid, err := c.generateUID()
if err != nil {
return err

View File

@ -84,8 +84,9 @@ func TestInit(t *testing.T) {
require := require.New(t)
client := Client{}
require.NoError(client.init("location", "name"))
require.NoError(client.init("location", "name", "rGroup"))
assert.Equal("location", client.location)
assert.Equal("name", client.name)
assert.Equal("rGroup", client.resourceGroup)
assert.NotEmpty(client.uid)
}

View File

@ -10,8 +10,6 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/azure/internal/poller"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
@ -213,45 +211,6 @@ type CreateScaleSetInput struct {
ConfidentialVM bool
}
// CreateResourceGroup creates a resource group.
func (c *Client) CreateResourceGroup(ctx context.Context) error {
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
armresources.ResourceGroup{
Location: &c.location,
}, nil)
if err != nil {
return err
}
c.resourceGroup = c.name + "-" + c.uid
return nil
}
// TerminateResourceGroup terminates a resource group.
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
if err != nil {
return err
}
if _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
}); err != nil {
return err
}
c.workers = nil
c.controlPlanes = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.workerScaleSet = ""
c.controlPlaneScaleSet = ""
return nil
}
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
type scaleSetCreationPollingHandler struct {
done bool

View File

@ -7,126 +7,16 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
resourceGroupAPI: stubResourceGroupAPI{},
},
"failed create": {
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroupAPI: tc.resourceGroupAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateResourceGroup(ctx))
} else {
assert.NoError(client.CreateResourceGroup(ctx))
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
}
})
}
}
func TestTerminateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
clientWithResourceGroup := Client{
resourceGroup: "name",
location: "location",
name: "name",
uid: "uid",
subnetID: "subnet",
workerScaleSet: "node-scale-set",
controlPlaneScaleSet: "controlplane-scale-set",
workers: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
controlPlanes: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
}
testCases := map[string]struct {
resourceGroup string
resourceGroupAPI resourceGroupAPI
client Client
wantErr bool
}{
"successful terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: clientWithResourceGroup,
},
"no resource group to terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: Client{},
resourceGroup: "",
},
"failed terminate": {
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
"failed to poll terminate response": {
resourceGroupAPI: stubResourceGroupAPI{pollErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.client.resourceGroupAPI = tc.resourceGroupAPI
ctx := context.Background()
if tc.wantErr {
assert.Error(tc.client.TerminateResourceGroup(ctx))
return
}
assert.NoError(tc.client.TerminateResourceGroup(ctx))
assert.Empty(tc.client.resourceGroup)
assert.Empty(tc.client.subnetID)
assert.Empty(tc.client.workers)
assert.Empty(tc.client.controlPlanes)
assert.Empty(tc.client.workerScaleSet)
assert.Empty(tc.client.controlPlaneScaleSet)
})
}
}
func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
wantErr bool
}{
@ -138,8 +28,6 @@ func TestCreateInstances(t *testing.T) {
VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
@ -153,8 +41,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
@ -169,8 +55,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
@ -185,8 +69,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountWorkers: 3,
InstanceType: "type",
@ -211,8 +93,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
loadBalancerPubIP: "lbip",
@ -232,11 +112,3 @@ func TestCreateInstances(t *testing.T) {
})
}
}
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
return &stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
}
}

View File

@ -0,0 +1,133 @@
package client
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
)
// TerminateResourceGroupResources deletes all resources from the resource group.
func (c *Client) TerminateResourceGroupResources(ctx context.Context) error {
const timeOut = 10 * time.Minute
ctx, cancel := context.WithTimeout(ctx, timeOut)
defer cancel()
pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20)
delete := make(chan struct{}, 1)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() { // This routine lists resources and starts their deletion, where possible.
defer wg.Done()
defer func() {
close(pollers)
for range delete { // drain channel
}
}()
for {
ids, err := c.getResourceIDList(ctx)
if err != nil {
time.Sleep(3 * time.Second)
continue
}
if len(ids) == 0 {
return
}
for _, id := range ids {
poller, err := c.deleteResourceByID(ctx, id)
if err != nil {
continue
}
pollers <- poller
}
select {
case <-ctx.Done():
return
case _, ok := <-delete:
if !ok { // channel was closed
return
}
}
}
}()
go func() { // This routine polls for for the deletions to complete.
defer wg.Done()
defer close(delete)
for poller := range pollers {
_, err := poller.PollUntilDone(ctx, nil)
if err != nil {
continue
}
select {
case delete <- struct{}{}:
default:
}
}
}()
wg.Wait()
return nil
}
func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) {
var ids []string
pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err)
}
for _, resource := range page.Value {
if resource.ID == nil {
return nil, fmt.Errorf("resource %v has no ID", resource)
}
ids = append(ids, *resource.ID)
}
}
return ids, nil
}
func (c *Client) deleteResourceByID(ctx context.Context, id string,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
apiVersion := "2020-02-02"
// First try, API version unknown, will fail.
poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
if isVersionWrongErr(err) {
// bad hack, but easiest way to get the right API version
apiVersion = parseAPIVersionFromErr(err)
poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
}
return poller, err
}
func isVersionWrongErr(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "NoRegisteredProviderFound") &&
strings.Contains(err.Error(), "The supported api-versions are")
}
var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`)
func parseAPIVersionFromErr(err error) string {
if err == nil {
return ""
}
matches := apiVersionRegex.FindStringSubmatch(err.Error())
return matches[1]
}

View File

@ -0,0 +1,139 @@
package client
import (
"context"
"errors"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestTerminateResourceGroupResources(t *testing.T) {
someErr := errors.New("failed")
apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'")
testCases := map[string]struct {
resourceAPI resourceAPI
}{
"no resources": {
resourceAPI: &fakeResourceAPI{},
},
"some resources": {
resourceAPI: &fakeResourceAPI{
resources: map[string]fakeResource{
"id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr},
"id-1": {beginDeleteByIDErr: apiVersionErr},
"id-2": {},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
resourceAPI: tc.resourceAPI,
}
ctx := context.Background()
err := client.TerminateResourceGroupResources(ctx)
assert.NoError(err)
})
}
}
type fakeResourceAPI struct {
resources map[string]fakeResource
fetchErr error
}
type fakeResource struct {
beginDeleteByIDErr error
pollErr error
}
func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse] {
pager := &stubClientListByResourceGroupResponsePager{
resources: a.resources,
fetchErr: a.fetchErr,
}
return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{
More: pager.moreFunc(),
Fetcher: pager.fetcherFunc(),
})
}
func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
res := a.resources[resourceID]
pollErr := res.pollErr
if pollErr != nil {
res.pollErr = nil
}
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{
Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{
result: armresources.ClientDeleteByIDResponse{},
resultErr: pollErr,
},
})
if err != nil {
panic(err)
}
beginDeleteByIDErr := res.beginDeleteByIDErr
if beginDeleteByIDErr != nil {
res.beginDeleteByIDErr = nil
}
if res.beginDeleteByIDErr == nil && res.pollErr == nil {
delete(a.resources, resourceID)
fmt.Printf("fake delete %s\n", resourceID)
} else {
a.resources[resourceID] = res
}
return poller, beginDeleteByIDErr
}
type stubClientListByResourceGroupResponsePager struct {
resources map[string]fakeResource
fetchErr error
more bool
}
func (p *stubClientListByResourceGroupResponsePager) moreFunc() func(
armresources.ClientListByResourceGroupResponse) bool {
return func(armresources.ClientListByResourceGroupResponse) bool {
return p.more
}
}
func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func(
context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error) {
return func(context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error,
) {
var resources []*armresources.GenericResourceExpanded
for id := range p.resources {
resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)})
}
p.more = false
return armresources.ClientListByResourceGroupResponse{
ResourceListResult: armresources.ResourceListResult{
Value: resources,
},
}, p.fetchErr
}
}

View File

@ -26,12 +26,9 @@ type azureclient interface {
GetState() state.ConstellationState
SetState(state.ConstellationState)
CreateApplicationInsight(ctx context.Context) error
CreateResourceGroup(ctx context.Context) error
CreateExternalLoadBalancer(ctx context.Context) error
CreateVirtualNetwork(ctx context.Context) error
CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error
CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error
CreateServicePrincipal(ctx context.Context) (string, error)
TerminateResourceGroup(ctx context.Context) error
TerminateServicePrincipal(ctx context.Context) error
TerminateResourceGroupResources(ctx context.Context) error
}

View File

@ -79,11 +79,6 @@ func (c *fakeAzureClient) CreateApplicationInsight(ctx context.Context) error {
return nil
}
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
c.resourceGroup = "resource-group"
return nil
}
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
c.subnetID = "subnet"
return nil
@ -123,17 +118,8 @@ func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
}.ToCloudServiceAccountURI(), nil
}
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
c.workers = nil
c.controlPlanes = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.workerScaleSet = ""
c.controlPlaneScaleSet = ""
func (c *fakeAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
// TODO(katexochen)
return nil
}
@ -146,18 +132,17 @@ func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
}
type stubAzureClient struct {
terminateResourceGroupCalled bool
terminateServicePrincipalCalled bool
terminateResourceGroupResourcesCalled bool
terminateServicePrincipalCalled bool
createApplicationInsightErr error
createResourceGroupErr error
createVirtualNetworkErr error
createSecurityGroupErr error
createLoadBalancerErr error
createInstancesErr error
createServicePrincipalErr error
terminateResourceGroupErr error
terminateServicePrincipalErr error
createApplicationInsightErr error
createVirtualNetworkErr error
createSecurityGroupErr error
createLoadBalancerErr error
createInstancesErr error
createServicePrincipalErr error
terminateResourceGroupResourcesErr error
terminateServicePrincipalErr error
}
func (c *stubAzureClient) GetState() state.ConstellationState {
@ -175,10 +160,6 @@ func (c *stubAzureClient) CreateApplicationInsight(ctx context.Context) error {
return c.createApplicationInsightErr
}
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
return c.createResourceGroupErr
}
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
return c.createVirtualNetworkErr
}
@ -198,9 +179,9 @@ func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
}.ToCloudServiceAccountURI(), c.createServicePrincipalErr
}
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error {
c.terminateResourceGroupCalled = true
return c.terminateResourceGroupErr
func (c *stubAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
c.terminateResourceGroupResourcesCalled = true
return c.terminateResourceGroupResourcesErr
}
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {

View File

@ -18,7 +18,7 @@ import (
type Creator struct {
out io.Writer
newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error)
newAzureClient func(subscriptionID, tenantID, name, location string) (azureclient, error)
newAzureClient func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error)
}
// NewCreator creates a new creator.
@ -28,8 +28,8 @@ func NewCreator(out io.Writer) *Creator {
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
return gcpcl.NewInitialized(ctx, project, zone, region, name)
},
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) {
return azurecl.NewInitialized(subscriptionID, tenantID, name, location)
newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
return azurecl.NewInitialized(subscriptionID, tenantID, name, location, resourceGroup)
},
}
}
@ -57,6 +57,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
config.Provider.Azure.TenantID,
name,
config.Provider.Azure.Location,
config.Provider.Azure.ResourceGroup,
)
if err != nil {
return state.ConstellationState{}, err
@ -144,9 +145,6 @@ func (c *Creator) createAzure(ctx context.Context, cl azureclient, config *confi
) (stat state.ConstellationState, retErr error) {
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl})
if err := cl.CreateResourceGroup(ctx); err != nil {
return state.ConstellationState{}, err
}
if err := cl.CreateApplicationInsight(ctx); err != nil {
return state.ConstellationState{}, err
}

View File

@ -51,7 +51,6 @@ func TestCreator(t *testing.T) {
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "resource-group",
AzureSubnet: "subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureWorkerScaleSet: "workers-scale-set",
@ -123,13 +122,6 @@ func TestCreator(t *testing.T) {
config: config.Default(),
wantErr: true,
},
"azure CreateResourceGroup error": {
azureclient: &stubAzureClient{createResourceGroupErr: someErr},
provider: cloudprovider.Azure,
config: config.Default(),
wantErr: true,
wantRollback: true,
},
"azure CreateVirtualNetwork error": {
azureclient: &stubAzureClient{createVirtualNetworkErr: someErr},
provider: cloudprovider.Azure,
@ -167,7 +159,7 @@ func TestCreator(t *testing.T) {
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
return tc.gcpclient, tc.newGCPClientErr
},
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) {
newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
return tc.azureclient, tc.newAzureClientErr
},
}
@ -186,7 +178,7 @@ func TestCreator(t *testing.T) {
assert.True(cl.closeCalled)
case cloudprovider.Azure:
cl := tc.azureclient.(*stubAzureClient)
assert.True(cl.terminateResourceGroupCalled)
assert.True(cl.terminateResourceGroupResourcesCalled)
}
}
} else {

View File

@ -46,5 +46,5 @@ type rollbackerAzure struct {
}
func (r *rollbackerAzure) rollback(ctx context.Context) error {
return r.client.TerminateResourceGroup(ctx)
return r.client.TerminateResourceGroupResources(ctx)
}

View File

@ -1,68 +0,0 @@
package cloudcmd
import (
"context"
"fmt"
azurecl "github.com/edgelesssys/constellation/cli/internal/azure/client"
gcpcl "github.com/edgelesssys/constellation/cli/internal/gcp/client"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
// ServiceAccountCreator creates service accounts.
type ServiceAccountCreator struct {
newGCPClient func(ctx context.Context) (gcpclient, error)
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
}
func NewServiceAccountCreator() *ServiceAccountCreator {
return &ServiceAccountCreator{
newGCPClient: func(ctx context.Context) (gcpclient, error) {
return gcpcl.NewFromDefault(ctx)
},
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return azurecl.NewFromDefault(subscriptionID, tenantID)
},
}
}
// Create creates a new cloud provider service account with access to the created resources.
func (c *ServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
) (string, state.ConstellationState, error) {
provider := cloudprovider.FromString(stat.CloudProvider)
switch provider {
case cloudprovider.GCP:
return "", state.ConstellationState{}, fmt.Errorf("creating service account not supported for GCP")
case cloudprovider.Azure:
cl, err := c.newAzureClient(stat.AzureSubscription, stat.AzureTenant)
if err != nil {
return "", state.ConstellationState{}, err
}
serviceAccount, stat, err := c.createServiceAccountAzure(ctx, cl, stat, config)
if err != nil {
return "", state.ConstellationState{}, err
}
return serviceAccount, stat, err
case cloudprovider.QEMU:
return "unsupported://qemu", stat, nil
default:
return "", state.ConstellationState{}, fmt.Errorf("unsupported provider: %s", provider)
}
}
func (c *ServiceAccountCreator) createServiceAccountAzure(ctx context.Context, cl azureclient,
stat state.ConstellationState, _ *config.Config,
) (string, state.ConstellationState, error) {
cl.SetState(stat)
serviceAccount, err := cl.CreateServicePrincipal(ctx)
if err != nil {
return "", state.ConstellationState{}, fmt.Errorf("creating service account: %w", err)
}
return serviceAccount, cl.GetState(), nil
}

View File

@ -1,89 +0,0 @@
package cloudcmd
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
)
func TestServiceAccountCreator(t *testing.T) {
someAzureState := func() state.ConstellationState {
return state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
}
}
someErr := errors.New("failed")
testCases := map[string]struct {
newGCPClient func(ctx context.Context) (gcpclient, error)
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
state state.ConstellationState
config *config.Config
wantErr bool
wantStateMutator func(*state.ConstellationState)
}{
"azure": {
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return &fakeAzureClient{}, nil
},
state: someAzureState(),
config: config.Default(),
wantStateMutator: func(stat *state.ConstellationState) {
stat.AzureADAppObjectID = "00000000-0000-0000-0000-000000000001"
},
},
"azure newAzureClient error": {
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return nil, someErr
},
state: someAzureState(),
config: config.Default(),
wantErr: true,
},
"azure client createServiceAccount error": {
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
return &stubAzureClient{createServicePrincipalErr: someErr}, nil
},
state: someAzureState(),
config: config.Default(),
wantErr: true,
},
"qemu": {
state: state.ConstellationState{CloudProvider: "qemu"},
wantStateMutator: func(cs *state.ConstellationState) {},
config: config.Default(),
},
"unknown cloud provider": {
state: state.ConstellationState{},
config: config.Default(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
creator := &ServiceAccountCreator{
newGCPClient: tc.newGCPClient,
newAzureClient: tc.newAzureClient,
}
serviceAccount, state, err := creator.Create(context.Background(), tc.state, tc.config)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(serviceAccount)
tc.wantStateMutator(&tc.state)
assert.Equal(tc.state, state)
}
})
}
}

View File

@ -72,8 +72,5 @@ func (t *Terminator) terminateGCP(ctx context.Context, cl gcpclient, state state
func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error {
cl.SetState(state)
if err := cl.TerminateServicePrincipal(ctx); err != nil {
return err
}
return cl.TerminateResourceGroup(ctx)
return cl.TerminateResourceGroupResources(ctx)
}

View File

@ -41,7 +41,6 @@ func TestTerminator(t *testing.T) {
AzureControlPlaneInstances: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "group",
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
}
}
@ -88,13 +87,8 @@ func TestTerminator(t *testing.T) {
state: someAzureState(),
wantErr: true,
},
"azure terminateServicePrincipal error": {
azureclient: &stubAzureClient{terminateServicePrincipalErr: someErr},
state: someAzureState(),
wantErr: true,
},
"azure terminateResourceGroup error": {
azureclient: &stubAzureClient{terminateResourceGroupErr: someErr},
"azure terminateResourceGroupResources error": {
azureclient: &stubAzureClient{terminateResourceGroupResourcesErr: someErr},
state: someAzureState(),
wantErr: true,
},
@ -132,8 +126,7 @@ func TestTerminator(t *testing.T) {
assert.True(cl.closeCalled)
case cloudprovider.Azure:
cl := tc.azureclient.(*stubAzureClient)
assert.True(cl.terminateResourceGroupCalled)
assert.True(cl.terminateServicePrincipalCalled)
assert.True(cl.terminateResourceGroupResourcesCalled)
}
}
})

View File

@ -21,8 +21,3 @@ type cloudCreator interface {
type cloudTerminator interface {
Terminate(context.Context, state.ConstellationState) error
}
type serviceAccountCreator interface {
Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
) (string, state.ConstellationState, error)
}

View File

@ -47,12 +47,3 @@ func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationStat
func (c *stubCloudTerminator) Called() bool {
return c.called
}
type stubServiceAccountCreator struct {
cloudServiceAccountURI string
createErr error
}
func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
return c.cloudServiceAccountURI, stat, c.createErr
}

View File

@ -55,17 +55,16 @@ func NewInitCmd() *cobra.Command {
// runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
}
helmLoader := &helm.ChartLoader{}
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient())
return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient())
}
// initialize initializes a Constellation.
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
) error {
flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil {
@ -105,22 +104,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
return err
}
var serviceAccURI string
// Temporary legacy flow for Azure.
if provider == cloudprovider.Azure {
cmd.Println("Creating service account ...")
serviceAccURI, stat, err = serviceAccCreator.Create(cmd.Context(), stat, config)
if err != nil {
return err
}
if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil {
return err
}
} else {
serviceAccURI, err = getMarschaledServiceAccountURI(provider, config, fileHandler)
if err != nil {
return err
}
serviceAccURI, err := getMarschaledServiceAccountURI(provider, config, fileHandler)
if err != nil {
return err
}
workers, err := getScalingGroupsFromState(stat, config)

View File

@ -69,49 +69,54 @@ func TestInitialize(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
state *state.ConstellationState
existingIDFile *clusterIDsFile
serviceAccCreator serviceAccountCreator
configMutator func(*config.Config)
serviceAccKey *gcpshared.ServiceAccountKey
helmLoader stubHelmLoader
initServerAPI *stubInitServer
endpointFlag string
setAutoscaleFlag bool
wantErr bool
state *state.ConstellationState
idFile *clusterIDsFile
configMutator func(*config.Config)
serviceAccKey *gcpshared.ServiceAccountKey
helmLoader stubHelmLoader
initServerAPI *stubInitServer
endpointFlag string
setAutoscaleFlag bool
wantErr bool
}{
"initialize some gcp instances": {
state: testGcpState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp},
state: testGcpState,
idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp},
},
"initialize some azure instances": {
state: testAzureState,
serviceAccCreator: &stubServiceAccountCreator{},
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initResp: testInitResp},
state: testAzureState,
idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) {
c.Provider.Azure.ResourceGroup = "resourceGroup"
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
},
initServerAPI: &stubInitServer{initResp: testInitResp},
},
"initialize some qemu instances": {
state: testQemuState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initResp: testInitResp},
state: testQemuState,
idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initResp: testInitResp},
},
"initialize gcp with autoscaling": {
state: testGcpState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
serviceAccKey: gcpServiceAccKey,
initServerAPI: &stubInitServer{initResp: testInitResp},
setAutoscaleFlag: true,
},
"initialize azure with autoscaling": {
state: testAzureState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
serviceAccCreator: &stubServiceAccountCreator{},
initServerAPI: &stubInitServer{initResp: testInitResp},
setAutoscaleFlag: true,
state: testAzureState,
idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) {
c.Provider.Azure.ResourceGroup = "resourceGroup"
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
},
initServerAPI: &stubInitServer{initResp: testInitResp},
setAutoscaleFlag: true,
},
"initialize with endpoint flag": {
state: testGcpState,
@ -121,27 +126,30 @@ func TestInitialize(t *testing.T) {
endpointFlag: "192.0.2.1",
},
"empty state": {
state: &state.ConstellationState{},
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{},
wantErr: true,
state: &state.ConstellationState{},
idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{},
wantErr: true,
},
"neither endpoint flag nor id file": {
state: &state.ConstellationState{},
wantErr: true,
},
"init call fails": {
state: testGcpState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initErr: someErr},
wantErr: true,
state: testGcpState,
idFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{initErr: someErr},
wantErr: true,
},
"fail to create service account": {
state: testAzureState,
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
initServerAPI: &stubInitServer{},
serviceAccCreator: &stubServiceAccountCreator{createErr: someErr},
wantErr: true,
state: testAzureState,
idFile: &clusterIDsFile{IP: "192.0.2.1"},
configMutator: func(c *config.Config) {
c.Provider.Azure.ResourceGroup = "resourceGroup"
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
},
initServerAPI: &stubInitServer{},
wantErr: true,
},
"fail to load helm charts": {
state: testGcpState,
@ -194,8 +202,8 @@ func TestInitialize(t *testing.T) {
if tc.state != nil {
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone))
}
if tc.existingIDFile != nil {
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.existingIDFile, file.OptNone))
if tc.idFile != nil {
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.idFile, file.OptNone))
}
if tc.serviceAccKey != nil {
require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone))
@ -206,7 +214,7 @@ func TestInitialize(t *testing.T) {
defer cancel()
cmd.SetContext(ctx)
err := initialize(cmd, newDialer, tc.serviceAccCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{})
err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{})
if tc.wantErr {
assert.Error(err)
@ -477,7 +485,7 @@ func TestAttestation(t *testing.T) {
defer cancel()
cmd.SetContext(ctx)
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
assert.Error(err)
// make sure the error is actually a TLS handshake error
assert.Contains(err.Error(), "transport: authentication handshake failed")
@ -548,6 +556,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
conf.Provider.Azure.Location = "test-location"
conf.Provider.Azure.UserAssignedIdentity = "test-identity"
conf.Provider.Azure.Image = "some/image/location"
conf.Provider.Azure.ResourceGroup = "test-resource-group"
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
case cloudprovider.GCP:

View File

@ -158,6 +158,9 @@ type AzureConfig struct {
// Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure
UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"`
// description: |
// Resource group to use.
ResourceGroup string `yaml:"resourceGroup" validate:"required"`
// description: |
// Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview
ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"`
}
@ -244,6 +247,7 @@ func Default() *Config {
TenantID: "",
Location: "",
UserAssignedIdentity: "",
ResourceGroup: "",
Image: DefaultImageAzure,
StateDiskType: "Premium_LRS",
Measurements: copyPCRMap(azurePCRs),

View File

@ -199,7 +199,7 @@ func init() {
FieldName: "azure",
},
}
AzureConfigDoc.Fields = make([]encoder.Doc, 9)
AzureConfigDoc.Fields = make([]encoder.Doc, 10)
AzureConfigDoc.Fields[0].Name = "subscription"
AzureConfigDoc.Fields[0].Type = "string"
AzureConfigDoc.Fields[0].Note = ""
@ -240,6 +240,11 @@ func init() {
AzureConfigDoc.Fields[7].Note = ""
AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
AzureConfigDoc.Fields[8].Name = "resourceGroup"
AzureConfigDoc.Fields[8].Type = "string"
AzureConfigDoc.Fields[8].Note = ""
AzureConfigDoc.Fields[8].Description = "Resource group to use."
AzureConfigDoc.Fields[8].Comments[encoder.LineComment] = "Resource group to use."
AzureConfigDoc.Fields[8].Name = "confidentialVM"
AzureConfigDoc.Fields[8].Type = "bool"
AzureConfigDoc.Fields[8].Note = ""