Manually manage resource group on Azure

This commit is contained in:
katexochen 2022-08-25 15:12:08 +02:00 committed by Paul Meyer
parent e6ae54a25a
commit f15605cb45
25 changed files with 403 additions and 1162 deletions

View file

@ -1,184 +0,0 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/edgelesssys/constellation/internal/azureshared"
"github.com/google/uuid"
)
const (
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
)
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
createAppRes, err := c.createADApplication(ctx)
if err != nil {
return "", err
}
c.adAppObjectID = createAppRes.ObjectID
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
if err != nil {
return "", err
}
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
return "", err
}
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
if err != nil {
return "", err
}
return azureshared.ApplicationCredentials{
TenantID: c.tenantID,
ClientID: createAppRes.AppID,
ClientSecret: clientSecret,
Location: c.location,
}.ToCloudServiceAccountURI(), nil
}
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
return err
}
c.adAppObjectID = ""
return nil
}
// createADApplication creates a new azure AD app.
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
createParameters := graphrbac.ApplicationCreateParameters{
AvailableToOtherTenants: to.BoolPtr(false),
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
}
app, err := c.applicationsAPI.Create(ctx, createParameters)
if err != nil {
return createADApplicationOutput{}, err
}
if app.AppID == nil || app.ObjectID == nil {
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
}
return createADApplicationOutput{
AppID: *app.AppID,
ObjectID: *app.ObjectID,
}, nil
}
// createAppServicePrincipal creates a new service principal for an azure AD app.
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
createParameters := graphrbac.ServicePrincipalCreateParameters{
AppID: &appID,
AccountEnabled: to.BoolPtr(true),
}
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
if err != nil {
return "", err
}
if servicePrincipal.ObjectID == nil {
return "", errors.New("creating AD service principal did not result in a valid object id")
}
return *servicePrincipal.ObjectID, nil
}
// updateAppCredentials sets app client-secret for authentication.
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
keyID := uuid.New().String()
clientSecret, err := generateClientSecret()
if err != nil {
return "", fmt.Errorf("generating client secret: %w", err)
}
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
Value: &[]graphrbac.PasswordCredential{
{
StartDate: &date.Time{Time: time.Now()},
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
Value: to.StringPtr(clientSecret),
KeyID: to.StringPtr(keyID),
},
},
}
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
if err != nil {
return "", err
}
return clientSecret, nil
}
// assignResourceGroupRole assigns the service principal a role at resource group scope.
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
if err != nil || resourceGroup.ID == nil {
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
}
roleAssignmentID := uuid.New().String()
createParameters := authorization.RoleAssignmentCreateParameters{
Properties: &authorization.RoleAssignmentProperties{
PrincipalID: to.StringPtr(principalID),
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
},
}
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
// proper fix: use API version 2018-09-01-preview or later
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
var detailedErr autorest.DetailedError
var ok bool
if detailedErr, ok = err.(autorest.DetailedError); !ok {
return err
}
var requestErr *azure.RequestError
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
return err
}
if requestErr.ServiceError.Code != "PrincipalNotFound" {
return err
}
time.Sleep(c.adReplicationLagCheckInterval)
}
return err
}
type createADApplicationOutput struct {
AppID string
ObjectID string
}
func generateClientSecret() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 64
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
return string(pw), nil
}

View file

@ -1,358 +0,0 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestCreateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
servicePrincipalsAPI servicePrincipalsAPI
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed service principal create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed role assignment": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"failed update creds": {
applicationsAPI: stubApplicationsAPI{
updateCredentialsErr: someErr,
},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
applicationsAPI: tc.applicationsAPI,
servicePrincipalsAPI: tc.servicePrincipalsAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
_, err := client.CreateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestTerminateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
appObjectID string
applicationsAPI applicationsAPI
wantErr bool
}{
"successful terminate": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{},
},
"nothing to terminate": {
applicationsAPI: stubApplicationsAPI{},
},
"failed delete": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{
deleteErr: someErr,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
adAppObjectID: tc.appObjectID,
applicationsAPI: tc.applicationsAPI,
}
err := client.TerminateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateADApplication(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"app create returns invalid appid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
},
},
wantErr: true,
},
"app create returns invalid objectid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
applicationsAPI: tc.applicationsAPI,
}
appCredentials, err := client.createADApplication(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(appCredentials)
})
}
}
func TestCreateAppServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
servicePrincipalsAPI servicePrincipalsAPI
wantErr bool
}{
"successful create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{},
},
"failed service principal create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"service principal create returns invalid objectid": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createServicePrincipal: &graphrbac.ServicePrincipal{},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
servicePrincipalsAPI: tc.servicePrincipalsAPI,
}
_, err := client.createAppServicePrincipal(ctx, "app-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestAssignOwnerOfResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful assign": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"failed role assignment": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"failed resource group get": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getErr: someErr,
},
wantErr: true,
},
"resource group get returns invalid id": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{},
},
wantErr: true,
},
"create returns PrincipalNotFound the first time": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "PrincipalNotFound",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
},
"create does not return request error": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{autorest.DetailedError{Original: someErr}},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
"create service error code is unknown": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "some-unknown-error-code",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}

View file

@ -64,15 +64,13 @@ type networkInterfacesAPI interface {
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
}
type resourceGroupAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error)
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
type resourceAPI interface {
NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse]
BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error)
}
type applicationsAPI interface {

View file

@ -4,15 +4,11 @@ import (
"context"
"net/http"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type stubNetworksAPI struct {
@ -94,44 +90,6 @@ func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, r
return poller, a.createErr
}
type stubResourceGroupAPI struct {
terminateErr error
createErr error
getErr error
getResourceGroup armresources.ResourceGroup
pollErr error
}
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
) {
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
}
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
return armresources.ResourceGroupsClientGetResponse{
ResourceGroup: a.getResourceGroup,
}, a.getErr
}
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ResourceGroupsClientDeleteResponse]{
Handler: &stubPoller[armresources.ResourceGroupsClientDeleteResponse]{
result: armresources.ResourceGroupsClientDeleteResponse{},
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.terminateErr
}
type stubScaleSetsAPI struct {
createErr error
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
@ -282,71 +240,6 @@ func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx
}, nil
}
type stubApplicationsAPI struct {
createErr error
deleteErr error
updateCredentialsErr error
createApplication *graphrbac.Application
}
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
if a.createErr != nil {
return graphrbac.Application{}, a.createErr
}
if a.createApplication != nil {
return *a.createApplication, nil
}
return graphrbac.Application{
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000001"),
}, nil
}
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
if a.deleteErr != nil {
return autorest.Response{}, a.deleteErr
}
return autorest.Response{}, nil
}
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
if a.updateCredentialsErr != nil {
return autorest.Response{}, a.updateCredentialsErr
}
return autorest.Response{}, nil
}
type stubServicePrincipalsAPI struct {
createErr error
createServicePrincipal *graphrbac.ServicePrincipal
}
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
if a.createErr != nil {
return graphrbac.ServicePrincipal{}, a.createErr
}
if a.createServicePrincipal != nil {
return *a.createServicePrincipal, nil
}
return graphrbac.ServicePrincipal{
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000002"),
}, nil
}
type stubRoleAssignmentsAPI struct {
createCounter int
createErrors []error
}
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
a.createCounter++
if len(a.createErrors) == 0 {
return authorization.RoleAssignment{}, nil
}
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
}
type stubApplicationInsightsAPI struct {
err error
}

View file

@ -29,7 +29,7 @@ const (
type Client struct {
networksAPI
networkSecurityGroupsAPI
resourceGroupAPI
resourceAPI
scaleSetsAPI
publicIPAddressesAPI
networkInterfacesAPI
@ -39,9 +39,7 @@ type Client struct {
roleAssignmentsAPI
applicationInsightsAPI
pollFrequency time.Duration
adReplicationLagCheckInterval time.Duration
adReplicationLagCheckMaxRetries int
pollFrequency time.Duration
workers cloudtypes.Instances
controlPlanes cloudtypes.Instances
@ -83,10 +81,6 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
if err != nil {
return nil, err
}
resGroupAPI, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
@ -107,6 +101,10 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
if err != nil {
return nil, err
}
resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
@ -115,42 +113,41 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{
networksAPI: netAPI,
networkSecurityGroupsAPI: netSecGrpAPI,
resourceGroupAPI: resGroupAPI,
scaleSetsAPI: scaleSetAPI,
publicIPAddressesAPI: publicIPAddressesAPI,
networkInterfacesAPI: networkInterfacesAPI,
loadBalancersAPI: loadBalancersAPI,
applicationsAPI: applicationsAPI,
servicePrincipalsAPI: servicePrincipalsAPI,
roleAssignmentsAPI: roleAssignmentsAPI,
applicationInsightsAPI: applicationInsightsAPI,
subscriptionID: subscriptionID,
tenantID: tenantID,
workers: cloudtypes.Instances{},
controlPlanes: cloudtypes.Instances{},
pollFrequency: time.Second * 5,
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
networksAPI: netAPI,
networkSecurityGroupsAPI: netSecGrpAPI,
resourceAPI: resourceAPI,
scaleSetsAPI: scaleSetAPI,
publicIPAddressesAPI: publicIPAddressesAPI,
networkInterfacesAPI: networkInterfacesAPI,
loadBalancersAPI: loadBalancersAPI,
applicationsAPI: applicationsAPI,
servicePrincipalsAPI: servicePrincipalsAPI,
roleAssignmentsAPI: roleAssignmentsAPI,
applicationInsightsAPI: applicationInsightsAPI,
subscriptionID: subscriptionID,
tenantID: tenantID,
workers: cloudtypes.Instances{},
controlPlanes: cloudtypes.Instances{},
pollFrequency: time.Second * 5,
}, nil
}
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil {
return nil, err
}
err = client.init(location, name)
err = client.init(location, name, resourceGroup)
return client, err
}
// init initializes the client.
func (c *Client) init(location, name string) error {
func (c *Client) init(location, name, resourceGroup string) error {
c.location = location
c.name = name
c.resourceGroup = resourceGroup
uid, err := c.generateUID()
if err != nil {
return err

View file

@ -84,8 +84,9 @@ func TestInit(t *testing.T) {
require := require.New(t)
client := Client{}
require.NoError(client.init("location", "name"))
require.NoError(client.init("location", "name", "rGroup"))
assert.Equal("location", client.location)
assert.Equal("name", client.name)
assert.Equal("rGroup", client.resourceGroup)
assert.NotEmpty(client.uid)
}

View file

@ -10,8 +10,6 @@ import (
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/azure/internal/poller"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
@ -213,45 +211,6 @@ type CreateScaleSetInput struct {
ConfidentialVM bool
}
// CreateResourceGroup creates a resource group.
func (c *Client) CreateResourceGroup(ctx context.Context) error {
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
armresources.ResourceGroup{
Location: &c.location,
}, nil)
if err != nil {
return err
}
c.resourceGroup = c.name + "-" + c.uid
return nil
}
// TerminateResourceGroup terminates a resource group.
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
if err != nil {
return err
}
if _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
}); err != nil {
return err
}
c.workers = nil
c.controlPlanes = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.workerScaleSet = ""
c.controlPlaneScaleSet = ""
return nil
}
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
type scaleSetCreationPollingHandler struct {
done bool

View file

@ -7,126 +7,16 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
resourceGroupAPI: stubResourceGroupAPI{},
},
"failed create": {
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroupAPI: tc.resourceGroupAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateResourceGroup(ctx))
} else {
assert.NoError(client.CreateResourceGroup(ctx))
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
}
})
}
}
func TestTerminateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
clientWithResourceGroup := Client{
resourceGroup: "name",
location: "location",
name: "name",
uid: "uid",
subnetID: "subnet",
workerScaleSet: "node-scale-set",
controlPlaneScaleSet: "controlplane-scale-set",
workers: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
controlPlanes: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
}
testCases := map[string]struct {
resourceGroup string
resourceGroupAPI resourceGroupAPI
client Client
wantErr bool
}{
"successful terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: clientWithResourceGroup,
},
"no resource group to terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: Client{},
resourceGroup: "",
},
"failed terminate": {
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
"failed to poll terminate response": {
resourceGroupAPI: stubResourceGroupAPI{pollErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.client.resourceGroupAPI = tc.resourceGroupAPI
ctx := context.Background()
if tc.wantErr {
assert.Error(tc.client.TerminateResourceGroup(ctx))
return
}
assert.NoError(tc.client.TerminateResourceGroup(ctx))
assert.Empty(tc.client.resourceGroup)
assert.Empty(tc.client.subnetID)
assert.Empty(tc.client.workers)
assert.Empty(tc.client.controlPlanes)
assert.Empty(tc.client.workerScaleSet)
assert.Empty(tc.client.controlPlaneScaleSet)
})
}
}
func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
wantErr bool
}{
@ -138,8 +28,6 @@ func TestCreateInstances(t *testing.T) {
VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
@ -153,8 +41,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
@ -169,8 +55,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
@ -185,8 +69,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountWorkers: 3,
InstanceType: "type",
@ -211,8 +93,6 @@ func TestCreateInstances(t *testing.T) {
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
loadBalancerPubIP: "lbip",
@ -232,11 +112,3 @@ func TestCreateInstances(t *testing.T) {
})
}
}
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
return &stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.Ptr("resource-group-id"),
},
}
}

View file

@ -0,0 +1,133 @@
package client
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
)
// TerminateResourceGroupResources deletes all resources from the resource group.
func (c *Client) TerminateResourceGroupResources(ctx context.Context) error {
const timeOut = 10 * time.Minute
ctx, cancel := context.WithTimeout(ctx, timeOut)
defer cancel()
pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20)
delete := make(chan struct{}, 1)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() { // This routine lists resources and starts their deletion, where possible.
defer wg.Done()
defer func() {
close(pollers)
for range delete { // drain channel
}
}()
for {
ids, err := c.getResourceIDList(ctx)
if err != nil {
time.Sleep(3 * time.Second)
continue
}
if len(ids) == 0 {
return
}
for _, id := range ids {
poller, err := c.deleteResourceByID(ctx, id)
if err != nil {
continue
}
pollers <- poller
}
select {
case <-ctx.Done():
return
case _, ok := <-delete:
if !ok { // channel was closed
return
}
}
}
}()
go func() { // This routine polls for for the deletions to complete.
defer wg.Done()
defer close(delete)
for poller := range pollers {
_, err := poller.PollUntilDone(ctx, nil)
if err != nil {
continue
}
select {
case delete <- struct{}{}:
default:
}
}
}()
wg.Wait()
return nil
}
func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) {
var ids []string
pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err)
}
for _, resource := range page.Value {
if resource.ID == nil {
return nil, fmt.Errorf("resource %v has no ID", resource)
}
ids = append(ids, *resource.ID)
}
}
return ids, nil
}
func (c *Client) deleteResourceByID(ctx context.Context, id string,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
apiVersion := "2020-02-02"
// First try, API version unknown, will fail.
poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
if isVersionWrongErr(err) {
// bad hack, but easiest way to get the right API version
apiVersion = parseAPIVersionFromErr(err)
poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
}
return poller, err
}
func isVersionWrongErr(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "NoRegisteredProviderFound") &&
strings.Contains(err.Error(), "The supported api-versions are")
}
var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`)
func parseAPIVersionFromErr(err error) string {
if err == nil {
return ""
}
matches := apiVersionRegex.FindStringSubmatch(err.Error())
return matches[1]
}

View file

@ -0,0 +1,139 @@
package client
import (
"context"
"errors"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestTerminateResourceGroupResources(t *testing.T) {
someErr := errors.New("failed")
apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'")
testCases := map[string]struct {
resourceAPI resourceAPI
}{
"no resources": {
resourceAPI: &fakeResourceAPI{},
},
"some resources": {
resourceAPI: &fakeResourceAPI{
resources: map[string]fakeResource{
"id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr},
"id-1": {beginDeleteByIDErr: apiVersionErr},
"id-2": {},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
resourceAPI: tc.resourceAPI,
}
ctx := context.Background()
err := client.TerminateResourceGroupResources(ctx)
assert.NoError(err)
})
}
}
type fakeResourceAPI struct {
resources map[string]fakeResource
fetchErr error
}
type fakeResource struct {
beginDeleteByIDErr error
pollErr error
}
func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse] {
pager := &stubClientListByResourceGroupResponsePager{
resources: a.resources,
fetchErr: a.fetchErr,
}
return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{
More: pager.moreFunc(),
Fetcher: pager.fetcherFunc(),
})
}
func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
res := a.resources[resourceID]
pollErr := res.pollErr
if pollErr != nil {
res.pollErr = nil
}
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{
Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{
result: armresources.ClientDeleteByIDResponse{},
resultErr: pollErr,
},
})
if err != nil {
panic(err)
}
beginDeleteByIDErr := res.beginDeleteByIDErr
if beginDeleteByIDErr != nil {
res.beginDeleteByIDErr = nil
}
if res.beginDeleteByIDErr == nil && res.pollErr == nil {
delete(a.resources, resourceID)
fmt.Printf("fake delete %s\n", resourceID)
} else {
a.resources[resourceID] = res
}
return poller, beginDeleteByIDErr
}
type stubClientListByResourceGroupResponsePager struct {
resources map[string]fakeResource
fetchErr error
more bool
}
func (p *stubClientListByResourceGroupResponsePager) moreFunc() func(
armresources.ClientListByResourceGroupResponse) bool {
return func(armresources.ClientListByResourceGroupResponse) bool {
return p.more
}
}
func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func(
context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error) {
return func(context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error,
) {
var resources []*armresources.GenericResourceExpanded
for id := range p.resources {
resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)})
}
p.more = false
return armresources.ClientListByResourceGroupResponse{
ResourceListResult: armresources.ResourceListResult{
Value: resources,
},
}, p.fetchErr
}
}