mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-06-07 14:12:57 -04:00
Manually manage resource group on Azure
This commit is contained in:
parent
e6ae54a25a
commit
f15605cb45
25 changed files with 403 additions and 1162 deletions
|
@ -1,184 +0,0 @@
|
||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math/big"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
|
||||||
"github.com/Azure/go-autorest/autorest"
|
|
||||||
"github.com/Azure/go-autorest/autorest/azure"
|
|
||||||
"github.com/Azure/go-autorest/autorest/date"
|
|
||||||
"github.com/Azure/go-autorest/autorest/to"
|
|
||||||
"github.com/edgelesssys/constellation/internal/azureshared"
|
|
||||||
"github.com/google/uuid"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
|
|
||||||
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
|
|
||||||
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
|
|
||||||
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
|
|
||||||
)
|
|
||||||
|
|
||||||
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
|
|
||||||
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
|
|
||||||
createAppRes, err := c.createADApplication(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
c.adAppObjectID = createAppRes.ObjectID
|
|
||||||
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
return azureshared.ApplicationCredentials{
|
|
||||||
TenantID: c.tenantID,
|
|
||||||
ClientID: createAppRes.AppID,
|
|
||||||
ClientSecret: clientSecret,
|
|
||||||
Location: c.location,
|
|
||||||
}.ToCloudServiceAccountURI(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
|
|
||||||
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
|
|
||||||
if c.adAppObjectID == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.adAppObjectID = ""
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createADApplication creates a new azure AD app.
|
|
||||||
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
|
|
||||||
createParameters := graphrbac.ApplicationCreateParameters{
|
|
||||||
AvailableToOtherTenants: to.BoolPtr(false),
|
|
||||||
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
|
|
||||||
}
|
|
||||||
app, err := c.applicationsAPI.Create(ctx, createParameters)
|
|
||||||
if err != nil {
|
|
||||||
return createADApplicationOutput{}, err
|
|
||||||
}
|
|
||||||
if app.AppID == nil || app.ObjectID == nil {
|
|
||||||
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
|
|
||||||
}
|
|
||||||
return createADApplicationOutput{
|
|
||||||
AppID: *app.AppID,
|
|
||||||
ObjectID: *app.ObjectID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// createAppServicePrincipal creates a new service principal for an azure AD app.
|
|
||||||
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
|
|
||||||
createParameters := graphrbac.ServicePrincipalCreateParameters{
|
|
||||||
AppID: &appID,
|
|
||||||
AccountEnabled: to.BoolPtr(true),
|
|
||||||
}
|
|
||||||
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
if servicePrincipal.ObjectID == nil {
|
|
||||||
return "", errors.New("creating AD service principal did not result in a valid object id")
|
|
||||||
}
|
|
||||||
return *servicePrincipal.ObjectID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateAppCredentials sets app client-secret for authentication.
|
|
||||||
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
|
|
||||||
keyID := uuid.New().String()
|
|
||||||
clientSecret, err := generateClientSecret()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("generating client secret: %w", err)
|
|
||||||
}
|
|
||||||
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
|
|
||||||
Value: &[]graphrbac.PasswordCredential{
|
|
||||||
{
|
|
||||||
StartDate: &date.Time{Time: time.Now()},
|
|
||||||
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
|
|
||||||
Value: to.StringPtr(clientSecret),
|
|
||||||
KeyID: to.StringPtr(keyID),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return clientSecret, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// assignResourceGroupRole assigns the service principal a role at resource group scope.
|
|
||||||
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
|
|
||||||
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
|
|
||||||
if err != nil || resourceGroup.ID == nil {
|
|
||||||
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
|
|
||||||
}
|
|
||||||
roleAssignmentID := uuid.New().String()
|
|
||||||
createParameters := authorization.RoleAssignmentCreateParameters{
|
|
||||||
Properties: &authorization.RoleAssignmentProperties{
|
|
||||||
PrincipalID: to.StringPtr(principalID),
|
|
||||||
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
|
|
||||||
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
|
|
||||||
// proper fix: use API version 2018-09-01-preview or later
|
|
||||||
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
|
|
||||||
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
|
|
||||||
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
|
|
||||||
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
|
|
||||||
var detailedErr autorest.DetailedError
|
|
||||||
var ok bool
|
|
||||||
if detailedErr, ok = err.(autorest.DetailedError); !ok {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var requestErr *azure.RequestError
|
|
||||||
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if requestErr.ServiceError.Code != "PrincipalNotFound" {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
time.Sleep(c.adReplicationLagCheckInterval)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
type createADApplicationOutput struct {
|
|
||||||
AppID string
|
|
||||||
ObjectID string
|
|
||||||
}
|
|
||||||
|
|
||||||
func generateClientSecret() (string, error) {
|
|
||||||
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
|
|
||||||
|
|
||||||
pwLen := 64
|
|
||||||
pw := make([]byte, 0, pwLen)
|
|
||||||
for i := 0; i < pwLen; i++ {
|
|
||||||
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
pw = append(pw, letters[n.Int64()])
|
|
||||||
}
|
|
||||||
return string(pw), nil
|
|
||||||
}
|
|
|
@ -1,358 +0,0 @@
|
||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
|
||||||
"github.com/Azure/go-autorest/autorest"
|
|
||||||
"github.com/Azure/go-autorest/autorest/azure"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCreateServicePrincipal(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
testCases := map[string]struct {
|
|
||||||
applicationsAPI applicationsAPI
|
|
||||||
servicePrincipalsAPI servicePrincipalsAPI
|
|
||||||
roleAssignmentsAPI roleAssignmentsAPI
|
|
||||||
resourceGroupAPI resourceGroupAPI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful create": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{},
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"failed app create": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{
|
|
||||||
createErr: someErr,
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"failed service principal create": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{},
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
|
||||||
createErr: someErr,
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"failed role assignment": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{},
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
|
||||||
createErrors: []error{someErr},
|
|
||||||
},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"failed update creds": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{
|
|
||||||
updateCredentialsErr: someErr,
|
|
||||||
},
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client := Client{
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
resourceGroup: "resource-group",
|
|
||||||
applicationsAPI: tc.applicationsAPI,
|
|
||||||
servicePrincipalsAPI: tc.servicePrincipalsAPI,
|
|
||||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
|
||||||
resourceGroupAPI: tc.resourceGroupAPI,
|
|
||||||
adReplicationLagCheckMaxRetries: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := client.CreateServicePrincipal(ctx)
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTerminateServicePrincipal(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
testCases := map[string]struct {
|
|
||||||
appObjectID string
|
|
||||||
applicationsAPI applicationsAPI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful terminate": {
|
|
||||||
appObjectID: "object-id",
|
|
||||||
applicationsAPI: stubApplicationsAPI{},
|
|
||||||
},
|
|
||||||
"nothing to terminate": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{},
|
|
||||||
},
|
|
||||||
"failed delete": {
|
|
||||||
appObjectID: "object-id",
|
|
||||||
applicationsAPI: stubApplicationsAPI{
|
|
||||||
deleteErr: someErr,
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client := Client{
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
resourceGroup: "resource-group",
|
|
||||||
adAppObjectID: tc.appObjectID,
|
|
||||||
applicationsAPI: tc.applicationsAPI,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := client.TerminateServicePrincipal(ctx)
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateADApplication(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
testCases := map[string]struct {
|
|
||||||
applicationsAPI applicationsAPI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful create": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{},
|
|
||||||
},
|
|
||||||
"failed app create": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{
|
|
||||||
createErr: someErr,
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"app create returns invalid appid": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{
|
|
||||||
createApplication: &graphrbac.Application{
|
|
||||||
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"app create returns invalid objectid": {
|
|
||||||
applicationsAPI: stubApplicationsAPI{
|
|
||||||
createApplication: &graphrbac.Application{
|
|
||||||
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client := Client{
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
applicationsAPI: tc.applicationsAPI,
|
|
||||||
}
|
|
||||||
|
|
||||||
appCredentials, err := client.createADApplication(ctx)
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(err)
|
|
||||||
assert.NotNil(appCredentials)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateAppServicePrincipal(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
testCases := map[string]struct {
|
|
||||||
servicePrincipalsAPI servicePrincipalsAPI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful create": {
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{},
|
|
||||||
},
|
|
||||||
"failed service principal create": {
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
|
||||||
createErr: someErr,
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"service principal create returns invalid objectid": {
|
|
||||||
servicePrincipalsAPI: stubServicePrincipalsAPI{
|
|
||||||
createServicePrincipal: &graphrbac.ServicePrincipal{},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client := Client{
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
servicePrincipalsAPI: tc.servicePrincipalsAPI,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := client.createAppServicePrincipal(ctx, "app-id")
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAssignOwnerOfResourceGroup(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
testCases := map[string]struct {
|
|
||||||
roleAssignmentsAPI roleAssignmentsAPI
|
|
||||||
resourceGroupAPI resourceGroupAPI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful assign": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"failed role assignment": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
|
||||||
createErrors: []error{someErr},
|
|
||||||
},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"failed resource group get": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getErr: someErr,
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"resource group get returns invalid id": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"create returns PrincipalNotFound the first time": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
|
||||||
createErrors: []error{
|
|
||||||
autorest.DetailedError{Original: &azure.RequestError{
|
|
||||||
ServiceError: &azure.ServiceError{
|
|
||||||
Code: "PrincipalNotFound",
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"create does not return request error": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
|
||||||
createErrors: []error{autorest.DetailedError{Original: someErr}},
|
|
||||||
},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"create service error code is unknown": {
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
|
|
||||||
createErrors: []error{
|
|
||||||
autorest.DetailedError{Original: &azure.RequestError{
|
|
||||||
ServiceError: &azure.ServiceError{
|
|
||||||
Code: "some-unknown-error-code",
|
|
||||||
},
|
|
||||||
}},
|
|
||||||
nil,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client := Client{
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
resourceGroup: "resource-group",
|
|
||||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
|
||||||
resourceGroupAPI: tc.resourceGroupAPI,
|
|
||||||
adReplicationLagCheckMaxRetries: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -64,15 +64,13 @@ type networkInterfacesAPI interface {
|
||||||
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
|
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type resourceGroupAPI interface {
|
type resourceAPI interface {
|
||||||
CreateOrUpdate(ctx context.Context, resourceGroupName string,
|
NewListByResourceGroupPager(resourceGroupName string,
|
||||||
parameters armresources.ResourceGroup,
|
options *armresources.ClientListByResourceGroupOptions,
|
||||||
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
|
) *runtime.Pager[armresources.ClientListByResourceGroupResponse]
|
||||||
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
|
BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
|
||||||
BeginDelete(ctx context.Context, resourceGroupName string,
|
options *armresources.ClientBeginDeleteByIDOptions,
|
||||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error)
|
||||||
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error)
|
|
||||||
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type applicationsAPI interface {
|
type applicationsAPI interface {
|
||||||
|
|
|
@ -4,15 +4,11 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
|
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
|
||||||
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
|
|
||||||
"github.com/Azure/go-autorest/autorest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type stubNetworksAPI struct {
|
type stubNetworksAPI struct {
|
||||||
|
@ -94,44 +90,6 @@ func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, r
|
||||||
return poller, a.createErr
|
return poller, a.createErr
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubResourceGroupAPI struct {
|
|
||||||
terminateErr error
|
|
||||||
createErr error
|
|
||||||
getErr error
|
|
||||||
getResourceGroup armresources.ResourceGroup
|
|
||||||
pollErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
|
|
||||||
parameters armresources.ResourceGroup,
|
|
||||||
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
|
|
||||||
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
|
|
||||||
) {
|
|
||||||
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
|
|
||||||
return armresources.ResourceGroupsClientGetResponse{
|
|
||||||
ResourceGroup: a.getResourceGroup,
|
|
||||||
}, a.getErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
|
|
||||||
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
|
|
||||||
*runtime.Poller[armresources.ResourceGroupsClientDeleteResponse], error,
|
|
||||||
) {
|
|
||||||
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ResourceGroupsClientDeleteResponse]{
|
|
||||||
Handler: &stubPoller[armresources.ResourceGroupsClientDeleteResponse]{
|
|
||||||
result: armresources.ResourceGroupsClientDeleteResponse{},
|
|
||||||
resultErr: a.pollErr,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return poller, a.terminateErr
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubScaleSetsAPI struct {
|
type stubScaleSetsAPI struct {
|
||||||
createErr error
|
createErr error
|
||||||
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
|
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
|
||||||
|
@ -282,71 +240,6 @@ func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubApplicationsAPI struct {
|
|
||||||
createErr error
|
|
||||||
deleteErr error
|
|
||||||
updateCredentialsErr error
|
|
||||||
createApplication *graphrbac.Application
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
|
|
||||||
if a.createErr != nil {
|
|
||||||
return graphrbac.Application{}, a.createErr
|
|
||||||
}
|
|
||||||
if a.createApplication != nil {
|
|
||||||
return *a.createApplication, nil
|
|
||||||
}
|
|
||||||
return graphrbac.Application{
|
|
||||||
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
|
|
||||||
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000001"),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
|
|
||||||
if a.deleteErr != nil {
|
|
||||||
return autorest.Response{}, a.deleteErr
|
|
||||||
}
|
|
||||||
return autorest.Response{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
|
|
||||||
if a.updateCredentialsErr != nil {
|
|
||||||
return autorest.Response{}, a.updateCredentialsErr
|
|
||||||
}
|
|
||||||
return autorest.Response{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubServicePrincipalsAPI struct {
|
|
||||||
createErr error
|
|
||||||
createServicePrincipal *graphrbac.ServicePrincipal
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
|
|
||||||
if a.createErr != nil {
|
|
||||||
return graphrbac.ServicePrincipal{}, a.createErr
|
|
||||||
}
|
|
||||||
if a.createServicePrincipal != nil {
|
|
||||||
return *a.createServicePrincipal, nil
|
|
||||||
}
|
|
||||||
return graphrbac.ServicePrincipal{
|
|
||||||
AppID: to.Ptr("00000000-0000-0000-0000-000000000000"),
|
|
||||||
ObjectID: to.Ptr("00000000-0000-0000-0000-000000000002"),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubRoleAssignmentsAPI struct {
|
|
||||||
createCounter int
|
|
||||||
createErrors []error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
|
|
||||||
a.createCounter++
|
|
||||||
if len(a.createErrors) == 0 {
|
|
||||||
return authorization.RoleAssignment{}, nil
|
|
||||||
}
|
|
||||||
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
|
|
||||||
}
|
|
||||||
|
|
||||||
type stubApplicationInsightsAPI struct {
|
type stubApplicationInsightsAPI struct {
|
||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,7 @@ const (
|
||||||
type Client struct {
|
type Client struct {
|
||||||
networksAPI
|
networksAPI
|
||||||
networkSecurityGroupsAPI
|
networkSecurityGroupsAPI
|
||||||
resourceGroupAPI
|
resourceAPI
|
||||||
scaleSetsAPI
|
scaleSetsAPI
|
||||||
publicIPAddressesAPI
|
publicIPAddressesAPI
|
||||||
networkInterfacesAPI
|
networkInterfacesAPI
|
||||||
|
@ -40,8 +40,6 @@ type Client struct {
|
||||||
applicationInsightsAPI
|
applicationInsightsAPI
|
||||||
|
|
||||||
pollFrequency time.Duration
|
pollFrequency time.Duration
|
||||||
adReplicationLagCheckInterval time.Duration
|
|
||||||
adReplicationLagCheckMaxRetries int
|
|
||||||
|
|
||||||
workers cloudtypes.Instances
|
workers cloudtypes.Instances
|
||||||
controlPlanes cloudtypes.Instances
|
controlPlanes cloudtypes.Instances
|
||||||
|
@ -83,10 +81,6 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
resGroupAPI, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
|
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -107,6 +101,10 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
|
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
|
||||||
applicationsAPI.Authorizer = graphAuthorizer
|
applicationsAPI.Authorizer = graphAuthorizer
|
||||||
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
|
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
|
||||||
|
@ -117,7 +115,7 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||||
return &Client{
|
return &Client{
|
||||||
networksAPI: netAPI,
|
networksAPI: netAPI,
|
||||||
networkSecurityGroupsAPI: netSecGrpAPI,
|
networkSecurityGroupsAPI: netSecGrpAPI,
|
||||||
resourceGroupAPI: resGroupAPI,
|
resourceAPI: resourceAPI,
|
||||||
scaleSetsAPI: scaleSetAPI,
|
scaleSetsAPI: scaleSetAPI,
|
||||||
publicIPAddressesAPI: publicIPAddressesAPI,
|
publicIPAddressesAPI: publicIPAddressesAPI,
|
||||||
networkInterfacesAPI: networkInterfacesAPI,
|
networkInterfacesAPI: networkInterfacesAPI,
|
||||||
|
@ -131,26 +129,25 @@ func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
|
||||||
workers: cloudtypes.Instances{},
|
workers: cloudtypes.Instances{},
|
||||||
controlPlanes: cloudtypes.Instances{},
|
controlPlanes: cloudtypes.Instances{},
|
||||||
pollFrequency: time.Second * 5,
|
pollFrequency: time.Second * 5,
|
||||||
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
|
|
||||||
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
|
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
|
||||||
// of the Constellation.
|
// of the Constellation.
|
||||||
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
|
func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) {
|
||||||
client, err := NewFromDefault(subscriptionID, tenantID)
|
client, err := NewFromDefault(subscriptionID, tenantID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = client.init(location, name)
|
err = client.init(location, name, resourceGroup)
|
||||||
return client, err
|
return client, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// init initializes the client.
|
// init initializes the client.
|
||||||
func (c *Client) init(location, name string) error {
|
func (c *Client) init(location, name, resourceGroup string) error {
|
||||||
c.location = location
|
c.location = location
|
||||||
c.name = name
|
c.name = name
|
||||||
|
c.resourceGroup = resourceGroup
|
||||||
uid, err := c.generateUID()
|
uid, err := c.generateUID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -84,8 +84,9 @@ func TestInit(t *testing.T) {
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
client := Client{}
|
client := Client{}
|
||||||
require.NoError(client.init("location", "name"))
|
require.NoError(client.init("location", "name", "rGroup"))
|
||||||
assert.Equal("location", client.location)
|
assert.Equal("location", client.location)
|
||||||
assert.Equal("name", client.name)
|
assert.Equal("name", client.name)
|
||||||
|
assert.Equal("rGroup", client.resourceGroup)
|
||||||
assert.NotEmpty(client.uid)
|
assert.NotEmpty(client.uid)
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,8 +10,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
|
||||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||||
"github.com/edgelesssys/constellation/cli/internal/azure/internal/poller"
|
"github.com/edgelesssys/constellation/cli/internal/azure/internal/poller"
|
||||||
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
||||||
|
@ -213,45 +211,6 @@ type CreateScaleSetInput struct {
|
||||||
ConfidentialVM bool
|
ConfidentialVM bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateResourceGroup creates a resource group.
|
|
||||||
func (c *Client) CreateResourceGroup(ctx context.Context) error {
|
|
||||||
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
|
|
||||||
armresources.ResourceGroup{
|
|
||||||
Location: &c.location,
|
|
||||||
}, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.resourceGroup = c.name + "-" + c.uid
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TerminateResourceGroup terminates a resource group.
|
|
||||||
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
|
|
||||||
if c.resourceGroup == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
|
|
||||||
Frequency: c.pollFrequency,
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
c.workers = nil
|
|
||||||
c.controlPlanes = nil
|
|
||||||
c.resourceGroup = ""
|
|
||||||
c.subnetID = ""
|
|
||||||
c.networkSecurityGroup = ""
|
|
||||||
c.workerScaleSet = ""
|
|
||||||
c.controlPlaneScaleSet = ""
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
|
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
|
||||||
type scaleSetCreationPollingHandler struct {
|
type scaleSetCreationPollingHandler struct {
|
||||||
done bool
|
done bool
|
||||||
|
|
|
@ -7,126 +7,16 @@ import (
|
||||||
|
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
|
||||||
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
|
||||||
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
|
||||||
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCreateResourceGroup(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
testCases := map[string]struct {
|
|
||||||
resourceGroupAPI resourceGroupAPI
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful create": {
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{},
|
|
||||||
},
|
|
||||||
"failed create": {
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
client := Client{
|
|
||||||
location: "location",
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
resourceGroupAPI: tc.resourceGroupAPI,
|
|
||||||
workers: make(cloudtypes.Instances),
|
|
||||||
controlPlanes: make(cloudtypes.Instances),
|
|
||||||
}
|
|
||||||
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(client.CreateResourceGroup(ctx))
|
|
||||||
} else {
|
|
||||||
assert.NoError(client.CreateResourceGroup(ctx))
|
|
||||||
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTerminateResourceGroup(t *testing.T) {
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
clientWithResourceGroup := Client{
|
|
||||||
resourceGroup: "name",
|
|
||||||
location: "location",
|
|
||||||
name: "name",
|
|
||||||
uid: "uid",
|
|
||||||
subnetID: "subnet",
|
|
||||||
workerScaleSet: "node-scale-set",
|
|
||||||
controlPlaneScaleSet: "controlplane-scale-set",
|
|
||||||
workers: cloudtypes.Instances{
|
|
||||||
"0": {
|
|
||||||
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
controlPlanes: cloudtypes.Instances{
|
|
||||||
"0": {
|
|
||||||
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
testCases := map[string]struct {
|
|
||||||
resourceGroup string
|
|
||||||
resourceGroupAPI resourceGroupAPI
|
|
||||||
client Client
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
"successful terminate": {
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{},
|
|
||||||
client: clientWithResourceGroup,
|
|
||||||
},
|
|
||||||
"no resource group to terminate": {
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{},
|
|
||||||
client: Client{},
|
|
||||||
resourceGroup: "",
|
|
||||||
},
|
|
||||||
"failed terminate": {
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
|
|
||||||
client: clientWithResourceGroup,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"failed to poll terminate response": {
|
|
||||||
resourceGroupAPI: stubResourceGroupAPI{pollErr: someErr},
|
|
||||||
client: clientWithResourceGroup,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
tc.client.resourceGroupAPI = tc.resourceGroupAPI
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(tc.client.TerminateResourceGroup(ctx))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.NoError(tc.client.TerminateResourceGroup(ctx))
|
|
||||||
assert.Empty(tc.client.resourceGroup)
|
|
||||||
assert.Empty(tc.client.subnetID)
|
|
||||||
assert.Empty(tc.client.workers)
|
|
||||||
assert.Empty(tc.client.controlPlanes)
|
|
||||||
assert.Empty(tc.client.workerScaleSet)
|
|
||||||
assert.Empty(tc.client.controlPlaneScaleSet)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateInstances(t *testing.T) {
|
func TestCreateInstances(t *testing.T) {
|
||||||
someErr := errors.New("failed")
|
someErr := errors.New("failed")
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
publicIPAddressesAPI publicIPAddressesAPI
|
publicIPAddressesAPI publicIPAddressesAPI
|
||||||
networkInterfacesAPI networkInterfacesAPI
|
networkInterfacesAPI networkInterfacesAPI
|
||||||
scaleSetsAPI scaleSetsAPI
|
scaleSetsAPI scaleSetsAPI
|
||||||
resourceGroupAPI resourceGroupAPI
|
|
||||||
roleAssignmentsAPI roleAssignmentsAPI
|
|
||||||
createInstancesInput CreateInstancesInput
|
createInstancesInput CreateInstancesInput
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
|
@ -138,8 +28,6 @@ func TestCreateInstances(t *testing.T) {
|
||||||
VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}},
|
VirtualMachineScaleSet: armcomputev2.VirtualMachineScaleSet{Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
createInstancesInput: CreateInstancesInput{
|
createInstancesInput: CreateInstancesInput{
|
||||||
CountControlPlanes: 3,
|
CountControlPlanes: 3,
|
||||||
CountWorkers: 3,
|
CountWorkers: 3,
|
||||||
|
@ -153,8 +41,6 @@ func TestCreateInstances(t *testing.T) {
|
||||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||||
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
|
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
|
||||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
createInstancesInput: CreateInstancesInput{
|
createInstancesInput: CreateInstancesInput{
|
||||||
CountControlPlanes: 3,
|
CountControlPlanes: 3,
|
||||||
CountWorkers: 3,
|
CountWorkers: 3,
|
||||||
|
@ -169,8 +55,6 @@ func TestCreateInstances(t *testing.T) {
|
||||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||||
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
networkInterfacesAPI: stubNetworkInterfacesAPI{},
|
||||||
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
|
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
|
||||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
createInstancesInput: CreateInstancesInput{
|
createInstancesInput: CreateInstancesInput{
|
||||||
CountControlPlanes: 3,
|
CountControlPlanes: 3,
|
||||||
CountWorkers: 3,
|
CountWorkers: 3,
|
||||||
|
@ -185,8 +69,6 @@ func TestCreateInstances(t *testing.T) {
|
||||||
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
|
||||||
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
|
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
|
||||||
scaleSetsAPI: stubScaleSetsAPI{},
|
scaleSetsAPI: stubScaleSetsAPI{},
|
||||||
resourceGroupAPI: newSuccessfulResourceGroupStub(),
|
|
||||||
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
|
|
||||||
createInstancesInput: CreateInstancesInput{
|
createInstancesInput: CreateInstancesInput{
|
||||||
CountWorkers: 3,
|
CountWorkers: 3,
|
||||||
InstanceType: "type",
|
InstanceType: "type",
|
||||||
|
@ -211,8 +93,6 @@ func TestCreateInstances(t *testing.T) {
|
||||||
publicIPAddressesAPI: tc.publicIPAddressesAPI,
|
publicIPAddressesAPI: tc.publicIPAddressesAPI,
|
||||||
networkInterfacesAPI: tc.networkInterfacesAPI,
|
networkInterfacesAPI: tc.networkInterfacesAPI,
|
||||||
scaleSetsAPI: tc.scaleSetsAPI,
|
scaleSetsAPI: tc.scaleSetsAPI,
|
||||||
resourceGroupAPI: tc.resourceGroupAPI,
|
|
||||||
roleAssignmentsAPI: tc.roleAssignmentsAPI,
|
|
||||||
workers: make(cloudtypes.Instances),
|
workers: make(cloudtypes.Instances),
|
||||||
controlPlanes: make(cloudtypes.Instances),
|
controlPlanes: make(cloudtypes.Instances),
|
||||||
loadBalancerPubIP: "lbip",
|
loadBalancerPubIP: "lbip",
|
||||||
|
@ -232,11 +112,3 @@ func TestCreateInstances(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
|
|
||||||
return &stubResourceGroupAPI{
|
|
||||||
getResourceGroup: armresources.ResourceGroup{
|
|
||||||
ID: to.Ptr("resource-group-id"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
133
cli/internal/azure/client/terminate.go
Normal file
133
cli/internal/azure/client/terminate.go
Normal file
|
@ -0,0 +1,133 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TerminateResourceGroupResources deletes all resources from the resource group.
|
||||||
|
func (c *Client) TerminateResourceGroupResources(ctx context.Context) error {
|
||||||
|
const timeOut = 10 * time.Minute
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeOut)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20)
|
||||||
|
delete := make(chan struct{}, 1)
|
||||||
|
wg := &sync.WaitGroup{}
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() { // This routine lists resources and starts their deletion, where possible.
|
||||||
|
defer wg.Done()
|
||||||
|
defer func() {
|
||||||
|
close(pollers)
|
||||||
|
for range delete { // drain channel
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
ids, err := c.getResourceIDList(ctx)
|
||||||
|
if err != nil {
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
poller, err := c.deleteResourceByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
pollers <- poller
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case _, ok := <-delete:
|
||||||
|
if !ok { // channel was closed
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() { // This routine polls for for the deletions to complete.
|
||||||
|
defer wg.Done()
|
||||||
|
defer close(delete)
|
||||||
|
|
||||||
|
for poller := range pollers {
|
||||||
|
_, err := poller.PollUntilDone(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case delete <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) {
|
||||||
|
var ids []string
|
||||||
|
pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil)
|
||||||
|
for pager.More() {
|
||||||
|
page, err := pager.NextPage(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err)
|
||||||
|
}
|
||||||
|
for _, resource := range page.Value {
|
||||||
|
if resource.ID == nil {
|
||||||
|
return nil, fmt.Errorf("resource %v has no ID", resource)
|
||||||
|
}
|
||||||
|
ids = append(ids, *resource.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) deleteResourceByID(ctx context.Context, id string,
|
||||||
|
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
|
||||||
|
apiVersion := "2020-02-02"
|
||||||
|
|
||||||
|
// First try, API version unknown, will fail.
|
||||||
|
poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
|
||||||
|
if isVersionWrongErr(err) {
|
||||||
|
// bad hack, but easiest way to get the right API version
|
||||||
|
apiVersion = parseAPIVersionFromErr(err)
|
||||||
|
poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
|
||||||
|
}
|
||||||
|
return poller, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func isVersionWrongErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(err.Error(), "NoRegisteredProviderFound") &&
|
||||||
|
strings.Contains(err.Error(), "The supported api-versions are")
|
||||||
|
}
|
||||||
|
|
||||||
|
var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`)
|
||||||
|
|
||||||
|
func parseAPIVersionFromErr(err error) string {
|
||||||
|
if err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
matches := apiVersionRegex.FindStringSubmatch(err.Error())
|
||||||
|
return matches[1]
|
||||||
|
}
|
139
cli/internal/azure/client/terminate_test.go
Normal file
139
cli/internal/azure/client/terminate_test.go
Normal file
|
@ -0,0 +1,139 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
||||||
|
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTerminateResourceGroupResources(t *testing.T) {
|
||||||
|
someErr := errors.New("failed")
|
||||||
|
apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'")
|
||||||
|
|
||||||
|
testCases := map[string]struct {
|
||||||
|
resourceAPI resourceAPI
|
||||||
|
}{
|
||||||
|
"no resources": {
|
||||||
|
resourceAPI: &fakeResourceAPI{},
|
||||||
|
},
|
||||||
|
"some resources": {
|
||||||
|
resourceAPI: &fakeResourceAPI{
|
||||||
|
resources: map[string]fakeResource{
|
||||||
|
"id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr},
|
||||||
|
"id-1": {beginDeleteByIDErr: apiVersionErr},
|
||||||
|
"id-2": {},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, tc := range testCases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
client := &Client{
|
||||||
|
resourceAPI: tc.resourceAPI,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err := client.TerminateResourceGroupResources(ctx)
|
||||||
|
assert.NoError(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeResourceAPI struct {
|
||||||
|
resources map[string]fakeResource
|
||||||
|
fetchErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakeResource struct {
|
||||||
|
beginDeleteByIDErr error
|
||||||
|
pollErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string,
|
||||||
|
options *armresources.ClientListByResourceGroupOptions,
|
||||||
|
) *runtime.Pager[armresources.ClientListByResourceGroupResponse] {
|
||||||
|
pager := &stubClientListByResourceGroupResponsePager{
|
||||||
|
resources: a.resources,
|
||||||
|
fetchErr: a.fetchErr,
|
||||||
|
}
|
||||||
|
return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{
|
||||||
|
More: pager.moreFunc(),
|
||||||
|
Fetcher: pager.fetcherFunc(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
|
||||||
|
options *armresources.ClientBeginDeleteByIDOptions,
|
||||||
|
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
|
||||||
|
res := a.resources[resourceID]
|
||||||
|
|
||||||
|
pollErr := res.pollErr
|
||||||
|
if pollErr != nil {
|
||||||
|
res.pollErr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{
|
||||||
|
Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{
|
||||||
|
result: armresources.ClientDeleteByIDResponse{},
|
||||||
|
resultErr: pollErr,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
beginDeleteByIDErr := res.beginDeleteByIDErr
|
||||||
|
if beginDeleteByIDErr != nil {
|
||||||
|
res.beginDeleteByIDErr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.beginDeleteByIDErr == nil && res.pollErr == nil {
|
||||||
|
delete(a.resources, resourceID)
|
||||||
|
fmt.Printf("fake delete %s\n", resourceID)
|
||||||
|
} else {
|
||||||
|
a.resources[resourceID] = res
|
||||||
|
}
|
||||||
|
|
||||||
|
return poller, beginDeleteByIDErr
|
||||||
|
}
|
||||||
|
|
||||||
|
type stubClientListByResourceGroupResponsePager struct {
|
||||||
|
resources map[string]fakeResource
|
||||||
|
fetchErr error
|
||||||
|
more bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubClientListByResourceGroupResponsePager) moreFunc() func(
|
||||||
|
armresources.ClientListByResourceGroupResponse) bool {
|
||||||
|
return func(armresources.ClientListByResourceGroupResponse) bool {
|
||||||
|
return p.more
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func(
|
||||||
|
context.Context, *armresources.ClientListByResourceGroupResponse) (
|
||||||
|
armresources.ClientListByResourceGroupResponse, error) {
|
||||||
|
return func(context.Context, *armresources.ClientListByResourceGroupResponse) (
|
||||||
|
armresources.ClientListByResourceGroupResponse, error,
|
||||||
|
) {
|
||||||
|
var resources []*armresources.GenericResourceExpanded
|
||||||
|
for id := range p.resources {
|
||||||
|
resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)})
|
||||||
|
}
|
||||||
|
p.more = false
|
||||||
|
return armresources.ClientListByResourceGroupResponse{
|
||||||
|
ResourceListResult: armresources.ResourceListResult{
|
||||||
|
Value: resources,
|
||||||
|
},
|
||||||
|
}, p.fetchErr
|
||||||
|
}
|
||||||
|
}
|
|
@ -26,12 +26,9 @@ type azureclient interface {
|
||||||
GetState() state.ConstellationState
|
GetState() state.ConstellationState
|
||||||
SetState(state.ConstellationState)
|
SetState(state.ConstellationState)
|
||||||
CreateApplicationInsight(ctx context.Context) error
|
CreateApplicationInsight(ctx context.Context) error
|
||||||
CreateResourceGroup(ctx context.Context) error
|
|
||||||
CreateExternalLoadBalancer(ctx context.Context) error
|
CreateExternalLoadBalancer(ctx context.Context) error
|
||||||
CreateVirtualNetwork(ctx context.Context) error
|
CreateVirtualNetwork(ctx context.Context) error
|
||||||
CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error
|
CreateSecurityGroup(ctx context.Context, input azurecl.NetworkSecurityGroupInput) error
|
||||||
CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error
|
CreateInstances(ctx context.Context, input azurecl.CreateInstancesInput) error
|
||||||
CreateServicePrincipal(ctx context.Context) (string, error)
|
TerminateResourceGroupResources(ctx context.Context) error
|
||||||
TerminateResourceGroup(ctx context.Context) error
|
|
||||||
TerminateServicePrincipal(ctx context.Context) error
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,11 +79,6 @@ func (c *fakeAzureClient) CreateApplicationInsight(ctx context.Context) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeAzureClient) CreateResourceGroup(ctx context.Context) error {
|
|
||||||
c.resourceGroup = "resource-group"
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
func (c *fakeAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
||||||
c.subnetID = "subnet"
|
c.subnetID = "subnet"
|
||||||
return nil
|
return nil
|
||||||
|
@ -123,17 +118,8 @@ func (c *fakeAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
|
||||||
}.ToCloudServiceAccountURI(), nil
|
}.ToCloudServiceAccountURI(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeAzureClient) TerminateResourceGroup(ctx context.Context) error {
|
func (c *fakeAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
|
||||||
if c.resourceGroup == "" {
|
// TODO(katexochen)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c.workers = nil
|
|
||||||
c.controlPlanes = nil
|
|
||||||
c.resourceGroup = ""
|
|
||||||
c.subnetID = ""
|
|
||||||
c.networkSecurityGroup = ""
|
|
||||||
c.workerScaleSet = ""
|
|
||||||
c.controlPlaneScaleSet = ""
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,17 +132,16 @@ func (c *fakeAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubAzureClient struct {
|
type stubAzureClient struct {
|
||||||
terminateResourceGroupCalled bool
|
terminateResourceGroupResourcesCalled bool
|
||||||
terminateServicePrincipalCalled bool
|
terminateServicePrincipalCalled bool
|
||||||
|
|
||||||
createApplicationInsightErr error
|
createApplicationInsightErr error
|
||||||
createResourceGroupErr error
|
|
||||||
createVirtualNetworkErr error
|
createVirtualNetworkErr error
|
||||||
createSecurityGroupErr error
|
createSecurityGroupErr error
|
||||||
createLoadBalancerErr error
|
createLoadBalancerErr error
|
||||||
createInstancesErr error
|
createInstancesErr error
|
||||||
createServicePrincipalErr error
|
createServicePrincipalErr error
|
||||||
terminateResourceGroupErr error
|
terminateResourceGroupResourcesErr error
|
||||||
terminateServicePrincipalErr error
|
terminateServicePrincipalErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,10 +160,6 @@ func (c *stubAzureClient) CreateApplicationInsight(ctx context.Context) error {
|
||||||
return c.createApplicationInsightErr
|
return c.createApplicationInsightErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubAzureClient) CreateResourceGroup(ctx context.Context) error {
|
|
||||||
return c.createResourceGroupErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
func (c *stubAzureClient) CreateVirtualNetwork(ctx context.Context) error {
|
||||||
return c.createVirtualNetworkErr
|
return c.createVirtualNetworkErr
|
||||||
}
|
}
|
||||||
|
@ -198,9 +179,9 @@ func (c *stubAzureClient) CreateServicePrincipal(ctx context.Context) (string, e
|
||||||
}.ToCloudServiceAccountURI(), c.createServicePrincipalErr
|
}.ToCloudServiceAccountURI(), c.createServicePrincipalErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubAzureClient) TerminateResourceGroup(ctx context.Context) error {
|
func (c *stubAzureClient) TerminateResourceGroupResources(ctx context.Context) error {
|
||||||
c.terminateResourceGroupCalled = true
|
c.terminateResourceGroupResourcesCalled = true
|
||||||
return c.terminateResourceGroupErr
|
return c.terminateResourceGroupResourcesErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
func (c *stubAzureClient) TerminateServicePrincipal(ctx context.Context) error {
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
type Creator struct {
|
type Creator struct {
|
||||||
out io.Writer
|
out io.Writer
|
||||||
newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error)
|
newGCPClient func(ctx context.Context, project, zone, region, name string) (gcpclient, error)
|
||||||
newAzureClient func(subscriptionID, tenantID, name, location string) (azureclient, error)
|
newAzureClient func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCreator creates a new creator.
|
// NewCreator creates a new creator.
|
||||||
|
@ -28,8 +28,8 @@ func NewCreator(out io.Writer) *Creator {
|
||||||
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
|
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
|
||||||
return gcpcl.NewInitialized(ctx, project, zone, region, name)
|
return gcpcl.NewInitialized(ctx, project, zone, region, name)
|
||||||
},
|
},
|
||||||
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) {
|
newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
|
||||||
return azurecl.NewInitialized(subscriptionID, tenantID, name, location)
|
return azurecl.NewInitialized(subscriptionID, tenantID, name, location, resourceGroup)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -57,6 +57,7 @@ func (c *Creator) Create(ctx context.Context, provider cloudprovider.Provider, c
|
||||||
config.Provider.Azure.TenantID,
|
config.Provider.Azure.TenantID,
|
||||||
name,
|
name,
|
||||||
config.Provider.Azure.Location,
|
config.Provider.Azure.Location,
|
||||||
|
config.Provider.Azure.ResourceGroup,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return state.ConstellationState{}, err
|
return state.ConstellationState{}, err
|
||||||
|
@ -144,9 +145,6 @@ func (c *Creator) createAzure(ctx context.Context, cl azureclient, config *confi
|
||||||
) (stat state.ConstellationState, retErr error) {
|
) (stat state.ConstellationState, retErr error) {
|
||||||
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl})
|
defer rollbackOnError(context.Background(), c.out, &retErr, &rollbackerAzure{client: cl})
|
||||||
|
|
||||||
if err := cl.CreateResourceGroup(ctx); err != nil {
|
|
||||||
return state.ConstellationState{}, err
|
|
||||||
}
|
|
||||||
if err := cl.CreateApplicationInsight(ctx); err != nil {
|
if err := cl.CreateApplicationInsight(ctx); err != nil {
|
||||||
return state.ConstellationState{}, err
|
return state.ConstellationState{}, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,7 +51,6 @@ func TestCreator(t *testing.T) {
|
||||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
},
|
},
|
||||||
AzureResourceGroup: "resource-group",
|
|
||||||
AzureSubnet: "subnet",
|
AzureSubnet: "subnet",
|
||||||
AzureNetworkSecurityGroup: "network-security-group",
|
AzureNetworkSecurityGroup: "network-security-group",
|
||||||
AzureWorkerScaleSet: "workers-scale-set",
|
AzureWorkerScaleSet: "workers-scale-set",
|
||||||
|
@ -123,13 +122,6 @@ func TestCreator(t *testing.T) {
|
||||||
config: config.Default(),
|
config: config.Default(),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"azure CreateResourceGroup error": {
|
|
||||||
azureclient: &stubAzureClient{createResourceGroupErr: someErr},
|
|
||||||
provider: cloudprovider.Azure,
|
|
||||||
config: config.Default(),
|
|
||||||
wantErr: true,
|
|
||||||
wantRollback: true,
|
|
||||||
},
|
|
||||||
"azure CreateVirtualNetwork error": {
|
"azure CreateVirtualNetwork error": {
|
||||||
azureclient: &stubAzureClient{createVirtualNetworkErr: someErr},
|
azureclient: &stubAzureClient{createVirtualNetworkErr: someErr},
|
||||||
provider: cloudprovider.Azure,
|
provider: cloudprovider.Azure,
|
||||||
|
@ -167,7 +159,7 @@ func TestCreator(t *testing.T) {
|
||||||
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
|
newGCPClient: func(ctx context.Context, project, zone, region, name string) (gcpclient, error) {
|
||||||
return tc.gcpclient, tc.newGCPClientErr
|
return tc.gcpclient, tc.newGCPClientErr
|
||||||
},
|
},
|
||||||
newAzureClient: func(subscriptionID, tenantID, name, location string) (azureclient, error) {
|
newAzureClient: func(subscriptionID, tenantID, name, location, resourceGroup string) (azureclient, error) {
|
||||||
return tc.azureclient, tc.newAzureClientErr
|
return tc.azureclient, tc.newAzureClientErr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -186,7 +178,7 @@ func TestCreator(t *testing.T) {
|
||||||
assert.True(cl.closeCalled)
|
assert.True(cl.closeCalled)
|
||||||
case cloudprovider.Azure:
|
case cloudprovider.Azure:
|
||||||
cl := tc.azureclient.(*stubAzureClient)
|
cl := tc.azureclient.(*stubAzureClient)
|
||||||
assert.True(cl.terminateResourceGroupCalled)
|
assert.True(cl.terminateResourceGroupResourcesCalled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -46,5 +46,5 @@ type rollbackerAzure struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rollbackerAzure) rollback(ctx context.Context) error {
|
func (r *rollbackerAzure) rollback(ctx context.Context) error {
|
||||||
return r.client.TerminateResourceGroup(ctx)
|
return r.client.TerminateResourceGroupResources(ctx)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
package cloudcmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
azurecl "github.com/edgelesssys/constellation/cli/internal/azure/client"
|
|
||||||
gcpcl "github.com/edgelesssys/constellation/cli/internal/gcp/client"
|
|
||||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
|
||||||
"github.com/edgelesssys/constellation/internal/config"
|
|
||||||
"github.com/edgelesssys/constellation/internal/state"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ServiceAccountCreator creates service accounts.
|
|
||||||
type ServiceAccountCreator struct {
|
|
||||||
newGCPClient func(ctx context.Context) (gcpclient, error)
|
|
||||||
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewServiceAccountCreator() *ServiceAccountCreator {
|
|
||||||
return &ServiceAccountCreator{
|
|
||||||
newGCPClient: func(ctx context.Context) (gcpclient, error) {
|
|
||||||
return gcpcl.NewFromDefault(ctx)
|
|
||||||
},
|
|
||||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
|
||||||
return azurecl.NewFromDefault(subscriptionID, tenantID)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create creates a new cloud provider service account with access to the created resources.
|
|
||||||
func (c *ServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
|
|
||||||
) (string, state.ConstellationState, error) {
|
|
||||||
provider := cloudprovider.FromString(stat.CloudProvider)
|
|
||||||
switch provider {
|
|
||||||
case cloudprovider.GCP:
|
|
||||||
return "", state.ConstellationState{}, fmt.Errorf("creating service account not supported for GCP")
|
|
||||||
case cloudprovider.Azure:
|
|
||||||
cl, err := c.newAzureClient(stat.AzureSubscription, stat.AzureTenant)
|
|
||||||
if err != nil {
|
|
||||||
return "", state.ConstellationState{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
serviceAccount, stat, err := c.createServiceAccountAzure(ctx, cl, stat, config)
|
|
||||||
if err != nil {
|
|
||||||
return "", state.ConstellationState{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return serviceAccount, stat, err
|
|
||||||
case cloudprovider.QEMU:
|
|
||||||
return "unsupported://qemu", stat, nil
|
|
||||||
default:
|
|
||||||
return "", state.ConstellationState{}, fmt.Errorf("unsupported provider: %s", provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ServiceAccountCreator) createServiceAccountAzure(ctx context.Context, cl azureclient,
|
|
||||||
stat state.ConstellationState, _ *config.Config,
|
|
||||||
) (string, state.ConstellationState, error) {
|
|
||||||
cl.SetState(stat)
|
|
||||||
|
|
||||||
serviceAccount, err := cl.CreateServicePrincipal(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return "", state.ConstellationState{}, fmt.Errorf("creating service account: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return serviceAccount, cl.GetState(), nil
|
|
||||||
}
|
|
|
@ -1,89 +0,0 @@
|
||||||
package cloudcmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
|
||||||
"github.com/edgelesssys/constellation/internal/config"
|
|
||||||
"github.com/edgelesssys/constellation/internal/state"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestServiceAccountCreator(t *testing.T) {
|
|
||||||
someAzureState := func() state.ConstellationState {
|
|
||||||
return state.ConstellationState{
|
|
||||||
CloudProvider: cloudprovider.Azure.String(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
someErr := errors.New("failed")
|
|
||||||
|
|
||||||
testCases := map[string]struct {
|
|
||||||
newGCPClient func(ctx context.Context) (gcpclient, error)
|
|
||||||
newAzureClient func(subscriptionID, tenantID string) (azureclient, error)
|
|
||||||
state state.ConstellationState
|
|
||||||
config *config.Config
|
|
||||||
wantErr bool
|
|
||||||
wantStateMutator func(*state.ConstellationState)
|
|
||||||
}{
|
|
||||||
"azure": {
|
|
||||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
|
||||||
return &fakeAzureClient{}, nil
|
|
||||||
},
|
|
||||||
state: someAzureState(),
|
|
||||||
config: config.Default(),
|
|
||||||
wantStateMutator: func(stat *state.ConstellationState) {
|
|
||||||
stat.AzureADAppObjectID = "00000000-0000-0000-0000-000000000001"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"azure newAzureClient error": {
|
|
||||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
|
||||||
return nil, someErr
|
|
||||||
},
|
|
||||||
state: someAzureState(),
|
|
||||||
config: config.Default(),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"azure client createServiceAccount error": {
|
|
||||||
newAzureClient: func(subscriptionID, tenantID string) (azureclient, error) {
|
|
||||||
return &stubAzureClient{createServicePrincipalErr: someErr}, nil
|
|
||||||
},
|
|
||||||
state: someAzureState(),
|
|
||||||
config: config.Default(),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"qemu": {
|
|
||||||
state: state.ConstellationState{CloudProvider: "qemu"},
|
|
||||||
wantStateMutator: func(cs *state.ConstellationState) {},
|
|
||||||
config: config.Default(),
|
|
||||||
},
|
|
||||||
"unknown cloud provider": {
|
|
||||||
state: state.ConstellationState{},
|
|
||||||
config: config.Default(),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, tc := range testCases {
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
creator := &ServiceAccountCreator{
|
|
||||||
newGCPClient: tc.newGCPClient,
|
|
||||||
newAzureClient: tc.newAzureClient,
|
|
||||||
}
|
|
||||||
|
|
||||||
serviceAccount, state, err := creator.Create(context.Background(), tc.state, tc.config)
|
|
||||||
|
|
||||||
if tc.wantErr {
|
|
||||||
assert.Error(err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(err)
|
|
||||||
assert.NotEmpty(serviceAccount)
|
|
||||||
tc.wantStateMutator(&tc.state)
|
|
||||||
assert.Equal(tc.state, state)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -72,8 +72,5 @@ func (t *Terminator) terminateGCP(ctx context.Context, cl gcpclient, state state
|
||||||
func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error {
|
func (t *Terminator) terminateAzure(ctx context.Context, cl azureclient, state state.ConstellationState) error {
|
||||||
cl.SetState(state)
|
cl.SetState(state)
|
||||||
|
|
||||||
if err := cl.TerminateServicePrincipal(ctx); err != nil {
|
return cl.TerminateResourceGroupResources(ctx)
|
||||||
return err
|
|
||||||
}
|
|
||||||
return cl.TerminateResourceGroup(ctx)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,7 +41,6 @@ func TestTerminator(t *testing.T) {
|
||||||
AzureControlPlaneInstances: cloudtypes.Instances{
|
AzureControlPlaneInstances: cloudtypes.Instances{
|
||||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
},
|
},
|
||||||
AzureResourceGroup: "group",
|
|
||||||
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
|
AzureADAppObjectID: "00000000-0000-0000-0000-000000000001",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,13 +87,8 @@ func TestTerminator(t *testing.T) {
|
||||||
state: someAzureState(),
|
state: someAzureState(),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"azure terminateServicePrincipal error": {
|
"azure terminateResourceGroupResources error": {
|
||||||
azureclient: &stubAzureClient{terminateServicePrincipalErr: someErr},
|
azureclient: &stubAzureClient{terminateResourceGroupResourcesErr: someErr},
|
||||||
state: someAzureState(),
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
"azure terminateResourceGroup error": {
|
|
||||||
azureclient: &stubAzureClient{terminateResourceGroupErr: someErr},
|
|
||||||
state: someAzureState(),
|
state: someAzureState(),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
@ -132,8 +126,7 @@ func TestTerminator(t *testing.T) {
|
||||||
assert.True(cl.closeCalled)
|
assert.True(cl.closeCalled)
|
||||||
case cloudprovider.Azure:
|
case cloudprovider.Azure:
|
||||||
cl := tc.azureclient.(*stubAzureClient)
|
cl := tc.azureclient.(*stubAzureClient)
|
||||||
assert.True(cl.terminateResourceGroupCalled)
|
assert.True(cl.terminateResourceGroupResourcesCalled)
|
||||||
assert.True(cl.terminateServicePrincipalCalled)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
|
@ -21,8 +21,3 @@ type cloudCreator interface {
|
||||||
type cloudTerminator interface {
|
type cloudTerminator interface {
|
||||||
Terminate(context.Context, state.ConstellationState) error
|
Terminate(context.Context, state.ConstellationState) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type serviceAccountCreator interface {
|
|
||||||
Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
|
|
||||||
) (string, state.ConstellationState, error)
|
|
||||||
}
|
|
||||||
|
|
|
@ -47,12 +47,3 @@ func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationStat
|
||||||
func (c *stubCloudTerminator) Called() bool {
|
func (c *stubCloudTerminator) Called() bool {
|
||||||
return c.called
|
return c.called
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubServiceAccountCreator struct {
|
|
||||||
cloudServiceAccountURI string
|
|
||||||
createErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
|
||||||
return c.cloudServiceAccountURI, stat, c.createErr
|
|
||||||
}
|
|
||||||
|
|
|
@ -55,17 +55,16 @@ func NewInitCmd() *cobra.Command {
|
||||||
// runInitialize runs the initialize command.
|
// runInitialize runs the initialize command.
|
||||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||||
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
|
||||||
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
|
newDialer := func(validator *cloudcmd.Validator) *dialer.Dialer {
|
||||||
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
|
return dialer.New(nil, validator.V(cmd), &net.Dialer{})
|
||||||
}
|
}
|
||||||
helmLoader := &helm.ChartLoader{}
|
helmLoader := &helm.ChartLoader{}
|
||||||
return initialize(cmd, newDialer, serviceAccountCreator, fileHandler, helmLoader, license.NewClient())
|
return initialize(cmd, newDialer, fileHandler, helmLoader, license.NewClient())
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize initializes a Constellation.
|
// initialize initializes a Constellation.
|
||||||
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
|
func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator) *dialer.Dialer,
|
||||||
serviceAccCreator serviceAccountCreator, fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
|
fileHandler file.Handler, helmLoader helmLoader, quotaChecker license.QuotaChecker,
|
||||||
) error {
|
) error {
|
||||||
flags, err := evalFlagArgs(cmd, fileHandler)
|
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -105,23 +104,10 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var serviceAccURI string
|
serviceAccURI, err := getMarschaledServiceAccountURI(provider, config, fileHandler)
|
||||||
// Temporary legacy flow for Azure.
|
|
||||||
if provider == cloudprovider.Azure {
|
|
||||||
cmd.Println("Creating service account ...")
|
|
||||||
serviceAccURI, stat, err = serviceAccCreator.Create(cmd.Context(), stat, config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
serviceAccURI, err = getMarschaledServiceAccountURI(provider, config, fileHandler)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
workers, err := getScalingGroupsFromState(stat, config)
|
workers, err := getScalingGroupsFromState(stat, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -70,8 +70,7 @@ func TestInitialize(t *testing.T) {
|
||||||
|
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
state *state.ConstellationState
|
state *state.ConstellationState
|
||||||
existingIDFile *clusterIDsFile
|
idFile *clusterIDsFile
|
||||||
serviceAccCreator serviceAccountCreator
|
|
||||||
configMutator func(*config.Config)
|
configMutator func(*config.Config)
|
||||||
serviceAccKey *gcpshared.ServiceAccountKey
|
serviceAccKey *gcpshared.ServiceAccountKey
|
||||||
helmLoader stubHelmLoader
|
helmLoader stubHelmLoader
|
||||||
|
@ -82,25 +81,28 @@ func TestInitialize(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
"initialize some gcp instances": {
|
"initialize some gcp instances": {
|
||||||
state: testGcpState,
|
state: testGcpState,
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
serviceAccKey: gcpServiceAccKey,
|
serviceAccKey: gcpServiceAccKey,
|
||||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||||
},
|
},
|
||||||
"initialize some azure instances": {
|
"initialize some azure instances": {
|
||||||
state: testAzureState,
|
state: testAzureState,
|
||||||
serviceAccCreator: &stubServiceAccountCreator{},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
configMutator: func(c *config.Config) {
|
||||||
|
c.Provider.Azure.ResourceGroup = "resourceGroup"
|
||||||
|
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
|
||||||
|
},
|
||||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||||
},
|
},
|
||||||
"initialize some qemu instances": {
|
"initialize some qemu instances": {
|
||||||
state: testQemuState,
|
state: testQemuState,
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||||
},
|
},
|
||||||
"initialize gcp with autoscaling": {
|
"initialize gcp with autoscaling": {
|
||||||
state: testGcpState,
|
state: testGcpState,
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
configMutator: func(c *config.Config) { c.Provider.GCP.ServiceAccountKeyPath = serviceAccPath },
|
||||||
serviceAccKey: gcpServiceAccKey,
|
serviceAccKey: gcpServiceAccKey,
|
||||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||||
|
@ -108,8 +110,11 @@ func TestInitialize(t *testing.T) {
|
||||||
},
|
},
|
||||||
"initialize azure with autoscaling": {
|
"initialize azure with autoscaling": {
|
||||||
state: testAzureState,
|
state: testAzureState,
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
serviceAccCreator: &stubServiceAccountCreator{},
|
configMutator: func(c *config.Config) {
|
||||||
|
c.Provider.Azure.ResourceGroup = "resourceGroup"
|
||||||
|
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
|
||||||
|
},
|
||||||
initServerAPI: &stubInitServer{initResp: testInitResp},
|
initServerAPI: &stubInitServer{initResp: testInitResp},
|
||||||
setAutoscaleFlag: true,
|
setAutoscaleFlag: true,
|
||||||
},
|
},
|
||||||
|
@ -122,7 +127,7 @@ func TestInitialize(t *testing.T) {
|
||||||
},
|
},
|
||||||
"empty state": {
|
"empty state": {
|
||||||
state: &state.ConstellationState{},
|
state: &state.ConstellationState{},
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
initServerAPI: &stubInitServer{},
|
initServerAPI: &stubInitServer{},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
|
@ -132,15 +137,18 @@ func TestInitialize(t *testing.T) {
|
||||||
},
|
},
|
||||||
"init call fails": {
|
"init call fails": {
|
||||||
state: testGcpState,
|
state: testGcpState,
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
initServerAPI: &stubInitServer{initErr: someErr},
|
initServerAPI: &stubInitServer{initErr: someErr},
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"fail to create service account": {
|
"fail to create service account": {
|
||||||
state: testAzureState,
|
state: testAzureState,
|
||||||
existingIDFile: &clusterIDsFile{IP: "192.0.2.1"},
|
idFile: &clusterIDsFile{IP: "192.0.2.1"},
|
||||||
|
configMutator: func(c *config.Config) {
|
||||||
|
c.Provider.Azure.ResourceGroup = "resourceGroup"
|
||||||
|
c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity"
|
||||||
|
},
|
||||||
initServerAPI: &stubInitServer{},
|
initServerAPI: &stubInitServer{},
|
||||||
serviceAccCreator: &stubServiceAccountCreator{createErr: someErr},
|
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"fail to load helm charts": {
|
"fail to load helm charts": {
|
||||||
|
@ -194,8 +202,8 @@ func TestInitialize(t *testing.T) {
|
||||||
if tc.state != nil {
|
if tc.state != nil {
|
||||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone))
|
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.state, file.OptNone))
|
||||||
}
|
}
|
||||||
if tc.existingIDFile != nil {
|
if tc.idFile != nil {
|
||||||
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.existingIDFile, file.OptNone))
|
require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, tc.idFile, file.OptNone))
|
||||||
}
|
}
|
||||||
if tc.serviceAccKey != nil {
|
if tc.serviceAccKey != nil {
|
||||||
require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone))
|
require.NoError(fileHandler.WriteJSON(serviceAccPath, tc.serviceAccKey, file.OptNone))
|
||||||
|
@ -206,7 +214,7 @@ func TestInitialize(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
cmd.SetContext(ctx)
|
cmd.SetContext(ctx)
|
||||||
|
|
||||||
err := initialize(cmd, newDialer, tc.serviceAccCreator, fileHandler, &tc.helmLoader, &stubLicenseClient{})
|
err := initialize(cmd, newDialer, fileHandler, &tc.helmLoader, &stubLicenseClient{})
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
|
@ -477,7 +485,7 @@ func TestAttestation(t *testing.T) {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
cmd.SetContext(ctx)
|
cmd.SetContext(ctx)
|
||||||
|
|
||||||
err := initialize(cmd, newDialer, &stubServiceAccountCreator{}, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
|
err := initialize(cmd, newDialer, fileHandler, &stubHelmLoader{}, &stubLicenseClient{})
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
// make sure the error is actually a TLS handshake error
|
// make sure the error is actually a TLS handshake error
|
||||||
assert.Contains(err.Error(), "transport: authentication handshake failed")
|
assert.Contains(err.Error(), "transport: authentication handshake failed")
|
||||||
|
@ -548,6 +556,7 @@ func defaultConfigWithExpectedMeasurements(t *testing.T, conf *config.Config, cs
|
||||||
conf.Provider.Azure.Location = "test-location"
|
conf.Provider.Azure.Location = "test-location"
|
||||||
conf.Provider.Azure.UserAssignedIdentity = "test-identity"
|
conf.Provider.Azure.UserAssignedIdentity = "test-identity"
|
||||||
conf.Provider.Azure.Image = "some/image/location"
|
conf.Provider.Azure.Image = "some/image/location"
|
||||||
|
conf.Provider.Azure.ResourceGroup = "test-resource-group"
|
||||||
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
|
conf.Provider.Azure.Measurements[8] = []byte("00000000000000000000000000000000")
|
||||||
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
|
conf.Provider.Azure.Measurements[9] = []byte("11111111111111111111111111111111")
|
||||||
case cloudprovider.GCP:
|
case cloudprovider.GCP:
|
||||||
|
|
|
@ -158,6 +158,9 @@ type AzureConfig struct {
|
||||||
// Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure
|
// Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure
|
||||||
UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"`
|
UserAssignedIdentity string `yaml:"userAssignedIdentity" validate:"required"`
|
||||||
// description: |
|
// description: |
|
||||||
|
// Resource group to use.
|
||||||
|
ResourceGroup string `yaml:"resourceGroup" validate:"required"`
|
||||||
|
// description: |
|
||||||
// Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview
|
// Use VMs with security type Confidential VM. If set to false, Trusted Launch VMs will be used instead. See: https://docs.microsoft.com/en-us/azure/confidential-computing/confidential-vm-overview
|
||||||
ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"`
|
ConfidentialVM *bool `yaml:"confidentialVM" validate:"required"`
|
||||||
}
|
}
|
||||||
|
@ -244,6 +247,7 @@ func Default() *Config {
|
||||||
TenantID: "",
|
TenantID: "",
|
||||||
Location: "",
|
Location: "",
|
||||||
UserAssignedIdentity: "",
|
UserAssignedIdentity: "",
|
||||||
|
ResourceGroup: "",
|
||||||
Image: DefaultImageAzure,
|
Image: DefaultImageAzure,
|
||||||
StateDiskType: "Premium_LRS",
|
StateDiskType: "Premium_LRS",
|
||||||
Measurements: copyPCRMap(azurePCRs),
|
Measurements: copyPCRMap(azurePCRs),
|
||||||
|
|
|
@ -199,7 +199,7 @@ func init() {
|
||||||
FieldName: "azure",
|
FieldName: "azure",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
AzureConfigDoc.Fields = make([]encoder.Doc, 9)
|
AzureConfigDoc.Fields = make([]encoder.Doc, 10)
|
||||||
AzureConfigDoc.Fields[0].Name = "subscription"
|
AzureConfigDoc.Fields[0].Name = "subscription"
|
||||||
AzureConfigDoc.Fields[0].Type = "string"
|
AzureConfigDoc.Fields[0].Type = "string"
|
||||||
AzureConfigDoc.Fields[0].Note = ""
|
AzureConfigDoc.Fields[0].Note = ""
|
||||||
|
@ -240,6 +240,11 @@ func init() {
|
||||||
AzureConfigDoc.Fields[7].Note = ""
|
AzureConfigDoc.Fields[7].Note = ""
|
||||||
AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
|
AzureConfigDoc.Fields[7].Description = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
|
||||||
AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
|
AzureConfigDoc.Fields[7].Comments[encoder.LineComment] = "Authorize spawned VMs to access Azure API. See: https://docs.edgeless.systems/constellation/latest/#/getting-started/install?id=azure"
|
||||||
|
AzureConfigDoc.Fields[8].Name = "resourceGroup"
|
||||||
|
AzureConfigDoc.Fields[8].Type = "string"
|
||||||
|
AzureConfigDoc.Fields[8].Note = ""
|
||||||
|
AzureConfigDoc.Fields[8].Description = "Resource group to use."
|
||||||
|
AzureConfigDoc.Fields[8].Comments[encoder.LineComment] = "Resource group to use."
|
||||||
AzureConfigDoc.Fields[8].Name = "confidentialVM"
|
AzureConfigDoc.Fields[8].Name = "confidentialVM"
|
||||||
AzureConfigDoc.Fields[8].Type = "bool"
|
AzureConfigDoc.Fields[8].Type = "bool"
|
||||||
AzureConfigDoc.Fields[8].Note = ""
|
AzureConfigDoc.Fields[8].Note = ""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue