mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-04-11 10:29:29 -04:00
Manually manage resource group on Azure
This commit is contained in:
parent
e6ae54a25a
commit
f15605cb45
@ -1,184 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/Azure/go-autorest/autorest/date"
|
||||
"github.com/Azure/go-autorest/autorest/to"
|
||||
"github.com/edgelesssys/constellation/internal/azureshared"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
|
||||
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
|
||||
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
|
||||
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
|
||||
)
|
||||
|
||||
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
|
||||
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
|
||||
createAppRes, err := c.createADApplication(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
c.adAppObjectID = createAppRes.ObjectID
|
||||
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return azureshared.ApplicationCredentials{
|
||||
TenantID: c.tenantID,
|
||||
ClientID: createAppRes.AppID,
|
||||
ClientSecret: clientSecret,
|
||||
Location: c.location,
|
||||
}.ToCloudServiceAccountURI(), nil
|
||||
}
|
||||
|
||||
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
|
||||
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
|
||||
if c.adAppObjectID == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
|
||||
return err
|
||||
}
|
||||
c.adAppObjectID = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// createADApplication creates a new azure AD app.
|
||||
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
|
||||
createParameters := graphrbac.ApplicationCreateParameters{
|
||||
AvailableToOtherTenants: to.BoolPtr(false),
|
||||
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
|
||||
}
|
||||
app, err := c.applicationsAPI.Create(ctx, createParameters)
|
||||
if err != nil {
|
||||
return createADApplicationOutput{}, err
|
||||
}
|
||||
if app.AppID == nil || app.ObjectID == nil {
|
||||
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
|
||||
}
|
||||
return createADApplicationOutput{
|
||||
AppID: *app.AppID,
|
||||
ObjectID: *app.ObjectID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createAppServicePrincipal creates a new service principal for an azure AD app.
|
||||
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
|
||||
createParameters := graphrbac.ServicePrincipalCreateParameters{
|
||||
AppID: &appID,
|
||||
AccountEnabled: to.BoolPtr(true),
|
||||
}
|
||||
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if servicePrincipal.ObjectID == nil {
|
||||
return "", errors.New("creating AD service principal did not result in a valid object id")
|
||||
}
|
||||
return *servicePrincipal.ObjectID, nil
|
||||
}
|
||||
|
||||
// updateAppCredentials sets app client-secret for authentication.
|
||||
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
|
||||
keyID := uuid.New().String()
|
||||
clientSecret, err := generateClientSecret()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generating client secret: %w", err)
|
||||
}
|
||||
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
|
||||
Value: &[]graphrbac.PasswordCredential{
|
||||
{
|
||||
StartDate: &date.Time{Time: time.Now()},
|
||||
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
|
||||
Value: to.StringPtr(clientSecret),
|
||||
KeyID: to.StringPtr(keyID),
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return clientSecret, nil
|
||||
}
|
||||
|
||||
// assignResourceGroupRole assigns the service principal a role at resource group scope.
|
||||
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
|
||||
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
|
||||
if err != nil || resourceGroup.ID == nil {
|
||||
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
|
||||
}
|
||||
roleAssignmentID := uuid.New().String()
|
||||
createParameters := authorization.RoleAssignmentCreateParameters{
|
||||
Properties: &authorization.RoleAssignmentProperties{
|
||||
PrincipalID: to.StringPtr(principalID),
|
||||
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
|
||||
},
|
||||
}
|
||||
|
||||
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
|
||||
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
|
||||
// proper fix: use API version 2018-09-01-preview or later
|
||||
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
|
||||
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
|
||||
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
|
||||
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
|
||||
var detailedErr autorest.DetailedError
|
||||
var ok bool
|
||||
if detailedErr, ok = err.(autorest.DetailedError); !ok {
|
||||
return err
|
||||
}
|
||||
var requestErr *azure.RequestError
|
||||
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
|
||||
return err
|
||||
}
|
||||
if requestErr.ServiceError.Code != "PrincipalNotFound" {
|
||||
return err
|
||||
}
|
||||
time.Sleep(c.adReplicationLagCheckInterval)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type createADApplicationOutput struct {
|
||||
AppID string
|
||||
ObjectID string
|
||||
}
|
||||
|
||||
func generateClientSecret() (string, error) {
|
||||
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
||||
|
||||
pwLen := 64
|
||||
pw := make([]byte, 0, pwLen)
|
||||
for i := 0; i < pwLen; i++ {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
pw = append(pw, letters[n.Int64()])
|
||||
}
|
||||
return string(pw), nil
|
||||
}
|
@ -1,358 +0,0 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
"github.com/Azure/go-autorest/autorest/azure"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestCreateServicePrincipal(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
applicationsAPI applicationsAPI
|
||||
servicePrincipalsAPI servicePrincipalsAPI
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
wantErr bool
|
||||
}{
|
||||
"successful create": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"failed app create": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"failed service principal create": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"failed role assignment": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{someErr},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"failed update creds": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
updateCredentialsErr: someErr,
|
||||
},
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "resource-group",
|
||||
applicationsAPI: tc.applicationsAPI,
|
||||
servicePrincipalsAPI: tc.servicePrincipalsAPI,
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
adReplicationLagCheckMaxRetries: 2,
|
||||
}
|
||||
|
||||
_, err := client.CreateServicePrincipal(ctx)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateServicePrincipal(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
appObjectID string
|
||||
applicationsAPI applicationsAPI
|
||||
wantErr bool
|
||||
}{
|
||||
"successful terminate": {
|
||||
appObjectID: "object-id",
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
},
|
||||
"nothing to terminate": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
},
|
||||
"failed delete": {
|
||||
appObjectID: "object-id",
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
deleteErr: someErr,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "resource-group",
|
||||
adAppObjectID: tc.appObjectID,
|
||||
applicationsAPI: tc.applicationsAPI,
|
||||
}
|
||||
|
||||
err := client.TerminateServicePrincipal(ctx)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateADApplication(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
applicationsAPI applicationsAPI
|
||||
wantErr bool
|
||||
}{
|
||||
"successful create": {
|
||||
applicationsAPI: stubApplicationsAPI{},
|
||||
},
|
||||
"failed app create": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"app create returns invalid appid": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createApplication: &graphrbac.Application{
|
||||
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"app create returns invalid objectid": {
|
||||
applicationsAPI: stubApplicationsAPI{
|
||||
createApplication: &graphrbac.Application{
|
||||
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
applicationsAPI: tc.applicationsAPI,
|
||||
}
|
||||
|
||||
appCredentials, err := client.createADApplication(ctx)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.NotNil(appCredentials)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateAppServicePrincipal(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
servicePrincipalsAPI servicePrincipalsAPI
|
||||
wantErr bool
|
||||
}{
|
||||
"successful create": {
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
||||
},
|
||||
"failed service principal create": {
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
||||
createErr: someErr,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"service principal create returns invalid objectid": {
|
||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
||||
createServicePrincipal: &graphrbac.ServicePrincipal{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
servicePrincipalsAPI: tc.servicePrincipalsAPI,
|
||||
}
|
||||
|
||||
_, err := client.createAppServicePrincipal(ctx, "app-id")
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAssignOwnerOfResourceGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
wantErr bool
|
||||
}{
|
||||
"successful assign": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"failed role assignment": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{someErr},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"failed resource group get": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getErr: someErr,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"resource group get returns invalid id": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"create returns PrincipalNotFound the first time": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{
|
||||
autorest.DetailedError{Original: &azure.RequestError{
|
||||
ServiceError: &azure.ServiceError{
|
||||
Code: "PrincipalNotFound",
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"create does not return request error": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{autorest.DetailedError{Original: someErr}},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"create service error code is unknown": {
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
||||
createErrors: []error{
|
||||
autorest.DetailedError{Original: &azure.RequestError{
|
||||
ServiceError: &azure.ServiceError{
|
||||
Code: "some-unknown-error-code",
|
||||
},
|
||||
}},
|
||||
nil,
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroup: "resource-group",
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
adReplicationLagCheckMaxRetries: 2,
|
||||
}
|
||||
|
||||
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
@ -64,15 +64,13 @@ type networkInterfacesAPI interface {
|
||||
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
|
||||
}
|
||||
|
||||
type resourceGroupAPI interface {
|
||||
CreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
parameters armresources.ResourceGroup,
|
||||
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
|
||||
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
|
||||
BeginDelete(ctx context.Context, resourceGroupName string,
|
||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
||||
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error)
|
||||
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
|
||||
type resourceAPI interface {
|
||||
NewListByResourceGroupPager(resourceGroupName string,
|
||||
options *armresources.ClientListByResourceGroupOptions,
|
||||
) *runtime.Pager[armresources.ClientListByResourceGroupResponse]
|
||||
BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
|
||||
options *armresources.ClientBeginDeleteByIDOptions,
|
||||
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error)
|
||||
}
|
||||
|
||||
type applicationsAPI interface {
|
||||
|
@ -4,15 +4,11 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
|
||||
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
||||
"github.com/Azure/go-autorest/autorest"
|
||||
)
|
||||
|
||||
type stubNetworksAPI struct {
|
||||
@ -94,44 +90,6 @@ func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, r
|
||||
return poller, a.createErr
|
||||
}
|
||||
|
||||
type stubResourceGroupAPI struct {
|
||||
terminateErr error
|
||||
createErr error
|
||||
getErr error
|
||||
getResourceGroup armresources.ResourceGroup
|
||||
pollErr error
|
||||
}
|
||||
|
||||
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
|
||||
parameters armresources.ResourceGroup,
|
||||
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
|
||||
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
|
||||
) {
|
||||
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
|
||||
}
|
||||
|
||||
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
|
||||
return armresources.ResourceGroupsClientGetResponse{
|
||||
ResourceGroup: a.getResourceGroup,
|
||||
}, a.getErr
|
||||
}
|
||||
|
||||
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
|
||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
||||
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error,
|
||||
) {
|
||||
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ResourceGroupsClientDeleteResponse]{
|
||||
Handler: &stubPoller[armresources.ResourceGroupsClientDeleteResponse]{
|
||||
result: armresources.ResourceGroupsClientDeleteResponse{},
|
||||
resultErr: a.pollErr,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return poller, a.terminateErr
|
||||
}
|
||||
|
||||
type stubScaleSetsAPI struct {
|
||||
createErr error
|
||||
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
|
||||
@ -282,71 +240,6 @@ func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx
|
||||
}, nil
|
||||
}
|
||||
|
||||
type stubApplicationsAPI struct {
|
||||
createErr error
|
||||
deleteErr error
|
||||
updateCredentialsErr error
|
||||
createApplication *graphrbac.Application
|
||||
}
|
||||
|
||||
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
|
||||
if a.createErr != nil {
|
||||
return graphrbac.Application{}, a.createErr
|
||||
}
|
||||
if a.createApplication != nil {
|
||||
return *a.createApplication, nil
|
||||
}
|
||||
return graphrbac.Application{
|
||||
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
|
||||
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000001"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
|
||||
if a.deleteErr != nil {
|
||||
return autorest.Response{}, a.deleteErr
|
||||
}
|
||||
return autorest.Response{}, nil
|
||||
}
|
||||
|
||||
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
|
||||
if a.updateCredentialsErr != nil {
|
||||
return autorest.Response{}, a.updateCredentialsErr
|
||||
}
|
||||
return autorest.Response{}, nil
|
||||
}
|
||||
|
||||
type stubServicePrincipalsAPI struct {
|
||||
createErr error
|
||||
createServicePrincipal *graphrbac.ServicePrincipal
|
||||
}
|
||||
|
||||
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
|
||||
if a.createErr != nil {
|
||||
return graphrbac.ServicePrincipal{}, a.createErr
|
||||
}
|
||||
if a.createServicePrincipal != nil {
|
||||
return *a.createServicePrincipal, nil
|
||||
}
|
||||
return graphrbac.ServicePrincipal{
|
||||
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
|
||||
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000002"),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type stubRoleAssignmentsAPI struct {
|
||||
createCounter int
|
||||
createErrors []error
|
||||
}
|
||||
|
||||
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
|
||||
a.createCounter++
|
||||
if len(a.createErrors) == 0 {
|
||||
return authorization.RoleAssignment{}, nil
|
||||
}
|
||||
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
|
||||
}
|
||||
|
||||
type stubApplicationInsightsAPI struct {
|
||||
err error
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ const (
|
||||
type Client struct {
|
||||
networksAPI
|
||||
networkSecurityGroupsAPI
|
||||
resourceGroupAPI
|
||||
resourceAPI
|
||||
scaleSetsAPI
|
||||
publicIPAddressesAPI
|
||||
networkInterfacesAPI
|
||||
@ -39,9 +39,7 @@ type Client struct {
|
||||
roleAssignmentsAPI
|
||||
applicationInsightsAPI
|
||||
|
||||
pollFrequency time.Duration
|
||||
adReplicationLagCheckInterval time.Duration
|
||||
adReplicationLagCheckMaxRetries int
|
||||
pollFrequency time.Duration
|
||||
|
||||
workers cloudtypes.Instances
|
||||
controlPlanes cloudtypes.Instances
|
||||
@ -83,10 +81,6 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resGroupAPI, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -107,6 +101,10 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
|
||||
applicationsAPI.Authorizer = graphAuthorizer
|
||||
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
|
||||
@ -115,42 +113,41 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||
roleAssignmentsAPI.Authorizer = managementAuthorizer
|
||||
|
||||
return &Client{
|
||||
networksAPI: netAPI,
|
||||
networkSecurityGroupsAPI: netSecGrpAPI,
|
||||
resourceGroupAPI: resGroupAPI,
|
||||
scaleSetsAPI: scaleSetAPI,
|
||||
publicIPAddressesAPI: publicIPAddressesAPI,
|
||||
networkInterfacesAPI: networkInterfacesAPI,
|
||||
loadBalancersAPI: loadBalancersAPI,
|
||||
applicationsAPI: applicationsAPI,
|
||||
servicePrincipalsAPI: servicePrincipalsAPI,
|
||||
roleAssignmentsAPI: roleAssignmentsAPI,
|
||||
applicationInsightsAPI: applicationInsightsAPI,
|
||||
subscriptionID: subscriptionID,
|
||||
tenantID: tenantID,
|
||||
workers: cloudtypes.Instances{},
|
||||
controlPlanes: cloudtypes.Instances{},
|
||||
pollFrequency: time.Second * 5,
|
||||
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
|
||||
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
|
||||
networksAPI: netAPI,
|
||||
networkSecurityGroupsAPI: netSecGrpAPI,
|
||||
resourceAPI: resourceAPI,
|
||||
scaleSetsAPI: scaleSetAPI,
|
||||
publicIPAddressesAPI: publicIPAddressesAPI,
|
||||
networkInterfacesAPI: networkInterfacesAPI,
|
||||
loadBalancersAPI: loadBalancersAPI,
|
||||
applicationsAPI: applicationsAPI,
|
||||
servicePrincipalsAPI: servicePrincipalsAPI,
|
||||
roleAssignmentsAPI: roleAssignmentsAPI,
|
||||
applicationInsightsAPI: applicationInsightsAPI,
|
||||
subscriptionID: subscriptionID,
|
||||
tenantID: tenantID,
|
||||
workers: cloudtypes.Instances{},
|
||||
controlPlanes: cloudtypes.Instances{},
|
||||
pollFrequency: time.Second * 5,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
|
||||
// of the Constellation.
|
||||
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
|
||||
func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) {
|
||||
client, err := NewFromDefault(subscriptionID, tenantID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = client.init(location, name)
|
||||
err = client.init(location, name, resourceGroup)
|
||||
return client, err
|
||||
}
|
||||
|
||||
// init initializes the client.
|
||||
func (c *Client) init(location, name string) error {
|
||||
func (c *Client) init(location, name, resourceGroup string) error {
|
||||
c.location = location
|
||||
c.name = name
|
||||
c.resourceGroup = resourceGroup
|
||||
uid, err := c.generateUID()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -84,8 +84,9 @@ func TestInit(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
client := Client{}
|
||||
require.NoError(client.init("location", "name"))
|
||||
require.NoError(client.init("location", "name", "rGroup"))
|
||||
assert.Equal("location", client.location)
|
||||
assert.Equal("name", client.name)
|
||||
assert.Equal("rGroup", client.resourceGroup)
|
||||
assert.NotEmpty(client.uid)
|
||||
}
|
||||
|
@ -10,8 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure/internal/poller"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
||||
@ -213,45 +211,6 @@ type CreateScaleSetInput struct {
|
||||
ConfidentialVM bool
|
||||
}
|
||||
|
||||
// CreateResourceGroup creates a resource group.
|
||||
func (c *Client) CreateResourceGroup(ctx context.Context) error {
|
||||
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
|
||||
armresources.ResourceGroup{
|
||||
Location: &c.location,
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.resourceGroup = c.name + "-" + c.uid
|
||||
return nil
|
||||
}
|
||||
|
||||
// TerminateResourceGroup terminates a resource group.
|
||||
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
|
||||
if c.resourceGroup == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
|
||||
Frequency: c.pollFrequency,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
c.workers = nil
|
||||
c.controlPlanes = nil
|
||||
c.resourceGroup = ""
|
||||
c.subnetID = ""
|
||||
c.networkSecurityGroup = ""
|
||||
c.workerScaleSet = ""
|
||||
c.controlPlaneScaleSet = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
|
||||
type scaleSetCreationPollingHandler struct {
|
||||
done bool
|
||||
|
@ -7,126 +7,16 @@ import (
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestCreateResourceGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
wantErr bool
|
||||
}{
|
||||
"successful create": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{},
|
||||
},
|
||||
"failed create": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
client := Client{
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
workers: make(cloudtypes.Instances),
|
||||
controlPlanes: make(cloudtypes.Instances),
|
||||
}
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(client.CreateResourceGroup(ctx))
|
||||
} else {
|
||||
assert.NoError(client.CreateResourceGroup(ctx))
|
||||
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminateResourceGroup(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
clientWithResourceGroup := Client{
|
||||
resourceGroup: "name",
|
||||
location: "location",
|
||||
name: "name",
|
||||
uid: "uid",
|
||||
subnetID: "subnet",
|
||||
workerScaleSet: "node-scale-set",
|
||||
controlPlaneScaleSet: "controlplane-scale-set",
|
||||
workers: cloudtypes.Instances{
|
||||
"0": {
|
||||
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
controlPlanes: cloudtypes.Instances{
|
||||
"0": {
|
||||
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
testCases := map[string]struct {
|
||||
resourceGroup string
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
client Client
|
||||
wantErr bool
|
||||
}{
|
||||
"successful terminate": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{},
|
||||
client: clientWithResourceGroup,
|
||||
},
|
||||
"no resource group to terminate": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{},
|
||||
client: Client{},
|
||||
resourceGroup: "",
|
||||
},
|
||||
"failed terminate": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
|
||||
client: clientWithResourceGroup,
|
||||
wantErr: true,
|
||||
},
|
||||
"failed to poll terminate response": {
|
||||
resourceGroupAPI: stubResourceGroupAPI{pollErr: someErr},
|
||||
client: clientWithResourceGroup,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
tc.client.resourceGroupAPI = tc.resourceGroupAPI
|
||||
ctx := context.Background()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(tc.client.TerminateResourceGroup(ctx))
|
||||
return
|
||||
}
|
||||
assert.NoError(tc.client.TerminateResourceGroup(ctx))
|
||||
assert.Empty(tc.client.resourceGroup)
|
||||
assert.Empty(tc.client.subnetID)
|
||||
assert.Empty(tc.client.workers)
|
||||
assert.Empty(tc.client.controlPlanes)
|
||||
assert.Empty(tc.client.workerScaleSet)
|
||||
assert.Empty(tc.client.controlPlaneScaleSet)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateInstances(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
testCases := map[string]struct {
|
||||
publicIPAddressesAPI publicIPAddressesAPI
|
||||
networkInterfacesAPI networkInterfacesAPI
|
||||
scaleSetsAPI scaleSetsAPI
|
||||
resourceGroupAPI resourceGroupAPI
|
||||
roleAssignmentsAPI roleAssignmentsAPI
|
||||
createInstancesInput CreateInstancesInput
|
||||
wantErr bool
|
||||
}{
|
||||
@ -138,8 +28,6 @@ func TestCreateInstances(t *testing.T) {
|
||||
VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}},
|
||||
},
|
||||
},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
CountControlPlanes: 3,
|
||||
CountWorkers: 3,
|
||||
@ -153,8 +41,6 @@ func TestCreateInstances(t *testing.T) {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
CountControlPlanes: 3,
|
||||
CountWorkers: 3,
|
||||
@ -169,8 +55,6 @@ func TestCreateInstances(t *testing.T) {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
CountControlPlanes: 3,
|
||||
CountWorkers: 3,
|
||||
@ -185,8 +69,6 @@ func TestCreateInstances(t *testing.T) {
|
||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
|
||||
scaleSetsAPI: stubScaleSetsAPI{},
|
||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
||||
createInstancesInput: CreateInstancesInput{
|
||||
CountWorkers: 3,
|
||||
InstanceType: "type",
|
||||
@ -211,8 +93,6 @@ func TestCreateInstances(t *testing.T) {
|
||||
publicIPAddressesAPI: tc.publicIPAddressesAPI,
|
||||
networkInterfacesAPI: tc.networkInterfacesAPI,
|
||||
scaleSetsAPI: tc.scaleSetsAPI,
|
||||
resourceGroupAPI: tc.resourceGroupAPI,
|
||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
||||
workers: make(cloudtypes.Instances),
|
||||
controlPlanes: make(cloudtypes.Instances),
|
||||
loadBalancerPubIP: "lbip",
|
||||
@ -232,11 +112,3 @@ func TestCreateInstances(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
|
||||
return &stubResourceGroupAPI{
|
||||
getResourceGroup: armresources.ResourceGroup{
|
||||
ID: to.Ptr("resource-group-id"),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
133
cli/internal/azure/client/terminate.go
Normal file
133
cli/internal/azure/client/terminate.go
Normal file
@ -0,0 +1,133 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
)
|
||||
|
||||
// TerminateResourceGroupResources deletes all resources from the resource group.
|
||||
func (c *Client) TerminateResourceGroupResources(ctx context.Context) error {
|
||||
const timeOut = 10 * time.Minute
|
||||
ctx, cancel := context.WithTimeout(ctx, timeOut)
|
||||
defer cancel()
|
||||
|
||||
pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20)
|
||||
delete := make(chan struct{}, 1)
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
go func() { // This routine lists resources and starts their deletion, where possible.
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
close(pollers)
|
||||
for range delete { // drain channel
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
ids, err := c.getResourceIDList(ctx)
|
||||
if err != nil {
|
||||
time.Sleep(3 * time.Second)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(ids) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
poller, err := c.deleteResourceByID(ctx, id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pollers <- poller
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case _, ok := <-delete:
|
||||
if !ok { // channel was closed
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() { // This routine polls for for the deletions to complete.
|
||||
defer wg.Done()
|
||||
defer close(delete)
|
||||
|
||||
for poller := range pollers {
|
||||
_, err := poller.PollUntilDone(ctx, nil)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case delete <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) {
|
||||
var ids []string
|
||||
pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil)
|
||||
for pager.More() {
|
||||
page, err := pager.NextPage(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err)
|
||||
}
|
||||
for _, resource := range page.Value {
|
||||
if resource.ID == nil {
|
||||
return nil, fmt.Errorf("resource %v has no ID", resource)
|
||||
}
|
||||
ids = append(ids, *resource.ID)
|
||||
}
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (c *Client) deleteResourceByID(ctx context.Context, id string,
|
||||
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
|
||||
apiVersion := "2020-02-02"
|
||||
|
||||
// First try, API version unknown, will fail.
|
||||
poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
|
||||
if isVersionWrongErr(err) {
|
||||
// bad hack, but easiest way to get the right API version
|
||||
apiVersion = parseAPIVersionFromErr(err)
|
||||
poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
|
||||
}
|
||||
return poller, err
|
||||
}
|
||||
|
||||
func isVersionWrongErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), "NoRegisteredProviderFound") &&
|
||||
strings.Contains(err.Error(), "The supported api-versions are")
|
||||
}
|
||||
|
||||
var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`)
|
||||
|
||||
func parseAPIVersionFromErr(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
}
|
||||
matches := apiVersionRegex.FindStringSubmatch(err.Error())
|
||||
return matches[1]
|
||||
}
|
139
cli/internal/azure/client/terminate_test.go
Normal file
139
cli/internal/azure/client/terminate_test.go
Normal file
@ -0,0 +1,139 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestTerminateResourceGroupResources(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'")
|
||||
|
||||
testCases := map[string]struct {
|
||||
resourceAPI resourceAPI
|
||||
}{
|
||||
"no resources": {
|
||||
resourceAPI: &fakeResourceAPI{},
|
||||
},
|
||||
"some resources": {
|
||||
resourceAPI: &fakeResourceAPI{
|
||||
resources: map[string]fakeResource{
|
||||
"id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr},
|
||||
"id-1": {beginDeleteByIDErr: apiVersionErr},
|
||||
"id-2": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
client := &Client{
|
||||
resourceAPI: tc.resourceAPI,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := client.TerminateResourceGroupResources(ctx)
|
||||
assert.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeResourceAPI struct {
|
||||
resources map[string]fakeResource
|
||||
fetchErr error
|
||||
}
|
||||
|
||||
type fakeResource struct {
|
||||
beginDeleteByIDErr error
|
||||
pollErr error
|
||||
}
|
||||
|
||||
func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string,
|
||||
options *armresources.ClientListByResourceGroupOptions,
|
||||
) *runtime.Pager[armresources.ClientListByResourceGroupResponse] {
|
||||
pager := &stubClientListByResourceGroupResponsePager{
|
||||
resources: a.resources,
|
||||
fetchErr: a.fetchErr,
|
||||
}
|
||||
return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{
|
||||
More: pager.moreFunc(),
|
||||
Fetcher: pager.fetcherFunc(),
|
||||
})
|
||||
}
|
||||
|
||||
func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
|
||||
options *armresources.ClientBeginDeleteByIDOptions,
|
||||
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
|
||||
res := a.resources[resourceID]
|
||||
|
||||
pollErr := res.pollErr
|
||||
if pollErr != nil {
|
||||
res.pollErr = nil
|
||||
}
|
||||
|
||||
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{
|
||||
Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{
|
||||
result: armresources.ClientDeleteByIDResponse{},
|
||||
resultErr: pollErr,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
beginDeleteByIDErr := res.beginDeleteByIDErr
|
||||
if beginDeleteByIDErr != nil {
|
||||
res.beginDeleteByIDErr = nil
|
||||
}
|
||||
|
||||
if res.beginDeleteByIDErr == nil && res.pollErr == nil {
|
||||
delete(a.resources, resourceID)
|
||||
fmt.Printf("fake delete %s\n", resourceID)
|
||||
} else {
|
||||
a.resources[resourceID] = res
|
||||
}
|
||||
|
||||
return poller, beginDeleteByIDErr
|
||||
}
|
||||
|
||||
type stubClientListByResourceGroupResponsePager struct {
|
||||
resources map[string]fakeResource
|
||||
fetchErr error
|
||||
more bool
|
||||
}
|
||||
|
||||
func (p *stubClientListByResourceGroupResponsePager) moreFunc() func(
|
||||
armresources.ClientListByResourceGroupResponse) bool {
|
||||
return func(armresources.ClientListByResourceGroupResponse) bool {
|
||||
return p.more
|
||||
}
|
||||
}
|
||||
|
||||
func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func(
|
||||
context.Context, *armresources.ClientListByResourceGroupResponse) (
|
||||
armresources.ClientListByResourceGroupResponse, error) {
|
||||
return func(context.Context, *armresources.ClientListByResourceGroupResponse) (
|
||||
armresources.ClientListByResourceGroupResponse, error,
|
||||
) {
|
||||
var resources []*armresources.GenericResourceExpanded
|
||||
for id := range p.resources {
|
||||
resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)})
|
||||
}
|
||||
p.more = false
|
||||
return armresources.ClientListByResourceGroupResponse{
|
||||
ResourceListResult: armresources.ResourceListResult{
|
||||
Value: resources,
|
||||
},
|
||||
}, p.fetchErr
|
||||
}
|
||||
}
|
@ -26,12 +26,9 @@ type azureclient interface {
|
||||
GetState() state.ConstellationState
|
||||
SetState(state.ConstellationState)
|
||||
CreateApplicationInsight(ctx context.Context) error
|
||||
CreateResourceGroup(ctx context.Context) error
|
||||
CreateExternalLoadBalancer(ctx context.Context) error
|
||||
CreateVirtualNetwork(ctx context.Context) error
|
||||
CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error
|
||||
CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error
|
||||
CreateServicePrincipal(ctx context.Context) (string, error)
|
||||
TerminateResourceGroup(ctx context.Context) error
|
||||
TerminateServicePrincipal(ctx context.Context) error
|
||||
TerminateResourceGroupResources(ctx context.Context) error
|
||||
}
|
||||
|
@ -79,11 +79,6 @@ func (c *fakeAzureClient) CreateApplicationInsight(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
|
||||
c.resourceGroup = "resource-group"
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
||||
c.subnetID = "subnet"
|
||||
return nil
|
||||
@ -123,17 +118,8 @@ func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
|
||||
}.ToCloudServiceAccountURI(), nil
|
||||
}
|
||||
|
||||
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error {
|
||||
if c.resourceGroup == "" {
|
||||
return nil
|
||||
}
|
||||
c.workers = nil
|
||||
c.controlPlanes = nil
|
||||
c.resourceGroup = ""
|
||||
c.subnetID = ""
|
||||
c.networkSecurityGroup = ""
|
||||
c.workerScaleSet = ""
|
||||
c.controlPlaneScaleSet = ""
|
||||
func (c *fakeAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
|
||||
// TODO(katexochen)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -146,18 +132,17 @@ func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
||||
}
|
||||
|
||||
type stubAzureClient struct {
|
||||
terminateResourceGroupCalled bool
|
||||
terminateServicePrincipalCalled bool
|
||||
terminateResourceGroupResourcesCalled bool
|
||||
terminateServicePrincipalCalled bool
|
||||
|
||||
createApplicationInsightErr error
|
||||
createResourceGroupErr error
|
||||
createVirtualNetworkErr error
|
||||
createSecurityGroupErr error
|
||||
createLoadBalancerErr error
|
||||
createInstancesErr error
|
||||
createServicePrincipalErr error
|
||||
terminateResourceGroupErr error
|
||||
terminateServicePrincipalErr error
|
||||
createApplicationInsightErr error
|
||||
createVirtualNetworkErr error
|
||||
createSecurityGroupErr error
|
||||
createLoadBalancerErr error
|
||||
createInstancesErr error
|
||||
createServicePrincipalErr error
|
||||
terminateResourceGroupResourcesErr error
|
||||
terminateServicePrincipalErr error
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) GetState() state.ConstellationState {
|
||||
@ -175,10 +160,6 @@ func (c *stubAzureClient) CreateApplicationInsight(ctx context.Context) error {
|
||||
return c.createApplicationInsightErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
|
||||
return c.createResourceGroupErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
||||
return c.createVirtualNetworkErr
|
||||
}
|
||||
@ -198,9 +179,9 @@ func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
|
||||
}.ToCloudServiceAccountURI(), c.createServicePrincipalErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error {
|
||||
c.terminateResourceGroupCalled = true
|
||||
return c.terminateResourceGroupErr
|
||||
func (c *stubAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
|
||||
c.terminateResourceGroupResourcesCalled = true
|
||||
return c.terminateResourceGroupResourcesErr
|
||||
}
|
||||
|
||||
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
||||
|
@ -18,7 +18,7 @@ import (
|
||||
type Creator struct {
|
||||
out io.Writer
|
||||
newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error)
|
||||
newAzureClient func(subscriptionID, tenantID, name, location string) (azureclient, error)
|
||||
newAzureClient func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error)
|
||||
}
|
||||
|
||||
// NewCreator creates a new creator.
|
||||
@ -28,8 +28,8 @@ func NewCreator(out io.Writer) *Creator {
|
||||
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
|
||||
return gcpcl.NewInitialized(ctx, project, zone, region, name)
|
||||
},
|
||||
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) {
|
||||
return azurecl.NewInitialized(subscriptionID, tenantID, name, location)
|
||||
newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
|
||||
return azurecl.NewInitialized(subscriptionID, tenantID, name, location, resourceGroup)
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -57,6 +57,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
|
||||
config.Provider.Azure.TenantID,
|
||||
name,
|
||||
config.Provider.Azure.Location,
|
||||
config.Provider.Azure.ResourceGroup,
|
||||
)
|
||||
if err != nil {
|
||||
return state.ConstellationState{}, err
|
||||
@ -144,9 +145,6 @@ func (c *Creator) createAzure(ctx context.Context, cl azureclient, config *confi
|
||||
) (stat state.ConstellationState, retErr error) {
|
||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl})
|
||||
|
||||
if err := cl.CreateResourceGroup(ctx); err != nil {
|
||||
return state.ConstellationState{}, err
|
||||
}
|
||||
if err := cl.CreateApplicationInsight(ctx); err != nil {
|
||||
return state.ConstellationState{}, err
|
||||
}
|
||||
|
@ -51,7 +51,6 @@ func TestCreator(t *testing.T) {
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureResourceGroup: "resource-group",
|
||||
AzureSubnet: "subnet",
|
||||
AzureNetworkSecurityGroup: "network-security-group",
|
||||
AzureWorkerScaleSet: "workers-scale-set",
|
||||
@ -123,13 +122,6 @@ func TestCreator(t *testing.T) {
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
},
|
||||
"azure CreateResourceGroup error": {
|
||||
azureclient: &stubAzureClient{createResourceGroupErr: someErr},
|
||||
provider: cloudprovider.Azure,
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
wantRollback: true,
|
||||
},
|
||||
"azure CreateVirtualNetwork error": {
|
||||
azureclient: &stubAzureClient{createVirtualNetworkErr: someErr},
|
||||
provider: cloudprovider.Azure,
|
||||
@ -167,7 +159,7 @@ func TestCreator(t *testing.T) {
|
||||
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
|
||||
return tc.gcpclient, tc.newGCPClientErr
|
||||
},
|
||||
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) {
|
||||
newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
|
||||
return tc.azureclient, tc.newAzureClientErr
|
||||
},
|
||||
}
|
||||
@ -186,7 +178,7 @@ func TestCreator(t *testing.T) {
|
||||
assert.True(cl.closeCalled)
|
||||
case cloudprovider.Azure:
|
||||
cl := tc.azureclient.(*stubAzureClient)
|
||||
assert.True(cl.terminateResourceGroupCalled)
|
||||
assert.True(cl.terminateResourceGroupResourcesCalled)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -46,5 +46,5 @@ type rollbackerAzure struct {
|
||||
}
|
||||
|
||||
func (r *rollbackerAzure) rollback(ctx context.Context) error {
|
||||
return r.client.TerminateResourceGroup(ctx)
|
||||
return r.client.TerminateResourceGroupResources(ctx)
|
||||
}
|
||||
|
@ -1,68 +0,0 @@
|
||||
package cloudcmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
azurecl "github.com/edgelesssys/constellation/cli/internal/azure/client"
|
||||
gcpcl "github.com/edgelesssys/constellation/cli/internal/gcp/client"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
// ServiceAccountCreator creates service accounts.
|
||||
type ServiceAccountCreator struct {
|
||||
newGCPClient func(ctx context.Context) (gcpclient, error)
|
||||
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
|
||||
}
|
||||
|
||||
func NewServiceAccountCreator() *ServiceAccountCreator {
|
||||
return &ServiceAccountCreator{
|
||||
newGCPClient: func(ctx context.Context) (gcpclient, error) {
|
||||
return gcpcl.NewFromDefault(ctx)
|
||||
},
|
||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
||||
return azurecl.NewFromDefault(subscriptionID, tenantID)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Create creates a new cloud provider service account with access to the created resources.
|
||||
func (c *ServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
|
||||
) (string, state.ConstellationState, error) {
|
||||
provider := cloudprovider.FromString(stat.CloudProvider)
|
||||
switch provider {
|
||||
case cloudprovider.GCP:
|
||||
return "", state.ConstellationState{}, fmt.Errorf("creating service account not supported for GCP")
|
||||
case cloudprovider.Azure:
|
||||
cl, err := c.newAzureClient(stat.AzureSubscription, stat.AzureTenant)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, err
|
||||
}
|
||||
|
||||
serviceAccount, stat, err := c.createServiceAccountAzure(ctx, cl, stat, config)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, err
|
||||
}
|
||||
|
||||
return serviceAccount, stat, err
|
||||
case cloudprovider.QEMU:
|
||||
return "unsupported://qemu", stat, nil
|
||||
default:
|
||||
return "", state.ConstellationState{}, fmt.Errorf("unsupported provider: %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ServiceAccountCreator) createServiceAccountAzure(ctx context.Context, cl azureclient,
|
||||
stat state.ConstellationState, _ *config.Config,
|
||||
) (string, state.ConstellationState, error) {
|
||||
cl.SetState(stat)
|
||||
|
||||
serviceAccount, err := cl.CreateServicePrincipal(ctx)
|
||||
if err != nil {
|
||||
return "", state.ConstellationState{}, fmt.Errorf("creating service account: %w", err)
|
||||
}
|
||||
|
||||
return serviceAccount, cl.GetState(), nil
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
package cloudcmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestServiceAccountCreator(t *testing.T) {
|
||||
someAzureState := func() state.ConstellationState {
|
||||
return state.ConstellationState{
|
||||
CloudProvider: cloudprovider.Azure.String(),
|
||||
}
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
newGCPClient func(ctx context.Context) (gcpclient, error)
|
||||
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
|
||||
state state.ConstellationState
|
||||
config *config.Config
|
||||
wantErr bool
|
||||
wantStateMutator func(*state.ConstellationState)
|
||||
}{
|
||||
"azure": {
|
||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
||||
return &fakeAzureClient{}, nil
|
||||
},
|
||||
state: someAzureState(),
|
||||
config: config.Default(),
|
||||
wantStateMutator: func(stat *state.ConstellationState) {
|
||||
stat.AzureADAppObjectID = "00000000-0000-0000-0000-000000000001"
|
||||
},
|
||||
},
|
||||
"azure newAzureClient error": {
|
||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
||||
return nil, someErr
|
||||
},
|
||||
state: someAzureState(),
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
},
|
||||
"azure client createServiceAccount error": {
|
||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
||||
return &stubAzureClient{createServicePrincipalErr: someErr}, nil
|
||||
},
|
||||
state: someAzureState(),
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
},
|
||||
"qemu": {
|
||||
state: state.ConstellationState{CloudProvider: "qemu"},
|
||||
wantStateMutator: func(cs *state.ConstellationState) {},
|
||||
config: config.Default(),
|
||||
},
|
||||
"unknown cloud provider": {
|
||||
state: state.ConstellationState{},
|
||||
config: config.Default(),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
creator := &ServiceAccountCreator{
|
||||
newGCPClient: tc.newGCPClient,
|
||||
newAzureClient: tc.newAzureClient,
|
||||
}
|
||||
|
||||
serviceAccount, state, err := creator.Create(context.Background(), tc.state, tc.config)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.NotEmpty(serviceAccount)
|
||||
tc.wantStateMutator(&tc.state)
|
||||
assert.Equal(tc.state, state)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -72,8 +72,5 @@ func (t *Terminator) terminateGCP(ctx context.Context, cl gcpclient, state state
|
||||
func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error {
|
||||
cl.SetState(state)
|
||||
|
||||
if err := cl.TerminateServicePrincipal(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return cl.TerminateResourceGroup(ctx)
|
||||
return cl.TerminateResourceGroupResources(ctx)
|
||||
}
|
||||
|
@ -41,7 +41,6 @@ func TestTerminator(t *testing.T) {
|
||||
AzureControlPlaneInstances: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureResourceGroup: "group",
|
||||
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
|
||||
}
|
||||
}
|
||||
@ -88,13 +87,8 @@ func TestTerminator(t *testing.T) {
|
||||
state: someAzureState(),
|
||||
wantErr: true,
|
||||
},
|
||||
"azure terminateServicePrincipal error": {
|
||||
azureclient: &stubAzureClient{terminateServicePrincipalErr: someErr},
|
||||
state: someAzureState(),
|
||||
wantErr: true,
|
||||
},
|
||||
"azure terminateResourceGroup error": {
|
||||
azureclient: &stubAzureClient{terminateResourceGroupErr: someErr},
|
||||
"azure terminateResourceGroupResources error": {
|
||||
azureclient: &stubAzureClient{terminateResourceGroupResourcesErr: someErr},
|
||||
state: someAzureState(),
|
||||
wantErr: true,
|
||||
},
|
||||
@ -132,8 +126,7 @@ func TestTerminator(t *testing.T) {
|
||||
assert.True(cl.closeCalled)
|
||||
case cloudprovider.Azure:
|
||||
cl := tc.azureclient.(*stubAzureClient)
|
||||
assert.True(cl.terminateResourceGroupCalled)
|
||||
assert.True(cl.terminateServicePrincipalCalled)
|
||||
assert.True(cl.terminateResourceGroupResourcesCalled)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
@ -21,8 +21,3 @@ type cloudCreator interface {
|
||||
type cloudTerminator interface {
|
||||
Terminate(context.Context, state.ConstellationState) error
|
||||
}
|
||||
|
||||
type serviceAccountCreator interface {
|
||||
Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
|
||||
) (string, state.ConstellationState, error)
|
||||
}
|
||||
|
@ -47,12 +47,3 @@ func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationStat
|
||||
func (c *stubCloudTerminator) Called() bool {
|
||||
return c.called
|
||||
}
|
||||
|
||||
type stubServiceAccountCreator struct {
|
||||
cloudServiceAccountURI string
|
||||
createErr error
|
||||
}
|
||||
|
||||
func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
||||
return c.cloudServiceAccountURI, stat, c.createErr
|
||||
}
|
||||
|
@ -55,17 +55,16 @@ func NewInitCmd() *cobra.Command {
|
||||
// runInitialize runs the initialize command.
|
||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
||||
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
|
||||
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
|
||||
}
|
||||
helmLoader := &helm.ChartLoader{}
|
||||
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient())
|
||||
return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient())
|
||||
}
|
||||
|
||||
// initialize initializes a Constellation.
|
||||
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
|
||||
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
|
||||
fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
|
||||
) error {
|
||||
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||
if err != nil {
|
||||
@ -105,22 +104,9 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||
return err
|
||||
}
|
||||
|
||||
var serviceAccURI string
|
||||
// Temporary legacy flow for Azure.
|
||||
if provider == cloudprovider.Azure {
|
||||
cmd.Println("Creating service account ...")
|
||||
serviceAccURI, stat, err = serviceAccCreator.Create(cmd.Context(), stat, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
serviceAccURI, err = getMarschaledServiceAccountURI(provider, config, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
serviceAccURI, err := getMarschaledServiceAccountURI(provider, config, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
workers, err := getScalingGroupsFromState(stat, config)
|
||||
|
@ -69,49 +69,54 @@ func TestInitialize(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
state *state.ConstellationState
|
||||
existingIDFile *clusterIDsFile
|
||||
serviceAccCreator serviceAccountCreator
|
||||
configMutator func(*config.Config)
|
||||
serviceAccKey *gcpshared.ServiceAccountKey
|
||||
helmLoader stubHelmLoader
|
||||
initServerAPI *stubInitServer
|
||||
endpointFlag string
|
||||
setAutoscaleFlag bool
|
||||
wantErr bool
|
||||
state *state.ConstellationState
|
||||
idFile *clusterIDsFile
|
||||
configMutator func(*config.Config)
|
||||
serviceAccKey *gcpshared.ServiceAccountKey
|
||||
helmLoader stubHelmLoader
|
||||
initServerAPI *stubInitServer
|
||||
endpointFlag string
|
||||
setAutoscaleFlag bool
|
||||
wantErr bool
|
||||
}{
|
||||
"initialize some gcp instances": {
|
||||
state: testGcpState,
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
state: testGcpState,
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
},
|
||||
"initialize some azure instances": {
|
||||
state: testAzureState,
|
||||
serviceAccCreator: &stubServiceAccountCreator{},
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
state: testAzureState,
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
configMutator: func(c *config.Config) {
|
||||
c.Provider.Azure.ResourceGroup = "resourceGroup"
|
||||
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
|
||||
},
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
},
|
||||
"initialize some qemu instances": {
|
||||
state: testQemuState,
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
state: testQemuState,
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
},
|
||||
"initialize gcp with autoscaling": {
|
||||
state: testGcpState,
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||
serviceAccKey: gcpServiceAccKey,
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
setAutoscaleFlag: true,
|
||||
},
|
||||
"initialize azure with autoscaling": {
|
||||
state: testAzureState,
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
serviceAccCreator: &stubServiceAccountCreator{},
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
setAutoscaleFlag: true,
|
||||
state: testAzureState,
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
configMutator: func(c *config.Config) {
|
||||
c.Provider.Azure.ResourceGroup = "resourceGroup"
|
||||
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
|
||||
},
|
||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||
setAutoscaleFlag: true,
|
||||
},
|
||||
"initialize with endpoint flag": {
|
||||
state: testGcpState,
|
||||
@ -121,27 +126,30 @@ func TestInitialize(t *testing.T) {
|
||||
endpointFlag: "192.0.2.1",
|
||||
},
|
||||
"empty state": {
|
||||
state: &state.ConstellationState{},
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{},
|
||||
wantErr: true,
|
||||
state: &state.ConstellationState{},
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{},
|
||||
wantErr: true,
|
||||
},
|
||||
"neither endpoint flag nor id file": {
|
||||
state: &state.ConstellationState{},
|
||||
wantErr: true,
|
||||
},
|
||||
"init call fails": {
|
||||
state: testGcpState,
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{initErr: someErr},
|
||||
wantErr: true,
|
||||
state: testGcpState,
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{initErr: someErr},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail to create service account": {
|
||||
state: testAzureState,
|
||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
initServerAPI: &stubInitServer{},
|
||||
serviceAccCreator: &stubServiceAccountCreator{createErr: someErr},
|
||||
wantErr: true,
|
||||
state: testAzureState,
|
||||
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||
configMutator: func(c *config.Config) {
|
||||
c.Provider.Azure.ResourceGroup = "resourceGroup"
|
||||
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
|
||||
},
|
||||
initServerAPI: &stubInitServer{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail to load helm charts": {
|
||||
state: testGcpState,
|
||||
@ -194,8 +202,8 @@ func TestInitialize(t *testing.T) {
|
||||
if tc.state != nil {
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone))
|
||||
}
|
||||
if tc.existingIDFile != nil {
|
||||
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.existingIDFile, file.OptNone))
|
||||
if tc.idFile != nil {
|
||||
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.idFile, file.OptNone))
|
||||
}
|
||||
if tc.serviceAccKey != nil {
|
||||
require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone))
|
||||
@ -206,7 +214,7 @@ func TestInitialize(t *testing.T) {
|
||||
defer cancel()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
err := initialize(cmd, newDialer, tc.serviceAccCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{})
|
||||
err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{})
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
@ -477,7 +485,7 @@ func TestAttestation(t *testing.T) {
|
||||
defer cancel()
|
||||
cmd.SetContext(ctx)
|
||||
|
||||
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
|
||||
err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
|
||||
assert.Error(err)
|
||||
// make sure the error is actually a TLS handshake error
|
||||
assert.Contains(err.Error(), "transport: authentication handshake failed")
|
||||
@ -548,6 +556,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
|
||||
conf.Provider.Azure.Location = "test-location"
|
||||
conf.Provider.Azure.UserAssignedIdentity = "test-identity"
|
||||
conf.Provider.Azure.Image = "some/image/location"
|
||||
conf.Provider.Azure.ResourceGroup = "test-resource-group"
|
||||
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
|
||||
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
|
||||
case cloudprovider.GCP:
|
||||
|
@ -158,6 +158,9 @@ type AzureConfig struct {
|
||||
// Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure
|
||||
UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"`
|
||||
// description: |
|
||||
// Resource group to use.
|
||||
ResourceGroup string `yaml:"resourceGroup" validate:"required"`
|
||||
// description: |
|
||||
// Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview
|
||||
ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"`
|
||||
}
|
||||
@ -244,6 +247,7 @@ func Default() *Config {
|
||||
TenantID: "",
|
||||
Location: "",
|
||||
UserAssignedIdentity: "",
|
||||
ResourceGroup: "",
|
||||
Image: DefaultImageAzure,
|
||||
StateDiskType: "Premium_LRS",
|
||||
Measurements: copyPCRMap(azurePCRs),
|
||||
|
@ -199,7 +199,7 @@ func init() {
|
||||
FieldName: "azure",
|
||||
},
|
||||
}
|
||||
AzureConfigDoc.Fields = make([]encoder.Doc, 9)
|
||||
AzureConfigDoc.Fields = make([]encoder.Doc, 10)
|
||||
AzureConfigDoc.Fields[0].Name = "subscription"
|
||||
AzureConfigDoc.Fields[0].Type = "string"
|
||||
AzureConfigDoc.Fields[0].Note = ""
|
||||
@ -240,6 +240,11 @@ func init() {
|
||||
AzureConfigDoc.Fields[7].Note = ""
|
||||
AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
|
||||
AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
|
||||
AzureConfigDoc.Fields[8].Name = "resourceGroup"
|
||||
AzureConfigDoc.Fields[8].Type = "string"
|
||||
AzureConfigDoc.Fields[8].Note = ""
|
||||
AzureConfigDoc.Fields[8].Description = "Resource group to use."
|
||||
AzureConfigDoc.Fields[8].Comments[encoder.LineComment] = "Resource group to use."
|
||||
AzureConfigDoc.Fields[8].Name = "confidentialVM"
|
||||
AzureConfigDoc.Fields[8].Type = "bool"
|
||||
AzureConfigDoc.Fields[8].Note = ""
|
||||
|
Loading…
x
Reference in New Issue
Block a user