Remove Azure client from CLI

This commit is contained in:
katexochen 2022-10-06 17:20:02 +02:00 committed by Paul Meyer
parent 38498b1981
commit 0d1fd8fb2a
23 changed files with 3 additions and 2999 deletions

View File

@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
<!-- For soon-to-be removed features. -->
### Removed
<!-- For now removed features. -->
- `endpoint` flag of `constellation init`. IP is now always taken from the `constellation-id.json` file.
### Fixed
### Security

View File

@ -1,14 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import "fmt"
// AutoscalingNodeGroup converts an azure scale set into a node group used by the k8s cluster-autoscaler.
func AutoscalingNodeGroup(scaleSet string, min int, max int) string {
return fmt.Sprintf("%d:%d:%s", min, max, scaleSet)
}

View File

@ -1,25 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestAutoscalingNodeGroup(t *testing.T) {
assert := assert.New(t)
nodeGroups := AutoscalingNodeGroup("scale-set", 0, 100)
wantNodeGroups := "0:100:scale-set"
assert.Equal(wantNodeGroups, nodeGroups)
}

View File

@ -1,107 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
)
type networksAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.VirtualNetworksClientCreateOrUpdateResponse], error)
}
type networkSecurityGroupsAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.SecurityGroupsClientCreateOrUpdateResponse], error)
}
type loadBalancersAPI interface {
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
loadBalancerName string, parameters armnetwork.LoadBalancer,
options *armnetwork.LoadBalancersClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.LoadBalancersClientCreateOrUpdateResponse], error,
)
}
type scaleSetsAPI interface {
Get(ctx context.Context, resourceGroupName string, vmScaleSetName string,
options *armcomputev2.VirtualMachineScaleSetsClientGetOptions,
) (armcomputev2.VirtualMachineScaleSetsClientGetResponse, error)
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcomputev2.VirtualMachineScaleSet,
options *armcomputev2.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse], error)
}
type virtualMachineScaleSetVMsAPI interface {
GetInstanceView(ctx context.Context, resourceGroupName string, vmScaleSetName string, instanceID string,
options *armcomputev2.VirtualMachineScaleSetVMsClientGetInstanceViewOptions,
) (armcomputev2.VirtualMachineScaleSetVMsClientGetInstanceViewResponse, error)
}
type publicIPAddressesAPI interface {
NewListVirtualMachineScaleSetVMPublicIPAddressesPager(
resourceGroupName string, virtualMachineScaleSetName string,
virtualmachineIndex string, networkInterfaceName string,
ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) *runtime.Pager[armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse]
BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.PublicIPAddressesClientCreateOrUpdateResponse], error)
}
type networkInterfacesAPI interface {
GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error)
}
type resourceAPI interface {
NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse]
BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error)
}
type applicationsAPI interface {
Create(ctx context.Context, parameters graphrbac.ApplicationCreateParameters) (graphrbac.Application, error)
Delete(ctx context.Context, applicationObjectID string) (autorest.Response, error)
UpdatePasswordCredentials(ctx context.Context, objectID string, parameters graphrbac.PasswordCredentialsUpdateParameters) (autorest.Response, error)
}
type servicePrincipalsAPI interface {
Create(ctx context.Context, parameters graphrbac.ServicePrincipalCreateParameters) (graphrbac.ServicePrincipal, error)
}
// the newer version "armauthorization.RoleAssignmentsClient" is currently broken: https://github.com/Azure/azure-sdk-for-go/issues/17071
// TODO: switch to "armauthorization.RoleAssignmentsClient" if issue is resolved.
type roleAssignmentsAPI interface {
Create(ctx context.Context, scope string, roleAssignmentName string, parameters authorization.RoleAssignmentCreateParameters) (authorization.RoleAssignment, error)
}
type applicationInsightsAPI interface {
CreateOrUpdate(ctx context.Context, resourceGroupName string, resourceName string, insightProperties armapplicationinsights.Component,
options *armapplicationinsights.ComponentsClientCreateOrUpdateOptions) (armapplicationinsights.ComponentsClientCreateOrUpdateResponse, error)
}

View File

@ -1,275 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"net/http"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
)
type stubNetworksAPI struct {
createErr error
pollErr error
}
func (a stubNetworksAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
virtualNetworkName string, parameters armnetwork.VirtualNetwork,
options *armnetwork.VirtualNetworksClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.VirtualNetworksClientCreateOrUpdateResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armnetwork.VirtualNetworksClientCreateOrUpdateResponse]{
Handler: &stubPoller[armnetwork.VirtualNetworksClientCreateOrUpdateResponse]{
result: armnetwork.VirtualNetworksClientCreateOrUpdateResponse{
VirtualNetwork: armnetwork.VirtualNetwork{
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
Subnets: []*armnetwork.Subnet{
{
ID: to.Ptr("subnet-id"),
},
},
},
},
},
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.createErr
}
type stubLoadBalancersAPI struct {
createErr error
stubResponse armnetwork.LoadBalancersClientCreateOrUpdateResponse
pollErr error
}
func (a stubLoadBalancersAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
loadBalancerName string, parameters armnetwork.LoadBalancer,
options *armnetwork.LoadBalancersClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.LoadBalancersClientCreateOrUpdateResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armnetwork.LoadBalancersClientCreateOrUpdateResponse]{
Handler: &stubPoller[armnetwork.LoadBalancersClientCreateOrUpdateResponse]{
result: a.stubResponse,
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.createErr
}
type stubNetworkSecurityGroupsAPI struct {
createErr error
pollErr error
}
func (a stubNetworkSecurityGroupsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
networkSecurityGroupName string, parameters armnetwork.SecurityGroup,
options *armnetwork.SecurityGroupsClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.SecurityGroupsClientCreateOrUpdateResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armnetwork.SecurityGroupsClientCreateOrUpdateResponse]{
Handler: &stubPoller[armnetwork.SecurityGroupsClientCreateOrUpdateResponse]{
result: armnetwork.SecurityGroupsClientCreateOrUpdateResponse{
SecurityGroup: armnetwork.SecurityGroup{ID: to.Ptr("network-security-group-id")},
},
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.createErr
}
type stubScaleSetsAPI struct {
createErr error
stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse
pollErr error
getResponse armcomputev2.VirtualMachineScaleSet
getErr error
}
func (a stubScaleSetsAPI) Get(ctx context.Context, resourceGroupName string, vmScaleSetName string,
options *armcomputev2.VirtualMachineScaleSetsClientGetOptions,
) (armcomputev2.VirtualMachineScaleSetsClientGetResponse, error) {
return armcomputev2.VirtualMachineScaleSetsClientGetResponse{
VirtualMachineScaleSet: a.getResponse,
}, a.getErr
}
func (a stubScaleSetsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string,
vmScaleSetName string, parameters armcomputev2.VirtualMachineScaleSet,
options *armcomputev2.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse]{
Handler: &stubPoller[armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse]{
result: a.stubResponse,
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.createErr
}
type stubPublicIPAddressesAPI struct {
createErr error
getErr error
pollErr error
}
type stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager struct {
pages int
fetchErr error
more bool
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) moreFunc() func(
armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse) bool {
return func(armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse) bool {
return p.more
}
}
func (p *stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager) fetcherFunc() func(
context.Context, *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse) (
armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse, error) {
return func(context.Context, *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse) (
armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse, error,
) {
page := make([]*armnetwork.PublicIPAddress, p.pages)
for i := 0; i < p.pages; i++ {
page[i] = &armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.Ptr("192.0.2.1"),
},
}
}
return armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse{
PublicIPAddressListResult: armnetwork.PublicIPAddressListResult{
Value: page,
},
}, p.fetchErr
}
}
func (a stubPublicIPAddressesAPI) NewListVirtualMachineScaleSetVMPublicIPAddressesPager(
resourceGroupName string, virtualMachineScaleSetName string,
virtualmachineIndex string, networkInterfaceName string,
ipConfigurationName string,
options *armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesOptions,
) *runtime.Pager[armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse] {
pager := &stubPublicIPAddressesListVirtualMachineScaleSetVMPublicIPAddressesPager{
pages: 1,
}
return runtime.NewPager(runtime.PagingHandler[armnetwork.PublicIPAddressesClientListVirtualMachineScaleSetVMPublicIPAddressesResponse]{
More: pager.moreFunc(),
Fetcher: pager.fetcherFunc(),
})
}
func (a stubPublicIPAddressesAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, publicIPAddressName string,
parameters armnetwork.PublicIPAddress, options *armnetwork.PublicIPAddressesClientBeginCreateOrUpdateOptions) (
*runtime.Poller[armnetwork.PublicIPAddressesClientCreateOrUpdateResponse], error,
) {
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armnetwork.PublicIPAddressesClientCreateOrUpdateResponse]{
Handler: &stubPoller[armnetwork.PublicIPAddressesClientCreateOrUpdateResponse]{
result: armnetwork.PublicIPAddressesClientCreateOrUpdateResponse{
PublicIPAddress: armnetwork.PublicIPAddress{
ID: to.Ptr("ip-address-id"),
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.Ptr("192.0.2.1"),
},
},
},
resultErr: a.pollErr,
},
})
if err != nil {
panic(err)
}
return poller, a.createErr
}
func (a stubPublicIPAddressesAPI) Get(ctx context.Context, resourceGroupName string, publicIPAddressName string, options *armnetwork.PublicIPAddressesClientGetOptions) (
armnetwork.PublicIPAddressesClientGetResponse, error,
) {
return armnetwork.PublicIPAddressesClientGetResponse{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.Ptr("192.0.2.1"),
},
},
}, a.getErr
}
type stubNetworkInterfacesAPI struct {
getErr error
}
func (a stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) {
if a.getErr != nil {
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{}, a.getErr
}
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{
Interface: armnetwork.Interface{
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.Ptr("192.0.2.1"),
},
},
},
},
},
}, nil
}
type stubApplicationInsightsAPI struct {
err error
}
func (a *stubApplicationInsightsAPI) CreateOrUpdate(ctx context.Context, resourceGroupName string, resourceName string, insightProperties armapplicationinsights.Component, options *armapplicationinsights.ComponentsClientCreateOrUpdateOptions) (armapplicationinsights.ComponentsClientCreateOrUpdateResponse, error) {
resp := armapplicationinsights.ComponentsClientCreateOrUpdateResponse{}
return resp, a.err
}
type stubPoller[T any] struct {
result T
pollErr error
resultErr error
}
func (p *stubPoller[T]) Done() bool {
return true
}
func (p *stubPoller[T]) Poll(context.Context) (*http.Response, error) {
return nil, p.pollErr
}
func (p *stubPoller[T]) Result(ctx context.Context, out *T) error {
*out = p.result
return p.resultErr
}

View File

@ -1,34 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
)
func (c *Client) CreateApplicationInsight(ctx context.Context) error {
properties := armapplicationinsights.Component{
Kind: to.Ptr("web"),
Location: to.Ptr(c.location),
Properties: &armapplicationinsights.ComponentProperties{
ApplicationType: to.Ptr(armapplicationinsights.ApplicationTypeWeb),
},
Tags: map[string]*string{"uid": to.Ptr(c.uid)},
}
_, err := c.applicationInsightsAPI.CreateOrUpdate(
ctx,
c.resourceGroup,
"constellation-insights-"+c.uid,
properties,
nil,
)
return err
}

View File

@ -1,52 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestCreateApplicationInsight(t *testing.T) {
testCases := map[string]struct {
applicationInsightsAPI applicationInsightsAPI
wantErr bool
}{
"successful create": {
applicationInsightsAPI: &stubApplicationInsightsAPI{
err: nil,
},
},
"failed create": {
applicationInsightsAPI: &stubApplicationInsightsAPI{
err: errors.New("some error"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := Client{
applicationInsightsAPI: tc.applicationInsightsAPI,
}
err := client.CreateApplicationInsight(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
})
}
}

View File

@ -1,237 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"crypto/rand"
"fmt"
"math/big"
"time"
"github.com/Azure/azure-sdk-for-go/profiles/latest/authorization/mgmt/authorization"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/Azure/azure-sdk-for-go/services/graphrbac/1.6/graphrbac"
"github.com/Azure/go-autorest/autorest"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
"github.com/edgelesssys/constellation/v2/internal/state"
)
const (
graphAPIResource = "https://graph.windows.net"
managementAPIResource = "https://management.azure.com"
)
// Client is a client for Azure.
type Client struct {
networksAPI
networkSecurityGroupsAPI
resourceAPI
scaleSetsAPI
virtualMachineScaleSetVMsAPI
publicIPAddressesAPI
networkInterfacesAPI
loadBalancersAPI
applicationsAPI
servicePrincipalsAPI
roleAssignmentsAPI
applicationInsightsAPI
pollFrequency time.Duration
workers cloudtypes.Instances
controlPlanes cloudtypes.Instances
name string
uid string
resourceGroup string
location string
subscriptionID string
tenantID string
subnetID string
controlPlaneScaleSet string
workerScaleSet string
loadBalancerName string
loadBalancerPubIP string
networkSecurityGroup string
adAppObjectID string
}
// NewFromDefault creates a client with initialized clients.
func NewFromDefault(subscriptionID, tenantID string) (*Client, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
graphAuthorizer, err := getAuthorizer(graphAPIResource)
if err != nil {
return nil, err
}
managementAuthorizer, err := getAuthorizer(managementAPIResource)
if err != nil {
return nil, err
}
netAPI, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
netSecGrpAPI, err := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
scaleSetAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
virtualMachineScaleSetVMsAPI, err := armcomputev2.NewVirtualMachineScaleSetVMsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
publicIPAddressesAPI, err := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
networkInterfacesAPI, err := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
loadBalancersAPI, err := armnetwork.NewLoadBalancersClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
applicationInsightsAPI, err := armapplicationinsights.NewComponentsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
resourceAPI, err := armresources.NewClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
applicationsAPI := graphrbac.NewApplicationsClient(tenantID)
applicationsAPI.Authorizer = graphAuthorizer
servicePrincipalsAPI := graphrbac.NewServicePrincipalsClient(tenantID)
servicePrincipalsAPI.Authorizer = graphAuthorizer
roleAssignmentsAPI := authorization.NewRoleAssignmentsClient(subscriptionID)
roleAssignmentsAPI.Authorizer = managementAuthorizer
return &Client{
networksAPI: netAPI,
networkSecurityGroupsAPI: netSecGrpAPI,
resourceAPI: resourceAPI,
scaleSetsAPI: scaleSetAPI,
virtualMachineScaleSetVMsAPI: virtualMachineScaleSetVMsAPI,
publicIPAddressesAPI: publicIPAddressesAPI,
networkInterfacesAPI: networkInterfacesAPI,
loadBalancersAPI: loadBalancersAPI,
applicationsAPI: applicationsAPI,
servicePrincipalsAPI: servicePrincipalsAPI,
roleAssignmentsAPI: roleAssignmentsAPI,
applicationInsightsAPI: applicationInsightsAPI,
subscriptionID: subscriptionID,
tenantID: tenantID,
workers: cloudtypes.Instances{},
controlPlanes: cloudtypes.Instances{},
pollFrequency: time.Second * 5,
}, nil
}
// NewInitialized creates and initializes client by setting the subscriptionID, location and name
// of the Constellation.
func NewInitialized(subscriptionID, tenantID, name, location, resourceGroup string) (*Client, error) {
client, err := NewFromDefault(subscriptionID, tenantID)
if err != nil {
return nil, err
}
err = client.init(location, name, resourceGroup)
return client, err
}
// init initializes the client.
func (c *Client) init(location, name, resourceGroup string) error {
c.location = location
c.name = name
c.resourceGroup = resourceGroup
uid, err := c.generateUID()
if err != nil {
return err
}
c.uid = uid
return nil
}
// GetState returns the state of the client as ConstellationState.
func (c *Client) GetState() state.ConstellationState {
return state.ConstellationState{
Name: c.name,
UID: c.uid,
CloudProvider: cloudprovider.Azure.String(),
LoadBalancerIP: c.loadBalancerPubIP,
AzureLocation: c.location,
AzureSubscription: c.subscriptionID,
AzureTenant: c.tenantID,
AzureResourceGroup: c.resourceGroup,
AzureNetworkSecurityGroup: c.networkSecurityGroup,
AzureSubnet: c.subnetID,
AzureWorkerScaleSet: c.workerScaleSet,
AzureControlPlaneScaleSet: c.controlPlaneScaleSet,
AzureWorkerInstances: c.workers,
AzureControlPlaneInstances: c.controlPlanes,
AzureADAppObjectID: c.adAppObjectID,
}
}
// SetState sets the state of the client to the handed ConstellationState.
func (c *Client) SetState(stat state.ConstellationState) {
c.resourceGroup = stat.AzureResourceGroup
c.name = stat.Name
c.uid = stat.UID
c.loadBalancerPubIP = stat.LoadBalancerIP
c.location = stat.AzureLocation
c.subscriptionID = stat.AzureSubscription
c.tenantID = stat.AzureTenant
c.subnetID = stat.AzureSubnet
c.networkSecurityGroup = stat.AzureNetworkSecurityGroup
c.workerScaleSet = stat.AzureWorkerScaleSet
c.controlPlaneScaleSet = stat.AzureControlPlaneScaleSet
c.workers = stat.AzureWorkerInstances
c.controlPlanes = stat.AzureControlPlaneInstances
c.adAppObjectID = stat.AzureADAppObjectID
}
func (c *Client) generateUID() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
const uidLen = 5
uid := make([]byte, uidLen)
for i := 0; i < uidLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
uid[i] = letters[n.Int64()]
}
return string(uid), nil
}
// getAuthorizer creates an autorest.Authorizer for different Azure AD APIs using either environment variables or azure cli credentials.
func getAuthorizer(resource string) (autorest.Authorizer, error) {
authorizer, cliErr := auth.NewAuthorizerFromCLIWithResource(resource)
if cliErr == nil {
return authorizer, nil
}
authorizer, envErr := auth.NewAuthorizerFromEnvironmentWithResource(resource)
if envErr == nil {
return authorizer, nil
}
return nil, fmt.Errorf("unable to create authorizer from env or cli: %v %v", envErr, cliErr)
}

View File

@ -1,98 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
"github.com/edgelesssys/constellation/v2/internal/state"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestSetGetState(t *testing.T) {
state := state.ConstellationState{
CloudProvider: cloudprovider.Azure.String(),
AzureWorkerInstances: cloudtypes.Instances{
"0": {PublicIP: "ip1", PrivateIP: "ip2"},
},
AzureControlPlaneInstances: cloudtypes.Instances{
"0": {PublicIP: "ip3", PrivateIP: "ip4"},
},
Name: "name",
UID: "uid",
LoadBalancerIP: "bootstrapper-host",
AzureResourceGroup: "resource-group",
AzureLocation: "location",
AzureSubscription: "subscription",
AzureTenant: "tenant",
AzureSubnet: "azure-subnet",
AzureNetworkSecurityGroup: "network-security-group",
AzureWorkerScaleSet: "worker-scale-set",
AzureControlPlaneScaleSet: "controlplane-scale-set",
}
t.Run("SetState", func(t *testing.T) {
assert := assert.New(t)
client := Client{}
client.SetState(state)
assert.Equal(state.AzureWorkerInstances, client.workers)
assert.Equal(state.AzureControlPlaneInstances, client.controlPlanes)
assert.Equal(state.Name, client.name)
assert.Equal(state.UID, client.uid)
assert.Equal(state.AzureResourceGroup, client.resourceGroup)
assert.Equal(state.AzureLocation, client.location)
assert.Equal(state.AzureSubscription, client.subscriptionID)
assert.Equal(state.AzureTenant, client.tenantID)
assert.Equal(state.AzureSubnet, client.subnetID)
assert.Equal(state.AzureNetworkSecurityGroup, client.networkSecurityGroup)
assert.Equal(state.AzureWorkerScaleSet, client.workerScaleSet)
assert.Equal(state.AzureControlPlaneScaleSet, client.controlPlaneScaleSet)
})
t.Run("GetState", func(t *testing.T) {
assert := assert.New(t)
client := Client{
workers: state.AzureWorkerInstances,
controlPlanes: state.AzureControlPlaneInstances,
name: state.Name,
uid: state.UID,
loadBalancerPubIP: state.LoadBalancerIP,
resourceGroup: state.AzureResourceGroup,
location: state.AzureLocation,
subscriptionID: state.AzureSubscription,
tenantID: state.AzureTenant,
subnetID: state.AzureSubnet,
networkSecurityGroup: state.AzureNetworkSecurityGroup,
workerScaleSet: state.AzureWorkerScaleSet,
controlPlaneScaleSet: state.AzureControlPlaneScaleSet,
}
stat := client.GetState()
assert.Equal(state, stat)
})
}
func TestInit(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{}
require.NoError(client.init("location", "name", "rGroup"))
assert.Equal("location", client.location)
assert.Equal("name", client.name)
assert.Equal("rGroup", client.resourceGroup)
assert.NotEmpty(client.uid)
}

View File

@ -1,299 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/edgelesssys/constellation/v2/cli/internal/azure"
"github.com/edgelesssys/constellation/v2/cli/internal/azure/internal/poller"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
)
const (
// scaleSetCreateTimeout maximum timeout to wait for scale set creation.
scaleSetCreateTimeout = 5 * time.Minute
powerStateStarting = "PowerState/starting"
powerStateRunning = "PowerState/running"
)
func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error {
// Create worker scale set
createWorkerInput := CreateScaleSetInput{
Name: "constellation-scale-set-workers-" + c.uid,
NamePrefix: c.name + "-worker-" + c.uid + "-",
Count: input.CountWorkers,
InstanceType: input.InstanceType,
StateDiskSizeGB: int32(input.StateDiskSizeGB),
StateDiskType: input.StateDiskType,
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
LoadBalancerBackendAddressPool: azure.BackendAddressPoolWorkerName + "-" + c.uid,
ConfidentialVM: input.ConfidentialVM,
}
// Create control plane scale set
createControlPlaneInput := CreateScaleSetInput{
Name: "constellation-scale-set-controlplanes-" + c.uid,
NamePrefix: c.name + "-control-plane-" + c.uid + "-",
Count: input.CountControlPlanes,
InstanceType: input.InstanceType,
StateDiskSizeGB: int32(input.StateDiskSizeGB),
StateDiskType: input.StateDiskType,
Image: input.Image,
UserAssingedIdentity: input.UserAssingedIdentity,
LoadBalancerBackendAddressPool: azure.BackendAddressPoolControlPlaneName + "-" + c.uid,
ConfidentialVM: input.ConfidentialVM,
}
var wg sync.WaitGroup
var controlPlaneErr, workerErr error
wg.Add(1)
go func() {
defer wg.Done()
workerErr = c.createScaleSet(ctx, createWorkerInput)
}()
wg.Add(1)
go func() {
defer wg.Done()
controlPlaneErr = c.createScaleSet(ctx, createControlPlaneInput)
}()
wg.Wait()
if controlPlaneErr != nil {
return fmt.Errorf("creating control-plane scaleset: %w", controlPlaneErr)
}
if workerErr != nil {
return fmt.Errorf("creating worker scaleset: %w", workerErr)
}
// TODO: Remove getInstanceIPs calls after init has been refactored to not use node IPs
// Get worker IPs
c.workerScaleSet = createWorkerInput.Name
instances, err := c.getInstanceIPs(ctx, createWorkerInput.Name, createWorkerInput.Count)
if err != nil {
return err
}
c.workers = instances
// Get control plane IPs
c.controlPlaneScaleSet = createControlPlaneInput.Name
instances, err = c.getInstanceIPs(ctx, createControlPlaneInput.Name, createControlPlaneInput.Count)
if err != nil {
return err
}
c.controlPlanes = instances
return nil
}
// CreateInstancesInput is the input for a CreateInstances operation.
type CreateInstancesInput struct {
CountWorkers int
CountControlPlanes int
InstanceType string
StateDiskSizeGB int
StateDiskType string
Image string
UserAssingedIdentity string
ConfidentialVM bool
}
func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) error {
// TODO: Generating a random password to be able
// to create the scale set. This is a temporary fix.
// We need to think about azure access at some point.
pw, err := azure.GeneratePassword()
if err != nil {
return err
}
scaleSet := azure.ScaleSet{
Name: input.Name,
NamePrefix: input.NamePrefix,
UID: c.uid,
Location: c.location,
InstanceType: input.InstanceType,
StateDiskSizeGB: input.StateDiskSizeGB,
StateDiskType: input.StateDiskType,
Count: int64(input.Count),
Username: "constellation",
SubnetID: c.subnetID,
NetworkSecurityGroup: c.networkSecurityGroup,
Image: input.Image,
Password: pw,
UserAssignedIdentity: input.UserAssingedIdentity,
Subscription: c.subscriptionID,
ResourceGroup: c.resourceGroup,
LoadBalancerName: c.loadBalancerName,
LoadBalancerBackendAddressPool: input.LoadBalancerBackendAddressPool,
ConfidentialVM: input.ConfidentialVM,
}.Azure()
_, err = c.scaleSetsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, input.Name,
scaleSet,
nil,
)
if err != nil {
return err
}
// use custom poller to wait for resource creation but skip waiting for OS provisioning.
// OS provisioning does not work reliably without the azure guest agent installed.
poller := poller.New[bool](&scaleSetCreationPollingHandler{
resourceGroup: c.resourceGroup,
scaleSet: input.Name,
scaleSetsAPI: c.scaleSetsAPI,
virtualMachineScaleSetVMsAPI: c.virtualMachineScaleSetVMsAPI,
})
pollCtx, cancel := context.WithTimeout(ctx, scaleSetCreateTimeout)
defer cancel()
_, err = poller.PollUntilDone(pollCtx, nil)
return err
}
func (c *Client) getInstanceIPs(ctx context.Context, scaleSet string, count int) (cloudtypes.Instances, error) {
instances := cloudtypes.Instances{}
for i := 0; i < count; i++ {
// get public ip address
var publicIPAddress string
pager := c.publicIPAddressesAPI.NewListVirtualMachineScaleSetVMPublicIPAddressesPager(
c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, scaleSet, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return cloudtypes.Instances{}, err
}
for _, v := range page.Value {
if v.Properties != nil && v.Properties.IPAddress != nil {
publicIPAddress = *v.Properties.IPAddress
break
}
}
}
// get private ip address
var privateIPAddress string
res, err := c.networkInterfacesAPI.GetVirtualMachineScaleSetNetworkInterface(
ctx, c.resourceGroup, scaleSet, strconv.Itoa(i), scaleSet, nil)
if err != nil {
return nil, err
}
configs := res.Interface.Properties.IPConfigurations
for _, config := range configs {
privateIPAddress = *config.Properties.PrivateIPAddress
break
}
instance := cloudtypes.Instance{
PrivateIP: privateIPAddress,
PublicIP: publicIPAddress,
}
instances[strconv.Itoa(i)] = instance
}
return instances, nil
}
// CreateScaleSetInput is the input for a CreateScaleSet operation.
type CreateScaleSetInput struct {
Name string
NamePrefix string
Count int
InstanceType string
StateDiskSizeGB int32
StateDiskType string
Image string
UserAssingedIdentity string
LoadBalancerBackendAddressPool string
ConfidentialVM bool
}
// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully.
type scaleSetCreationPollingHandler struct {
done bool
instanceIDOffset int
resourceGroup string
scaleSet string
scaleSetsAPI scaleSetsAPI
virtualMachineScaleSetVMsAPI virtualMachineScaleSetVMsAPI
}
// Done returns true if the condition is met.
func (h *scaleSetCreationPollingHandler) Done() bool {
return h.done
}
// Poll checks if the scale set resource was created successfully and every VM is starting or running.
func (h *scaleSetCreationPollingHandler) Poll(ctx context.Context) error {
// check if scale set can be retrieved from API
scaleSet, err := h.scaleSetsAPI.Get(ctx, h.resourceGroup, h.scaleSet, nil)
if err != nil {
return ignoreNotFoundError(err)
}
if scaleSet.SKU == nil || scaleSet.SKU.Capacity == nil {
return errors.New("invalid scale set capacity")
}
// check if every VM in the scale set has power state starting or running
for i := h.instanceIDOffset; i < int(*scaleSet.SKU.Capacity); i++ {
instanceView, err := h.virtualMachineScaleSetVMsAPI.GetInstanceView(ctx, h.resourceGroup, h.scaleSet, strconv.Itoa(i), nil)
if err != nil {
return ignoreNotFoundError(err)
}
if !vmIsStartingOrRunning(instanceView.Statuses) {
return nil
}
h.instanceIDOffset = i + 1 // skip this VM in the next Poll() invocation
}
h.done = true
return nil
}
// Result returns the result of the poller if the condition is met.
// If the condition is not met, an error is returned.
func (h *scaleSetCreationPollingHandler) Result(ctx context.Context, out *bool) error {
if !h.done {
return fmt.Errorf("failed to create scale set")
}
*out = h.done
return nil
}
func ignoreNotFoundError(err error) error {
var respErr *azcore.ResponseError
if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound {
// resource does not exist yet - retry later
return nil
}
return err
}
func vmIsStartingOrRunning(statuses []*armcomputev2.InstanceViewStatus) bool {
for _, status := range statuses {
if status == nil || status.Code == nil {
continue
}
switch *status.Code {
case powerStateStarting:
return true
case powerStateRunning:
return true
}
}
return false
}

View File

@ -1,120 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateInstances(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
createInstancesInput CreateInstancesInput
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{
getResponse: armcomputev2.VirtualMachineScaleSet{
Identity: &armcomputev2.VirtualMachineScaleSetIdentity{PrincipalID: to.Ptr("principal-id")}, SKU: &armcomputev2.SKU{Capacity: to.Ptr[int64](0)},
},
},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
ConfidentialVM: true,
},
},
"error when creating scale set": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{createErr: someErr},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
ConfidentialVM: true,
},
wantErr: true,
},
"error when polling create scale set response": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{},
scaleSetsAPI: stubScaleSetsAPI{getErr: someErr},
createInstancesInput: CreateInstancesInput{
CountControlPlanes: 3,
CountWorkers: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
ConfidentialVM: true,
},
wantErr: true,
},
"error when retrieving private IPs": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
networkInterfacesAPI: stubNetworkInterfacesAPI{getErr: someErr},
scaleSetsAPI: stubScaleSetsAPI{},
createInstancesInput: CreateInstancesInput{
CountWorkers: 3,
InstanceType: "type",
Image: "image",
UserAssingedIdentity: "identity",
ConfidentialVM: true,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
location: "location",
name: "name",
uid: "uid",
resourceGroup: "name",
publicIPAddressesAPI: tc.publicIPAddressesAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
loadBalancerPubIP: "lbip",
}
if tc.wantErr {
assert.Error(client.CreateInstances(ctx, tc.createInstancesInput))
} else {
assert.NoError(client.CreateInstances(ctx, tc.createInstancesInput))
assert.Equal(tc.createInstancesInput.CountControlPlanes, len(client.controlPlanes))
assert.Equal(tc.createInstancesInput.CountWorkers, len(client.workers))
assert.NotEmpty(client.workers["0"].PrivateIP)
assert.NotEmpty(client.workers["0"].PublicIP)
assert.NotEmpty(client.controlPlanes["0"].PrivateIP)
assert.NotEmpty(client.controlPlanes["0"].PublicIP)
}
})
}
}

View File

@ -1,208 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/v2/cli/internal/azure"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
)
type createNetworkInput struct {
name string
location string
addressSpace string
nodeAddressSpace string
podAddressSpace string
}
const (
nodeNetworkName = "nodeNetwork"
podNetworkName = "podNetwork"
networkAddressSpace = "10.0.0.0/8"
nodeAddressSpace = "10.9.0.0/16"
podAddressSpace = "10.10.0.0/16"
)
// CreateVirtualNetwork creates a virtual network.
func (c *Client) CreateVirtualNetwork(ctx context.Context) error {
createNetworkInput := createNetworkInput{
name: "constellation-" + c.uid,
location: c.location,
addressSpace: networkAddressSpace,
nodeAddressSpace: nodeAddressSpace,
podAddressSpace: podAddressSpace,
}
poller, err := c.networksAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkInput.name,
armnetwork.VirtualNetwork{
Name: to.Ptr(createNetworkInput.name), // this is supposed to be read-only
Tags: map[string]*string{"uid": to.Ptr(c.uid)},
Location: to.Ptr(createNetworkInput.location),
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
AddressSpace: &armnetwork.AddressSpace{
AddressPrefixes: []*string{
to.Ptr(createNetworkInput.addressSpace),
},
},
Subnets: []*armnetwork.Subnet{
{
Name: to.Ptr(nodeNetworkName),
Properties: &armnetwork.SubnetPropertiesFormat{
AddressPrefix: to.Ptr(createNetworkInput.nodeAddressSpace),
},
},
{
Name: to.Ptr(podNetworkName),
Properties: &armnetwork.SubnetPropertiesFormat{
AddressPrefix: to.Ptr(createNetworkInput.podAddressSpace),
},
},
},
},
},
nil,
)
if err != nil {
return err
}
resp, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
})
if err != nil {
return err
}
c.subnetID = *resp.VirtualNetwork.Properties.Subnets[0].ID
return nil
}
type createNetworkSecurityGroupInput struct {
name string
location string
rules []*armnetwork.SecurityRule
}
// CreateSecurityGroup creates a security group containing firewall rules.
func (c *Client) CreateSecurityGroup(ctx context.Context, input NetworkSecurityGroupInput) error {
rules, err := input.Ingress.Azure()
if err != nil {
return err
}
createNetworkSecurityGroupInput := createNetworkSecurityGroupInput{
name: "constellation-security-group-" + c.uid,
location: c.location,
rules: rules,
}
poller, err := c.networkSecurityGroupsAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, createNetworkSecurityGroupInput.name,
armnetwork.SecurityGroup{
Name: to.Ptr(createNetworkSecurityGroupInput.name),
Tags: map[string]*string{"uid": to.Ptr(c.uid)},
Location: to.Ptr(createNetworkSecurityGroupInput.location),
Properties: &armnetwork.SecurityGroupPropertiesFormat{
SecurityRules: createNetworkSecurityGroupInput.rules,
},
},
nil,
)
if err != nil {
return err
}
pollerResp, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
})
if err != nil {
return err
}
c.networkSecurityGroup = *pollerResp.SecurityGroup.ID
return nil
}
func (c *Client) createPublicIPAddress(ctx context.Context, name string) (*armnetwork.PublicIPAddress, error) {
poller, err := c.publicIPAddressesAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, name,
armnetwork.PublicIPAddress{
Tags: map[string]*string{"uid": to.Ptr(c.uid)},
Location: to.Ptr(c.location),
SKU: &armnetwork.PublicIPAddressSKU{
Name: to.Ptr(armnetwork.PublicIPAddressSKUNameStandard),
},
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
PublicIPAllocationMethod: to.Ptr(armnetwork.IPAllocationMethodStatic),
},
},
nil,
)
if err != nil {
return nil, err
}
pollerResp, err := poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
})
if err != nil {
return nil, err
}
return &pollerResp.PublicIPAddress, nil
}
// NetworkSecurityGroupInput defines firewall rules to be set.
type NetworkSecurityGroupInput struct {
Ingress cloudtypes.Firewall
Egress cloudtypes.Firewall
}
// CreateExternalLoadBalancer creates an external load balancer.
func (c *Client) CreateExternalLoadBalancer(ctx context.Context, isDebugCluster bool) error {
// First, create a public IP address for the load balancer.
publicIPAddress, err := c.createPublicIPAddress(ctx, "loadbalancer-public-ip-"+c.uid)
if err != nil {
return err
}
// Then, create the load balancer.
loadBalancerName := "constellation-load-balancer-" + c.uid
loadBalancer := azure.LoadBalancer{
Name: loadBalancerName,
Location: c.location,
ResourceGroup: c.resourceGroup,
Subscription: c.subscriptionID,
PublicIPID: *publicIPAddress.ID,
UID: c.uid,
}
azureLoadBalancer := loadBalancer.Azure()
if isDebugCluster {
azureLoadBalancer = loadBalancer.AppendDebugRules(azureLoadBalancer)
}
poller, err := c.loadBalancersAPI.BeginCreateOrUpdate(
ctx, c.resourceGroup, loadBalancerName,
azureLoadBalancer,
nil,
)
if err != nil {
return err
}
_, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{
Frequency: c.pollFrequency,
})
if err != nil {
return err
}
c.loadBalancerName = loadBalancerName
c.loadBalancerPubIP = *publicIPAddress.Properties.IPAddress
return nil
}

View File

@ -1,233 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
"github.com/stretchr/testify/assert"
)
func TestCreateVirtualNetwork(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
networksAPI networksAPI
wantErr bool
}{
"successful create": {
networksAPI: stubNetworksAPI{},
},
"failed to get response from successful create": {
networksAPI: stubNetworksAPI{pollErr: someErr},
wantErr: true,
},
"failed create": {
networksAPI: stubNetworksAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
networksAPI: tc.networksAPI,
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
}
if tc.wantErr {
assert.Error(client.CreateVirtualNetwork(ctx))
} else {
assert.NoError(client.CreateVirtualNetwork(ctx))
assert.NotEmpty(client.subnetID)
}
})
}
}
func TestCreateSecurityGroup(t *testing.T) {
someErr := errors.New("failed")
testNetworkSecurityGroupInput := NetworkSecurityGroupInput{
Ingress: cloudtypes.Firewall{
{
Name: "test-1",
Description: "test-1 description",
Protocol: "tcp",
IPRange: "192.0.2.0/24",
FromPort: 9000,
},
{
Name: "test-2",
Description: "test-2 description",
Protocol: "udp",
IPRange: "192.0.2.0/24",
FromPort: 51820,
},
},
Egress: cloudtypes.Firewall{},
}
testCases := map[string]struct {
networkSecurityGroupsAPI networkSecurityGroupsAPI
wantErr bool
}{
"successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{},
},
"failed to get response from successful create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{pollErr: someErr},
wantErr: true,
},
"failed create": {
networkSecurityGroupsAPI: stubNetworkSecurityGroupsAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
networkSecurityGroupsAPI: tc.networkSecurityGroupsAPI,
}
if tc.wantErr {
assert.Error(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
} else {
assert.NoError(client.CreateSecurityGroup(ctx, testNetworkSecurityGroupInput))
assert.Equal("network-security-group-id", client.networkSecurityGroup)
}
})
}
}
func TestCreatePublicIPAddress(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
name string
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
name: "nic-name",
},
"failed to get response from successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{pollErr: someErr},
wantErr: true,
},
"failed create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
id, err := client.createPublicIPAddress(ctx, tc.name)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.NotEmpty(id)
}
})
}
}
func TestCreateExternalLoadBalancer(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
publicIPAddressesAPI publicIPAddressesAPI
loadBalancersAPI loadBalancersAPI
isDebugCluster bool
wantErr bool
}{
"successful create": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
loadBalancersAPI: stubLoadBalancersAPI{},
},
"successful create (debug cluster)": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
loadBalancersAPI: stubLoadBalancersAPI{},
isDebugCluster: true,
},
"failed to get response from successful create": {
loadBalancersAPI: stubLoadBalancersAPI{pollErr: someErr},
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
wantErr: true,
},
"failed create": {
loadBalancersAPI: stubLoadBalancersAPI{createErr: someErr},
publicIPAddressesAPI: stubPublicIPAddressesAPI{},
wantErr: true,
},
"cannot create public IP": {
publicIPAddressesAPI: stubPublicIPAddressesAPI{createErr: someErr},
loadBalancersAPI: stubLoadBalancersAPI{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ctx := context.Background()
client := Client{
resourceGroup: "resource-group",
location: "location",
name: "name",
uid: "uid",
workers: make(cloudtypes.Instances),
controlPlanes: make(cloudtypes.Instances),
loadBalancersAPI: tc.loadBalancersAPI,
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
err := client.CreateExternalLoadBalancer(ctx, tc.isDebugCluster)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}

View File

@ -1,139 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"fmt"
"regexp"
"strings"
"sync"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
)
// TerminateResourceGroupResources deletes all resources from the resource group.
func (c *Client) TerminateResourceGroupResources(ctx context.Context) error {
const timeOut = 10 * time.Minute
ctx, cancel := context.WithTimeout(ctx, timeOut)
defer cancel()
pollers := make(chan *runtime.Poller[armresources.ClientDeleteByIDResponse], 20)
delete := make(chan struct{}, 1)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() { // This routine lists resources and starts their deletion, where possible.
defer wg.Done()
defer func() {
close(pollers)
for range delete { // drain channel
}
}()
for {
ids, err := c.getResourceIDList(ctx)
if err != nil {
time.Sleep(3 * time.Second)
continue
}
if len(ids) == 0 {
return
}
for _, id := range ids {
poller, err := c.deleteResourceByID(ctx, id)
if err != nil {
continue
}
pollers <- poller
}
select {
case <-ctx.Done():
return
case _, ok := <-delete:
if !ok { // channel was closed
return
}
}
}
}()
go func() { // This routine polls for for the deletions to complete.
defer wg.Done()
defer close(delete)
for poller := range pollers {
_, err := poller.PollUntilDone(ctx, nil)
if err != nil {
continue
}
select {
case delete <- struct{}{}:
default:
}
}
}()
wg.Wait()
return nil
}
func (c *Client) getResourceIDList(ctx context.Context) ([]string, error) {
var ids []string
pager := c.resourceAPI.NewListByResourceGroupPager(c.resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("getting next page of ListByResourceGroup: %w", err)
}
for _, resource := range page.Value {
if resource.ID == nil {
return nil, fmt.Errorf("resource %v has no ID", resource)
}
ids = append(ids, *resource.ID)
}
}
return ids, nil
}
func (c *Client) deleteResourceByID(ctx context.Context, id string,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
apiVersion := "2020-02-02"
// First try, API version unknown, will fail.
poller, err := c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
if isVersionWrongErr(err) {
// bad hack, but easiest way to get the right API version
apiVersion = parseAPIVersionFromErr(err)
poller, err = c.resourceAPI.BeginDeleteByID(ctx, id, apiVersion, nil)
}
return poller, err
}
func isVersionWrongErr(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "NoRegisteredProviderFound") &&
strings.Contains(err.Error(), "The supported api-versions are")
}
var apiVersionRegex = regexp.MustCompile(` (\d\d\d\d-\d\d-\d\d)'`)
func parseAPIVersionFromErr(err error) string {
if err == nil {
return ""
}
matches := apiVersionRegex.FindStringSubmatch(err.Error())
return matches[1]
}

View File

@ -1,145 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"fmt"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
)
func TestTerminateResourceGroupResources(t *testing.T) {
someErr := errors.New("failed")
apiVersionErr := errors.New("NoRegisteredProviderFound, The supported api-versions are: 2015-01-01'")
testCases := map[string]struct {
resourceAPI resourceAPI
}{
"no resources": {
resourceAPI: &fakeResourceAPI{},
},
"some resources": {
resourceAPI: &fakeResourceAPI{
resources: map[string]fakeResource{
"id-0": {beginDeleteByIDErr: apiVersionErr, pollErr: someErr},
"id-1": {beginDeleteByIDErr: apiVersionErr},
"id-2": {},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
client := &Client{
resourceAPI: tc.resourceAPI,
}
ctx := context.Background()
err := client.TerminateResourceGroupResources(ctx)
assert.NoError(err)
})
}
}
type fakeResourceAPI struct {
resources map[string]fakeResource
fetchErr error
}
type fakeResource struct {
beginDeleteByIDErr error
pollErr error
}
func (a fakeResourceAPI) NewListByResourceGroupPager(resourceGroupName string,
options *armresources.ClientListByResourceGroupOptions,
) *runtime.Pager[armresources.ClientListByResourceGroupResponse] {
pager := &stubClientListByResourceGroupResponsePager{
resources: a.resources,
fetchErr: a.fetchErr,
}
return runtime.NewPager(runtime.PagingHandler[armresources.ClientListByResourceGroupResponse]{
More: pager.moreFunc(),
Fetcher: pager.fetcherFunc(),
})
}
func (a fakeResourceAPI) BeginDeleteByID(ctx context.Context, resourceID string, apiVersion string,
options *armresources.ClientBeginDeleteByIDOptions,
) (*runtime.Poller[armresources.ClientDeleteByIDResponse], error) {
res := a.resources[resourceID]
pollErr := res.pollErr
if pollErr != nil {
res.pollErr = nil
}
poller, err := runtime.NewPoller(nil, runtime.NewPipeline("", "", runtime.PipelineOptions{}, nil), &runtime.NewPollerOptions[armresources.ClientDeleteByIDResponse]{
Handler: &stubPoller[armresources.ClientDeleteByIDResponse]{
result: armresources.ClientDeleteByIDResponse{},
resultErr: pollErr,
},
})
if err != nil {
panic(err)
}
beginDeleteByIDErr := res.beginDeleteByIDErr
if beginDeleteByIDErr != nil {
res.beginDeleteByIDErr = nil
}
if res.beginDeleteByIDErr == nil && res.pollErr == nil {
delete(a.resources, resourceID)
fmt.Printf("fake delete %s\n", resourceID)
} else {
a.resources[resourceID] = res
}
return poller, beginDeleteByIDErr
}
type stubClientListByResourceGroupResponsePager struct {
resources map[string]fakeResource
fetchErr error
more bool
}
func (p *stubClientListByResourceGroupResponsePager) moreFunc() func(
armresources.ClientListByResourceGroupResponse) bool {
return func(armresources.ClientListByResourceGroupResponse) bool {
return p.more
}
}
func (p *stubClientListByResourceGroupResponsePager) fetcherFunc() func(
context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error) {
return func(context.Context, *armresources.ClientListByResourceGroupResponse) (
armresources.ClientListByResourceGroupResponse, error,
) {
var resources []*armresources.GenericResourceExpanded
for id := range p.resources {
resources = append(resources, &armresources.GenericResourceExpanded{ID: proto.String(id)})
}
p.more = false
return armresources.ClientListByResourceGroupResponse{
ResourceListResult: armresources.ResourceListResult{
Value: resources,
},
}, p.fetchErr
}
}

View File

@ -1,127 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
// Package poller implements a poller that can be used to wait for a condition to be met.
// The poller is designed to be a replacement for the azure-sdk-for-go poller
// with exponential backoff and an injectable clock.
// reference: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azcore@v1.1.1/runtime#Poller .
package poller
import (
"context"
"errors"
"time"
"k8s.io/utils/clock"
)
// PollUntilDoneOptions provides options for the Poller.
// Used to specify backoff and clock options.
type PollUntilDoneOptions struct {
StartingBackoff time.Duration
MaxBackoff time.Duration
clock.Clock
}
// NewPollUntilDoneOptions creates a new PollUntilDoneOptions with the default values and a real clock.
func NewPollUntilDoneOptions() *PollUntilDoneOptions {
return &PollUntilDoneOptions{
Clock: clock.RealClock{},
}
}
// Poller is a poller that can be used to wait for a condition to be met.
// The poller is designed to be a replacement for the azure-sdk-for-go poller
// with exponential backoff and an injectable clock.
// reference: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azcore@v1.1.1/runtime#Poller .
type Poller[T any] struct {
handler PollingHandler[T]
err error
result *T
done bool
}
// New creates a new Poller.
func New[T any](handler PollingHandler[T]) *Poller[T] {
return &Poller[T]{
handler: handler,
result: new(T),
}
}
// PollUntilDone polls the handler until the condition is met or the context is cancelled.
func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOptions) (T, error) {
if options == nil {
options = NewPollUntilDoneOptions()
}
if options.MaxBackoff == 0 {
options.MaxBackoff = time.Minute
}
if options.StartingBackoff < time.Second {
options.StartingBackoff = time.Second
}
backoff := options.StartingBackoff
for {
timer := options.Clock.NewTimer(backoff)
err := p.Poll(ctx)
if err != nil {
return *new(T), err
}
if p.Done() {
return p.Result(ctx)
}
select {
case <-ctx.Done():
return *new(T), ctx.Err()
case <-timer.C():
}
if backoff >= options.MaxBackoff/2 {
backoff = options.MaxBackoff
} else {
backoff *= 2
}
}
}
// Poll polls the handler.
func (p *Poller[T]) Poll(ctx context.Context) error {
return p.handler.Poll(ctx)
}
// Done returns true if the condition is met.
func (p *Poller[T]) Done() bool {
return p.handler.Done()
}
// Result returns the result of the poller if the condition is met.
// If the condition is not met, an error is returned.
func (p *Poller[T]) Result(ctx context.Context) (T, error) {
if !p.Done() {
return *new(T), errors.New("poller is in a non-terminal state")
}
if p.done {
// the result has already been retrieved, return the cached value
if p.err != nil {
return *new(T), p.err
}
return *p.result, nil
}
err := p.handler.Result(ctx, p.result)
p.done = true
if err != nil {
p.err = err
return *new(T), p.err
}
return *p.result, nil
}
// PollingHandler is a handler that can be used to poll for a condition to be met.
type PollingHandler[T any] interface {
Done() bool
Poll(ctx context.Context) error
Result(ctx context.Context, out *T) error
}

View File

@ -1,211 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package poller
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/stretchr/testify/assert"
testclock "k8s.io/utils/clock/testing"
)
func TestResult(t *testing.T) {
testCases := map[string]struct {
done bool
pollErr error
resultErr error
wantErr bool
wantResult int
}{
"result called before poller is done": {
wantErr: true,
},
"result returns error": {
done: true,
resultErr: errors.New("result error"),
wantErr: true,
},
"result succeeds": {
done: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
poller := New[int](&stubPoller[int]{
result: &tc.wantResult,
done: tc.done,
pollErr: tc.pollErr,
resultErr: tc.resultErr,
})
_, firstErr := poller.Result(context.Background())
if tc.wantErr {
assert.Error(firstErr)
// calling Result again should return the same error
_, secondErr := poller.Result(context.Background())
assert.Equal(firstErr, secondErr)
return
}
assert.NoError(firstErr)
// calling Result again should still not return an error
_, secondErr := poller.Result(context.Background())
assert.NoError(secondErr)
})
}
}
func TestPollUntilDone(t *testing.T) {
testCases := map[string]struct {
messages []message
maxBackoff time.Duration
resultErr error
wantErr bool
wantResult int
}{
"poll succeeds on first try": {
messages: []message{
{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
{done: to.Ptr(true)}, // Result() will call Done() after the last poll
},
wantResult: 1,
},
"poll succeeds on fourth try": {
messages: []message{
{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: 2 * time.Second},
{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: 4 * time.Second},
{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
{done: to.Ptr(true)}, // Result() will call Done() after the last poll
},
wantResult: 1,
},
"max backoff reached": {
messages: []message{
{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
{pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second},
{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
{done: to.Ptr(true)}, // Result() will call Done() after the last poll
},
maxBackoff: time.Second,
wantResult: 1,
},
"poll errors": {
messages: []message{
{pollErr: to.Ptr(errors.New("poll error"))},
},
wantErr: true,
},
"result errors": {
messages: []message{
{pollErr: to.Ptr[error](nil), done: to.Ptr(true)},
{done: to.Ptr(true)}, // Result() will call Done() after the last poll
},
resultErr: errors.New("result error"),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
doneC := make(chan bool)
pollC := make(chan error)
poller := New[int](&fakePoller[int]{
result: &tc.wantResult,
resultErr: tc.resultErr,
doneC: doneC,
pollC: pollC,
})
clock := testclock.NewFakeClock(time.Now())
var gotResult int
var gotErr error
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
gotResult, gotErr = poller.PollUntilDone(context.Background(), &PollUntilDoneOptions{
MaxBackoff: tc.maxBackoff,
Clock: clock,
})
}()
for _, msg := range tc.messages {
if msg.pollErr != nil {
pollC <- *msg.pollErr
}
if msg.done != nil {
doneC <- *msg.done
}
clock.Step(msg.backoff)
}
wg.Wait()
if tc.wantErr {
assert.Error(gotErr)
return
}
assert.NoError(gotErr)
assert.Equal(tc.wantResult, gotResult)
})
}
}
type stubPoller[T any] struct {
result *T
done bool
pollErr error
resultErr error
}
func (s *stubPoller[T]) Poll(ctx context.Context) error {
return s.pollErr
}
func (s *stubPoller[T]) Done() bool {
return s.done
}
func (s *stubPoller[T]) Result(ctx context.Context, out *T) error {
*out = *s.result
return s.resultErr
}
type message struct {
pollErr *error
done *bool
backoff time.Duration
}
type fakePoller[T any] struct {
result *T
resultErr error
doneC chan bool
pollC chan error
}
func (s *fakePoller[T]) Poll(ctx context.Context) error {
return <-s.pollC
}
func (s *fakePoller[T]) Done() bool {
return <-s.doneC
}
func (s *fakePoller[T]) Result(ctx context.Context, out *T) error {
*out = *s.result
return s.resultErr
}

View File

@ -1,324 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/v2/internal/constants"
)
// LoadBalancer defines a Azure load balancer.
type LoadBalancer struct {
Name string
Subscription string
ResourceGroup string
Location string
PublicIPID string
UID string
}
const (
BackendAddressPoolWorkerName = "backendAddressWorkerPool"
BackendAddressPoolControlPlaneName = "backendAddressControlPlanePool"
frontEndIPConfigName = "frontEndIPConfig"
kubeHealthProbeName = "kubeHealthProbe"
verifyHealthProbeName = "verifyHealthProbe"
coordHealthProbeName = "coordHealthProbe"
debugdHealthProbeName = "debugdHealthProbe"
konnectivityHealthProbeName = "konnectivityHealthProbe"
recoveryHealthProbeName = "recoveryHealthProbe"
)
// Azure returns a Azure representation of LoadBalancer.
func (l LoadBalancer) Azure() armnetwork.LoadBalancer {
backEndAddressPoolNodeName := BackendAddressPoolWorkerName + "-" + l.UID
backEndAddressPoolControlPlaneName := BackendAddressPoolControlPlaneName + "-" + l.UID
return armnetwork.LoadBalancer{
Name: to.Ptr(l.Name),
Location: to.Ptr(l.Location),
SKU: &armnetwork.LoadBalancerSKU{Name: to.Ptr(armnetwork.LoadBalancerSKUNameStandard)},
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{
{
Name: to.Ptr(frontEndIPConfigName),
Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.Ptr(l.PublicIPID),
},
},
},
},
BackendAddressPools: []*armnetwork.BackendAddressPool{
{Name: to.Ptr(backEndAddressPoolNodeName)},
{Name: to.Ptr(backEndAddressPoolControlPlaneName)},
{Name: to.Ptr("all")},
},
Probes: []*armnetwork.Probe{
{
Name: to.Ptr(kubeHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: to.Ptr(armnetwork.ProbeProtocolTCP),
Port: to.Ptr(int32(constants.KubernetesPort)),
},
},
{
Name: to.Ptr(verifyHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: to.Ptr(armnetwork.ProbeProtocolTCP),
Port: to.Ptr[int32](constants.VerifyServiceNodePortGRPC),
},
},
{
Name: to.Ptr(coordHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: to.Ptr(armnetwork.ProbeProtocolTCP),
Port: to.Ptr[int32](constants.BootstrapperPort),
},
},
{
Name: to.Ptr(debugdHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: to.Ptr(armnetwork.ProbeProtocolTCP),
Port: to.Ptr[int32](constants.DebugdPort),
},
},
{
Name: to.Ptr(konnectivityHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: to.Ptr(armnetwork.ProbeProtocolTCP),
Port: to.Ptr[int32](constants.KonnectivityPort),
},
},
{
Name: to.Ptr(recoveryHealthProbeName),
Properties: &armnetwork.ProbePropertiesFormat{
Protocol: to.Ptr(armnetwork.ProbeProtocolTCP),
Port: to.Ptr[int32](constants.RecoveryPort),
IntervalInSeconds: to.Ptr[int32](5),
},
},
},
LoadBalancingRules: []*armnetwork.LoadBalancingRule{
{
Name: to.Ptr("kubeLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Ptr[int32](constants.KubernetesPort),
BackendPort: to.Ptr[int32](constants.KubernetesPort),
Protocol: to.Ptr(armnetwork.TransportProtocolTCP),
Probe: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/probes/" + kubeHealthProbeName),
},
DisableOutboundSnat: to.Ptr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
{
Name: to.Ptr("verifyLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Ptr[int32](constants.VerifyServiceNodePortGRPC),
BackendPort: to.Ptr[int32](constants.VerifyServiceNodePortGRPC),
Protocol: to.Ptr(armnetwork.TransportProtocolTCP),
Probe: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/probes/" + verifyHealthProbeName),
},
DisableOutboundSnat: to.Ptr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
{
Name: to.Ptr("coordLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Ptr[int32](constants.BootstrapperPort),
BackendPort: to.Ptr[int32](constants.BootstrapperPort),
Protocol: to.Ptr(armnetwork.TransportProtocolTCP),
Probe: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/probes/" + coordHealthProbeName),
},
DisableOutboundSnat: to.Ptr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
{
Name: to.Ptr("konnectivityLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Ptr[int32](constants.KonnectivityPort),
BackendPort: to.Ptr[int32](constants.KonnectivityPort),
Protocol: to.Ptr(armnetwork.TransportProtocolTCP),
Probe: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/probes/" + konnectivityHealthProbeName),
},
DisableOutboundSnat: to.Ptr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
{
Name: to.Ptr("recoveryLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Ptr[int32](constants.RecoveryPort),
BackendPort: to.Ptr[int32](constants.RecoveryPort),
Protocol: to.Ptr(armnetwork.TransportProtocolTCP),
Probe: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/probes/" + recoveryHealthProbeName),
},
DisableOutboundSnat: to.Ptr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
},
},
OutboundRules: []*armnetwork.OutboundRule{
{
Name: to.Ptr("outboundRuleControlPlane"),
Properties: &armnetwork.OutboundRulePropertiesFormat{
FrontendIPConfigurations: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
},
BackendAddressPool: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/all"),
},
Protocol: to.Ptr(armnetwork.LoadBalancerOutboundRuleProtocolAll),
},
},
},
},
}
}
func (l *LoadBalancer) AppendDebugRules(armLoadBalancer armnetwork.LoadBalancer) armnetwork.LoadBalancer {
backEndAddressPoolControlPlaneName := BackendAddressPoolControlPlaneName + "-" + l.UID
if armLoadBalancer.Properties == nil {
armLoadBalancer.Properties = &armnetwork.LoadBalancerPropertiesFormat{}
}
if armLoadBalancer.Properties.LoadBalancingRules == nil {
armLoadBalancer.Properties.LoadBalancingRules = []*armnetwork.LoadBalancingRule{}
}
debugdRule := armnetwork.LoadBalancingRule{
Name: to.Ptr("debugdLoadBalancerRule"),
Properties: &armnetwork.LoadBalancingRulePropertiesFormat{
FrontendIPConfiguration: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/frontendIPConfigurations/" + frontEndIPConfigName),
},
FrontendPort: to.Ptr[int32](constants.DebugdPort),
BackendPort: to.Ptr[int32](constants.DebugdPort),
Protocol: to.Ptr(armnetwork.TransportProtocolTCP),
Probe: &armnetwork.SubResource{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/probes/" + debugdHealthProbeName),
},
DisableOutboundSnat: to.Ptr(true),
BackendAddressPools: []*armnetwork.SubResource{
{
ID: to.Ptr("/subscriptions/" + l.Subscription +
"/resourceGroups/" + l.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + l.Name +
"/backendAddressPools/" + backEndAddressPoolControlPlaneName),
},
},
},
}
armLoadBalancer.Properties.LoadBalancingRules = append(armLoadBalancer.Properties.LoadBalancingRules, &debugdRule)
return armLoadBalancer
}

View File

@ -1,43 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/stretchr/testify/assert"
)
func TestAppendDebugRules(t *testing.T) {
assert := assert.New(t)
// Test with empty rules
emptyAzureLoadBalancer := armnetwork.LoadBalancer{}
someLoadBalancer := LoadBalancer{
Name: "test",
Subscription: "00000000-0000-0000-0000-000000000000",
Location: "westeurope",
ResourceGroup: "test-resource-group",
PublicIPID: "some-public-ip-id",
UID: "test-uid",
}
appendedEmptyAzureLoadBalancer := someLoadBalancer.AppendDebugRules(emptyAzureLoadBalancer)
assert.Equal("debugdLoadBalancerRule", *(appendedEmptyAzureLoadBalancer.Properties.LoadBalancingRules[0]).Name, "Debug load balancer rule not found at index 0")
// Test with existing rules
defaultAzureLoadBalancer := someLoadBalancer.Azure()
appendedDefaultAzureLoadBalancer := someLoadBalancer.AppendDebugRules(defaultAzureLoadBalancer)
var foundDebugLoadBalancer bool
for _, rule := range appendedDefaultAzureLoadBalancer.Properties.LoadBalancingRules {
if *(rule).Name == "debugdLoadBalancerRule" {
foundDebugLoadBalancer = true
}
}
assert.True(foundDebugLoadBalancer, "Debug load balancer rule not found")
}

View File

@ -1,171 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"crypto/rand"
"math/big"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/edgelesssys/constellation/v2/internal/cloud/azure"
)
// ScaleSet defines a Azure scale set.
type ScaleSet struct {
Name string
NamePrefix string
UID string
Subscription string
ResourceGroup string
Location string
InstanceType string
StateDiskSizeGB int32
StateDiskType string
Count int64
Username string
SubnetID string
NetworkSecurityGroup string
Password string
Image string
UserAssignedIdentity string
LoadBalancerName string
LoadBalancerBackendAddressPool string
ConfidentialVM bool
}
// Azure returns the Azure representation of ScaleSet.
func (s ScaleSet) Azure() armcomputev2.VirtualMachineScaleSet {
securityType := armcomputev2.SecurityTypesTrustedLaunch
var diskSecurityProfile *armcomputev2.VMDiskSecurityProfile
if s.ConfidentialVM {
securityType = armcomputev2.SecurityTypesConfidentialVM
diskSecurityProfile = &armcomputev2.VMDiskSecurityProfile{
SecurityEncryptionType: to.Ptr(armcomputev2.SecurityEncryptionTypesVMGuestStateOnly),
}
}
return armcomputev2.VirtualMachineScaleSet{
Name: to.Ptr(s.Name),
Location: to.Ptr(s.Location),
SKU: &armcomputev2.SKU{
Name: to.Ptr(s.InstanceType),
Capacity: to.Ptr(s.Count),
},
Properties: &armcomputev2.VirtualMachineScaleSetProperties{
Overprovision: to.Ptr(false),
UpgradePolicy: &armcomputev2.UpgradePolicy{
Mode: to.Ptr(armcomputev2.UpgradeModeManual),
AutomaticOSUpgradePolicy: &armcomputev2.AutomaticOSUpgradePolicy{
EnableAutomaticOSUpgrade: to.Ptr(false),
DisableAutomaticRollback: to.Ptr(false),
},
},
VirtualMachineProfile: &armcomputev2.VirtualMachineScaleSetVMProfile{
OSProfile: &armcomputev2.VirtualMachineScaleSetOSProfile{
ComputerNamePrefix: to.Ptr(s.NamePrefix),
AdminUsername: to.Ptr(s.Username),
AdminPassword: to.Ptr(s.Password),
LinuxConfiguration: &armcomputev2.LinuxConfiguration{},
},
StorageProfile: &armcomputev2.VirtualMachineScaleSetStorageProfile{
ImageReference: azure.ImageReferenceFromImage(s.Image),
DataDisks: []*armcomputev2.VirtualMachineScaleSetDataDisk{
{
CreateOption: to.Ptr(armcomputev2.DiskCreateOptionTypesEmpty),
DiskSizeGB: to.Ptr(s.StateDiskSizeGB),
Lun: to.Ptr[int32](0),
ManagedDisk: &armcomputev2.VirtualMachineScaleSetManagedDiskParameters{
StorageAccountType: (*armcomputev2.StorageAccountTypes)(to.Ptr(s.StateDiskType)),
},
},
},
OSDisk: &armcomputev2.VirtualMachineScaleSetOSDisk{
ManagedDisk: &armcomputev2.VirtualMachineScaleSetManagedDiskParameters{
SecurityProfile: diskSecurityProfile,
},
CreateOption: to.Ptr(armcomputev2.DiskCreateOptionTypesFromImage),
},
},
NetworkProfile: &armcomputev2.VirtualMachineScaleSetNetworkProfile{
NetworkInterfaceConfigurations: []*armcomputev2.VirtualMachineScaleSetNetworkConfiguration{
{
Name: to.Ptr(s.Name),
Properties: &armcomputev2.VirtualMachineScaleSetNetworkConfigurationProperties{
Primary: to.Ptr(true),
EnableIPForwarding: to.Ptr(true),
IPConfigurations: []*armcomputev2.VirtualMachineScaleSetIPConfiguration{
{
Name: to.Ptr(s.Name),
Properties: &armcomputev2.VirtualMachineScaleSetIPConfigurationProperties{
Primary: to.Ptr(true),
Subnet: &armcomputev2.APIEntityReference{
ID: to.Ptr(s.SubnetID),
},
LoadBalancerBackendAddressPools: []*armcomputev2.SubResource{
{
ID: to.Ptr("/subscriptions/" + s.Subscription +
"/resourcegroups/" + s.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + s.LoadBalancerName +
"/backendAddressPools/" + s.LoadBalancerBackendAddressPool),
},
{
ID: to.Ptr("/subscriptions/" + s.Subscription +
"/resourcegroups/" + s.ResourceGroup +
"/providers/Microsoft.Network/loadBalancers/" + s.LoadBalancerName +
"/backendAddressPools/all"),
},
},
},
},
},
NetworkSecurityGroup: &armcomputev2.SubResource{
ID: to.Ptr(s.NetworkSecurityGroup),
},
},
},
},
},
SecurityProfile: &armcomputev2.SecurityProfile{
SecurityType: to.Ptr(securityType),
UefiSettings: &armcomputev2.UefiSettings{VTpmEnabled: to.Ptr(true), SecureBootEnabled: to.Ptr(true)},
},
DiagnosticsProfile: &armcomputev2.DiagnosticsProfile{
BootDiagnostics: &armcomputev2.BootDiagnostics{
Enabled: to.Ptr(true),
},
},
},
},
Identity: &armcomputev2.VirtualMachineScaleSetIdentity{
Type: to.Ptr(armcomputev2.ResourceIdentityTypeUserAssigned),
UserAssignedIdentities: map[string]*armcomputev2.UserAssignedIdentitiesValue{
s.UserAssignedIdentity: {},
},
},
Tags: map[string]*string{"uid": to.Ptr(s.UID)},
}
}
// GeneratePassword is a helper function to generate a random password
// for Azure's scale set.
func GeneratePassword() (string, error) {
letters := []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789")
pwLen := 16
pw := make([]byte, 0, pwLen)
for i := 0; i < pwLen; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
if err != nil {
return "", err
}
pw = append(pw, letters[n.Int64()])
}
// bypass password rules
pw = append(pw, []byte("Aa1!")...)
return string(pw), nil
}

View File

@ -1,118 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"testing"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFirewallPermissions(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
scaleSet := ScaleSet{
Name: "name",
NamePrefix: "constellation-",
Location: "UK South",
InstanceType: "Standard_DC4as_v5",
Count: 3,
Username: "constellation",
SubnetID: "subnet-id",
NetworkSecurityGroup: "network-security-group",
Password: "password",
Image: "image",
UserAssignedIdentity: "user-identity",
ConfidentialVM: true,
}
scaleSetAzure := scaleSet.Azure()
require.NotNil(scaleSetAzure.Name)
assert.Equal(scaleSet.Name, *scaleSetAzure.Name)
require.NotNil(scaleSetAzure.Location)
assert.Equal(scaleSet.Location, *scaleSetAzure.Location)
require.NotNil(scaleSetAzure.SKU)
require.NotNil(scaleSetAzure.SKU.Name)
assert.Equal(scaleSet.InstanceType, *scaleSetAzure.SKU.Name)
require.NotNil(scaleSetAzure.SKU.Capacity)
assert.Equal(scaleSet.Count, *scaleSetAzure.SKU.Capacity)
require.NotNil(scaleSetAzure.Properties)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
assert.Equal(scaleSet.NamePrefix, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.ComputerNamePrefix)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
assert.Equal(scaleSet.Username, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminUsername)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
assert.Equal(scaleSet.Password, *scaleSetAzure.Properties.VirtualMachineProfile.OSProfile.AdminPassword)
// Verify image
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
assert.Equal(scaleSet.Image, *scaleSetAzure.Properties.VirtualMachineProfile.StorageProfile.ImageReference.ID)
// Verify network
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile)
require.Len(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations, 1)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0])
networkConfig := scaleSetAzure.Properties.VirtualMachineProfile.NetworkProfile.NetworkInterfaceConfigurations[0]
require.NotNil(networkConfig.Name)
assert.Equal(scaleSet.Name, *networkConfig.Name)
require.NotNil(networkConfig.Properties)
require.Len(networkConfig.Properties.IPConfigurations, 1)
require.NotNil(networkConfig.Properties.IPConfigurations[0])
ipConfig := networkConfig.Properties.IPConfigurations[0]
require.NotNil(ipConfig.Name)
assert.Equal(scaleSet.Name, *ipConfig.Name)
require.NotNil(ipConfig.Properties)
require.NotNil(ipConfig.Properties.Subnet)
require.NotNil(ipConfig.Properties.Subnet.ID)
assert.Equal(scaleSet.SubnetID, *ipConfig.Properties.Subnet.ID)
require.NotNil(networkConfig.Properties.NetworkSecurityGroup)
assert.Equal(scaleSet.NetworkSecurityGroup, *networkConfig.Properties.NetworkSecurityGroup.ID)
// Verify vTPM
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
assert.Equal(armcomputev2.SecurityTypesConfidentialVM, *scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.SecurityType)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings)
require.NotNil(scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
assert.True(*scaleSetAzure.Properties.VirtualMachineProfile.SecurityProfile.UefiSettings.VTpmEnabled)
// Verify UserAssignedIdentity
require.NotNil(scaleSetAzure.Identity)
require.NotNil(scaleSetAzure.Identity.Type)
assert.Equal(armcomputev2.ResourceIdentityTypeUserAssigned, *scaleSetAzure.Identity.Type)
require.Len(scaleSetAzure.Identity.UserAssignedIdentities, 1)
assert.Contains(scaleSetAzure.Identity.UserAssignedIdentities, scaleSet.UserAssignedIdentity)
}
func TestGeneratePassword(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
pw, err := GeneratePassword()
require.NoError(err)
assert.Len(pw, 20)
}

View File

@ -55,9 +55,7 @@ func TestInitialize(t *testing.T) {
Type: "service_account",
}
testAzureState := &state.ConstellationState{
CloudProvider: "Azure",
AzureWorkerInstances: cloudtypes.Instances{"id-0": {}, "id-1": {}},
AzureResourceGroup: "test",
CloudProvider: "Azure",
}
testQemuState := &state.ConstellationState{
CloudProvider: "QEMU",

View File

@ -6,26 +6,10 @@ SPDX-License-Identifier: AGPL-3.0-only
package state
import (
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudtypes"
)
// ConstellationState is the state of a Constellation.
type ConstellationState struct {
Name string `json:"name,omitempty"`
UID string `json:"uid,omitempty"`
CloudProvider string `json:"cloudprovider,omitempty"`
LoadBalancerIP string `json:"bootstrapperhost,omitempty"`
AzureWorkerInstances cloudtypes.Instances `json:"azureworkers,omitempty"`
AzureControlPlaneInstances cloudtypes.Instances `json:"azurecontrolplanes,omitempty"`
AzureResourceGroup string `json:"azureresourcegroup,omitempty"`
AzureLocation string `json:"azurelocation,omitempty"`
AzureSubscription string `json:"azuresubscription,omitempty"`
AzureTenant string `json:"azuretenant,omitempty"`
AzureSubnet string `json:"azuresubnet,omitempty"`
AzureNetworkSecurityGroup string `json:"azurenetworksecuritygroup,omitempty"`
AzureWorkerScaleSet string `json:"azureworkersscaleset,omitempty"`
AzureControlPlaneScaleSet string `json:"azurecontrolplanesscaleset,omitempty"`
AzureADAppObjectID string `json:"azureadappobjectid,omitempty"`
}