From f15605cb4580eaee5e6f4d9fb7ddb92d0d7c3c1d Mon Sep 17 00:00:00 2001 From: katexochen <49727155+katexochen@users.noreply.github.com> Date: Thu, 25 Aug 2022 15:12:08 +0200 Subject: [PATCH] Manually manage resource group on Azure --- cli/internal/azure/client/activedirectory.go | 184 --------- .../azure/client/activedirectory_test.go | 358 ------------------ cli/internal/azure/client/api.go | 16 +- cli/internal/azure/client/api_test.go | 107 ------ cli/internal/azure/client/client.go | 55 ++- cli/internal/azure/client/client_test.go | 3 +- cli/internal/azure/client/compute.go | 41 -- cli/internal/azure/client/compute_test.go | 128 ------- cli/internal/azure/client/terminate.go | 133 +++++++ cli/internal/azure/client/terminate_test.go | 139 +++++++ cli/internal/cloudcmd/clients.go | 5 +- cli/internal/cloudcmd/clients_test.go | 49 +-- cli/internal/cloudcmd/create.go | 10 +- cli/internal/cloudcmd/create_test.go | 12 +- cli/internal/cloudcmd/rollback.go | 2 +- cli/internal/cloudcmd/serviceaccount.go | 68 ---- cli/internal/cloudcmd/serviceaccount_test.go | 89 ----- cli/internal/cloudcmd/terminate.go | 5 +- cli/internal/cloudcmd/terminate_test.go | 13 +- cli/internal/cmd/cloud.go | 5 - cli/internal/cmd/cloud_test.go | 9 - cli/internal/cmd/init.go | 24 +- cli/internal/cmd/init_test.go | 99 ++--- internal/config/config.go | 4 + internal/config/config_doc.go | 7 +- 25 files changed, 403 insertions(+), 1162 deletions(-) delete mode 100644 cli/internal/azure/client/activedirectory.go delete mode 100644 cli/internal/azure/client/activedirectory_test.go create mode 100644 cli/internal/azure/client/terminate.go create mode 100644 cli/internal/azure/client/terminate_test.go delete mode 100644 cli/internal/cloudcmd/serviceaccount.go delete mode 100644 cli/internal/cloudcmd/serviceaccount_test.go diff --git a/cli/internal/azure/client/activedirectory.go b/cli/internal/azure/client/activedirectory.go deleted file mode 100644 index 9f0e03d14..000000000 --- a/cli/internal/azure/client/activedirectory.go +++ /dev/null @@ -1,184 +0,0 @@ -package client - -import ( - "context" - "crypto/rand" - "errors" - "fmt" - "math/big" - "time" - - "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/Azure/go-autorest/autorest/date" - "github.com/Azure/go-autorest/autorest/to" - "github.com/edgelesssys/constellation/internal/azureshared" - "github.com/google/uuid" -) - -const ( - adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years - adReplicationLagCheckInterval = time.Second * 5 // 5 seconds - adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication - ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner -) - -// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials. -func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) { - createAppRes, err := c.createADApplication(ctx) - if err != nil { - return "", err - } - c.adAppObjectID = createAppRes.ObjectID - servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID) - if err != nil { - return "", err - } - - if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil { - return "", err - } - - clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID) - if err != nil { - return "", err - } - - return azureshared.ApplicationCredentials{ - TenantID: c.tenantID, - ClientID: createAppRes.AppID, - ClientSecret: clientSecret, - Location: c.location, - }.ToCloudServiceAccountURI(), nil -} - -// TerminateServicePrincipal terminates an Azure AD app together with the service principal. -func (c *Client) TerminateServicePrincipal(ctx context.Context) error { - if c.adAppObjectID == "" { - return nil - } - if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil { - return err - } - c.adAppObjectID = "" - return nil -} - -// createADApplication creates a new azure AD app. -func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) { - createParameters := graphrbac.ApplicationCreateParameters{ - AvailableToOtherTenants: to.BoolPtr(false), - DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid), - } - app, err := c.applicationsAPI.Create(ctx, createParameters) - if err != nil { - return createADApplicationOutput{}, err - } - if app.AppID == nil || app.ObjectID == nil { - return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id") - } - return createADApplicationOutput{ - AppID: *app.AppID, - ObjectID: *app.ObjectID, - }, nil -} - -// createAppServicePrincipal creates a new service principal for an azure AD app. -func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) { - createParameters := graphrbac.ServicePrincipalCreateParameters{ - AppID: &appID, - AccountEnabled: to.BoolPtr(true), - } - servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters) - if err != nil { - return "", err - } - if servicePrincipal.ObjectID == nil { - return "", errors.New("creating AD service principal did not result in a valid object id") - } - return *servicePrincipal.ObjectID, nil -} - -// updateAppCredentials sets app client-secret for authentication. -func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) { - keyID := uuid.New().String() - clientSecret, err := generateClientSecret() - if err != nil { - return "", fmt.Errorf("generating client secret: %w", err) - } - updateParameters := graphrbac.PasswordCredentialsUpdateParameters{ - Value: &[]graphrbac.PasswordCredential{ - { - StartDate: &date.Time{Time: time.Now()}, - EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)}, - Value: to.StringPtr(clientSecret), - KeyID: to.StringPtr(keyID), - }, - }, - } - _, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters) - if err != nil { - return "", err - } - return clientSecret, nil -} - -// assignResourceGroupRole assigns the service principal a role at resource group scope. -func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error { - resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil) - if err != nil || resourceGroup.ID == nil { - return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err) - } - roleAssignmentID := uuid.New().String() - createParameters := authorization.RoleAssignmentCreateParameters{ - Properties: &authorization.RoleAssignmentProperties{ - PrincipalID: to.StringPtr(principalID), - RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)), - }, - } - - // due to an azure AD replication lag, retry role assignment if principal does not exist yet - // reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal - // proper fix: use API version 2018-09-01-preview or later - // azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95 - // the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071 - for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ { - _, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters) - var detailedErr autorest.DetailedError - var ok bool - if detailedErr, ok = err.(autorest.DetailedError); !ok { - return err - } - var requestErr *azure.RequestError - if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil { - return err - } - if requestErr.ServiceError.Code != "PrincipalNotFound" { - return err - } - time.Sleep(c.adReplicationLagCheckInterval) - } - return err -} - -type createADApplicationOutput struct { - AppID string - ObjectID string -} - -func generateClientSecret() (string, error) { - letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") - - pwLen := 64 - pw := make([]byte, 0, pwLen) - for i := 0; i < pwLen; i++ { - n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters)))) - if err != nil { - return "", err - } - pw = append(pw, letters[n.Int64()]) - } - return string(pw), nil -} diff --git a/cli/internal/azure/client/activedirectory_test.go b/cli/internal/azure/client/activedirectory_test.go deleted file mode 100644 index dd8bc727c..000000000 --- a/cli/internal/azure/client/activedirectory_test.go +++ /dev/null @@ -1,358 +0,0 @@ -package client - -import ( - "context" - "errors" - "testing" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/go-autorest/autorest" - "github.com/Azure/go-autorest/autorest/azure" - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/proto" -) - -func TestCreateServicePrincipal(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - applicationsAPI applicationsAPI - servicePrincipalsAPI servicePrincipalsAPI - roleAssignmentsAPI roleAssignmentsAPI - resourceGroupAPI resourceGroupAPI - wantErr bool - }{ - "successful create": { - applicationsAPI: stubApplicationsAPI{}, - servicePrincipalsAPI: stubServicePrincipalsAPI{}, - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - }, - "failed app create": { - applicationsAPI: stubApplicationsAPI{ - createErr: someErr, - }, - wantErr: true, - }, - "failed service principal create": { - applicationsAPI: stubApplicationsAPI{}, - servicePrincipalsAPI: stubServicePrincipalsAPI{ - createErr: someErr, - }, - wantErr: true, - }, - "failed role assignment": { - applicationsAPI: stubApplicationsAPI{}, - servicePrincipalsAPI: stubServicePrincipalsAPI{}, - roleAssignmentsAPI: &stubRoleAssignmentsAPI{ - createErrors: []error{someErr}, - }, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - wantErr: true, - }, - "failed update creds": { - applicationsAPI: stubApplicationsAPI{ - updateCredentialsErr: someErr, - }, - servicePrincipalsAPI: stubServicePrincipalsAPI{}, - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - ctx := context.Background() - - client := Client{ - name: "name", - uid: "uid", - resourceGroup: "resource-group", - applicationsAPI: tc.applicationsAPI, - servicePrincipalsAPI: tc.servicePrincipalsAPI, - roleAssignmentsAPI: tc.roleAssignmentsAPI, - resourceGroupAPI: tc.resourceGroupAPI, - adReplicationLagCheckMaxRetries: 2, - } - - _, err := client.CreateServicePrincipal(ctx) - if tc.wantErr { - assert.Error(err) - return - } - assert.NoError(err) - }) - } -} - -func TestTerminateServicePrincipal(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - appObjectID string - applicationsAPI applicationsAPI - wantErr bool - }{ - "successful terminate": { - appObjectID: "object-id", - applicationsAPI: stubApplicationsAPI{}, - }, - "nothing to terminate": { - applicationsAPI: stubApplicationsAPI{}, - }, - "failed delete": { - appObjectID: "object-id", - applicationsAPI: stubApplicationsAPI{ - deleteErr: someErr, - }, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - ctx := context.Background() - - client := Client{ - name: "name", - uid: "uid", - resourceGroup: "resource-group", - adAppObjectID: tc.appObjectID, - applicationsAPI: tc.applicationsAPI, - } - - err := client.TerminateServicePrincipal(ctx) - if tc.wantErr { - assert.Error(err) - return - } - assert.NoError(err) - }) - } -} - -func TestCreateADApplication(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - applicationsAPI applicationsAPI - wantErr bool - }{ - "successful create": { - applicationsAPI: stubApplicationsAPI{}, - }, - "failed app create": { - applicationsAPI: stubApplicationsAPI{ - createErr: someErr, - }, - wantErr: true, - }, - "app create returns invalid appid": { - applicationsAPI: stubApplicationsAPI{ - createApplication: &graphrbac.Application{ - ObjectID: proto.String("00000000-0000-0000-0000-000000000001"), - }, - }, - wantErr: true, - }, - "app create returns invalid objectid": { - applicationsAPI: stubApplicationsAPI{ - createApplication: &graphrbac.Application{ - AppID: proto.String("00000000-0000-0000-0000-000000000000"), - }, - }, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - ctx := context.Background() - - client := Client{ - name: "name", - uid: "uid", - applicationsAPI: tc.applicationsAPI, - } - - appCredentials, err := client.createADApplication(ctx) - if tc.wantErr { - assert.Error(err) - return - } - assert.NoError(err) - assert.NotNil(appCredentials) - }) - } -} - -func TestCreateAppServicePrincipal(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - servicePrincipalsAPI servicePrincipalsAPI - wantErr bool - }{ - "successful create": { - servicePrincipalsAPI: stubServicePrincipalsAPI{}, - }, - "failed service principal create": { - servicePrincipalsAPI: stubServicePrincipalsAPI{ - createErr: someErr, - }, - wantErr: true, - }, - "service principal create returns invalid objectid": { - servicePrincipalsAPI: stubServicePrincipalsAPI{ - createServicePrincipal: &graphrbac.ServicePrincipal{}, - }, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - ctx := context.Background() - - client := Client{ - name: "name", - uid: "uid", - servicePrincipalsAPI: tc.servicePrincipalsAPI, - } - - _, err := client.createAppServicePrincipal(ctx, "app-id") - if tc.wantErr { - assert.Error(err) - return - } - assert.NoError(err) - }) - } -} - -func TestAssignOwnerOfResourceGroup(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - roleAssignmentsAPI roleAssignmentsAPI - resourceGroupAPI resourceGroupAPI - wantErr bool - }{ - "successful assign": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - }, - "failed role assignment": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{ - createErrors: []error{someErr}, - }, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - wantErr: true, - }, - "failed resource group get": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, - resourceGroupAPI: stubResourceGroupAPI{ - getErr: someErr, - }, - wantErr: true, - }, - "resource group get returns invalid id": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{}, - }, - wantErr: true, - }, - "create returns PrincipalNotFound the first time": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{ - createErrors: []error{ - autorest.DetailedError{Original: &azure.RequestError{ - ServiceError: &azure.ServiceError{ - Code: "PrincipalNotFound", - }, - }}, - nil, - }, - }, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - }, - "create does not return request error": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{ - createErrors: []error{autorest.DetailedError{Original: someErr}}, - }, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - wantErr: true, - }, - "create service error code is unknown": { - roleAssignmentsAPI: &stubRoleAssignmentsAPI{ - createErrors: []error{ - autorest.DetailedError{Original: &azure.RequestError{ - ServiceError: &azure.ServiceError{ - Code: "some-unknown-error-code", - }, - }}, - nil, - }, - }, - resourceGroupAPI: stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - }, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - ctx := context.Background() - - client := Client{ - name: "name", - uid: "uid", - resourceGroup: "resource-group", - roleAssignmentsAPI: tc.roleAssignmentsAPI, - resourceGroupAPI: tc.resourceGroupAPI, - adReplicationLagCheckMaxRetries: 2, - } - - err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id") - if tc.wantErr { - assert.Error(err) - return - } - assert.NoError(err) - }) - } -} diff --git a/cli/internal/azure/client/api.go b/cli/internal/azure/client/api.go index 998464d71..a0f109deb 100644 --- a/cli/internal/azure/client/api.go +++ b/cli/internal/azure/client/api.go @@ -64,15 +64,13 @@ type networkInterfacesAPI interface { ) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) } -type resourceGroupAPI interface { - CreateOrUpdate(ctx context.Context, resourceGroupName string, - parameters armresources.ResourceGroup, - options *armresources.ResourceGroupsClientCreateOrUpdateOptions) ( - armresources.ResourceGroupsClientCreateOrUpdateResponse, error) - BeginDelete(ctx context.Context, resourceGroupName string, - options *armresources.ResourceGroupsClientBeginDeleteOptions) ( - *runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error) - Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) +type resourceAPI interface { + NewListByResourceGroupPager(resourceGroupName string, + options *armresources.ClientListByResourceGroupOptions, + ) *runtime.Pager[armresources.ClientListByResourceGroupResponse] + BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string, + options *armresources.ClientBeginDeleteByIDOptions, + ) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) } type applicationsAPI interface { diff --git a/cli/internal/azure/client/api_test.go b/cli/internal/azure/client/api_test.go index 7d439f6a5..d896cdcf3 100644 --- a/cli/internal/azure/client/api_test.go +++ b/cli/internal/azure/client/api_test.go @@ -4,15 +4,11 @@ import ( "context" "net/http" - "github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights" armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" - "github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac" - "github.com/Azure/go-autorest/autorest" ) type stubNetworksAPI struct { @@ -94,44 +90,6 @@ func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, r return poller, a.createErr } -type stubResourceGroupAPI struct { - terminateErr error - createErr error - getErr error - getResourceGroup armresources.ResourceGroup - pollErr error -} - -func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string, - parameters armresources.ResourceGroup, - options *armresources.ResourceGroupsClientCreateOrUpdateOptions) ( - armresources.ResourceGroupsClientCreateOrUpdateResponse, error, -) { - return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr -} - -func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) { - return armresources.ResourceGroupsClientGetResponse{ - ResourceGroup: a.getResourceGroup, - }, a.getErr -} - -func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string, - options *armresources.ResourceGroupsClientBeginDeleteOptions) ( - *runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error, -) { - poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ResourceGroupsClientDeleteResponse]{ - Handler: &stubPoller[armresources.ResourceGroupsClientDeleteResponse]{ - result: armresources.ResourceGroupsClientDeleteResponse{}, - resultErr: a.pollErr, - }, - }) - if err != nil { - panic(err) - } - return poller, a.terminateErr -} - type stubScaleSetsAPI struct { createErr error stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse @@ -282,71 +240,6 @@ func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx }, nil } -type stubApplicationsAPI struct { - createErr error - deleteErr error - updateCredentialsErr error - createApplication *graphrbac.Application -} - -func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) { - if a.createErr != nil { - return graphrbac.Application{}, a.createErr - } - if a.createApplication != nil { - return *a.createApplication, nil - } - return graphrbac.Application{ - AppID: to.Ptr("00000000-0000-0000-0000-000000000000"), - ObjectID: to.Ptr("00000000-0000-0000-0000-000000000001"), - }, nil -} - -func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) { - if a.deleteErr != nil { - return autorest.Response{}, a.deleteErr - } - return autorest.Response{}, nil -} - -func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) { - if a.updateCredentialsErr != nil { - return autorest.Response{}, a.updateCredentialsErr - } - return autorest.Response{}, nil -} - -type stubServicePrincipalsAPI struct { - createErr error - createServicePrincipal *graphrbac.ServicePrincipal -} - -func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) { - if a.createErr != nil { - return graphrbac.ServicePrincipal{}, a.createErr - } - if a.createServicePrincipal != nil { - return *a.createServicePrincipal, nil - } - return graphrbac.ServicePrincipal{ - AppID: to.Ptr("00000000-0000-0000-0000-000000000000"), - ObjectID: to.Ptr("00000000-0000-0000-0000-000000000002"), - }, nil -} - -type stubRoleAssignmentsAPI struct { - createCounter int - createErrors []error -} - -func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) { - a.createCounter++ - if len(a.createErrors) == 0 { - return authorization.RoleAssignment{}, nil - } - return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)] -} - type stubApplicationInsightsAPI struct { err error } diff --git a/cli/internal/azure/client/client.go b/cli/internal/azure/client/client.go index e92f4e8a1..b7d4e0d0b 100644 --- a/cli/internal/azure/client/client.go +++ b/cli/internal/azure/client/client.go @@ -29,7 +29,7 @@ const ( type Client struct { networksAPI networkSecurityGroupsAPI - resourceGroupAPI + resourceAPI scaleSetsAPI publicIPAddressesAPI networkInterfacesAPI @@ -39,9 +39,7 @@ type Client struct { roleAssignmentsAPI applicationInsightsAPI - pollFrequency time.Duration - adReplicationLagCheckInterval time.Duration - adReplicationLagCheckMaxRetries int + pollFrequency time.Duration workers cloudtypes.Instances controlPlanes cloudtypes.Instances @@ -83,10 +81,6 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) { if err != nil { return nil, err } - resGroupAPI, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) - if err != nil { - return nil, err - } scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil) if err != nil { return nil, err @@ -107,6 +101,10 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) { if err != nil { return nil, err } + resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil) + if err != nil { + return nil, err + } applicationsAPI := graphrbac.NewApplicationsClient(tenantID) applicationsAPI.Authorizer = graphAuthorizer servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID) @@ -115,42 +113,41 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) { roleAssignmentsAPI.Authorizer = managementAuthorizer return &Client{ - networksAPI: netAPI, - networkSecurityGroupsAPI: netSecGrpAPI, - resourceGroupAPI: resGroupAPI, - scaleSetsAPI: scaleSetAPI, - publicIPAddressesAPI: publicIPAddressesAPI, - networkInterfacesAPI: networkInterfacesAPI, - loadBalancersAPI: loadBalancersAPI, - applicationsAPI: applicationsAPI, - servicePrincipalsAPI: servicePrincipalsAPI, - roleAssignmentsAPI: roleAssignmentsAPI, - applicationInsightsAPI: applicationInsightsAPI, - subscriptionID: subscriptionID, - tenantID: tenantID, - workers: cloudtypes.Instances{}, - controlPlanes: cloudtypes.Instances{}, - pollFrequency: time.Second * 5, - adReplicationLagCheckInterval: adReplicationLagCheckInterval, - adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries, + networksAPI: netAPI, + networkSecurityGroupsAPI: netSecGrpAPI, + resourceAPI: resourceAPI, + scaleSetsAPI: scaleSetAPI, + publicIPAddressesAPI: publicIPAddressesAPI, + networkInterfacesAPI: networkInterfacesAPI, + loadBalancersAPI: loadBalancersAPI, + applicationsAPI: applicationsAPI, + servicePrincipalsAPI: servicePrincipalsAPI, + roleAssignmentsAPI: roleAssignmentsAPI, + applicationInsightsAPI: applicationInsightsAPI, + subscriptionID: subscriptionID, + tenantID: tenantID, + workers: cloudtypes.Instances{}, + controlPlanes: cloudtypes.Instances{}, + pollFrequency: time.Second * 5, }, nil } // NewInitialized creates and initializes client by setting the subscriptionID, location and name // of the Constellation. -func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) { +func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) { client, err := NewFromDefault(subscriptionID, tenantID) if err != nil { return nil, err } - err = client.init(location, name) + err = client.init(location, name, resourceGroup) return client, err } // init initializes the client. -func (c *Client) init(location, name string) error { +func (c *Client) init(location, name, resourceGroup string) error { c.location = location c.name = name + c.resourceGroup = resourceGroup uid, err := c.generateUID() if err != nil { return err diff --git a/cli/internal/azure/client/client_test.go b/cli/internal/azure/client/client_test.go index 50ac80bee..a23c11e94 100644 --- a/cli/internal/azure/client/client_test.go +++ b/cli/internal/azure/client/client_test.go @@ -84,8 +84,9 @@ func TestInit(t *testing.T) { require := require.New(t) client := Client{} - require.NoError(client.init("location", "name")) + require.NoError(client.init("location", "name", "rGroup")) assert.Equal("location", client.location) assert.Equal("name", client.name) + assert.Equal("rGroup", client.resourceGroup) assert.NotEmpty(client.uid) } diff --git a/cli/internal/azure/client/compute.go b/cli/internal/azure/client/compute.go index dc9815b64..1226fdd0d 100644 --- a/cli/internal/azure/client/compute.go +++ b/cli/internal/azure/client/compute.go @@ -10,8 +10,6 @@ import ( "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" "github.com/edgelesssys/constellation/cli/internal/azure" "github.com/edgelesssys/constellation/cli/internal/azure/internal/poller" "github.com/edgelesssys/constellation/internal/cloud/cloudtypes" @@ -213,45 +211,6 @@ type CreateScaleSetInput struct { ConfidentialVM bool } -// CreateResourceGroup creates a resource group. -func (c *Client) CreateResourceGroup(ctx context.Context) error { - _, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid, - armresources.ResourceGroup{ - Location: &c.location, - }, nil) - if err != nil { - return err - } - c.resourceGroup = c.name + "-" + c.uid - return nil -} - -// TerminateResourceGroup terminates a resource group. -func (c *Client) TerminateResourceGroup(ctx context.Context) error { - if c.resourceGroup == "" { - return nil - } - - poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil) - if err != nil { - return err - } - - if _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{ - Frequency: c.pollFrequency, - }); err != nil { - return err - } - c.workers = nil - c.controlPlanes = nil - c.resourceGroup = "" - c.subnetID = "" - c.networkSecurityGroup = "" - c.workerScaleSet = "" - c.controlPlaneScaleSet = "" - return nil -} - // scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully. type scaleSetCreationPollingHandler struct { done bool diff --git a/cli/internal/azure/client/compute_test.go b/cli/internal/azure/client/compute_test.go index e1d2c62c2..5bc1fec63 100644 --- a/cli/internal/azure/client/compute_test.go +++ b/cli/internal/azure/client/compute_test.go @@ -7,126 +7,16 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2" - "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" "github.com/edgelesssys/constellation/internal/cloud/cloudtypes" "github.com/stretchr/testify/assert" ) -func TestCreateResourceGroup(t *testing.T) { - someErr := errors.New("failed") - testCases := map[string]struct { - resourceGroupAPI resourceGroupAPI - wantErr bool - }{ - "successful create": { - resourceGroupAPI: stubResourceGroupAPI{}, - }, - "failed create": { - resourceGroupAPI: stubResourceGroupAPI{createErr: someErr}, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - ctx := context.Background() - - client := Client{ - location: "location", - name: "name", - uid: "uid", - resourceGroupAPI: tc.resourceGroupAPI, - workers: make(cloudtypes.Instances), - controlPlanes: make(cloudtypes.Instances), - } - - if tc.wantErr { - assert.Error(client.CreateResourceGroup(ctx)) - } else { - assert.NoError(client.CreateResourceGroup(ctx)) - assert.Equal(client.name+"-"+client.uid, client.resourceGroup) - } - }) - } -} - -func TestTerminateResourceGroup(t *testing.T) { - someErr := errors.New("failed") - clientWithResourceGroup := Client{ - resourceGroup: "name", - location: "location", - name: "name", - uid: "uid", - subnetID: "subnet", - workerScaleSet: "node-scale-set", - controlPlaneScaleSet: "controlplane-scale-set", - workers: cloudtypes.Instances{ - "0": { - PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1", - }, - }, - controlPlanes: cloudtypes.Instances{ - "0": { - PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1", - }, - }, - } - testCases := map[string]struct { - resourceGroup string - resourceGroupAPI resourceGroupAPI - client Client - wantErr bool - }{ - "successful terminate": { - resourceGroupAPI: stubResourceGroupAPI{}, - client: clientWithResourceGroup, - }, - "no resource group to terminate": { - resourceGroupAPI: stubResourceGroupAPI{}, - client: Client{}, - resourceGroup: "", - }, - "failed terminate": { - resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr}, - client: clientWithResourceGroup, - wantErr: true, - }, - "failed to poll terminate response": { - resourceGroupAPI: stubResourceGroupAPI{pollErr: someErr}, - client: clientWithResourceGroup, - wantErr: true, - }, - } - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - tc.client.resourceGroupAPI = tc.resourceGroupAPI - ctx := context.Background() - - if tc.wantErr { - assert.Error(tc.client.TerminateResourceGroup(ctx)) - return - } - assert.NoError(tc.client.TerminateResourceGroup(ctx)) - assert.Empty(tc.client.resourceGroup) - assert.Empty(tc.client.subnetID) - assert.Empty(tc.client.workers) - assert.Empty(tc.client.controlPlanes) - assert.Empty(tc.client.workerScaleSet) - assert.Empty(tc.client.controlPlaneScaleSet) - }) - } -} - func TestCreateInstances(t *testing.T) { someErr := errors.New("failed") testCases := map[string]struct { publicIPAddressesAPI publicIPAddressesAPI networkInterfacesAPI networkInterfacesAPI scaleSetsAPI scaleSetsAPI - resourceGroupAPI resourceGroupAPI - roleAssignmentsAPI roleAssignmentsAPI createInstancesInput CreateInstancesInput wantErr bool }{ @@ -138,8 +28,6 @@ func TestCreateInstances(t *testing.T) { VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}}, }, }, - resourceGroupAPI: newSuccessfulResourceGroupStub(), - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, createInstancesInput: CreateInstancesInput{ CountControlPlanes: 3, CountWorkers: 3, @@ -153,8 +41,6 @@ func TestCreateInstances(t *testing.T) { publicIPAddressesAPI: stubPublicIPAddressesAPI{}, networkInterfacesAPI: stubNetworkInterfacesAPI{}, scaleSetsAPI: stubScaleSetsAPI{createErr: someErr}, - resourceGroupAPI: newSuccessfulResourceGroupStub(), - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, createInstancesInput: CreateInstancesInput{ CountControlPlanes: 3, CountWorkers: 3, @@ -169,8 +55,6 @@ func TestCreateInstances(t *testing.T) { publicIPAddressesAPI: stubPublicIPAddressesAPI{}, networkInterfacesAPI: stubNetworkInterfacesAPI{}, scaleSetsAPI: stubScaleSetsAPI{getErr: someErr}, - resourceGroupAPI: newSuccessfulResourceGroupStub(), - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, createInstancesInput: CreateInstancesInput{ CountControlPlanes: 3, CountWorkers: 3, @@ -185,8 +69,6 @@ func TestCreateInstances(t *testing.T) { publicIPAddressesAPI: stubPublicIPAddressesAPI{}, networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr}, scaleSetsAPI: stubScaleSetsAPI{}, - resourceGroupAPI: newSuccessfulResourceGroupStub(), - roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, createInstancesInput: CreateInstancesInput{ CountWorkers: 3, InstanceType: "type", @@ -211,8 +93,6 @@ func TestCreateInstances(t *testing.T) { publicIPAddressesAPI: tc.publicIPAddressesAPI, networkInterfacesAPI: tc.networkInterfacesAPI, scaleSetsAPI: tc.scaleSetsAPI, - resourceGroupAPI: tc.resourceGroupAPI, - roleAssignmentsAPI: tc.roleAssignmentsAPI, workers: make(cloudtypes.Instances), controlPlanes: make(cloudtypes.Instances), loadBalancerPubIP: "lbip", @@ -232,11 +112,3 @@ func TestCreateInstances(t *testing.T) { }) } } - -func newSuccessfulResourceGroupStub() *stubResourceGroupAPI { - return &stubResourceGroupAPI{ - getResourceGroup: armresources.ResourceGroup{ - ID: to.Ptr("resource-group-id"), - }, - } -} diff --git a/cli/internal/azure/client/terminate.go b/cli/internal/azure/client/terminate.go new file mode 100644 index 000000000..2e364e293 --- /dev/null +++ b/cli/internal/azure/client/terminate.go @@ -0,0 +1,133 @@ +package client + +import ( + "context" + "fmt" + "regexp" + "strings" + "sync" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" +) + +// TerminateResourceGroupResources deletes all resources from the resource group. +func (c *Client) TerminateResourceGroupResources(ctx context.Context) error { + const timeOut = 10 * time.Minute + ctx, cancel := context.WithTimeout(ctx, timeOut) + defer cancel() + + pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20) + delete := make(chan struct{}, 1) + wg := &sync.WaitGroup{} + wg.Add(2) + + go func() { // This routine lists resources and starts their deletion, where possible. + defer wg.Done() + defer func() { + close(pollers) + for range delete { // drain channel + } + }() + + for { + ids, err := c.getResourceIDList(ctx) + if err != nil { + time.Sleep(3 * time.Second) + continue + } + + if len(ids) == 0 { + return + } + + for _, id := range ids { + poller, err := c.deleteResourceByID(ctx, id) + if err != nil { + continue + } + pollers <- poller + } + + select { + case <-ctx.Done(): + return + case _, ok := <-delete: + if !ok { // channel was closed + return + } + } + } + }() + + go func() { // This routine polls for for the deletions to complete. + defer wg.Done() + defer close(delete) + + for poller := range pollers { + _, err := poller.PollUntilDone(ctx, nil) + if err != nil { + continue + } + select { + case delete <- struct{}{}: + default: + } + } + }() + + wg.Wait() + + return nil +} + +func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) { + var ids []string + pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err) + } + for _, resource := range page.Value { + if resource.ID == nil { + return nil, fmt.Errorf("resource %v has no ID", resource) + } + ids = append(ids, *resource.ID) + } + } + return ids, nil +} + +func (c *Client) deleteResourceByID(ctx context.Context, id string, +) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) { + apiVersion := "2020-02-02" + + // First try, API version unknown, will fail. + poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil) + if isVersionWrongErr(err) { + // bad hack, but easiest way to get the right API version + apiVersion = parseAPIVersionFromErr(err) + poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil) + } + return poller, err +} + +func isVersionWrongErr(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "NoRegisteredProviderFound") && + strings.Contains(err.Error(), "The supported api-versions are") +} + +var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`) + +func parseAPIVersionFromErr(err error) string { + if err == nil { + return "" + } + matches := apiVersionRegex.FindStringSubmatch(err.Error()) + return matches[1] +} diff --git a/cli/internal/azure/client/terminate_test.go b/cli/internal/azure/client/terminate_test.go new file mode 100644 index 000000000..683fadd4e --- /dev/null +++ b/cli/internal/azure/client/terminate_test.go @@ -0,0 +1,139 @@ +package client + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" +) + +func TestTerminateResourceGroupResources(t *testing.T) { + someErr := errors.New("failed") + apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'") + + testCases := map[string]struct { + resourceAPI resourceAPI + }{ + "no resources": { + resourceAPI: &fakeResourceAPI{}, + }, + "some resources": { + resourceAPI: &fakeResourceAPI{ + resources: map[string]fakeResource{ + "id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr}, + "id-1": {beginDeleteByIDErr: apiVersionErr}, + "id-2": {}, + }, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + client := &Client{ + resourceAPI: tc.resourceAPI, + } + + ctx := context.Background() + err := client.TerminateResourceGroupResources(ctx) + assert.NoError(err) + }) + } +} + +type fakeResourceAPI struct { + resources map[string]fakeResource + fetchErr error +} + +type fakeResource struct { + beginDeleteByIDErr error + pollErr error +} + +func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string, + options *armresources.ClientListByResourceGroupOptions, +) *runtime.Pager[armresources.ClientListByResourceGroupResponse] { + pager := &stubClientListByResourceGroupResponsePager{ + resources: a.resources, + fetchErr: a.fetchErr, + } + return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{ + More: pager.moreFunc(), + Fetcher: pager.fetcherFunc(), + }) +} + +func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string, + options *armresources.ClientBeginDeleteByIDOptions, +) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) { + res := a.resources[resourceID] + + pollErr := res.pollErr + if pollErr != nil { + res.pollErr = nil + } + + poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{ + Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{ + result: armresources.ClientDeleteByIDResponse{}, + resultErr: pollErr, + }, + }) + if err != nil { + panic(err) + } + + beginDeleteByIDErr := res.beginDeleteByIDErr + if beginDeleteByIDErr != nil { + res.beginDeleteByIDErr = nil + } + + if res.beginDeleteByIDErr == nil && res.pollErr == nil { + delete(a.resources, resourceID) + fmt.Printf("fake delete %s\n", resourceID) + } else { + a.resources[resourceID] = res + } + + return poller, beginDeleteByIDErr +} + +type stubClientListByResourceGroupResponsePager struct { + resources map[string]fakeResource + fetchErr error + more bool +} + +func (p *stubClientListByResourceGroupResponsePager) moreFunc() func( + armresources.ClientListByResourceGroupResponse) bool { + return func(armresources.ClientListByResourceGroupResponse) bool { + return p.more + } +} + +func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func( + context.Context, *armresources.ClientListByResourceGroupResponse) ( + armresources.ClientListByResourceGroupResponse, error) { + return func(context.Context, *armresources.ClientListByResourceGroupResponse) ( + armresources.ClientListByResourceGroupResponse, error, + ) { + var resources []*armresources.GenericResourceExpanded + for id := range p.resources { + resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)}) + } + p.more = false + return armresources.ClientListByResourceGroupResponse{ + ResourceListResult: armresources.ResourceListResult{ + Value: resources, + }, + }, p.fetchErr + } +} diff --git a/cli/internal/cloudcmd/clients.go b/cli/internal/cloudcmd/clients.go index d9c7fb04b..d5ce912db 100644 --- a/cli/internal/cloudcmd/clients.go +++ b/cli/internal/cloudcmd/clients.go @@ -26,12 +26,9 @@ type azureclient interface { GetState() state.ConstellationState SetState(state.ConstellationState) CreateApplicationInsight(ctx context.Context) error - CreateResourceGroup(ctx context.Context) error CreateExternalLoadBalancer(ctx context.Context) error CreateVirtualNetwork(ctx context.Context) error CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error - CreateServicePrincipal(ctx context.Context) (string, error) - TerminateResourceGroup(ctx context.Context) error - TerminateServicePrincipal(ctx context.Context) error + TerminateResourceGroupResources(ctx context.Context) error } diff --git a/cli/internal/cloudcmd/clients_test.go b/cli/internal/cloudcmd/clients_test.go index 8abe2795c..fe8a7aff1 100644 --- a/cli/internal/cloudcmd/clients_test.go +++ b/cli/internal/cloudcmd/clients_test.go @@ -79,11 +79,6 @@ func (c *fakeAzureClient) CreateApplicationInsight(ctx context.Context) error { return nil } -func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error { - c.resourceGroup = "resource-group" - return nil -} - func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error { c.subnetID = "subnet" return nil @@ -123,17 +118,8 @@ func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, e }.ToCloudServiceAccountURI(), nil } -func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error { - if c.resourceGroup == "" { - return nil - } - c.workers = nil - c.controlPlanes = nil - c.resourceGroup = "" - c.subnetID = "" - c.networkSecurityGroup = "" - c.workerScaleSet = "" - c.controlPlaneScaleSet = "" +func (c *fakeAzureClient) TerminateResourceGroupResources(ctx context.Context) error { + // TODO(katexochen) return nil } @@ -146,18 +132,17 @@ func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error { } type stubAzureClient struct { - terminateResourceGroupCalled bool - terminateServicePrincipalCalled bool + terminateResourceGroupResourcesCalled bool + terminateServicePrincipalCalled bool - createApplicationInsightErr error - createResourceGroupErr error - createVirtualNetworkErr error - createSecurityGroupErr error - createLoadBalancerErr error - createInstancesErr error - createServicePrincipalErr error - terminateResourceGroupErr error - terminateServicePrincipalErr error + createApplicationInsightErr error + createVirtualNetworkErr error + createSecurityGroupErr error + createLoadBalancerErr error + createInstancesErr error + createServicePrincipalErr error + terminateResourceGroupResourcesErr error + terminateServicePrincipalErr error } func (c *stubAzureClient) GetState() state.ConstellationState { @@ -175,10 +160,6 @@ func (c *stubAzureClient) CreateApplicationInsight(ctx context.Context) error { return c.createApplicationInsightErr } -func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error { - return c.createResourceGroupErr -} - func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error { return c.createVirtualNetworkErr } @@ -198,9 +179,9 @@ func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, e }.ToCloudServiceAccountURI(), c.createServicePrincipalErr } -func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error { - c.terminateResourceGroupCalled = true - return c.terminateResourceGroupErr +func (c *stubAzureClient) TerminateResourceGroupResources(ctx context.Context) error { + c.terminateResourceGroupResourcesCalled = true + return c.terminateResourceGroupResourcesErr } func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error { diff --git a/cli/internal/cloudcmd/create.go b/cli/internal/cloudcmd/create.go index f978884b3..c953e90c7 100644 --- a/cli/internal/cloudcmd/create.go +++ b/cli/internal/cloudcmd/create.go @@ -18,7 +18,7 @@ import ( type Creator struct { out io.Writer newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error) - newAzureClient func(subscriptionID, tenantID, name, location string) (azureclient, error) + newAzureClient func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) } // NewCreator creates a new creator. @@ -28,8 +28,8 @@ func NewCreator(out io.Writer) *Creator { newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) { return gcpcl.NewInitialized(ctx, project, zone, region, name) }, - newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) { - return azurecl.NewInitialized(subscriptionID, tenantID, name, location) + newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) { + return azurecl.NewInitialized(subscriptionID, tenantID, name, location, resourceGroup) }, } } @@ -57,6 +57,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c config.Provider.Azure.TenantID, name, config.Provider.Azure.Location, + config.Provider.Azure.ResourceGroup, ) if err != nil { return state.ConstellationState{}, err @@ -144,9 +145,6 @@ func (c *Creator) createAzure(ctx context.Context, cl azureclient, config *confi ) (stat state.ConstellationState, retErr error) { defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl}) - if err := cl.CreateResourceGroup(ctx); err != nil { - return state.ConstellationState{}, err - } if err := cl.CreateApplicationInsight(ctx); err != nil { return state.ConstellationState{}, err } diff --git a/cli/internal/cloudcmd/create_test.go b/cli/internal/cloudcmd/create_test.go index 74bace924..84a91578b 100644 --- a/cli/internal/cloudcmd/create_test.go +++ b/cli/internal/cloudcmd/create_test.go @@ -51,7 +51,6 @@ func TestCreator(t *testing.T) { "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, - AzureResourceGroup: "resource-group", AzureSubnet: "subnet", AzureNetworkSecurityGroup: "network-security-group", AzureWorkerScaleSet: "workers-scale-set", @@ -123,13 +122,6 @@ func TestCreator(t *testing.T) { config: config.Default(), wantErr: true, }, - "azure CreateResourceGroup error": { - azureclient: &stubAzureClient{createResourceGroupErr: someErr}, - provider: cloudprovider.Azure, - config: config.Default(), - wantErr: true, - wantRollback: true, - }, "azure CreateVirtualNetwork error": { azureclient: &stubAzureClient{createVirtualNetworkErr: someErr}, provider: cloudprovider.Azure, @@ -167,7 +159,7 @@ func TestCreator(t *testing.T) { newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) { return tc.gcpclient, tc.newGCPClientErr }, - newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) { + newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) { return tc.azureclient, tc.newAzureClientErr }, } @@ -186,7 +178,7 @@ func TestCreator(t *testing.T) { assert.True(cl.closeCalled) case cloudprovider.Azure: cl := tc.azureclient.(*stubAzureClient) - assert.True(cl.terminateResourceGroupCalled) + assert.True(cl.terminateResourceGroupResourcesCalled) } } } else { diff --git a/cli/internal/cloudcmd/rollback.go b/cli/internal/cloudcmd/rollback.go index 100c1b204..043f783f2 100644 --- a/cli/internal/cloudcmd/rollback.go +++ b/cli/internal/cloudcmd/rollback.go @@ -46,5 +46,5 @@ type rollbackerAzure struct { } func (r *rollbackerAzure) rollback(ctx context.Context) error { - return r.client.TerminateResourceGroup(ctx) + return r.client.TerminateResourceGroupResources(ctx) } diff --git a/cli/internal/cloudcmd/serviceaccount.go b/cli/internal/cloudcmd/serviceaccount.go deleted file mode 100644 index 727c3c46e..000000000 --- a/cli/internal/cloudcmd/serviceaccount.go +++ /dev/null @@ -1,68 +0,0 @@ -package cloudcmd - -import ( - "context" - "fmt" - - azurecl "github.com/edgelesssys/constellation/cli/internal/azure/client" - gcpcl "github.com/edgelesssys/constellation/cli/internal/gcp/client" - "github.com/edgelesssys/constellation/internal/cloud/cloudprovider" - "github.com/edgelesssys/constellation/internal/config" - "github.com/edgelesssys/constellation/internal/state" -) - -// ServiceAccountCreator creates service accounts. -type ServiceAccountCreator struct { - newGCPClient func(ctx context.Context) (gcpclient, error) - newAzureClient func(subscriptionID, tenantID string) (azureclient, error) -} - -func NewServiceAccountCreator() *ServiceAccountCreator { - return &ServiceAccountCreator{ - newGCPClient: func(ctx context.Context) (gcpclient, error) { - return gcpcl.NewFromDefault(ctx) - }, - newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) { - return azurecl.NewFromDefault(subscriptionID, tenantID) - }, - } -} - -// Create creates a new cloud provider service account with access to the created resources. -func (c *ServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config, -) (string, state.ConstellationState, error) { - provider := cloudprovider.FromString(stat.CloudProvider) - switch provider { - case cloudprovider.GCP: - return "", state.ConstellationState{}, fmt.Errorf("creating service account not supported for GCP") - case cloudprovider.Azure: - cl, err := c.newAzureClient(stat.AzureSubscription, stat.AzureTenant) - if err != nil { - return "", state.ConstellationState{}, err - } - - serviceAccount, stat, err := c.createServiceAccountAzure(ctx, cl, stat, config) - if err != nil { - return "", state.ConstellationState{}, err - } - - return serviceAccount, stat, err - case cloudprovider.QEMU: - return "unsupported://qemu", stat, nil - default: - return "", state.ConstellationState{}, fmt.Errorf("unsupported provider: %s", provider) - } -} - -func (c *ServiceAccountCreator) createServiceAccountAzure(ctx context.Context, cl azureclient, - stat state.ConstellationState, _ *config.Config, -) (string, state.ConstellationState, error) { - cl.SetState(stat) - - serviceAccount, err := cl.CreateServicePrincipal(ctx) - if err != nil { - return "", state.ConstellationState{}, fmt.Errorf("creating service account: %w", err) - } - - return serviceAccount, cl.GetState(), nil -} diff --git a/cli/internal/cloudcmd/serviceaccount_test.go b/cli/internal/cloudcmd/serviceaccount_test.go deleted file mode 100644 index 946000374..000000000 --- a/cli/internal/cloudcmd/serviceaccount_test.go +++ /dev/null @@ -1,89 +0,0 @@ -package cloudcmd - -import ( - "context" - "errors" - "testing" - - "github.com/edgelesssys/constellation/internal/cloud/cloudprovider" - "github.com/edgelesssys/constellation/internal/config" - "github.com/edgelesssys/constellation/internal/state" - "github.com/stretchr/testify/assert" -) - -func TestServiceAccountCreator(t *testing.T) { - someAzureState := func() state.ConstellationState { - return state.ConstellationState{ - CloudProvider: cloudprovider.Azure.String(), - } - } - someErr := errors.New("failed") - - testCases := map[string]struct { - newGCPClient func(ctx context.Context) (gcpclient, error) - newAzureClient func(subscriptionID, tenantID string) (azureclient, error) - state state.ConstellationState - config *config.Config - wantErr bool - wantStateMutator func(*state.ConstellationState) - }{ - "azure": { - newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) { - return &fakeAzureClient{}, nil - }, - state: someAzureState(), - config: config.Default(), - wantStateMutator: func(stat *state.ConstellationState) { - stat.AzureADAppObjectID = "00000000-0000-0000-0000-000000000001" - }, - }, - "azure newAzureClient error": { - newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) { - return nil, someErr - }, - state: someAzureState(), - config: config.Default(), - wantErr: true, - }, - "azure client createServiceAccount error": { - newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) { - return &stubAzureClient{createServicePrincipalErr: someErr}, nil - }, - state: someAzureState(), - config: config.Default(), - wantErr: true, - }, - "qemu": { - state: state.ConstellationState{CloudProvider: "qemu"}, - wantStateMutator: func(cs *state.ConstellationState) {}, - config: config.Default(), - }, - "unknown cloud provider": { - state: state.ConstellationState{}, - config: config.Default(), - wantErr: true, - }, - } - - for name, tc := range testCases { - t.Run(name, func(t *testing.T) { - assert := assert.New(t) - - creator := &ServiceAccountCreator{ - newGCPClient: tc.newGCPClient, - newAzureClient: tc.newAzureClient, - } - - serviceAccount, state, err := creator.Create(context.Background(), tc.state, tc.config) - - if tc.wantErr { - assert.Error(err) - } else { - assert.NoError(err) - assert.NotEmpty(serviceAccount) - tc.wantStateMutator(&tc.state) - assert.Equal(tc.state, state) - } - }) - } -} diff --git a/cli/internal/cloudcmd/terminate.go b/cli/internal/cloudcmd/terminate.go index 0b7a58f86..07ddaf5a0 100644 --- a/cli/internal/cloudcmd/terminate.go +++ b/cli/internal/cloudcmd/terminate.go @@ -72,8 +72,5 @@ func (t *Terminator) terminateGCP(ctx context.Context, cl gcpclient, state state func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error { cl.SetState(state) - if err := cl.TerminateServicePrincipal(ctx); err != nil { - return err - } - return cl.TerminateResourceGroup(ctx) + return cl.TerminateResourceGroupResources(ctx) } diff --git a/cli/internal/cloudcmd/terminate_test.go b/cli/internal/cloudcmd/terminate_test.go index 3366396bc..83462ae85 100644 --- a/cli/internal/cloudcmd/terminate_test.go +++ b/cli/internal/cloudcmd/terminate_test.go @@ -41,7 +41,6 @@ func TestTerminator(t *testing.T) { AzureControlPlaneInstances: cloudtypes.Instances{ "id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, - AzureResourceGroup: "group", AzureADAppObjectID: "00000000-0000-0000-0000-000000000001", } } @@ -88,13 +87,8 @@ func TestTerminator(t *testing.T) { state: someAzureState(), wantErr: true, }, - "azure terminateServicePrincipal error": { - azureclient: &stubAzureClient{terminateServicePrincipalErr: someErr}, - state: someAzureState(), - wantErr: true, - }, - "azure terminateResourceGroup error": { - azureclient: &stubAzureClient{terminateResourceGroupErr: someErr}, + "azure terminateResourceGroupResources error": { + azureclient: &stubAzureClient{terminateResourceGroupResourcesErr: someErr}, state: someAzureState(), wantErr: true, }, @@ -132,8 +126,7 @@ func TestTerminator(t *testing.T) { assert.True(cl.closeCalled) case cloudprovider.Azure: cl := tc.azureclient.(*stubAzureClient) - assert.True(cl.terminateResourceGroupCalled) - assert.True(cl.terminateServicePrincipalCalled) + assert.True(cl.terminateResourceGroupResourcesCalled) } } }) diff --git a/cli/internal/cmd/cloud.go b/cli/internal/cmd/cloud.go index 86cd16efc..f8a8895d8 100644 --- a/cli/internal/cmd/cloud.go +++ b/cli/internal/cmd/cloud.go @@ -21,8 +21,3 @@ type cloudCreator interface { type cloudTerminator interface { Terminate(context.Context, state.ConstellationState) error } - -type serviceAccountCreator interface { - Create(ctx context.Context, stat state.ConstellationState, config *config.Config, - ) (string, state.ConstellationState, error) -} diff --git a/cli/internal/cmd/cloud_test.go b/cli/internal/cmd/cloud_test.go index 958f2a5c4..4a2f2bdae 100644 --- a/cli/internal/cmd/cloud_test.go +++ b/cli/internal/cmd/cloud_test.go @@ -47,12 +47,3 @@ func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationStat func (c *stubCloudTerminator) Called() bool { return c.called } - -type stubServiceAccountCreator struct { - cloudServiceAccountURI string - createErr error -} - -func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) { - return c.cloudServiceAccountURI, stat, c.createErr -} diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index 9de0fe2e1..7b160c1f3 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -55,17 +55,16 @@ func NewInitCmd() *cobra.Command { // runInitialize runs the initialize command. func runInitialize(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) - serviceAccountCreator := cloudcmd.NewServiceAccountCreator() newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer { return dialer.New(nil, validator.V(cmd), &net.Dialer{}) } helmLoader := &helm.ChartLoader{} - return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient()) + return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient()) } // initialize initializes a Constellation. func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer, - serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, + fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker, ) error { flags, err := evalFlagArgs(cmd, fileHandler) if err != nil { @@ -105,22 +104,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator return err } - var serviceAccURI string - // Temporary legacy flow for Azure. - if provider == cloudprovider.Azure { - cmd.Println("Creating service account ...") - serviceAccURI, stat, err = serviceAccCreator.Create(cmd.Context(), stat, config) - if err != nil { - return err - } - if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil { - return err - } - } else { - serviceAccURI, err = getMarschaledServiceAccountURI(provider, config, fileHandler) - if err != nil { - return err - } + serviceAccURI, err := getMarschaledServiceAccountURI(provider, config, fileHandler) + if err != nil { + return err } workers, err := getScalingGroupsFromState(stat, config) diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 41c66e3f6..9edb5d7f8 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -69,49 +69,54 @@ func TestInitialize(t *testing.T) { someErr := errors.New("failed") testCases := map[string]struct { - state *state.ConstellationState - existingIDFile *clusterIDsFile - serviceAccCreator serviceAccountCreator - configMutator func(*config.Config) - serviceAccKey *gcpshared.ServiceAccountKey - helmLoader stubHelmLoader - initServerAPI *stubInitServer - endpointFlag string - setAutoscaleFlag bool - wantErr bool + state *state.ConstellationState + idFile *clusterIDsFile + configMutator func(*config.Config) + serviceAccKey *gcpshared.ServiceAccountKey + helmLoader stubHelmLoader + initServerAPI *stubInitServer + endpointFlag string + setAutoscaleFlag bool + wantErr bool }{ "initialize some gcp instances": { - state: testGcpState, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, - serviceAccKey: gcpServiceAccKey, - initServerAPI: &stubInitServer{initResp: testInitResp}, + state: testGcpState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, + serviceAccKey: gcpServiceAccKey, + initServerAPI: &stubInitServer{initResp: testInitResp}, }, "initialize some azure instances": { - state: testAzureState, - serviceAccCreator: &stubServiceAccountCreator{}, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - initServerAPI: &stubInitServer{initResp: testInitResp}, + state: testAzureState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + configMutator: func(c *config.Config) { + c.Provider.Azure.ResourceGroup = "resourceGroup" + c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity" + }, + initServerAPI: &stubInitServer{initResp: testInitResp}, }, "initialize some qemu instances": { - state: testQemuState, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - initServerAPI: &stubInitServer{initResp: testInitResp}, + state: testQemuState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + initServerAPI: &stubInitServer{initResp: testInitResp}, }, "initialize gcp with autoscaling": { state: testGcpState, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath }, serviceAccKey: gcpServiceAccKey, initServerAPI: &stubInitServer{initResp: testInitResp}, setAutoscaleFlag: true, }, "initialize azure with autoscaling": { - state: testAzureState, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - serviceAccCreator: &stubServiceAccountCreator{}, - initServerAPI: &stubInitServer{initResp: testInitResp}, - setAutoscaleFlag: true, + state: testAzureState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + configMutator: func(c *config.Config) { + c.Provider.Azure.ResourceGroup = "resourceGroup" + c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity" + }, + initServerAPI: &stubInitServer{initResp: testInitResp}, + setAutoscaleFlag: true, }, "initialize with endpoint flag": { state: testGcpState, @@ -121,27 +126,30 @@ func TestInitialize(t *testing.T) { endpointFlag: "192.0.2.1", }, "empty state": { - state: &state.ConstellationState{}, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - initServerAPI: &stubInitServer{}, - wantErr: true, + state: &state.ConstellationState{}, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + initServerAPI: &stubInitServer{}, + wantErr: true, }, "neither endpoint flag nor id file": { state: &state.ConstellationState{}, wantErr: true, }, "init call fails": { - state: testGcpState, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - initServerAPI: &stubInitServer{initErr: someErr}, - wantErr: true, + state: testGcpState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + initServerAPI: &stubInitServer{initErr: someErr}, + wantErr: true, }, "fail to create service account": { - state: testAzureState, - existingIDFile: &clusterIDsFile{IP: "192.0.2.1"}, - initServerAPI: &stubInitServer{}, - serviceAccCreator: &stubServiceAccountCreator{createErr: someErr}, - wantErr: true, + state: testAzureState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + configMutator: func(c *config.Config) { + c.Provider.Azure.ResourceGroup = "resourceGroup" + c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity" + }, + initServerAPI: &stubInitServer{}, + wantErr: true, }, "fail to load helm charts": { state: testGcpState, @@ -194,8 +202,8 @@ func TestInitialize(t *testing.T) { if tc.state != nil { require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone)) } - if tc.existingIDFile != nil { - require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.existingIDFile, file.OptNone)) + if tc.idFile != nil { + require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.idFile, file.OptNone)) } if tc.serviceAccKey != nil { require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone)) @@ -206,7 +214,7 @@ func TestInitialize(t *testing.T) { defer cancel() cmd.SetContext(ctx) - err := initialize(cmd, newDialer, tc.serviceAccCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{}) + err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{}) if tc.wantErr { assert.Error(err) @@ -477,7 +485,7 @@ func TestAttestation(t *testing.T) { defer cancel() cmd.SetContext(ctx) - err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}) + err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{}) assert.Error(err) // make sure the error is actually a TLS handshake error assert.Contains(err.Error(), "transport: authentication handshake failed") @@ -548,6 +556,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs conf.Provider.Azure.Location = "test-location" conf.Provider.Azure.UserAssignedIdentity = "test-identity" conf.Provider.Azure.Image = "some/image/location" + conf.Provider.Azure.ResourceGroup = "test-resource-group" conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000") conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111") case cloudprovider.GCP: diff --git a/internal/config/config.go b/internal/config/config.go index 004879860..134f3a1f3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -158,6 +158,9 @@ type AzureConfig struct { // Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"` // description: | + // Resource group to use. + ResourceGroup string `yaml:"resourceGroup" validate:"required"` + // description: | // Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"` } @@ -244,6 +247,7 @@ func Default() *Config { TenantID: "", Location: "", UserAssignedIdentity: "", + ResourceGroup: "", Image: DefaultImageAzure, StateDiskType: "Premium_LRS", Measurements: copyPCRMap(azurePCRs), diff --git a/internal/config/config_doc.go b/internal/config/config_doc.go index 8f076fe6b..4837d1605 100644 --- a/internal/config/config_doc.go +++ b/internal/config/config_doc.go @@ -199,7 +199,7 @@ func init() { FieldName: "azure", }, } - AzureConfigDoc.Fields = make([]encoder.Doc, 9) + AzureConfigDoc.Fields = make([]encoder.Doc, 10) AzureConfigDoc.Fields[0].Name = "subscription" AzureConfigDoc.Fields[0].Type = "string" AzureConfigDoc.Fields[0].Note = "" @@ -240,6 +240,11 @@ func init() { AzureConfigDoc.Fields[7].Note = "" AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure" AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure" + AzureConfigDoc.Fields[8].Name = "resourceGroup" + AzureConfigDoc.Fields[8].Type = "string" + AzureConfigDoc.Fields[8].Note = "" + AzureConfigDoc.Fields[8].Description = "Resource group to use." + AzureConfigDoc.Fields[8].Comments[encoder.LineComment] = "Resource group to use." AzureConfigDoc.Fields[8].Name = "confidentialVM" AzureConfigDoc.Fields[8].Type = "bool" AzureConfigDoc.Fields[8].Note = ""