Move cli/azure to cli/internal/azure

This commit is contained in:
katexochen 2022-06-07 16:30:41 +02:00
parent 180d7872dd
commit 064151a956
27 changed files with 11 additions and 11 deletions

View file

@ -0,0 +1,8 @@
package azure
import "fmt"
// AutoscalingNodeGroup converts an azure scale set into a node group used by the k8s cluster-autoscaler.
func AutoscalingNodeGroup(scaleSet string, min int, max int) string {
return fmt.Sprintf("%d:%d:%s", min, max, scaleSet)
}

View file

@ -0,0 +1,14 @@
package azure
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestAutoscalingNodeGroup(t *testing.T) {
assert := assert.New(t)
nodeGroups := AutoscalingNodeGroup("scale-set", 0, 100)
wantNodeGroups := "0:100:scale-set"
assert.Equal(wantNodeGroups, nodeGroups)
}

View file

@ -0,0 +1,185 @@
package client
import (
"context"
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/date"
"github.com/Azure/go-autorest/autorest/to"
"github.com/edgelesssys/constellation/internal/azureshared"
"github.com/google/uuid"
)
const (
adAppCredentialValidity = time.Hour * 24 * 365 * 5 // ~5 years
adReplicationLagCheckInterval = time.Second * 5 // 5 seconds
adReplicationLagCheckMaxRetries = int((15 * time.Minute) / adReplicationLagCheckInterval) // wait for up to 15 minutes for AD replication
ownerRoleDefinitionID = "8e3af657-a8ff-443c-a75c-2fe8c4bcb635" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#owner
virtualMachineContributorRoleDefinitionID = "9980e02c-c2be-4d73-94e8-173b1dc7cf3c" // https://docs.microsoft.com/en-us/azure/role-based-access-control/built-in-roles#virtual-machine-contributor
)
// CreateServicePrincipal creates an Azure AD app with a service principal, gives it "Owner" role on the resource group and creates new credentials.
func (c *Client) CreateServicePrincipal(ctx context.Context) (string, error) {
createAppRes, err := c.createADApplication(ctx)
if err != nil {
return "", err
}
c.adAppObjectID = createAppRes.ObjectID
servicePrincipalObjectID, err := c.createAppServicePrincipal(ctx, createAppRes.AppID)
if err != nil {
return "", err
}
if err := c.assignResourceGroupRole(ctx, servicePrincipalObjectID, ownerRoleDefinitionID); err != nil {
return "", err
}
clientSecret, err := c.updateAppCredentials(ctx, createAppRes.ObjectID)
if err != nil {
return "", err
}
return azureshared.ApplicationCredentials{
TenantID: c.tenantID,
ClientID: createAppRes.AppID,
ClientSecret: clientSecret,
Location: c.location,
}.ToCloudServiceAccountURI(), nil
}
// TerminateServicePrincipal terminates an Azure AD app together with the service principal.
func (c *Client) TerminateServicePrincipal(ctx context.Context) error {
if c.adAppObjectID == "" {
return nil
}
if _, err := c.applicationsAPI.Delete(ctx, c.adAppObjectID); err != nil {
return err
}
c.adAppObjectID = ""
return nil
}
// createADApplication creates a new azure AD app.
func (c *Client) createADApplication(ctx context.Context) (createADApplicationOutput, error) {
createParameters := graphrbac.ApplicationCreateParameters{
AvailableToOtherTenants: to.BoolPtr(false),
DisplayName: to.StringPtr("constellation-app-" + c.name + "-" + c.uid),
}
app, err := c.applicationsAPI.Create(ctx, createParameters)
if err != nil {
return createADApplicationOutput{}, err
}
if app.AppID == nil || app.ObjectID == nil {
return createADApplicationOutput{}, errors.New("creating AD application did not result in valid app id and object id")
}
return createADApplicationOutput{
AppID: *app.AppID,
ObjectID: *app.ObjectID,
}, nil
}
// createAppServicePrincipal creates a new service principal for an azure AD app.
func (c *Client) createAppServicePrincipal(ctx context.Context, appID string) (string, error) {
createParameters := graphrbac.ServicePrincipalCreateParameters{
AppID: &appID,
AccountEnabled: to.BoolPtr(true),
}
servicePrincipal, err := c.servicePrincipalsAPI.Create(ctx, createParameters)
if err != nil {
return "", err
}
if servicePrincipal.ObjectID == nil {
return "", errors.New("creating AD service principal did not result in a valid object id")
}
return *servicePrincipal.ObjectID, nil
}
// updateAppCredentials sets app client-secret for authentication.
func (c *Client) updateAppCredentials(ctx context.Context, objectID string) (string, error) {
keyID := uuid.New().String()
clientSecret, err := generateClientSecret()
if err != nil {
return "", fmt.Errorf("generating client secret failed: %w", err)
}
updateParameters := graphrbac.PasswordCredentialsUpdateParameters{
Value: &[]graphrbac.PasswordCredential{
{
StartDate: &date.Time{Time: time.Now()},
EndDate: &date.Time{Time: time.Now().Add(adAppCredentialValidity)},
Value: to.StringPtr(clientSecret),
KeyID: to.StringPtr(keyID),
},
},
}
_, err = c.applicationsAPI.UpdatePasswordCredentials(ctx, objectID, updateParameters)
if err != nil {
return "", err
}
return clientSecret, nil
}
// assignResourceGroupRole assigns the service principal a role at resource group scope.
func (c *Client) assignResourceGroupRole(ctx context.Context, principalID, roleDefinitionID string) error {
resourceGroup, err := c.resourceGroupAPI.Get(ctx, c.resourceGroup, nil)
if err != nil || resourceGroup.ID == nil {
return fmt.Errorf("unable to retrieve resource group id for group %v: %w", c.resourceGroup, err)
}
roleAssignmentID := uuid.New().String()
createParameters := authorization.RoleAssignmentCreateParameters{
Properties: &authorization.RoleAssignmentProperties{
PrincipalID: to.StringPtr(principalID),
RoleDefinitionID: to.StringPtr(fmt.Sprintf("/subscriptions/%s/providers/Microsoft.Authorization/roleDefinitions/%s", c.subscriptionID, roleDefinitionID)),
},
}
// due to an azure AD replication lag, retry role assignment if principal does not exist yet
// reference: https://docs.microsoft.com/en-us/azure/role-based-access-control/role-assignments-rest#new-service-principal
// proper fix: use API version 2018-09-01-preview or later
// azure go sdk currently uses version 2015-07-01: https://github.com/Azure/azure-sdk-for-go/blob/v62.0.0/services/authorization/mgmt/2015-07-01/authorization/roleassignments.go#L95
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
for i := 0; i < c.adReplicationLagCheckMaxRetries; i++ {
_, err = c.roleAssignmentsAPI.Create(ctx, *resourceGroup.ID, roleAssignmentID, createParameters)
var detailedErr autorest.DetailedError
var ok bool
if detailedErr, ok = err.(autorest.DetailedError); !ok {
return err
}
var requestErr *azure.RequestError
if requestErr, ok = detailedErr.Original.(*azure.RequestError); !ok || requestErr.ServiceError == nil {
return err
}
if requestErr.ServiceError.Code != "PrincipalNotFound" {
return err
}
time.Sleep(c.adReplicationLagCheckInterval)
}
return err
}
type createADApplicationOutput struct {
AppID string
ObjectID string
}
func generateClientSecret() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 64
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
return string(pw), nil
}

View file

@ -0,0 +1,358 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestCreateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
servicePrincipalsAPI servicePrincipalsAPI
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed service principal create": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"failed role assignment": {
applicationsAPI: stubApplicationsAPI{},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
wantErr: true,
},
"failed update creds": {
applicationsAPI: stubApplicationsAPI{
updateCredentialsErr: someErr,
},
servicePrincipalsAPI: stubServicePrincipalsAPI{},
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
applicationsAPI: tc.applicationsAPI,
servicePrincipalsAPI: tc.servicePrincipalsAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
_, err := client.CreateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestTerminateServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
appObjectID string
applicationsAPI applicationsAPI
wantErr bool
}{
"successful terminate": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{},
},
"nothing to terminate": {
applicationsAPI: stubApplicationsAPI{},
},
"failed delete": {
appObjectID: "object-id",
applicationsAPI: stubApplicationsAPI{
deleteErr: someErr,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
adAppObjectID: tc.appObjectID,
applicationsAPI: tc.applicationsAPI,
}
err := client.TerminateServicePrincipal(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestCreateADApplication(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
applicationsAPI applicationsAPI
wantErr bool
}{
"successful create": {
applicationsAPI: stubApplicationsAPI{},
},
"failed app create": {
applicationsAPI: stubApplicationsAPI{
createErr: someErr,
},
wantErr: true,
},
"app create returns invalid appid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
ObjectID: proto.String("00000000-0000-0000-0000-000000000001"),
},
},
wantErr: true,
},
"app create returns invalid objectid": {
applicationsAPI: stubApplicationsAPI{
createApplication: &graphrbac.Application{
AppID: proto.String("00000000-0000-0000-0000-000000000000"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
applicationsAPI: tc.applicationsAPI,
}
appCredentials, err := client.createADApplication(ctx)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.NotNil(appCredentials)
})
}
}
func TestCreateAppServicePrincipal(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
servicePrincipalsAPI servicePrincipalsAPI
wantErr bool
}{
"successful create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{},
},
"failed service principal create": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createErr: someErr,
},
wantErr: true,
},
"service principal create returns invalid objectid": {
servicePrincipalsAPI: stubServicePrincipalsAPI{
createServicePrincipal: &graphrbac.ServicePrincipal{},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
servicePrincipalsAPI: tc.servicePrincipalsAPI,
}
_, err := client.createAppServicePrincipal(ctx, "app-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}
func TestAssignOwnerOfResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
roleAssignmentsAPI roleAssignmentsAPI
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful assign": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"failed role assignment": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{someErr},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
wantErr: true,
},
"failed resource group get": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getErr: someErr,
},
wantErr: true,
},
"resource group get returns invalid id": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{},
},
wantErr: true,
},
"create returns PrincipalNotFound the first time": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "PrincipalNotFound",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
},
"create does not return request error": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{autorest.DetailedError{Original: someErr}},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
wantErr: true,
},
"create service error code is unknown": {
roleAssignmentsAPI: &stubRoleAssignmentsAPI{
createErrors: []error{
autorest.DetailedError{Original: &azure.RequestError{
ServiceError: &azure.ServiceError{
Code: "some-unknown-error-code",
},
}},
nil,
},
},
resourceGroupAPI: stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
name: "name",
uid: "uid",
resourceGroup: "resource-group",
roleAssignmentsAPI: tc.roleAssignmentsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
adReplicationLagCheckMaxRetries: 2,
}
err := client.assignResourceGroupRole(ctx, "principal-id", "role-definition-id")
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}

View file

@ -0,0 +1,141 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type virtualNetworksCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error)
}
type networksAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error)
}
type networkSecurityGroupsCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error)
}
type networkSecurityGroupsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error)
}
type loadBalancersClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.LoadBalancersClientCreateOrUpdateResponse, error)
}
type loadBalancersAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
loadBalancerName string, parameters armnetwork.LoadBalancer,
options *armnetwork.LoadBalancersClientBeginCreateOrUpdateOptions) (
loadBalancersClientCreateOrUpdatePollerResponse, error,
)
}
type virtualMachineScaleSetsCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error)
}
type scaleSetsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error)
}
type publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager interface {
NextPage(ctx context.Context) bool
PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse
}
// TODO: deprecate as soon as scale sets are available.
type publicIPAddressesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error)
}
type publicIPAddressesAPI interface {
ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager
// TODO: deprecate as soon as scale sets are available.
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error)
// TODO: deprecate as soon as scale sets are available.
Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error)
}
// TODO: deprecate as soon as scale sets are available.
type interfacesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armnetwork.InterfacesClientCreateOrUpdateResponse, error)
}
type networkInterfacesAPI interface {
GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
// TODO: deprecate as soon as scale sets are available
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
interfacesClientCreateOrUpdatePollerResponse, error)
}
type resourceGroupsDeletePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armresources.ResourceGroupsClientDeleteResponse, error)
}
type resourceGroupAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error)
BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error)
Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error)
}
type applicationsAPI interface {
Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error)
Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error)
UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error)
}
type servicePrincipalsAPI interface {
Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error)
}
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
// TODO: switch to "armauthorization.RoleAssignmentsClient" if issue is resolved.
type roleAssignmentsAPI interface {
Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesClientCreateOrUpdatePollerResponse interface {
PollUntilDone(ctx context.Context, freq time.Duration) (armcompute.VirtualMachinesClientCreateOrUpdateResponse, error)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions) (virtualMachinesClientCreateOrUpdatePollerResponse, error)
}

View file

@ -0,0 +1,404 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type stubNetworksAPI struct {
createErr error
stubResponse stubVirtualNetworksCreateOrUpdatePollerResponse
}
type stubVirtualNetworksCreateOrUpdatePollerResponse struct {
armnetwork.VirtualNetworksClientCreateOrUpdatePollerResponse
pollerErr error
}
func (r stubVirtualNetworksCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.VirtualNetworksClientCreateOrUpdateResponse, error) {
return armnetwork.VirtualNetworksClientCreateOrUpdateResponse{
VirtualNetworksClientCreateOrUpdateResult: armnetwork.VirtualNetworksClientCreateOrUpdateResult{
VirtualNetwork: armnetwork.VirtualNetwork{
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
Subnets: []*armnetwork.Subnet{
{
ID: to.StringPtr("virtual-network-subnet-id"),
},
},
},
},
},
}, r.pollerErr
}
func (a stubNetworksAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
type stubLoadBalancersAPI struct {
createErr error
stubResponse stubLoadBalancersClientCreateOrUpdatePollerResponse
}
type stubLoadBalancersClientCreateOrUpdatePollerResponse struct {
pollResponse armnetwork.LoadBalancersClientCreateOrUpdateResponse
pollErr error
}
func (r stubLoadBalancersClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.LoadBalancersClientCreateOrUpdateResponse, error) {
return r.pollResponse, r.pollErr
}
func (a stubLoadBalancersAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
loadBalancerName string, parameters armnetwork.LoadBalancer,
options *armnetwork.LoadBalancersClientBeginCreateOrUpdateOptions) (
loadBalancersClientCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
type stubNetworkSecurityGroupsCreateOrUpdatePollerResponse struct {
armnetwork.SecurityGroupsClientCreateOrUpdatePollerResponse
pollerErr error
}
func (r stubNetworkSecurityGroupsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration,
) (armnetwork.SecurityGroupsClientCreateOrUpdateResponse, error) {
return armnetwork.SecurityGroupsClientCreateOrUpdateResponse{
SecurityGroupsClientCreateOrUpdateResult: armnetwork.SecurityGroupsClientCreateOrUpdateResult{
SecurityGroup: armnetwork.SecurityGroup{
ID: to.StringPtr("network-security-group-id"),
},
},
}, r.pollerErr
}
type stubNetworkSecurityGroupsAPI struct {
createErr error
stubPoller stubNetworkSecurityGroupsCreateOrUpdatePollerResponse
}
func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
) {
return a.stubPoller, a.createErr
}
type stubResourceGroupAPI struct {
terminateErr error
createErr error
getErr error
getResourceGroup armresources.ResourceGroup
stubResponse stubResourceGroupsDeletePollerResponse
}
func (a stubResourceGroupAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string,
parameters armresources.ResourceGroup,
options *armresources.ResourceGroupsClientCreateOrUpdateOptions) (
armresources.ResourceGroupsClientCreateOrUpdateResponse, error,
) {
return armresources.ResourceGroupsClientCreateOrUpdateResponse{}, a.createErr
}
func (a stubResourceGroupAPI) Get(ctx context.Context, resourceGroupName string, options *armresources.ResourceGroupsClientGetOptions) (armresources.ResourceGroupsClientGetResponse, error) {
return armresources.ResourceGroupsClientGetResponse{
ResourceGroupsClientGetResult: armresources.ResourceGroupsClientGetResult{
ResourceGroup: a.getResourceGroup,
},
}, a.getErr
}
type stubResourceGroupsDeletePollerResponse struct {
armresources.ResourceGroupsClientDeletePollerResponse
pollerErr error
}
func (r stubResourceGroupsDeletePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armresources.ResourceGroupsClientDeleteResponse, error,
) {
return armresources.ResourceGroupsClientDeleteResponse{}, r.pollerErr
}
func (a stubResourceGroupAPI) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error,
) {
return a.stubResponse, a.terminateErr
}
type stubScaleSetsAPI struct {
createErr error
stubResponse stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse
}
type stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse struct {
pollResponse armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse
pollErr error
}
func (r stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse, error,
) {
return r.pollResponse, r.pollErr
}
func (a stubScaleSetsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
) {
return a.stubResponse, a.createErr
}
type stubPublicIPAddressesAPI struct {
createErr error
getErr error
stubCreateResponse stubPublicIPAddressesClientCreateOrUpdatePollerResponse
}
type stubPublicIPAddressesClientCreateOrUpdatePollerResponse struct {
armnetwork.PublicIPAddressesClientCreateOrUpdatePollerResponse
pollErr error
}
func (r stubPublicIPAddressesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armnetwork.PublicIPAddressesClientCreateOrUpdateResponse, error,
) {
return armnetwork.PublicIPAddressesClientCreateOrUpdateResponse{
PublicIPAddressesClientCreateOrUpdateResult: armnetwork.PublicIPAddressesClientCreateOrUpdateResult{
PublicIPAddress: armnetwork.PublicIPAddress{
ID: to.StringPtr("pubIP-id"),
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
}, r.pollErr
}
type stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager struct {
pagesCounter int
PagesMax int
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) NextPage(ctx context.Context) bool {
p.pagesCounter++
return p.pagesCounter <= p.PagesMax
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) PageResponse() armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse {
return armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse{
PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult: armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResult{
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
Value: []*armnetwork.PublicIPAddress{
{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
}
}
func (a stubPublicIPAddressesAPI) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
return &stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager{pagesCounter: 0, PagesMax: 1}
}
func (a stubPublicIPAddressesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
) {
return a.stubCreateResponse, a.createErr
}
func (a stubPublicIPAddressesAPI) Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error,
) {
return armnetwork.PublicIPAddressesClientGetResponse{
PublicIPAddressesClientGetResult: armnetwork.PublicIPAddressesClientGetResult{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
},
}, a.getErr
}
type stubNetworkInterfacesAPI struct {
getErr error
createErr error
stubResp stubInterfacesClientCreateOrUpdatePollerResponse
}
func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) {
if a.getErr != nil {
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{}, a.getErr
}
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{
InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult: armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult{
Interface: armnetwork.Interface{
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
}, nil
}
// TODO: deprecate as soon as scale sets are available.
type stubInterfacesClientCreateOrUpdatePollerResponse struct {
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubInterfacesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armnetwork.InterfacesClientCreateOrUpdateResponse, error,
) {
return armnetwork.InterfacesClientCreateOrUpdateResponse{
InterfacesClientCreateOrUpdateResult: armnetwork.InterfacesClientCreateOrUpdateResult{
Interface: armnetwork.Interface{
ID: to.StringPtr("interface-id"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
}, r.pollErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubNetworkInterfacesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions) (
interfacesClientCreateOrUpdatePollerResponse, error,
) {
return a.stubResp, a.createErr
}
// TODO: deprecate as soon as scale sets are available.
type stubVirtualMachinesAPI struct {
stubResponse stubVirtualMachinesClientCreateOrUpdatePollerResponse
createErr error
}
// TODO: deprecate as soon as scale sets are available.
type stubVirtualMachinesClientCreateOrUpdatePollerResponse struct {
pollResponse armcompute.VirtualMachinesClientCreateOrUpdateResponse
pollErr error
}
// TODO: deprecate as soon as scale sets are available.
func (r stubVirtualMachinesClientCreateOrUpdatePollerResponse) PollUntilDone(ctx context.Context, freq time.Duration) (
armcompute.VirtualMachinesClientCreateOrUpdateResponse, error,
) {
return r.pollResponse, r.pollErr
}
// TODO: deprecate as soon as scale sets are available.
func (a stubVirtualMachinesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
return a.stubResponse, a.createErr
}
type stubApplicationsAPI struct {
createErr error
deleteErr error
updateCredentialsErr error
createApplication *graphrbac.Application
}
func (a stubApplicationsAPI) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
if a.createErr != nil {
return graphrbac.Application{}, a.createErr
}
if a.createApplication != nil {
return *a.createApplication, nil
}
return graphrbac.Application{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000001"),
}, nil
}
func (a stubApplicationsAPI) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
if a.deleteErr != nil {
return autorest.Response{}, a.deleteErr
}
return autorest.Response{}, nil
}
func (a stubApplicationsAPI) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
if a.updateCredentialsErr != nil {
return autorest.Response{}, a.updateCredentialsErr
}
return autorest.Response{}, nil
}
type stubServicePrincipalsAPI struct {
createErr error
createServicePrincipal *graphrbac.ServicePrincipal
}
func (a stubServicePrincipalsAPI) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
if a.createErr != nil {
return graphrbac.ServicePrincipal{}, a.createErr
}
if a.createServicePrincipal != nil {
return *a.createServicePrincipal, nil
}
return graphrbac.ServicePrincipal{
AppID: to.StringPtr("00000000-0000-0000-0000-000000000000"),
ObjectID: to.StringPtr("00000000-0000-0000-0000-000000000002"),
}, nil
}
type stubRoleAssignmentsAPI struct {
createCounter int
createErrors []error
}
func (a *stubRoleAssignmentsAPI) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
a.createCounter++
if len(a.createErrors) == 0 {
return authorization.RoleAssignment{}, nil
}
return authorization.RoleAssignment{}, a.createErrors[(a.createCounter-1)%len(a.createErrors)]
}

View file

@ -0,0 +1,147 @@
package client
import (
"context"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type networksClient struct {
*armnetwork.VirtualNetworksClient
}
func (c *networksClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
virtualNetworksCreateOrUpdatePollerResponse, error,
) {
return c.VirtualNetworksClient.BeginCreateOrUpdate(ctx, resourceGroupName, virtualNetworkName, parameters, options)
}
// TODO: deprecate as soon as scale sets are available.
type networkInterfacesClient struct {
*armnetwork.InterfacesClient
}
// TODO: deprecate as soon as scale sets are available.
func (c *networkInterfacesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, networkInterfaceName string,
parameters armnetwork.Interface, options *armnetwork.InterfacesClientBeginCreateOrUpdateOptions,
) (interfacesClientCreateOrUpdatePollerResponse, error) {
return c.InterfacesClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkInterfaceName, parameters, options)
}
type loadBalancersClient struct {
*armnetwork.LoadBalancersClient
}
func (c *loadBalancersClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, loadBalancerName string,
parameters armnetwork.LoadBalancer, options *armnetwork.LoadBalancersClientBeginCreateOrUpdateOptions) (
loadBalancersClientCreateOrUpdatePollerResponse, error,
) {
return c.LoadBalancersClient.BeginCreateOrUpdate(ctx, resourceGroupName, loadBalancerName, parameters, options)
}
type networkSecurityGroupsClient struct {
*armnetwork.SecurityGroupsClient
}
func (c *networkSecurityGroupsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
networkSecurityGroupsCreateOrUpdatePollerResponse, error,
) {
return c.SecurityGroupsClient.BeginCreateOrUpdate(ctx, resourceGroupName, networkSecurityGroupName, parameters, options)
}
type publicIPAddressesClient struct {
*armnetwork.PublicIPAddressesClient
}
func (c *publicIPAddressesClient) ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string,
networkInterfaceName string, ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) publicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager {
return c.PublicIPAddressesClient.ListVirtualMachineScaleSetVMPublicIPAddresses(resourceGroupName, virtualMachineScaleSetName,
virtualmachineIndex, networkInterfaceName, ipConfigurationName, options)
}
// TODO: deprecate as soon as scale sets are available.
func (c *publicIPAddressesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
publicIPAddressesClientCreateOrUpdatePollerResponse, error,
) {
return c.PublicIPAddressesClient.BeginCreateOrUpdate(ctx, resourceGroupName, publicIPAddressName, parameters, options)
}
type virtualMachineScaleSetsClient struct {
*armcompute.VirtualMachineScaleSetsClient
}
func (c *virtualMachineScaleSetsClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcompute.VirtualMachineScaleSet,
options *armcompute.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
virtualMachineScaleSetsCreateOrUpdatePollerResponse, error,
) {
return c.VirtualMachineScaleSetsClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmScaleSetName, parameters, options)
}
type resourceGroupsClient struct {
*armresources.ResourceGroupsClient
}
func (c *resourceGroupsClient) BeginDelete(ctx context.Context, resourceGroupName string,
options *armresources.ResourceGroupsClientBeginDeleteOptions) (
resourceGroupsDeletePollerResponse, error,
) {
return c.ResourceGroupsClient.BeginDelete(ctx, resourceGroupName, options)
}
// TODO: deprecate as soon as scale sets are available.
type virtualMachinesClient struct {
*armcompute.VirtualMachinesClient
}
// TODO: deprecate as soon as scale sets are available.
func (c *virtualMachinesClient) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmName string, parameters armcompute.VirtualMachine,
options *armcompute.VirtualMachinesClientBeginCreateOrUpdateOptions,
) (virtualMachinesClientCreateOrUpdatePollerResponse, error) {
return c.VirtualMachinesClient.BeginCreateOrUpdate(ctx, resourceGroupName, vmName, parameters, options)
}
type applicationsClient struct {
*graphrbac.ApplicationsClient
}
func (c *applicationsClient) Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error) {
return c.ApplicationsClient.Create(ctx, parameters)
}
func (c *applicationsClient) Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error) {
return c.ApplicationsClient.Delete(ctx, applicationObjectID)
}
func (c *applicationsClient) UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error) {
return c.ApplicationsClient.UpdatePasswordCredentials(ctx, objectID, parameters)
}
type servicePrincipalsClient struct {
*graphrbac.ServicePrincipalsClient
}
func (c *servicePrincipalsClient) Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error) {
return c.ServicePrincipalsClient.Create(ctx, parameters)
}
type roleAssignmentsClient struct {
*authorization.RoleAssignmentsClient
}
func (c *roleAssignmentsClient) Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error) {
return c.RoleAssignmentsClient.Create(ctx, scope, roleAssignmentName, parameters)
}

View file

@ -0,0 +1,280 @@
package client
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
)
const (
graphAPIResource = "https://graph.windows.net"
managementAPIResource = "https://management.azure.com"
)
// Client is a client for Azure.
type Client struct {
networksAPI
networkSecurityGroupsAPI
resourceGroupAPI
scaleSetsAPI
publicIPAddressesAPI
networkInterfacesAPI
loadBalancersAPI
virtualMachinesAPI
applicationsAPI
servicePrincipalsAPI
roleAssignmentsAPI
adReplicationLagCheckInterval time.Duration
adReplicationLagCheckMaxRetries int
nodes cloudtypes.Instances
coordinators cloudtypes.Instances
name string
uid string
resourceGroup string
location string
subscriptionID string
tenantID string
subnetID string
coordinatorsScaleSet string
nodesScaleSet string
loadBalancerName string
loadBalancerPubIP string
networkSecurityGroup string
adAppObjectID string
}
// NewFromDefault creates a client with initialized clients.
func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
graphAuthorizer, err := getAuthorizer(graphAPIResource)
if err != nil {
return nil, err
}
managementAuthorizer, err := getAuthorizer(managementAPIResource)
if err != nil {
return nil, err
}
netAPI := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
netSecGrpAPI := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
resGroupAPI := armresources.NewResourceGroupsClient(subscriptionID, cred, nil)
scaleSetAPI := armcompute.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
publicIPAddressesAPI := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
networkInterfacesAPI := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
loadBalancersAPI := armnetwork.NewLoadBalancersClient(subscriptionID, cred, nil)
virtualMachinesAPI := armcompute.NewVirtualMachinesClient(subscriptionID, cred, nil)
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
servicePrincipalsAPI.Authorizer = graphAuthorizer
roleAssignmentsAPI := authorization.NewRoleAssignmentsClient(subscriptionID)
roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{
networksAPI: &networksClient{netAPI},
networkSecurityGroupsAPI: &networkSecurityGroupsClient{netSecGrpAPI},
resourceGroupAPI: &resourceGroupsClient{resGroupAPI},
scaleSetsAPI: &virtualMachineScaleSetsClient{scaleSetAPI},
publicIPAddressesAPI: &publicIPAddressesClient{publicIPAddressesAPI},
networkInterfacesAPI: &networkInterfacesClient{networkInterfacesAPI},
loadBalancersAPI: &loadBalancersClient{loadBalancersAPI},
applicationsAPI: &applicationsClient{&applicationsAPI},
servicePrincipalsAPI: &servicePrincipalsClient{&servicePrincipalsAPI},
roleAssignmentsAPI: &roleAssignmentsClient{&roleAssignmentsAPI},
virtualMachinesAPI: &virtualMachinesClient{virtualMachinesAPI},
subscriptionID: subscriptionID,
tenantID: tenantID,
nodes: cloudtypes.Instances{},
coordinators: cloudtypes.Instances{},
adReplicationLagCheckInterval: adReplicationLagCheckInterval,
adReplicationLagCheckMaxRetries: adReplicationLagCheckMaxRetries,
}, nil
}
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil {
return nil, err
}
err = client.init(location, name)
return client, err
}
// init initializes the client.
func (c *Client) init(location, name string) error {
c.location = location
c.name = name
uid, err := c.generateUID()
if err != nil {
return err
}
c.uid = uid
return nil
}
// GetState returns the state of the client as ConstellationState.
func (c *Client) GetState() (state.ConstellationState, error) {
var stat state.ConstellationState
stat.CloudProvider = cloudprovider.Azure.String()
if len(c.resourceGroup) == 0 {
return state.ConstellationState{}, errors.New("client has no resource group")
}
stat.AzureResourceGroup = c.resourceGroup
if c.name == "" {
return state.ConstellationState{}, errors.New("client has no name")
}
stat.Name = c.name
if len(c.uid) == 0 {
return state.ConstellationState{}, errors.New("client has no uid")
}
stat.UID = c.uid
if len(c.location) == 0 {
return state.ConstellationState{}, errors.New("client has no location")
}
stat.AzureLocation = c.location
if len(c.subscriptionID) == 0 {
return state.ConstellationState{}, errors.New("client has no subscription")
}
stat.AzureSubscription = c.subscriptionID
if len(c.tenantID) == 0 {
return state.ConstellationState{}, errors.New("client has no tenant")
}
stat.AzureTenant = c.tenantID
if len(c.subnetID) == 0 {
return state.ConstellationState{}, errors.New("client has no subnet")
}
stat.AzureSubnet = c.subnetID
if len(c.networkSecurityGroup) == 0 {
return state.ConstellationState{}, errors.New("client has no network security group")
}
stat.AzureNetworkSecurityGroup = c.networkSecurityGroup
if len(c.nodesScaleSet) == 0 {
return state.ConstellationState{}, errors.New("client has no nodes scale set")
}
stat.AzureNodesScaleSet = c.nodesScaleSet
if len(c.coordinatorsScaleSet) == 0 {
return state.ConstellationState{}, errors.New("client has no coordinators scale set")
}
stat.AzureCoordinatorsScaleSet = c.coordinatorsScaleSet
if len(c.nodes) == 0 {
return state.ConstellationState{}, errors.New("client has no nodes")
}
stat.AzureNodes = c.nodes
if len(c.coordinators) == 0 {
return state.ConstellationState{}, errors.New("client has no coordinators")
}
stat.AzureCoordinators = c.coordinators
// AD App Object ID does not have to be set at all times
stat.AzureADAppObjectID = c.adAppObjectID
return stat, nil
}
// SetState sets the state of the client to the handed ConstellationState.
func (c *Client) SetState(stat state.ConstellationState) error {
if stat.CloudProvider != cloudprovider.Azure.String() {
return errors.New("state is not azure state")
}
if len(stat.AzureResourceGroup) == 0 {
return errors.New("state has no resource group")
}
c.resourceGroup = stat.AzureResourceGroup
if stat.Name == "" {
return errors.New("state has no name")
}
c.name = stat.Name
if len(stat.UID) == 0 {
return errors.New("state has no uuid")
}
c.uid = stat.UID
if len(stat.AzureLocation) == 0 {
return errors.New("state has no location")
}
c.location = stat.AzureLocation
if len(stat.AzureSubscription) == 0 {
return errors.New("state has no subscription")
}
c.subscriptionID = stat.AzureSubscription
if len(stat.AzureTenant) == 0 {
return errors.New("state has no tenant")
}
c.tenantID = stat.AzureTenant
if len(stat.AzureSubnet) == 0 {
return errors.New("state has no subnet")
}
c.subnetID = stat.AzureSubnet
if len(stat.AzureNetworkSecurityGroup) == 0 {
return errors.New("state has no subnet")
}
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
if len(stat.AzureNodesScaleSet) == 0 {
return errors.New("state has no nodes scale set")
}
c.nodesScaleSet = stat.AzureNodesScaleSet
if len(stat.AzureCoordinatorsScaleSet) == 0 {
return errors.New("state has no nodes scale set")
}
c.coordinatorsScaleSet = stat.AzureCoordinatorsScaleSet
if len(stat.AzureNodes) == 0 {
return errors.New("state has no nodes")
}
c.nodes = stat.AzureNodes
if len(stat.AzureCoordinators) == 0 {
return errors.New("state has no coordinators")
}
c.coordinators = stat.AzureCoordinators
// AD App Object ID does not have to be set at all times
c.adAppObjectID = stat.AzureADAppObjectID
return nil
}
func (c *Client) generateUID() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
const uidLen = 5
uid := make([]byte, uidLen)
for i := 0; i < uidLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
uid[i] = letters[n.Int64()]
}
return string(uid), nil
}
// getAuthorizer creates an autorest.Authorizer for different Azure AD APIs using either environment variables or azure cli credentials.
func getAuthorizer(resource string) (autorest.Authorizer, error) {
authorizer, cliErr := auth.NewAuthorizerFromCLIWithResource(resource)
if cliErr == nil {
return authorizer, nil
}
authorizer, envErr := auth.NewAuthorizerFromEnvironmentWithResource(resource)
if envErr == nil {
return authorizer, nil
}
return nil, fmt.Errorf("unable to create authorizer from env or cli: %v %v", envErr, cliErr)
}

View file

@ -0,0 +1,484 @@
package client
import (
"testing"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/state"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSetGetState(t *testing.T) {
testCases := map[string]struct {
state state.ConstellationState
wantErr bool
}{
"valid state": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
},
"missing nodes": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing coordinator": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing name": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing uid": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing resource group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing location": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing subscription": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureTenant: "tenant",
AzureLocation: "location",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing tenant": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureSubscription: "subscription",
AzureLocation: "location",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing subnet": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing network security group": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing node scale set": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
},
wantErr: true,
},
"missing coordinator scale set": {
state: state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
},
wantErr: true,
},
}
t.Run("SetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{}
if tc.wantErr {
assert.Error(client.SetState(tc.state))
} else {
assert.NoError(client.SetState(tc.state))
assert.Equal(tc.state.AzureNodes, client.nodes)
assert.Equal(tc.state.AzureCoordinators, client.coordinators)
assert.Equal(tc.state.Name, client.name)
assert.Equal(tc.state.UID, client.uid)
assert.Equal(tc.state.AzureResourceGroup, client.resourceGroup)
assert.Equal(tc.state.AzureLocation, client.location)
assert.Equal(tc.state.AzureSubscription, client.subscriptionID)
assert.Equal(tc.state.AzureTenant, client.tenantID)
assert.Equal(tc.state.AzureSubnet, client.subnetID)
assert.Equal(tc.state.AzureNetworkSecurityGroup, client.networkSecurityGroup)
assert.Equal(tc.state.AzureNodesScaleSet, client.nodesScaleSet)
assert.Equal(tc.state.AzureCoordinatorsScaleSet, client.coordinatorsScaleSet)
}
})
}
})
t.Run("GetState", func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{
nodes: tc.state.AzureNodes,
coordinators: tc.state.AzureCoordinators,
name: tc.state.Name,
uid: tc.state.UID,
resourceGroup: tc.state.AzureResourceGroup,
location: tc.state.AzureLocation,
subscriptionID: tc.state.AzureSubscription,
tenantID: tc.state.AzureTenant,
subnetID: tc.state.AzureSubnet,
networkSecurityGroup: tc.state.AzureNetworkSecurityGroup,
nodesScaleSet: tc.state.AzureNodesScaleSet,
coordinatorsScaleSet: tc.state.AzureCoordinatorsScaleSet,
}
if tc.wantErr {
_, err := client.GetState()
assert.Error(err)
} else {
state, err := client.GetState()
assert.NoError(err)
assert.Equal(tc.state, state)
}
})
}
})
}
func TestSetStateCloudProvider(t *testing.T) {
assert := assert.New(t)
client := Client{}
stateMissingCloudProvider := state.ConstellationState{
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
}
assert.Error(client.SetState(stateMissingCloudProvider))
stateIncorrectCloudProvider := state.ConstellationState{
CloudProvider: "incorrect",
AzureNodes: cloudtypes.Instances{
"0": {
PublicIP: "ip1",
PrivateIP: "ip2",
},
},
AzureCoordinators: cloudtypes.Instances{
"0": {
PublicIP: "ip3",
PrivateIP: "ip4",
},
},
Name: "name",
UID: "uid",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureNodesScaleSet: "node-scale-set",
AzureCoordinatorsScaleSet: "coordinator-scale-set",
}
assert.Error(client.SetState(stateIncorrectCloudProvider))
}
func TestInit(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{}
require.NoError(client.init("location", "name"))
assert.Equal("location", client.location)
assert.Equal("name", client.name)
assert.NotEmpty(client.uid)
}

View file

@ -0,0 +1,305 @@
package client
import (
"context"
"errors"
"strconv"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/cli/internal/azure"
)
func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error {
// Create nodes scale set
createNodesInput := CreateScaleSetInput{
Name: "constellation-scale-set-nodes-" + c.uid,
NamePrefix: c.name + "-worker-" + c.uid + "-",
Count: input.CountNodes,
InstanceType: input.InstanceType,
StateDiskSizeGB: int32(input.StateDiskSizeGB),
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
LoadBalancerBackendAddressPool: azure.BackendAddressPoolWorkerName + "-" + c.uid,
}
if err := c.createScaleSet(ctx, createNodesInput); err != nil {
return err
}
c.nodesScaleSet = createNodesInput.Name
// Create coordinator scale set
createCoordinatorsInput := CreateScaleSetInput{
Name: "constellation-scale-set-coordinators-" + c.uid,
NamePrefix: c.name + "-control-plane-" + c.uid + "-",
Count: input.CountCoordinators,
InstanceType: input.InstanceType,
StateDiskSizeGB: int32(input.StateDiskSizeGB),
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
LoadBalancerBackendAddressPool: azure.BackendAddressPoolControlPlaneName + "-" + c.uid,
}
if err := c.createScaleSet(ctx, createCoordinatorsInput); err != nil {
return err
}
// Get nodes IPs
instances, err := c.getInstanceIPs(ctx, createNodesInput.Name, createNodesInput.Count)
if err != nil {
return err
}
c.nodes = instances
// Get coordinators IPs
c.coordinatorsScaleSet = createCoordinatorsInput.Name
instances, err = c.getInstanceIPs(ctx, createCoordinatorsInput.Name, createCoordinatorsInput.Count)
if err != nil {
return err
}
c.coordinators = instances
// Set the load balancer public IP in the first coordinator
coord, ok := c.coordinators["0"]
if !ok {
return errors.New("coordinator 0 not found")
}
coord.PublicIP = c.loadBalancerPubIP
c.coordinators["0"] = coord
return nil
}
// CreateInstancesInput is the input for a CreateInstances operation.
type CreateInstancesInput struct {
CountNodes int
CountCoordinators int
InstanceType string
StateDiskSizeGB int
Image string
UserAssingedIdentity string
}
// CreateInstancesVMs creates instances based on standalone VMs.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) CreateInstancesVMs(ctx context.Context, input CreateInstancesInput) error {
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
for i := 0; i < input.CountCoordinators; i++ {
vm := azure.VMInstance{
Name: c.name + "-control-plane-" + c.uid + "-" + strconv.Itoa(i),
Username: "constell",
Password: pw,
Location: c.location,
InstanceType: input.InstanceType,
Image: input.Image,
}
instance, err := c.createInstanceVM(ctx, vm)
if err != nil {
return err
}
c.coordinators[strconv.Itoa(i)] = instance
}
for i := 0; i < input.CountNodes; i++ {
vm := azure.VMInstance{
Name: c.name + "-node-" + c.uid + "-" + strconv.Itoa(i),
Username: "constell",
Password: pw,
Location: c.location,
InstanceType: input.InstanceType,
Image: input.Image,
}
instance, err := c.createInstanceVM(ctx, vm)
if err != nil {
return err
}
c.nodes[strconv.Itoa(i)] = instance
}
return nil
}
// createInstanceVM creates a single VM with a public IP address
// and a network interface.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createInstanceVM(ctx context.Context, input azure.VMInstance) (cloudtypes.Instance, error) {
pubIPName := input.Name + "-pubIP"
pubIP, err := c.createPublicIPAddress(ctx, pubIPName)
if err != nil {
return cloudtypes.Instance{}, err
}
nicName := input.Name + "-NIC"
privIP, nicID, err := c.createNIC(ctx, nicName, *pubIP.ID)
if err != nil {
return cloudtypes.Instance{}, err
}
input.NIC = nicID
poller, err := c.virtualMachinesAPI.BeginCreateOrUpdate(ctx, c.resourceGroup, input.Name, input.Azure(), nil)
if err != nil {
return cloudtypes.Instance{}, err
}
vm, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return cloudtypes.Instance{}, err
}
if vm.Identity == nil || vm.Identity.PrincipalID == nil {
return cloudtypes.Instance{}, errors.New("virtual machine was created without system managed identity")
}
if err := c.assignResourceGroupRole(ctx, *vm.Identity.PrincipalID, virtualMachineContributorRoleDefinitionID); err != nil {
return cloudtypes.Instance{}, err
}
res, err := c.publicIPAddressesAPI.Get(ctx, c.resourceGroup, pubIPName, nil)
if err != nil {
return cloudtypes.Instance{}, err
}
return cloudtypes.Instance{PublicIP: *res.PublicIPAddressesClientGetResult.PublicIPAddress.Properties.IPAddress, PrivateIP: privIP}, nil
}
func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) error {
// TODO: Generating a random password to be able
// to create the scale set. This is a temporary fix.
// We need to think about azure access at some point.
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
scaleSet := azure.ScaleSet{
Name: input.Name,
NamePrefix: input.NamePrefix,
Location: c.location,
InstanceType: input.InstanceType,
StateDiskSizeGB: input.StateDiskSizeGB,
Count: int64(input.Count),
Username: "constellation",
SubnetID: c.subnetID,
NetworkSecurityGroup: c.networkSecurityGroup,
Image: input.Image,
Password: pw,
UserAssignedIdentity: input.UserAssingedIdentity,
Subscription: c.subscriptionID,
ResourceGroup: c.resourceGroup,
LoadBalancerName: c.loadBalancerName,
LoadBalancerBackendAddressPool: input.LoadBalancerBackendAddressPool,
}.Azure()
poller, err := c.scaleSetsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, input.Name,
scaleSet,
nil,
)
if err != nil {
return err
}
_, err = poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
return nil
}
func (c *Client) getInstanceIPs(ctx context.Context, scaleSet string, count int) (cloudtypes.Instances, error) {
instances := cloudtypes.Instances{}
for i := 0; i < count; i++ {
// get public ip address
var publicIPAddress string
pager := c.publicIPAddressesAPI.ListVirtualMachineScaleSetVMPublicIPAddresses(
c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, scaleSet, nil)
// We always need one pager.NextPage, since calling
// pager.PageResponse() directly return no result.
// We expect to get one page with one entry for each VM.
for pager.NextPage(ctx) {
for _, v := range pager.PageResponse().Value {
if v.Properties != nil && v.Properties.IPAddress != nil {
publicIPAddress = *v.Properties.IPAddress
break
}
}
}
// get private ip address
var privateIPAddress string
res, err := c.networkInterfacesAPI.GetVirtualMachineScaleSetNetworkInterface(
ctx, c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, nil)
if err != nil {
return nil, err
}
configs := res.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResult.Interface.Properties.IPConfigurations
for _, config := range configs {
privateIPAddress = *config.Properties.PrivateIPAddress
break
}
instance := cloudtypes.Instance{
PrivateIP: privateIPAddress,
PublicIP: publicIPAddress,
}
instances[strconv.Itoa(i)] = instance
}
return instances, nil
}
// CreateScaleSetInput is the input for a CreateScaleSet operation.
type CreateScaleSetInput struct {
Name string
NamePrefix string
Count int
InstanceType string
StateDiskSizeGB int32
Image string
UserAssingedIdentity string
LoadBalancerBackendAddressPool string
}
// CreateResourceGroup creates a resource group.
func (c *Client) CreateResourceGroup(ctx context.Context) error {
_, err := c.resourceGroupAPI.CreateOrUpdate(ctx, c.name+"-"+c.uid,
armresources.ResourceGroup{
Location: &c.location,
}, nil)
if err != nil {
return err
}
c.resourceGroup = c.name + "-" + c.uid
return nil
}
// TerminateResourceGroup terminates a resource group.
func (c *Client) TerminateResourceGroup(ctx context.Context) error {
if c.resourceGroup == "" {
return nil
}
poller, err := c.resourceGroupAPI.BeginDelete(ctx, c.resourceGroup, nil)
if err != nil {
return err
}
if _, err = poller.PollUntilDone(ctx, 30*time.Second); err != nil {
return err
}
c.nodes = nil
c.coordinators = nil
c.resourceGroup = ""
c.subnetID = ""
c.networkSecurityGroup = ""
c.nodesScaleSet = ""
c.coordinatorsScaleSet = ""
return nil
}

View file

@ -0,0 +1,384 @@
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
resourceGroupAPI resourceGroupAPI
wantErr bool
}{
"successful create": {
resourceGroupAPI: stubResourceGroupAPI{},
},
"failed create": {
resourceGroupAPI: stubResourceGroupAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroupAPI: tc.resourceGroupAPI,
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateResourceGroup(ctx))
} else {
assert.NoError(client.CreateResourceGroup(ctx))
assert.Equal(client.name+"-"+client.uid, client.resourceGroup)
}
})
}
}
func TestTerminateResourceGroup(t *testing.T) {
someErr := errors.New("failed")
clientWithResourceGroup := Client{
resourceGroup: "name",
location: "location",
name: "name",
uid: "uid",
subnetID: "subnet",
nodesScaleSet: "node-scale-set",
coordinatorsScaleSet: "coordinator-scale-set",
nodes: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
coordinators: cloudtypes.Instances{
"0": {
PublicIP: "192.0.2.1", PrivateIP: "192.0.2.1",
},
},
}
testCases := map[string]struct {
resourceGroup string
resourceGroupAPI resourceGroupAPI
client Client
wantErr bool
}{
"successful terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: clientWithResourceGroup,
},
"no resource group to terminate": {
resourceGroupAPI: stubResourceGroupAPI{},
client: Client{},
resourceGroup: "",
},
"failed terminate": {
resourceGroupAPI: stubResourceGroupAPI{terminateErr: someErr},
client: clientWithResourceGroup,
wantErr: true,
},
"failed to poll terminate response": {
resourceGroupAPI: stubResourceGroupAPI{stubResponse: stubResourceGroupsDeletePollerResponse{pollerErr: someErr}},
client: clientWithResourceGroup,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tc.client.resourceGroupAPI = tc.resourceGroupAPI
ctx := context.Background()
if tc.wantErr {
assert.Error(tc.client.TerminateResourceGroup(ctx))
return
}
assert.NoError(tc.client.TerminateResourceGroup(ctx))
assert.Empty(tc.client.resourceGroup)
assert.Empty(tc.client.subnetID)
assert.Empty(tc.client.nodes)
assert.Empty(tc.client.coordinators)
assert.Empty(tc.client.nodesScaleSet)
assert.Empty(tc.client.coordinatorsScaleSet)
})
}
}
func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{
stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{
pollResponse: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResponse{
VirtualMachineScaleSetsClientCreateOrUpdateResult: armcompute.VirtualMachineScaleSetsClientCreateOrUpdateResult{
VirtualMachineScaleSet: armcompute.VirtualMachineScaleSet{Identity: &armcompute.VirtualMachineScaleSetIdentity{PrincipalID: to.StringPtr("principal-id")}},
},
},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
wantErr: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{stubResponse: stubVirtualMachineScaleSetsCreateOrUpdatePollerResponse{pollErr: someErr}},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
wantErr: true,
},
"error when retrieving private IPs": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountNodes: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
loadBalancerPubIP: "lbip",
}
if tc.wantErr {
assert.Error(client.CreateInstances(ctx, tc.createInstancesInput))
} else {
assert.NoError(client.CreateInstances(ctx, tc.createInstancesInput))
assert.Equal(tc.createInstancesInput.CountCoordinators, len(client.coordinators))
assert.Equal(tc.createInstancesInput.CountNodes, len(client.nodes))
assert.NotEmpty(client.nodes["0"].PrivateIP)
assert.NotEmpty(client.nodes["0"].PublicIP)
assert.NotEmpty(client.coordinators["0"].PrivateIP)
assert.Equal("lbip", client.coordinators["0"].PublicIP)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreateInstancesVMs(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
virtualMachinesAPI virtualMachinesAPI
resourceGroupAPI resourceGroupAPI
roleAssignmentsAPI roleAssignmentsAPI
createInstancesInput CreateInstancesInput
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{
stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{
pollResponse: armcompute.VirtualMachinesClientCreateOrUpdateResponse{VirtualMachinesClientCreateOrUpdateResult: armcompute.VirtualMachinesClientCreateOrUpdateResult{
VirtualMachine: armcompute.VirtualMachine{
Identity: &armcompute.VirtualMachineIdentity{PrincipalID: to.StringPtr("principal-id")},
},
}},
},
},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{createErr: someErr},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
},
wantErr: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{stubResponse: stubVirtualMachinesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
},
wantErr: true,
},
"error when creating NIC": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
},
wantErr: true,
},
"error when creating public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
},
wantErr: true,
},
"error when retrieving public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{getErr: someErr},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
virtualMachinesAPI: stubVirtualMachinesAPI{},
resourceGroupAPI: newSuccessfulResourceGroupStub(),
roleAssignmentsAPI: &stubRoleAssignmentsAPI{},
createInstancesInput: CreateInstancesInput{
CountCoordinators: 3,
CountNodes: 3,
InstanceType: "type",
Image: "image",
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachinesAPI: tc.virtualMachinesAPI,
resourceGroupAPI: tc.resourceGroupAPI,
roleAssignmentsAPI: tc.roleAssignmentsAPI,
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
return
}
require.NoError(client.CreateInstancesVMs(ctx, tc.createInstancesInput))
assert.Equal(tc.createInstancesInput.CountCoordinators, len(client.coordinators))
assert.Equal(tc.createInstancesInput.CountNodes, len(client.nodes))
assert.NotEmpty(client.nodes["0"].PrivateIP)
assert.NotEmpty(client.nodes["0"].PublicIP)
assert.NotEmpty(client.coordinators["0"].PrivateIP)
assert.NotEmpty(client.coordinators["0"].PublicIP)
})
}
}
func newSuccessfulResourceGroupStub() *stubResourceGroupAPI {
return &stubResourceGroupAPI{
getResourceGroup: armresources.ResourceGroup{
ID: to.StringPtr("resource-group-id"),
},
}
}

View file

@ -0,0 +1,231 @@
package client
import (
"context"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/cli/internal/azure"
)
type createNetworkInput struct {
name string
location string
addressSpace string
nodeAddressSpace string
podAddressSpace string
}
const (
nodeNetworkName = "nodeNetwork"
podNetworkName = "podNetwork"
networkAddressSpace = "10.0.0.0/8"
nodeAddressSpace = "10.9.0.0/16"
podAddressSpace = "10.10.0.0/16"
)
// CreateVirtualNetwork creates a virtual network.
func (c *Client) CreateVirtualNetwork(ctx context.Context) error {
createNetworkInput := createNetworkInput{
name: "constellation-" + c.uid,
location: c.location,
addressSpace: networkAddressSpace,
nodeAddressSpace: nodeAddressSpace,
podAddressSpace: podAddressSpace,
}
poller, err := c.networksAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkInput.name,
armnetwork.VirtualNetwork{
Name: to.StringPtr(createNetworkInput.name), // this is supposed to be read-only
Location: to.StringPtr(createNetworkInput.location),
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
AddressSpace: &armnetwork.AddressSpace{
AddressPrefixes: []*string{
to.StringPtr(createNetworkInput.addressSpace),
},
},
Subnets: []*armnetwork.Subnet{
{
Name: to.StringPtr(nodeNetworkName),
Properties: &armnetwork.SubnetPropertiesFormat{
AddressPrefix: to.StringPtr(createNetworkInput.nodeAddressSpace),
},
},
{
Name: to.StringPtr(podNetworkName),
Properties: &armnetwork.SubnetPropertiesFormat{
AddressPrefix: to.StringPtr(createNetworkInput.podAddressSpace),
},
},
},
},
},
nil,
)
if err != nil {
return err
}
resp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.subnetID = *resp.VirtualNetworksClientCreateOrUpdateResult.VirtualNetwork.Properties.Subnets[0].ID
return nil
}
type createNetworkSecurityGroupInput struct {
name string
location string
rules []*armnetwork.SecurityRule
}
// CreateSecurityGroup creates a security group containing firewall rules.
func (c *Client) CreateSecurityGroup(ctx context.Context, input NetworkSecurityGroupInput) error {
rules, err := input.Ingress.Azure()
if err != nil {
return err
}
createNetworkSecurityGroupInput := createNetworkSecurityGroupInput{
name: "constellation-security-group-" + c.uid,
location: c.location,
rules: rules,
}
poller, err := c.networkSecurityGroupsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkSecurityGroupInput.name,
armnetwork.SecurityGroup{
Name: to.StringPtr(createNetworkSecurityGroupInput.name),
Location: to.StringPtr(createNetworkSecurityGroupInput.location),
Properties: &armnetwork.SecurityGroupPropertiesFormat{
SecurityRules: createNetworkSecurityGroupInput.rules,
},
},
nil,
)
if err != nil {
return err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.networkSecurityGroup = *pollerResp.SecurityGroupsClientCreateOrUpdateResult.SecurityGroup.ID
return nil
}
// createNIC creates a network interface that references a public IP address.
// TODO: deprecate as soon as scale sets are available.
func (c *Client) createNIC(ctx context.Context, name, publicIPAddressID string) (ip string, id string, err error) {
poller, err := c.networkInterfacesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.Interface{
Location: to.StringPtr(c.location),
Properties: &armnetwork.InterfacePropertiesFormat{
NetworkSecurityGroup: &armnetwork.SecurityGroup{
ID: to.StringPtr(c.networkSecurityGroup),
},
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Name: to.StringPtr(name),
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
Subnet: &armnetwork.Subnet{
ID: to.StringPtr(c.subnetID),
},
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.StringPtr(publicIPAddressID),
},
},
},
},
},
},
nil,
)
if err != nil {
return "", "", err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return "", "", err
}
netInterface := pollerResp.InterfacesClientCreateOrUpdateResult.Interface
return *netInterface.Properties.IPConfigurations[0].Properties.PrivateIPAddress,
*netInterface.ID,
nil
}
func (c *Client) createPublicIPAddress(ctx context.Context, name string) (*armnetwork.PublicIPAddress, error) {
poller, err := c.publicIPAddressesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.PublicIPAddress{
Location: to.StringPtr(c.location),
SKU: &armnetwork.PublicIPAddressSKU{
Name: armnetwork.PublicIPAddressSKUNameStandard.ToPtr(),
},
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
PublicIPAllocationMethod: armnetwork.IPAllocationMethodStatic.ToPtr(),
},
},
nil,
)
if err != nil {
return nil, err
}
pollerResp, err := poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return nil, err
}
return &pollerResp.PublicIPAddressesClientCreateOrUpdateResult.PublicIPAddress, nil
}
// NetworkSecurityGroupInput defines firewall rules to be set.
type NetworkSecurityGroupInput struct {
Ingress cloudtypes.Firewall
Egress cloudtypes.Firewall
}
// CreateExternalLoadBalancer creates an external load balancer.
func (c *Client) CreateExternalLoadBalancer(ctx context.Context) error {
// First, create a public IP address for the load balancer.
publicIPAddress, err := c.createPublicIPAddress(ctx, "loadbalancer-public-ip-"+c.uid)
if err != nil {
return err
}
// Then, create the load balancer.
loadBalancerName := "constellation-load-balancer-" + c.uid
loadBalancer := azure.LoadBalancer{
Name: loadBalancerName,
Location: c.location,
ResourceGroup: c.resourceGroup,
Subscription: c.subscriptionID,
PublicIPID: *publicIPAddress.ID,
UID: c.uid,
}
azureLoadBalancer := loadBalancer.Azure()
poller, err := c.loadBalancersAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, loadBalancerName,
azureLoadBalancer,
nil,
)
if err != nil {
return err
}
_, err = poller.PollUntilDone(ctx, 30*time.Second)
if err != nil {
return err
}
c.loadBalancerName = loadBalancerName
c.loadBalancerPubIP = *publicIPAddress.Properties.IPAddress
return nil
}

View file

@ -0,0 +1,273 @@
package client
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateVirtualNetwork(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networksAPI networksAPI
wantErr bool
}{
"successful create": {
networksAPI: stubNetworksAPI{},
},
"failed to get response from successful create": {
networksAPI: stubNetworksAPI{stubResponse: stubVirtualNetworksCreateOrUpdatePollerResponse{pollerErr: someErr}},
wantErr: true,
},
"failed create": {
networksAPI: stubNetworksAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
networksAPI: tc.networksAPI,
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateVirtualNetwork(ctx))
} else {
assert.NoError(client.CreateVirtualNetwork(ctx))
assert.NotEmpty(client.subnetID)
}
})
}
}
func TestCreateSecurityGroup(t *testing.T) {
someErr := errors.New("failed")
testNetworkSecurityGroupInput := NetworkSecurityGroupInput{
Ingress: cloudtypes.Firewall{
{
Name: "test-1",
Description: "test-1 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
FromPort: 9000,
},
{
Name: "test-2",
Description: "test-2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
FromPort: 51820,
},
},
Egress: cloudtypes.Firewall{},
}
testCases := map[string]struct {
networkSecurityGroupsAPI networkSecurityGroupsAPI
wantErr bool
}{
"successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{},
},
"failed to get response from successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{stubPoller: stubNetworkSecurityGroupsCreateOrUpdatePollerResponse{pollerErr: someErr}},
wantErr: true,
},
"failed create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
networkSecurityGroupsAPI: tc.networkSecurityGroupsAPI,
}
if tc.wantErr {
assert.Error(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
} else {
assert.NoError(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
assert.Equal("network-security-group-id", client.networkSecurityGroup)
}
})
}
}
// TODO: deprecate as soon as scale sets are available.
func TestCreateNIC(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networkInterfacesAPI networkInterfacesAPI
name string
publicIPAddressID string
wantErr bool
}{
"successful create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{},
name: "nic-name",
publicIPAddressID: "pubIP-id",
},
"failed to get response from successful create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{stubResp: stubInterfacesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
wantErr: true,
},
"failed create": {
networkInterfacesAPI: stubNetworkInterfacesAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
networkInterfacesAPI: tc.networkInterfacesAPI,
}
ip, id, err := client.createNIC(ctx, tc.name, tc.publicIPAddressID)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(ip)
assert.NotEmpty(id)
}
})
}
}
func TestCreatePublicIPAddress(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
name string
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
name: "nic-name",
},
"failed to get response from successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{stubCreateResponse: stubPublicIPAddressesClientCreateOrUpdatePollerResponse{pollErr: someErr}},
wantErr: true,
},
"failed create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
id, err := client.createPublicIPAddress(ctx, tc.name)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(id)
}
})
}
}
func TestCreateExternalLoadBalancer(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
loadBalancersAPI loadBalancersAPI
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{stubCreateResponse: stubPublicIPAddressesClientCreateOrUpdatePollerResponse{}},
loadBalancersAPI: stubLoadBalancersAPI{},
},
"failed to get response from successful create": {
loadBalancersAPI: stubLoadBalancersAPI{stubResponse: stubLoadBalancersClientCreateOrUpdatePollerResponse{pollErr: someErr}},
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
wantErr: true,
},
"failed create": {
loadBalancersAPI: stubLoadBalancersAPI{createErr: someErr},
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
wantErr: true,
},
"cannot create public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
loadBalancersAPI: stubLoadBalancersAPI{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
nodes: make(cloudtypes.Instances),
coordinators: make(cloudtypes.Instances),
loadBalancersAPI: tc.loadBalancersAPI,
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
err := client.CreateExternalLoadBalancer(ctx)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

View file

@ -0,0 +1,78 @@
package azure
// copy of ec2/instances.go
// TODO(katexochen): refactor into mulitcloud package.
import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
)
// TODO: deprecate as soon as scale sets are available.
type VMInstance struct {
Name string
Location string
InstanceType string
Username string
Password string
NIC string
Image string
}
// TODO: deprecate as soon as scale sets are available.
func (i VMInstance) Azure() armcompute.VirtualMachine {
return armcompute.VirtualMachine{
Name: to.StringPtr(i.Name),
Location: to.StringPtr(i.Location),
Properties: &armcompute.VirtualMachineProperties{
HardwareProfile: &armcompute.HardwareProfile{
VMSize: (*armcompute.VirtualMachineSizeTypes)(to.StringPtr(i.InstanceType)),
},
OSProfile: &armcompute.OSProfile{
ComputerName: to.StringPtr(i.Name),
AdminPassword: to.StringPtr(i.Password),
AdminUsername: to.StringPtr(i.Username),
},
SecurityProfile: &armcompute.SecurityProfile{
UefiSettings: &armcompute.UefiSettings{
SecureBootEnabled: to.BoolPtr(true),
VTpmEnabled: to.BoolPtr(true),
},
SecurityType: armcompute.SecurityTypesConfidentialVM.ToPtr(),
},
NetworkProfile: &armcompute.NetworkProfile{
NetworkInterfaces: []*armcompute.NetworkInterfaceReference{
{
ID: to.StringPtr(i.NIC),
},
},
},
StorageProfile: &armcompute.StorageProfile{
OSDisk: &armcompute.OSDisk{
CreateOption: armcompute.DiskCreateOptionTypesFromImage.ToPtr(),
ManagedDisk: &armcompute.ManagedDiskParameters{
StorageAccountType: armcompute.StorageAccountTypesPremiumLRS.ToPtr(),
SecurityProfile: &armcompute.VMDiskSecurityProfile{
SecurityEncryptionType: armcompute.SecurityEncryptionTypesVMGuestStateOnly.ToPtr(),
},
},
},
ImageReference: &armcompute.ImageReference{
Publisher: to.StringPtr("0001-com-ubuntu-confidential-vm-focal"),
Offer: to.StringPtr("canonical"),
SKU: to.StringPtr("20_04-lts-gen2"),
Version: to.StringPtr("latest"),
},
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.BoolPtr(true),
},
},
},
Identity: &armcompute.VirtualMachineIdentity{
Type: armcompute.ResourceIdentityTypeSystemAssigned.ToPtr(),
},
}
}

View file

@ -0,0 +1,13 @@
package azure
import "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
// InstanceTypes are valid Azure instance types.
// Normally, this would be string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
// but currently needed instances are not in SDK.
var InstanceTypes = []string{
string(armcompute.VirtualMachineSizeTypesStandardD4SV3),
"Standard_DC2as_v5",
"Standard_DC4as_v5",
"Standard_DC8as_v5",
}

View file

@ -0,0 +1,162 @@
package azure
import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/internal/constants"
)
// LoadBalancer defines a Azure load balancer.
type LoadBalancer struct {
Name string
Subscription string
ResourceGroup string
Location string
PublicIPID string
UID string
}
const (
BackendAddressPoolWorkerName = "backendAddressWorkerPool"
BackendAddressPoolControlPlaneName = "backendAddressControlPlanePool"
)
// Azure returns a Azure representation of LoadBalancer.
func (l LoadBalancer) Azure() armnetwork.LoadBalancer {
frontEndIPConfigName := "frontEndIPConfig"
kubeHealthProbeName := "kubeHealthProbe"
coordHealthProbeName := "coordHealthProbe"
debugdHealthProbeName := "debugdHealthProbe"
backEndAddressPoolNodeName := BackendAddressPoolWorkerName + "-" + l.UID
backEndAddressPoolControlPlaneName := BackendAddressPoolControlPlaneName + "-" + l.UID
return armnetwork.LoadBalancer{
Name: to.StringPtr(l.Name),
Location: to.StringPtr(l.Location),
SKU: &armnetwork.LoadBalancerSKU{Name: armnetwork.LoadBalancerSKUNameStandard.ToPtr()},
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{
{
Name: to.StringPtr(frontEndIPConfigName),
Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.StringPtr(l.PublicIPID),
},
},
},
},
BackendAddressPools: []*armnetwork.BackendAddressPool{
{
Name: to.StringPtr(backEndAddressPoolNodeName),
},
{
Name: to.StringPtr(backEndAddressPoolControlPlaneName),
},
{
Name: to.StringPtr("all"),
},
},
Probes: []*armnetwork.Probe{
{
Name: to.StringPtr(kubeHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: armnetwork.ProbeProtocolTCP.ToPtr(),
Port: to.Int32Ptr(int32(6443)),
},
},
{
Name: to.StringPtr(coordHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: armnetwork.ProbeProtocolTCP.ToPtr(),
Port: to.Int32Ptr(int32(constants.CoordinatorPort)),
},
},
{
Name: to.StringPtr(debugdHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: armnetwork.ProbeProtocolTCP.ToPtr(),
Port: to.Int32Ptr(int32(4000)),
},
},
},
LoadBalancingRules: []*armnetwork.LoadBalancingRule{
{
Name: to.StringPtr("kubeLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Int32Ptr(int32(6443)),
BackendPort: to.Int32Ptr(int32(6443)),
Protocol: armnetwork.TransportProtocolTCP.ToPtr(),
Probe: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/probes/" + kubeHealthProbeName),
},
DisableOutboundSnat: to.BoolPtr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
{
Name: to.StringPtr("coordLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Int32Ptr(int32(constants.CoordinatorPort)),
BackendPort: to.Int32Ptr(int32(constants.CoordinatorPort)),
Protocol: armnetwork.TransportProtocolTCP.ToPtr(),
Probe: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/probes/" + coordHealthProbeName),
},
DisableOutboundSnat: to.BoolPtr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
{
Name: to.StringPtr("debudLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Int32Ptr(int32(4000)),
BackendPort: to.Int32Ptr(int32(4000)),
Protocol: armnetwork.TransportProtocolTCP.ToPtr(),
Probe: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/probes/" + debugdHealthProbeName),
},
DisableOutboundSnat: to.BoolPtr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
},
OutboundRules: []*armnetwork.OutboundRule{
{
Name: to.StringPtr("outboundRuleControlPlane"),
Properties: &armnetwork.OutboundRulePropertiesFormat{
FrontendIPConfigurations: []*armnetwork.SubResource{
{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/frontendIPConfigurations/" + frontEndIPConfigName),
},
},
BackendAddressPool: &armnetwork.SubResource{
ID: to.StringPtr("/subscriptions/" + l.Subscription + "/resourceGroups/" + l.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + l.Name + "/backendAddressPools/all"),
},
Protocol: armnetwork.LoadBalancerOutboundRuleProtocolAll.ToPtr(),
},
},
},
},
}
}

View file

@ -0,0 +1,138 @@
package azure
import (
"crypto/rand"
"math/big"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
)
// ScaleSet defines a Azure scale set.
type ScaleSet struct {
Name string
NamePrefix string
Subscription string
ResourceGroup string
Location string
InstanceType string
StateDiskSizeGB int32
Count int64
Username string
SubnetID string
NetworkSecurityGroup string
Password string
Image string
UserAssignedIdentity string
LoadBalancerName string
LoadBalancerBackendAddressPool string
}
// Azure returns the Azure representation of ScaleSet.
func (s ScaleSet) Azure() armcompute.VirtualMachineScaleSet {
return armcompute.VirtualMachineScaleSet{
Name: to.StringPtr(s.Name),
Location: to.StringPtr(s.Location),
SKU: &armcompute.SKU{
Name: to.StringPtr(s.InstanceType),
Capacity: to.Int64Ptr(s.Count),
},
Properties: &armcompute.VirtualMachineScaleSetProperties{
Overprovision: to.BoolPtr(false),
UpgradePolicy: &armcompute.UpgradePolicy{
Mode: armcompute.UpgradeModeManual.ToPtr(),
AutomaticOSUpgradePolicy: &armcompute.AutomaticOSUpgradePolicy{
EnableAutomaticOSUpgrade: to.BoolPtr(false),
DisableAutomaticRollback: to.BoolPtr(false),
},
},
VirtualMachineProfile: &armcompute.VirtualMachineScaleSetVMProfile{
OSProfile: &armcompute.VirtualMachineScaleSetOSProfile{
ComputerNamePrefix: to.StringPtr(s.NamePrefix),
AdminUsername: to.StringPtr(s.Username),
AdminPassword: to.StringPtr(s.Password),
LinuxConfiguration: &armcompute.LinuxConfiguration{},
},
StorageProfile: &armcompute.VirtualMachineScaleSetStorageProfile{
ImageReference: &armcompute.ImageReference{
ID: to.StringPtr(s.Image),
},
DataDisks: []*armcompute.VirtualMachineScaleSetDataDisk{
{
CreateOption: armcompute.DiskCreateOptionTypesEmpty.ToPtr(),
DiskSizeGB: to.Int32Ptr(s.StateDiskSizeGB),
Lun: to.Int32Ptr(0),
},
},
},
NetworkProfile: &armcompute.VirtualMachineScaleSetNetworkProfile{
NetworkInterfaceConfigurations: []*armcompute.VirtualMachineScaleSetNetworkConfiguration{
{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetNetworkConfigurationProperties{
Primary: to.BoolPtr(true),
EnableIPForwarding: to.BoolPtr(true),
IPConfigurations: []*armcompute.VirtualMachineScaleSetIPConfiguration{
{
Name: to.StringPtr(s.Name),
Properties: &armcompute.VirtualMachineScaleSetIPConfigurationProperties{
Primary: to.BoolPtr(true),
Subnet: &armcompute.APIEntityReference{
ID: to.StringPtr(s.SubnetID),
},
LoadBalancerBackendAddressPools: []*armcompute.SubResource{
{
ID: to.StringPtr("/subscriptions/" + s.Subscription + "/resourcegroups/" + s.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + s.LoadBalancerName + "/backendAddressPools/" + s.LoadBalancerBackendAddressPool),
},
{
ID: to.StringPtr("/subscriptions/" + s.Subscription + "/resourcegroups/" + s.ResourceGroup + "/providers/Microsoft.Network/loadBalancers/" + s.LoadBalancerName + "/backendAddressPools/all"),
},
},
},
},
},
NetworkSecurityGroup: &armcompute.SubResource{
ID: to.StringPtr(s.NetworkSecurityGroup),
},
},
},
},
},
SecurityProfile: &armcompute.SecurityProfile{
SecurityType: armcompute.SecurityTypesTrustedLaunch.ToPtr(),
UefiSettings: &armcompute.UefiSettings{VTpmEnabled: to.BoolPtr(true)},
},
DiagnosticsProfile: &armcompute.DiagnosticsProfile{
BootDiagnostics: &armcompute.BootDiagnostics{
Enabled: to.BoolPtr(true),
},
},
},
},
Identity: &armcompute.VirtualMachineScaleSetIdentity{
Type: armcompute.ResourceIdentityTypeUserAssigned.ToPtr(),
UserAssignedIdentities: map[string]*armcompute.VirtualMachineScaleSetIdentityUserAssignedIdentitiesValue{
s.UserAssignedIdentity: {},
},
},
}
}
// GeneratePassword is a helper function to generate a random password
// for Azure's scale set.
func GeneratePassword() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 16
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
// bypass password rules
pw = append(pw, []byte("Aa1!")...)
return string(pw), nil
}

View file

@ -0,0 +1,111 @@
package azure
import (
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFirewallPermissions(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
scaleSet := ScaleSet{
Name: "name",
NamePrefix: "constellation-",
Location: "UK South",
InstanceType: "Standard_D2s_v3",
Count: 3,
Username: "constellation",
SubnetID: "subnet-id",
NetworkSecurityGroup: "network-security-group",
Password: "password",
Image: "image",
UserAssignedIdentity: "user-identity",
}
scaleSetAzure := scaleSet.Azure()
require.NotNil(scaleSetAzure.Name)
assert.Equal(scaleSet.Name, *scaleSetAzure.Name)
require.NotNil(scaleSetAzure.Location)
assert.Equal(scaleSet.Location, *scaleSetAzure.Location)
require.NotNil(scaleSetAzure.SKU)
require.NotNil(scaleSetAzure.SKU.Name)
assert.Equal(scaleSet.InstanceType, *scaleSetAzure.SKU.Name)
require.NotNil(scaleSetAzure.SKU.Capacity)
assert.Equal(scaleSet.Count, *scaleSetAzure.SKU.Capacity)
require.NotNil(scaleSetAzure.Properties)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
assert.Equal(scaleSet.NamePrefix, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
assert.Equal(scaleSet.Username, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
assert.Equal(scaleSet.Password, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
// Verify image
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
assert.Equal(scaleSet.Image, *scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
// Verify network
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile)
require.Len(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations, 1)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0])
networkConfig := scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0]
require.NotNil(networkConfig.Name)
assert.Equal(scaleSet.Name, *networkConfig.Name)
require.NotNil(networkConfig.Properties)
require.Len(networkConfig.Properties.IPConfigurations, 1)
require.NotNil(networkConfig.Properties.IPConfigurations[0])
ipConfig := networkConfig.Properties.IPConfigurations[0]
require.NotNil(ipConfig.Name)
assert.Equal(scaleSet.Name, *ipConfig.Name)
require.NotNil(ipConfig.Properties)
require.NotNil(ipConfig.Properties.Subnet)
require.NotNil(ipConfig.Properties.Subnet.ID)
assert.Equal(scaleSet.SubnetID, *ipConfig.Properties.Subnet.ID)
require.NotNil(networkConfig.Properties.NetworkSecurityGroup)
assert.Equal(scaleSet.NetworkSecurityGroup, *networkConfig.Properties.NetworkSecurityGroup.ID)
// Verify vTPM
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
assert.Equal(armcompute.SecurityTypesTrustedLaunch, *scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
assert.True(*scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
// Verify UserAssignedIdentity
require.NotNil(scaleSetAzure.Identity)
require.NotNil(scaleSetAzure.Identity.Type)
assert.Equal(armcompute.ResourceIdentityTypeUserAssigned, *scaleSetAzure.Identity.Type)
require.Len(scaleSetAzure.Identity.UserAssignedIdentities, 1)
assert.Contains(scaleSetAzure.Identity.UserAssignedIdentities, scaleSet.UserAssignedIdentity)
}
func TestGeneratePassword(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
pw, err := GeneratePassword()
require.NoError(err)
assert.Len(pw, 20)
}