AB#2524 Refactor Azure metadata/cloud API (#477)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2022-11-15 09:08:18 +01:00 committed by GitHub
parent 74a7a80153
commit f41c54e837
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 2127 additions and 2492 deletions

View File

@ -154,11 +154,11 @@ func main() {
issuer = initserver.NewIssuerWrapper(trustedlaunch.NewIssuer(), vmtype.AzureTrustedLaunch, idkeydigest) issuer = initserver.NewIssuerWrapper(trustedlaunch.NewIssuer(), vmtype.AzureTrustedLaunch, idkeydigest)
} }
metadata, err := azurecloud.NewMetadata(ctx) metadata, err := azurecloud.New(ctx)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to create Azure metadata client") log.With(zap.Error(err)).Fatalf("Failed to create Azure metadata client")
} }
cloudLogger, err = azurecloud.NewLogger(ctx, metadata) cloudLogger, err = azurecloud.NewLogger(ctx)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger") log.With(zap.Error(err)).Fatalf("Failed to set up cloud logger")
} }

View File

@ -65,7 +65,7 @@ func main() {
fetcher = cloudprovider.New(meta) fetcher = cloudprovider.New(meta)
case platform.Azure: case platform.Azure:
meta, err := azurecloud.NewMetadata(ctx) meta, err := azurecloud.New(ctx)
if err != nil { if err != nil {
log.With(zap.Error(err)).Fatalf("Failed to initialize Azure metadata") log.With(zap.Error(err)).Fatalf("Failed to initialize Azure metadata")
} }

View File

@ -85,7 +85,7 @@ func main() {
_ = exportPCRs() _ = exportPCRs()
log.With(zap.Error(err)).Fatalf("Unable to resolve Azure state disk path") log.With(zap.Error(err)).Fatalf("Unable to resolve Azure state disk path")
} }
metadataAPI, err = azurecloud.NewMetadata(context.Background()) metadataAPI, err = azurecloud.New(context.Background())
if err != nil { if err != nil {
log.With(zap.Error).Fatalf("Failed to set up Azure metadata API") log.With(zap.Error).Fatalf("Failed to set up Azure metadata API")
} }

3
go.mod
View File

@ -45,7 +45,6 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights v1.0.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2 v2.0.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2 v2.0.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.0.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1 github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1
github.com/Azure/go-autorest/autorest/to v0.4.0 github.com/Azure/go-autorest/autorest/to v0.4.0
github.com/aws/aws-sdk-go-v2 v1.17.1 github.com/aws/aws-sdk-go-v2 v1.17.1
@ -115,6 +114,7 @@ require (
require ( require (
cloud.google.com/go/longrunning v0.3.0 // indirect cloud.google.com/go/longrunning v0.3.0 // indirect
github.com/Azure/go-autorest v14.2.0+incompatible // indirect
github.com/golang-jwt/jwt/v4 v4.4.2 // indirect github.com/golang-jwt/jwt/v4 v4.4.2 // indirect
github.com/google/logger v1.1.1 // indirect github.com/google/logger v1.1.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.1 // indirect github.com/hashicorp/go-retryablehttp v0.7.1 // indirect
@ -128,7 +128,6 @@ require (
github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.1 // indirect github.com/Azure/azure-sdk-for-go/sdk/internal v1.0.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0 // indirect github.com/Azure/azure-sdk-for-go/sdk/keyvault/internal v0.7.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Azure/go-autorest v14.2.0+incompatible // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v0.7.0 // indirect
github.com/BurntSushi/toml v1.1.0 // indirect github.com/BurntSushi/toml v1.1.0 // indirect
github.com/MakeNowJust/heredoc v1.0.0 // indirect github.com/MakeNowJust/heredoc v1.0.0 // indirect

1
go.sum
View File

@ -121,7 +121,6 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal v1.0.0 h1:lMW1lD/
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0 h1:QM6sE5k2ZT/vI5BEe0r7mqjsUSnhVBFbOsVkEuaEfiA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0 h1:QM6sE5k2ZT/vI5BEe0r7mqjsUSnhVBFbOsVkEuaEfiA=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0/go.mod h1:243D9iHbcQXoFUtgHJwL7gl2zx1aDuDMjvBZVGr2uW0= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork v1.1.0/go.mod h1:243D9iHbcQXoFUtgHJwL7gl2zx1aDuDMjvBZVGr2uW0=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.0.0 h1:ECsQtyERDVz3NP3kvDOTLvbQhqWp/x9EsGKtb4ogUr8= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.0.0 h1:ECsQtyERDVz3NP3kvDOTLvbQhqWp/x9EsGKtb4ogUr8=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources v1.0.0/go.mod h1:s1tW/At+xHqjNFvWU4G0c0Qv33KOhvbGNj0RCTQDV8s=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1 h1:BMTdr+ib5ljLa9MxTJK8x/Ds0MbBb4MfuW5BL0zMJnI= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1 h1:BMTdr+ib5ljLa9MxTJK8x/Ds0MbBb4MfuW5BL0zMJnI=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1/go.mod h1:c6WvOhtmjNUWbLfOG1qxM/q0SPvQNSVJvolm+C52dIU= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v0.5.1/go.mod h1:c6WvOhtmjNUWbLfOG1qxM/q0SPvQNSVJvolm+C52dIU=
github.com/Azure/azure-service-bus-go v0.9.1/go.mod h1:yzBx6/BUGfjfeqbRZny9AQIbIe3AcV9WZbAdpkoXOa0= github.com/Azure/azure-service-bus-go v0.9.1/go.mod h1:yzBx6/BUGfjfeqbRZny9AQIbIe3AcV9WZbAdpkoXOa0=

View File

@ -1,300 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
type stubIMDSAPI struct {
providerIDErr error
providerID string
subscriptionIDErr error
subscriptionID string
resourceGroupErr error
resourceGroup string
uidErr error
uid string
}
func (a *stubIMDSAPI) ProviderID(ctx context.Context) (string, error) {
return a.providerID, a.providerIDErr
}
func (a *stubIMDSAPI) SubscriptionID(ctx context.Context) (string, error) {
return a.subscriptionID, a.subscriptionIDErr
}
func (a *stubIMDSAPI) ResourceGroup(ctx context.Context) (string, error) {
return a.resourceGroup, a.resourceGroupErr
}
func (a *stubIMDSAPI) UID(ctx context.Context) (string, error) {
return a.uid, a.uidErr
}
type stubNetworkInterfacesAPI struct {
getInterface armnetwork.Interface
getErr error
}
func (a *stubNetworkInterfacesAPI) GetVirtualMachineScaleSetNetworkInterface(ctx context.Context, resourceGroupName string,
virtualMachineScaleSetName string, virtualmachineIndex string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceOptions,
) (armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse, error) {
return armnetwork.InterfacesClientGetVirtualMachineScaleSetNetworkInterfaceResponse{
Interface: a.getInterface,
}, a.getErr
}
func (a *stubNetworkInterfacesAPI) Get(ctx context.Context, resourceGroupName string, networkInterfaceName string,
options *armnetwork.InterfacesClientGetOptions,
) (armnetwork.InterfacesClientGetResponse, error) {
return armnetwork.InterfacesClientGetResponse{
Interface: a.getInterface,
}, a.getErr
}
type stubVirtualMachineScaleSetVMsAPI struct {
getVM armcomputev2.VirtualMachineScaleSetVM
getErr error
pager *stubVirtualMachineScaleSetVMPager
}
func (a *stubVirtualMachineScaleSetVMsAPI) Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, instanceID string, options *armcomputev2.VirtualMachineScaleSetVMsClientGetOptions) (armcomputev2.VirtualMachineScaleSetVMsClientGetResponse, error) {
return armcomputev2.VirtualMachineScaleSetVMsClientGetResponse{
VirtualMachineScaleSetVM: a.getVM,
}, a.getErr
}
func (a *stubVirtualMachineScaleSetVMsAPI) NewListPager(resourceGroupName string, virtualMachineScaleSetName string, options *armcomputev2.VirtualMachineScaleSetVMsClientListOptions) *runtime.Pager[armcomputev2.VirtualMachineScaleSetVMsClientListResponse] {
return runtime.NewPager(runtime.PagingHandler[armcomputev2.VirtualMachineScaleSetVMsClientListResponse]{
More: a.pager.moreFunc(),
Fetcher: a.pager.fetcherFunc(),
})
}
type stubVirtualMachineScaleSetsClientListPager struct {
list []armcomputev2.VirtualMachineScaleSet
fetchErr error
more bool
}
func (p *stubVirtualMachineScaleSetsClientListPager) moreFunc() func(armcomputev2.VirtualMachineScaleSetsClientListResponse) bool {
return func(armcomputev2.VirtualMachineScaleSetsClientListResponse) bool {
return p.more
}
}
func (p *stubVirtualMachineScaleSetsClientListPager) fetcherFunc() func(context.Context, *armcomputev2.VirtualMachineScaleSetsClientListResponse) (armcomputev2.VirtualMachineScaleSetsClientListResponse, error) {
return func(context.Context, *armcomputev2.VirtualMachineScaleSetsClientListResponse) (armcomputev2.VirtualMachineScaleSetsClientListResponse, error) {
page := make([]*armcomputev2.VirtualMachineScaleSet, len(p.list))
for i := range p.list {
page[i] = &p.list[i]
}
return armcomputev2.VirtualMachineScaleSetsClientListResponse{
VirtualMachineScaleSetListResult: armcomputev2.VirtualMachineScaleSetListResult{
Value: page,
},
}, p.fetchErr
}
}
type stubScaleSetsAPI struct {
pager *stubVirtualMachineScaleSetsClientListPager
}
func (a *stubScaleSetsAPI) NewListPager(resourceGroupName string, options *armcomputev2.VirtualMachineScaleSetsClientListOptions) *runtime.Pager[armcomputev2.VirtualMachineScaleSetsClientListResponse] {
return runtime.NewPager(runtime.PagingHandler[armcomputev2.VirtualMachineScaleSetsClientListResponse]{
More: a.pager.moreFunc(),
Fetcher: a.pager.fetcherFunc(),
})
}
type stubTagsAPI struct {
createOrUpdateAtScopeErr error
updateAtScopeErr error
}
func (a *stubTagsAPI) CreateOrUpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsResource, options *armresources.TagsClientCreateOrUpdateAtScopeOptions) (armresources.TagsClientCreateOrUpdateAtScopeResponse, error) {
return armresources.TagsClientCreateOrUpdateAtScopeResponse{}, a.createOrUpdateAtScopeErr
}
func (a *stubTagsAPI) UpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsPatchResource, options *armresources.TagsClientUpdateAtScopeOptions) (armresources.TagsClientUpdateAtScopeResponse, error) {
return armresources.TagsClientUpdateAtScopeResponse{}, a.updateAtScopeErr
}
type stubSecurityGroupsClientListPager struct {
list []armnetwork.SecurityGroup
fetchErr error
more bool
}
func (p *stubSecurityGroupsClientListPager) moreFunc() func(armnetwork.SecurityGroupsClientListResponse) bool {
return func(armnetwork.SecurityGroupsClientListResponse) bool {
return p.more
}
}
func (p *stubSecurityGroupsClientListPager) fetcherFunc() func(context.Context, *armnetwork.SecurityGroupsClientListResponse) (armnetwork.SecurityGroupsClientListResponse, error) {
return func(context.Context, *armnetwork.SecurityGroupsClientListResponse) (armnetwork.SecurityGroupsClientListResponse, error) {
page := make([]*armnetwork.SecurityGroup, len(p.list))
for i := range p.list {
page[i] = &p.list[i]
}
return armnetwork.SecurityGroupsClientListResponse{
SecurityGroupListResult: armnetwork.SecurityGroupListResult{
Value: page,
},
}, p.fetchErr
}
}
type stubSecurityGroupsAPI struct {
pager *stubSecurityGroupsClientListPager
}
func (a *stubSecurityGroupsAPI) NewListPager(resourceGroupName string, options *armnetwork.SecurityGroupsClientListOptions) *runtime.Pager[armnetwork.SecurityGroupsClientListResponse] {
return runtime.NewPager(runtime.PagingHandler[armnetwork.SecurityGroupsClientListResponse]{
More: a.pager.moreFunc(),
Fetcher: a.pager.fetcherFunc(),
})
}
type stubVirtualNetworksClientListPager struct {
list []armnetwork.VirtualNetwork
fetchErr error
more bool
}
func (p *stubVirtualNetworksClientListPager) moreFunc() func(armnetwork.VirtualNetworksClientListResponse) bool {
return func(armnetwork.VirtualNetworksClientListResponse) bool {
return p.more
}
}
func (p *stubVirtualNetworksClientListPager) fetcherFunc() func(context.Context, *armnetwork.VirtualNetworksClientListResponse) (armnetwork.VirtualNetworksClientListResponse, error) {
return func(context.Context, *armnetwork.VirtualNetworksClientListResponse) (armnetwork.VirtualNetworksClientListResponse, error) {
page := make([]*armnetwork.VirtualNetwork, len(p.list))
for i := range p.list {
page[i] = &p.list[i]
}
return armnetwork.VirtualNetworksClientListResponse{
VirtualNetworkListResult: armnetwork.VirtualNetworkListResult{
Value: page,
},
}, p.fetchErr
}
}
type stubVirtualNetworksAPI struct {
pager *stubVirtualNetworksClientListPager
}
func (a *stubVirtualNetworksAPI) NewListPager(resourceGroupName string, options *armnetwork.VirtualNetworksClientListOptions) *runtime.Pager[armnetwork.VirtualNetworksClientListResponse] {
return runtime.NewPager(runtime.PagingHandler[armnetwork.VirtualNetworksClientListResponse]{
More: a.pager.moreFunc(),
Fetcher: a.pager.fetcherFunc(),
})
}
type stubLoadBalancersAPI struct {
pager *stubLoadBalancersClientListPager
}
func (a *stubLoadBalancersAPI) NewListPager(resourceGroupName string, options *armnetwork.LoadBalancersClientListOptions,
) *runtime.Pager[armnetwork.LoadBalancersClientListResponse] {
return runtime.NewPager(runtime.PagingHandler[armnetwork.LoadBalancersClientListResponse]{
More: a.pager.moreFunc(),
Fetcher: a.pager.fetcherFunc(),
})
}
type stubPublicIPAddressesAPI struct {
getResponse armnetwork.PublicIPAddressesClientGetResponse
getVirtualMachineScaleSetPublicIPAddressResponse armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse
getErr error
}
func (a *stubPublicIPAddressesAPI) Get(ctx context.Context, resourceGroupName string, publicIPAddressName string,
options *armnetwork.PublicIPAddressesClientGetOptions,
) (armnetwork.PublicIPAddressesClientGetResponse, error) {
return a.getResponse, a.getErr
}
func (a *stubPublicIPAddressesAPI) GetVirtualMachineScaleSetPublicIPAddress(ctx context.Context, resourceGroupName string, virtualMachineScaleSetName string,
virtualmachineIndex string, networkInterfaceName string, IPConfigurationName string, publicIPAddressName string,
options *armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressOptions,
) (armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse, error) {
return a.getVirtualMachineScaleSetPublicIPAddressResponse, a.getErr
}
type stubVirtualMachineScaleSetVMPager struct {
list []armcomputev2.VirtualMachineScaleSetVM
fetchErr error
more bool
}
func (p *stubVirtualMachineScaleSetVMPager) moreFunc() func(armcomputev2.VirtualMachineScaleSetVMsClientListResponse) bool {
return func(armcomputev2.VirtualMachineScaleSetVMsClientListResponse) bool {
return p.more
}
}
func (p *stubVirtualMachineScaleSetVMPager) fetcherFunc() func(context.Context, *armcomputev2.VirtualMachineScaleSetVMsClientListResponse) (armcomputev2.VirtualMachineScaleSetVMsClientListResponse, error) {
return func(context.Context, *armcomputev2.VirtualMachineScaleSetVMsClientListResponse) (armcomputev2.VirtualMachineScaleSetVMsClientListResponse, error) {
page := make([]*armcomputev2.VirtualMachineScaleSetVM, len(p.list))
for i := range p.list {
page[i] = &p.list[i]
}
return armcomputev2.VirtualMachineScaleSetVMsClientListResponse{
VirtualMachineScaleSetVMListResult: armcomputev2.VirtualMachineScaleSetVMListResult{
Value: page,
},
}, p.fetchErr
}
}
type stubLoadBalancersClientListPager struct {
list []armnetwork.LoadBalancer
fetchErr error
more bool
}
func (p *stubLoadBalancersClientListPager) moreFunc() func(armnetwork.LoadBalancersClientListResponse) bool {
return func(armnetwork.LoadBalancersClientListResponse) bool {
return p.more
}
}
func (p *stubLoadBalancersClientListPager) fetcherFunc() func(context.Context, *armnetwork.LoadBalancersClientListResponse) (armnetwork.LoadBalancersClientListResponse, error) {
return func(context.Context, *armnetwork.LoadBalancersClientListResponse) (armnetwork.LoadBalancersClientListResponse, error) {
page := make([]*armnetwork.LoadBalancer, len(p.list))
for i := range p.list {
page[i] = &p.list[i]
}
return armnetwork.LoadBalancersClientListResponse{
LoadBalancerListResult: armnetwork.LoadBalancerListResult{
Value: page,
},
}, p.fetchErr
}
}

View File

@ -0,0 +1,428 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"path"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/role"
)
// Cloud provides Azure metadata and API access.
type Cloud struct {
imds imdsAPI
virtNetAPI virtualNetworksAPI
secGroupAPI securityGroupsAPI
netIfacAPI networkInterfacesAPI
pubIPAPI publicIPAddressesAPI
scaleSetsAPI scaleSetsAPI
loadBalancerAPI loadBalancerAPI
scaleSetsVMAPI virtualMachineScaleSetVMsAPI
}
// New initializes Cloud with the needed API clients.
// Default credentials are used for authentication.
func New(ctx context.Context) (*Cloud, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, fmt.Errorf("loading credentials: %w", err)
}
// The default http client may use a system-wide proxy and it is recommended to disable the proxy explicitly:
// https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#proxies
// See also: https://github.com/microsoft/azureimds/blob/master/imdssample.go#L10
imdsAPI := imdsClient{
client: &http.Client{Transport: &http.Transport{Proxy: nil}},
}
subscriptionID, err := imdsAPI.subscriptionID(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving subscription ID: %w", err)
}
virtualNetworksAPI, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
networkInterfacesAPI, err := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
publicIPAddressesAPI, err := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
securityGroupsAPI, err := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
scaleSetsAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
loadBalancerAPI, err := armnetwork.NewLoadBalancersClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
virtualMachineScaleSetVMsAPI, err := armcomputev2.NewVirtualMachineScaleSetVMsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
return &Cloud{
imds: &imdsAPI,
netIfacAPI: networkInterfacesAPI,
virtNetAPI: virtualNetworksAPI,
secGroupAPI: securityGroupsAPI,
pubIPAPI: publicIPAddressesAPI,
loadBalancerAPI: loadBalancerAPI,
scaleSetsAPI: scaleSetsAPI,
scaleSetsVMAPI: virtualMachineScaleSetVMsAPI,
}, nil
}
// GetInstance retrieves an instance using its providerID.
func (c *Cloud) GetInstance(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
return c.getInstance(ctx, providerID)
}
// GetCCMConfig returns the configuration needed for the Kubernetes Cloud Controller Manager on Azure.
func (c *Cloud) GetCCMConfig(ctx context.Context, providerID string, cloudServiceAccountURI string) ([]byte, error) {
subscriptionID, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
if err != nil {
return nil, fmt.Errorf("parsing provider ID: %w", err)
}
creds, err := azureshared.ApplicationCredentialsFromURI(cloudServiceAccountURI)
if err != nil {
return nil, fmt.Errorf("parsing service account URI: %w", err)
}
uid, err := c.imds.uid(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving instance UID: %w", err)
}
securityGroupName, err := c.getNetworkSecurityGroupName(ctx, resourceGroup, uid)
if err != nil {
return nil, fmt.Errorf("retrieving network security group name: %w", err)
}
loadBalancer, err := c.getLoadBalancer(ctx, resourceGroup, uid)
if err != nil {
return nil, fmt.Errorf("retrieving load balancer: %w", err)
}
if loadBalancer == nil || loadBalancer.Name == nil {
return nil, fmt.Errorf("could not dereference load balancer name")
}
config := cloudConfig{
Cloud: "AzurePublicCloud",
TenantID: creds.TenantID,
SubscriptionID: subscriptionID,
ResourceGroup: resourceGroup,
LoadBalancerSku: "standard",
SecurityGroupName: securityGroupName,
LoadBalancerName: *loadBalancer.Name,
UseInstanceMetadata: true,
VMType: "vmss",
Location: creds.Location,
AADClientID: creds.AppClientID,
AADClientSecret: creds.ClientSecretValue,
}
return json.Marshal(config)
}
// GetLoadBalancerEndpoint retrieves the first load balancer IP from cloud provider metadata.
//
// The returned string is an IP address without a port, but the method name needs to satisfy the
// metadata interface.
func (c *Cloud) GetLoadBalancerEndpoint(ctx context.Context) (string, error) {
resourceGroup, err := c.imds.resourceGroup(ctx)
if err != nil {
return "", fmt.Errorf("retrieving resource group: %w", err)
}
uid, err := c.imds.uid(ctx)
if err != nil {
return "", fmt.Errorf("retrieving instance UID: %w", err)
}
lb, err := c.getLoadBalancer(ctx, resourceGroup, uid)
if err != nil {
return "", fmt.Errorf("retrieving load balancer: %w", err)
}
if lb == nil || lb.Properties == nil {
return "", errors.New("could not dereference load balancer IP configuration")
}
var pubIP string
for _, fipConf := range lb.Properties.FrontendIPConfigurations {
if fipConf == nil || fipConf.Properties == nil || fipConf.Properties.PublicIPAddress == nil || fipConf.Properties.PublicIPAddress.ID == nil {
continue
}
pubIP = path.Base(*fipConf.Properties.PublicIPAddress.ID)
break
}
resp, err := c.pubIPAPI.Get(ctx, resourceGroup, pubIP, nil)
if err != nil {
return "", fmt.Errorf("retrieving load balancer public IP address: %w", err)
}
if resp.Properties == nil || resp.Properties.IPAddress == nil {
return "", fmt.Errorf("could not resolve public IP address reference for load balancer")
}
return *resp.Properties.IPAddress, nil
}
// List retrieves all instances belonging to the current constellation.
func (c *Cloud) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
resourceGroup, err := c.imds.resourceGroup(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving resource group: %w", err)
}
uid, err := c.imds.uid(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving instance UID: %w", err)
}
instances := []metadata.InstanceMetadata{}
pager := c.scaleSetsAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving scale sets: %w", err)
}
for _, scaleSet := range page.Value {
if scaleSet == nil || scaleSet.Name == nil || scaleSet.Tags == nil ||
scaleSet.Tags[cloud.TagUID] == nil || *scaleSet.Tags[cloud.TagUID] != uid {
continue
}
vmPager := c.scaleSetsVMAPI.NewListPager(resourceGroup, *scaleSet.Name, nil)
for vmPager.More() {
vmPage, err := vmPager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving vms: %w", err)
}
for _, vm := range vmPage.Value {
if vm == nil || vm.InstanceID == nil {
continue
}
interfaces, err := c.getVMInterfaces(ctx, *vm, resourceGroup, *scaleSet.Name, *vm.InstanceID)
if err != nil {
return nil, fmt.Errorf("retrieving VM network interfaces: %w", err)
}
instance, err := convertToInstanceMetadata(*vm, interfaces)
if err != nil {
return nil, fmt.Errorf("converting VM to instance metadata: %w", err)
}
instances = append(instances, instance)
}
}
}
}
return instances, nil
}
// Self retrieves the current instance.
func (c *Cloud) Self(ctx context.Context) (metadata.InstanceMetadata, error) {
providerID, err := c.imds.providerID(ctx)
if err != nil {
return metadata.InstanceMetadata{}, fmt.Errorf("retrieving provider ID: %w", err)
}
return c.getInstance(ctx, "azure://"+providerID)
}
// UID retrieves the UID of the constellation.
func (c *Cloud) UID(ctx context.Context) (string, error) {
uid, err := c.imds.uid(ctx)
if err != nil {
return "", fmt.Errorf("retrieving instance UID: %w", err)
}
return uid, nil
}
// getLoadBalancer retrieves a load balancer from cloud provider metadata.
func (c *Cloud) getLoadBalancer(ctx context.Context, resourceGroup, uid string) (*armnetwork.LoadBalancer, error) {
pager := c.loadBalancerAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving available load balancers: %w", err)
}
for _, lb := range page.Value {
if lb == nil || lb.Tags == nil ||
lb.Tags[cloud.TagUID] == nil || *lb.Tags[cloud.TagUID] != uid {
continue
}
return lb, nil
}
}
return nil, fmt.Errorf("load balancer with UID %s not found", uid)
}
// getInstance returns an Azure instance given a providerID.
func (c *Cloud) getInstance(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
_, resourceGroup, scaleSet, instanceID, err := azureshared.ScaleSetInformationFromProviderID(providerID)
if err != nil {
return metadata.InstanceMetadata{}, fmt.Errorf("invalid provider ID: %w", err)
}
vmResp, err := c.scaleSetsVMAPI.Get(ctx, resourceGroup, scaleSet, instanceID, nil)
if err != nil {
return metadata.InstanceMetadata{}, fmt.Errorf("retrieving instance: %w", err)
}
networkInterfaces, err := c.getVMInterfaces(ctx, vmResp.VirtualMachineScaleSetVM, resourceGroup, scaleSet, instanceID)
if err != nil {
return metadata.InstanceMetadata{}, fmt.Errorf("retrieving VM network interfaces: %w", err)
}
return convertToInstanceMetadata(vmResp.VirtualMachineScaleSetVM, networkInterfaces)
}
// getNetworkSecurityGroupName returns the security group name of the resource group.
func (c *Cloud) getNetworkSecurityGroupName(ctx context.Context, resourceGroup, uid string) (string, error) {
pager := c.secGroupAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return "", fmt.Errorf("retrieving security groups: %w", err)
}
for _, secGroup := range page.Value {
if secGroup == nil || secGroup.Name == nil || secGroup.Tags == nil ||
secGroup.Tags[cloud.TagUID] == nil || *secGroup.Tags[cloud.TagUID] != uid {
continue
}
return *secGroup.Name, nil
}
}
return "", fmt.Errorf("network security group with UID %s not found in resource group %s", uid, resourceGroup)
}
// getSubnetworkCIDR retrieves the subnetwork CIDR from cloud provider metadata.
func (c *Cloud) getSubnetworkCIDR(ctx context.Context) (string, error) {
resourceGroup, err := c.imds.resourceGroup(ctx)
if err != nil {
return "", fmt.Errorf("retrieving resource group: %w", err)
}
uid, err := c.imds.uid(ctx)
if err != nil {
return "", fmt.Errorf("retrieving instance UID: %w", err)
}
pager := c.virtNetAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return "", fmt.Errorf("retrieving virtual networks: %w", err)
}
for _, network := range page.Value {
if network == nil || network.Properties == nil || len(network.Properties.Subnets) == 0 ||
network.Properties.Subnets[0] == nil || network.Properties.Subnets[0].Properties == nil ||
network.Properties.Subnets[0].Properties.AddressPrefix == nil ||
network.Tags == nil || network.Tags[cloud.TagUID] == nil || *network.Tags[cloud.TagUID] != uid {
continue
}
return *network.Properties.Subnets[0].Properties.AddressPrefix, nil
}
}
return "", fmt.Errorf("no virtual network found matching UID %s in resource group %s", uid, resourceGroup)
}
// getVMInterfaces retrieves all network interfaces referenced by a scale set virtual machine.
func (c *Cloud) getVMInterfaces(ctx context.Context, vm armcomputev2.VirtualMachineScaleSetVM, resourceGroup, scaleSet, instanceID string) ([]armnetwork.Interface, error) {
if vm.Properties == nil || vm.Properties.NetworkProfile == nil {
return []armnetwork.Interface{}, errors.New("no network profile found")
}
var interfaceNames []string
for _, iface := range vm.Properties.NetworkProfile.NetworkInterfaces {
if iface == nil || iface.ID == nil {
continue
}
interfaceNames = append(interfaceNames, path.Base(*iface.ID))
}
networkInterfaces := []armnetwork.Interface{}
for _, interfaceName := range interfaceNames {
networkInterfacesResp, err := c.netIfacAPI.GetVirtualMachineScaleSetNetworkInterface(ctx, resourceGroup, scaleSet, instanceID, interfaceName, nil)
if err != nil {
return nil, fmt.Errorf("retrieving network interface %v: %w", interfaceName, err)
}
networkInterfaces = append(networkInterfaces, networkInterfacesResp.Interface)
}
return networkInterfaces, nil
}
type cloudConfig struct {
Cloud string `json:"cloud,omitempty"`
TenantID string `json:"tenantId,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
ResourceGroup string `json:"resourceGroup,omitempty"`
Location string `json:"location,omitempty"`
SubnetName string `json:"subnetName,omitempty"`
SecurityGroupName string `json:"securityGroupName,omitempty"`
SecurityGroupResourceGroup string `json:"securityGroupResourceGroup,omitempty"`
LoadBalancerName string `json:"loadBalancerName,omitempty"`
LoadBalancerSku string `json:"loadBalancerSku,omitempty"`
VNetName string `json:"vnetName,omitempty"`
VNetResourceGroup string `json:"vnetResourceGroup,omitempty"`
CloudProviderBackoff bool `json:"cloudProviderBackoff,omitempty"`
UseInstanceMetadata bool `json:"useInstanceMetadata,omitempty"`
VMType string `json:"vmType,omitempty"`
AADClientID string `json:"aadClientId,omitempty"`
AADClientSecret string `json:"aadClientSecret,omitempty"`
}
// convertToInstanceMetadata converts a armcomputev2.VirtualMachineScaleSetVM to a metadata.InstanceMetadata.
func convertToInstanceMetadata(vm armcomputev2.VirtualMachineScaleSetVM, networkInterfaces []armnetwork.Interface,
) (metadata.InstanceMetadata, error) {
if vm.ID == nil {
return metadata.InstanceMetadata{}, errors.New("missing instance ID")
}
if vm.Properties == nil || vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil {
return metadata.InstanceMetadata{}, errors.New("missing computer name")
}
var instanceRole string
if vm.Tags != nil || vm.Tags[cloud.TagRole] != nil {
instanceRole = *vm.Tags[cloud.TagRole]
}
var privateIP string
for _, networkInterface := range networkInterfaces {
if networkInterface.Properties == nil {
continue
}
for _, config := range networkInterface.Properties.IPConfigurations {
if config == nil || config.Properties == nil || config.Properties.PrivateIPAddress == nil || config.Properties.Primary == nil {
continue
}
if *config.Properties.Primary {
privateIP = *config.Properties.PrivateIPAddress
}
}
}
return metadata.InstanceMetadata{
Name: *vm.Properties.OSProfile.ComputerName,
ProviderID: "azure://" + *vm.ID,
Role: role.FromString(instanceRole),
VPCIP: privateIP,
}, nil
}

File diff suppressed because it is too large Load Diff

View File

@ -35,8 +35,8 @@ type imdsClient struct {
cacheTime time.Time cacheTime time.Time
} }
// ProviderID returns the provider ID of the instance the function is called from. // providerID returns the provider ID of the instance the function is called from.
func (c *imdsClient) ProviderID(ctx context.Context) (string, error) { func (c *imdsClient) providerID(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.ResourceID == "" { if c.timeForUpdate() || c.cache.Compute.ResourceID == "" {
if err := c.update(ctx); err != nil { if err := c.update(ctx); err != nil {
return "", err return "", err
@ -50,9 +50,23 @@ func (c *imdsClient) ProviderID(ctx context.Context) (string, error) {
return c.cache.Compute.ResourceID, nil return c.cache.Compute.ResourceID, nil
} }
// SubscriptionID returns the subscription ID of the instance the function func (c *imdsClient) name(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.OSProfile.ComputerName == "" {
if err := c.update(ctx); err != nil {
return "", err
}
}
if c.cache.Compute.OSProfile.ComputerName == "" {
return "", errors.New("unable to get name")
}
return c.cache.Compute.OSProfile.ComputerName, nil
}
// subscriptionID returns the subscription ID of the instance the function
// is called from. // is called from.
func (c *imdsClient) SubscriptionID(ctx context.Context) (string, error) { func (c *imdsClient) subscriptionID(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.SubscriptionID == "" { if c.timeForUpdate() || c.cache.Compute.SubscriptionID == "" {
if err := c.update(ctx); err != nil { if err := c.update(ctx); err != nil {
return "", err return "", err
@ -66,9 +80,9 @@ func (c *imdsClient) SubscriptionID(ctx context.Context) (string, error) {
return c.cache.Compute.SubscriptionID, nil return c.cache.Compute.SubscriptionID, nil
} }
// ResourceGroup returns the resource group of the instance the function // resourceGroup returns the resource group of the instance the function
// is called from. // is called from.
func (c *imdsClient) ResourceGroup(ctx context.Context) (string, error) { func (c *imdsClient) resourceGroup(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.ResourceGroup == "" { if c.timeForUpdate() || c.cache.Compute.ResourceGroup == "" {
if err := c.update(ctx); err != nil { if err := c.update(ctx); err != nil {
return "", err return "", err
@ -82,9 +96,9 @@ func (c *imdsClient) ResourceGroup(ctx context.Context) (string, error) {
return c.cache.Compute.ResourceGroup, nil return c.cache.Compute.ResourceGroup, nil
} }
// UID returns the UID of the cluster, based on the tags on the instance // uid returns the UID of the cluster, based on the tags on the instance
// the function is calles from, which are inherited from the scale set. // the function is called from, which are inherited from the scale set.
func (c *imdsClient) UID(ctx context.Context) (string, error) { func (c *imdsClient) uid(ctx context.Context) (string, error) {
if c.timeForUpdate() || len(c.cache.Compute.Tags) == 0 { if c.timeForUpdate() || len(c.cache.Compute.Tags) == 0 {
if err := c.update(ctx); err != nil { if err := c.update(ctx); err != nil {
return "", err return "", err
@ -100,7 +114,8 @@ func (c *imdsClient) UID(ctx context.Context) (string, error) {
return "", fmt.Errorf("unable to get uid from metadata tags %v", c.cache.Compute.Tags) return "", fmt.Errorf("unable to get uid from metadata tags %v", c.cache.Compute.Tags)
} }
func (c *imdsClient) Role(ctx context.Context) (role.Role, error) { // role returns the role of the instance the function is called from.
func (c *imdsClient) role(ctx context.Context) (role.Role, error) {
if c.timeForUpdate() || len(c.cache.Compute.Tags) == 0 { if c.timeForUpdate() || len(c.cache.Compute.Tags) == 0 {
if err := c.update(ctx); err != nil { if err := c.update(ctx); err != nil {
return role.Unknown, err return role.Unknown, err
@ -161,6 +176,9 @@ type metadataResponseCompute struct {
SubscriptionID string `json:"subscriptionId,omitempty"` SubscriptionID string `json:"subscriptionId,omitempty"`
ResourceGroup string `json:"resourceGroupName,omitempty"` ResourceGroup string `json:"resourceGroupName,omitempty"`
Tags []metadataTag `json:"tagsList,omitempty"` Tags []metadataTag `json:"tagsList,omitempty"`
OSProfile struct {
ComputerName string `json:"computerName,omitempty"`
} `json:"osProfile,omitempty"`
} }
type metadataTag struct { type metadataTag struct {

View File

@ -26,37 +26,69 @@ func TestIMDSClient(t *testing.T) {
{Name: cloud.TagUID, Value: "uid"}, {Name: cloud.TagUID, Value: "uid"},
{Name: cloud.TagRole, Value: "worker"}, {Name: cloud.TagRole, Value: "worker"},
} }
osProfile := struct {
ComputerName string `json:"computerName,omitempty"`
}{
ComputerName: "computer-name",
}
response := metadataResponse{ response := metadataResponse{
Compute: metadataResponseCompute{ Compute: metadataResponseCompute{
ResourceID: "resource-id", ResourceID: "resource-id",
SubscriptionID: "subscription-id",
ResourceGroup: "resource-group", ResourceGroup: "resource-group",
Tags: uidTags, Tags: uidTags,
OSProfile: osProfile,
}, },
} }
responseWithoutID := metadataResponse{ responseWithoutID := metadataResponse{
Compute: metadataResponseCompute{ Compute: metadataResponseCompute{
ResourceGroup: "resource-group", ResourceGroup: "resource-group",
SubscriptionID: "subscription-id",
Tags: uidTags, Tags: uidTags,
OSProfile: osProfile,
}, },
} }
responseWithoutGroup := metadataResponse{ responseWithoutGroup := metadataResponse{
Compute: metadataResponseCompute{ Compute: metadataResponseCompute{
ResourceID: "resource-id", ResourceID: "resource-id",
SubscriptionID: "subscription-id",
Tags: uidTags, Tags: uidTags,
OSProfile: osProfile,
}, },
} }
responseWithoutUID := metadataResponse{ responseWithoutUID := metadataResponse{
Compute: metadataResponseCompute{ Compute: metadataResponseCompute{
ResourceID: "resource-id", ResourceID: "resource-id",
SubscriptionID: "subscription-id",
ResourceGroup: "resource-group", ResourceGroup: "resource-group",
Tags: []metadataTag{{Name: cloud.TagRole, Value: "worker"}}, Tags: []metadataTag{{Name: cloud.TagRole, Value: "worker"}},
OSProfile: osProfile,
}, },
} }
responseWithoutRole := metadataResponse{ responseWithoutRole := metadataResponse{
Compute: metadataResponseCompute{ Compute: metadataResponseCompute{
ResourceID: "resource-id", ResourceID: "resource-id",
SubscriptionID: "subscription-id",
ResourceGroup: "resource-group", ResourceGroup: "resource-group",
Tags: []metadataTag{{Name: cloud.TagUID, Value: "uid"}}, Tags: []metadataTag{{Name: cloud.TagUID, Value: "uid"}},
OSProfile: osProfile,
},
}
responseWithoutName := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
SubscriptionID: "subscription-id",
ResourceGroup: "resource-group",
Tags: uidTags,
},
}
responseWithoutSubscriptionID := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
ResourceGroup: "resource-group",
Tags: uidTags,
OSProfile: osProfile,
}, },
} }
@ -70,6 +102,10 @@ func TestIMDSClient(t *testing.T) {
wantUID string wantUID string
wantRoleErr bool wantRoleErr bool
wantRole role.Role wantRole role.Role
wantNameErr bool
wantName string
wantSubscriptionErr bool
wantSubscriptionID string
}{ }{
"metadata response parsed": { "metadata response parsed": {
server: newHTTPBufconnServerWithMetadataResponse(response), server: newHTTPBufconnServerWithMetadataResponse(response),
@ -77,6 +113,8 @@ func TestIMDSClient(t *testing.T) {
wantResourceGroup: "resource-group", wantResourceGroup: "resource-group",
wantUID: "uid", wantUID: "uid",
wantRole: role.Worker, wantRole: role.Worker,
wantName: "computer-name",
wantSubscriptionID: "subscription-id",
}, },
"metadata response without resource ID": { "metadata response without resource ID": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutID), server: newHTTPBufconnServerWithMetadataResponse(responseWithoutID),
@ -84,6 +122,8 @@ func TestIMDSClient(t *testing.T) {
wantResourceGroup: "resource-group", wantResourceGroup: "resource-group",
wantUID: "uid", wantUID: "uid",
wantRole: role.Worker, wantRole: role.Worker,
wantName: "computer-name",
wantSubscriptionID: "subscription-id",
}, },
"metadata response without UID tag": { "metadata response without UID tag": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutUID), server: newHTTPBufconnServerWithMetadataResponse(responseWithoutUID),
@ -91,6 +131,8 @@ func TestIMDSClient(t *testing.T) {
wantResourceGroup: "resource-group", wantResourceGroup: "resource-group",
wantUIDErr: true, wantUIDErr: true,
wantRole: role.Worker, wantRole: role.Worker,
wantName: "computer-name",
wantSubscriptionID: "subscription-id",
}, },
"metadata response without role tag": { "metadata response without role tag": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutRole), server: newHTTPBufconnServerWithMetadataResponse(responseWithoutRole),
@ -98,6 +140,8 @@ func TestIMDSClient(t *testing.T) {
wantResourceGroup: "resource-group", wantResourceGroup: "resource-group",
wantUID: "uid", wantUID: "uid",
wantRoleErr: true, wantRoleErr: true,
wantName: "computer-name",
wantSubscriptionID: "subscription-id",
}, },
"metadata response without resource group": { "metadata response without resource group": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutGroup), server: newHTTPBufconnServerWithMetadataResponse(responseWithoutGroup),
@ -105,6 +149,26 @@ func TestIMDSClient(t *testing.T) {
wantResourceGroupErr: true, wantResourceGroupErr: true,
wantUID: "uid", wantUID: "uid",
wantRole: role.Worker, wantRole: role.Worker,
wantName: "computer-name",
wantSubscriptionID: "subscription-id",
},
"metadata response without name": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutName),
wantProviderID: "resource-id",
wantResourceGroup: "resource-group",
wantUID: "uid",
wantRole: role.Worker,
wantNameErr: true,
wantSubscriptionID: "subscription-id",
},
"metadata response without subscription ID": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutSubscriptionID),
wantProviderID: "resource-id",
wantResourceGroup: "resource-group",
wantUID: "uid",
wantRole: role.Worker,
wantName: "computer-name",
wantSubscriptionErr: true,
}, },
"invalid imds response detected": { "invalid imds response detected": {
server: newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) { server: newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) {
@ -114,6 +178,8 @@ func TestIMDSClient(t *testing.T) {
wantResourceGroupErr: true, wantResourceGroupErr: true,
wantUIDErr: true, wantUIDErr: true,
wantRoleErr: true, wantRoleErr: true,
wantNameErr: true,
wantSubscriptionErr: true,
}, },
} }
@ -135,7 +201,7 @@ func TestIMDSClient(t *testing.T) {
ctx := context.Background() ctx := context.Background()
id, err := iClient.ProviderID(ctx) id, err := iClient.providerID(ctx)
if tc.wantProviderIDErr { if tc.wantProviderIDErr {
assert.Error(err) assert.Error(err)
} else { } else {
@ -143,7 +209,7 @@ func TestIMDSClient(t *testing.T) {
assert.Equal(tc.wantProviderID, id) assert.Equal(tc.wantProviderID, id)
} }
group, err := iClient.ResourceGroup(ctx) group, err := iClient.resourceGroup(ctx)
if tc.wantResourceGroupErr { if tc.wantResourceGroupErr {
assert.Error(err) assert.Error(err)
} else { } else {
@ -151,7 +217,7 @@ func TestIMDSClient(t *testing.T) {
assert.Equal(tc.wantResourceGroup, group) assert.Equal(tc.wantResourceGroup, group)
} }
uid, err := iClient.UID(ctx) uid, err := iClient.uid(ctx)
if tc.wantUIDErr { if tc.wantUIDErr {
assert.Error(err) assert.Error(err)
} else { } else {
@ -159,13 +225,29 @@ func TestIMDSClient(t *testing.T) {
assert.Equal(tc.wantUID, uid) assert.Equal(tc.wantUID, uid)
} }
role, err := iClient.Role(ctx) role, err := iClient.role(ctx)
if tc.wantRoleErr { if tc.wantRoleErr {
assert.Error(err) assert.Error(err)
} else { } else {
assert.NoError(err) assert.NoError(err)
assert.Equal(tc.wantRole, role) assert.Equal(tc.wantRole, role)
} }
name, err := iClient.name(ctx)
if tc.wantNameErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantName, name)
}
subscriptionID, err := iClient.subscriptionID(ctx)
if tc.wantSubscriptionErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantSubscriptionID, subscriptionID)
}
}) })
} }
} }

View File

@ -13,14 +13,14 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2" armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
) )
type imdsAPI interface { type imdsAPI interface {
ProviderID(ctx context.Context) (string, error) providerID(ctx context.Context) (string, error)
SubscriptionID(ctx context.Context) (string, error) name(ctx context.Context) (string, error)
ResourceGroup(ctx context.Context) (string, error) resourceGroup(ctx context.Context) (string, error)
UID(ctx context.Context) (string, error) subscriptionID(ctx context.Context) (string, error)
uid(ctx context.Context) (string, error)
} }
type virtualNetworksAPI interface { type virtualNetworksAPI interface {
@ -74,15 +74,6 @@ type loadBalancerAPI interface {
) *runtime.Pager[armnetwork.LoadBalancersClientListResponse] ) *runtime.Pager[armnetwork.LoadBalancersClientListResponse]
} }
type tagsAPI interface {
CreateOrUpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsResource,
options *armresources.TagsClientCreateOrUpdateAtScopeOptions,
) (armresources.TagsClientCreateOrUpdateAtScopeResponse, error)
UpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsPatchResource,
options *armresources.TagsClientUpdateAtScopeOptions,
) (armresources.TagsClientUpdateAtScopeResponse, error)
}
type applicationInsightsAPI interface { type applicationInsightsAPI interface {
NewListByResourceGroupPager(resourceGroupName string, NewListByResourceGroupPager(resourceGroupName string,
options *armapplicationinsights.ComponentsClientListByResourceGroupOptions, options *armapplicationinsights.ComponentsClientListByResourceGroupOptions,

View File

@ -10,7 +10,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net/http"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/microsoft/ApplicationInsights-Go/appinsights" "github.com/microsoft/ApplicationInsights-Go/appinsights"
) )
@ -22,23 +26,35 @@ type Logger struct {
// NewLogger creates a new client to store information in Azure Application Insights // NewLogger creates a new client to store information in Azure Application Insights
// https://github.com/Microsoft/ApplicationInsights-go // https://github.com/Microsoft/ApplicationInsights-go
func NewLogger(ctx context.Context, metadata *Metadata) (*Logger, error) { func NewLogger(ctx context.Context) (*Logger, error) {
component, err := metadata.getAppInsights(ctx) cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting app insights: %w", err) return nil, fmt.Errorf("loading credentials: %w", err)
} }
imdsAPI := &imdsClient{
if component.Properties == nil || component.Properties.InstrumentationKey == nil { client: &http.Client{Transport: &http.Transport{Proxy: nil}},
return nil, errors.New("unable to get instrumentation key")
} }
subscriptionID, err := imdsAPI.subscriptionID(ctx)
client := appinsights.NewTelemetryClient(*component.Properties.InstrumentationKey)
self, err := metadata.Self(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting self: %w", err) return nil, fmt.Errorf("retrieving subscription ID: %w", err)
} }
client.Context().CommonProperties["instance-name"] = self.Name appInsightAPI, err := armapplicationinsights.NewComponentsClient(subscriptionID, cred, nil)
if err != nil {
return nil, fmt.Errorf("setting up insights API client. %w", err)
}
instrumentationKey, err := getAppInsightsKey(ctx, imdsAPI, appInsightAPI)
if err != nil {
return nil, fmt.Errorf("getting app insights instrumentation key: %w", err)
}
client := appinsights.NewTelemetryClient(instrumentationKey)
name, err := imdsAPI.name(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving instance name: %w", err)
}
client.Context().CommonProperties["instance-name"] = name
return &Logger{client: client}, nil return &Logger{client: client}, nil
} }
@ -54,3 +70,37 @@ func (l *Logger) Close() error {
<-l.client.Channel().Close() <-l.client.Channel().Close()
return nil return nil
} }
// getAppInsightsKey returns a instrumentation key needed to set up cloud logging on Azure.
// The key is retrieved from the resource group of the instance the function is called from.
func getAppInsightsKey(ctx context.Context, imdsAPI imdsAPI, appInsightAPI applicationInsightsAPI) (string, error) {
resourceGroup, err := imdsAPI.resourceGroup(ctx)
if err != nil {
return "", err
}
uid, err := imdsAPI.uid(ctx)
if err != nil {
return "", err
}
pager := appInsightAPI.NewListByResourceGroupPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return "", fmt.Errorf("retrieving application insights: %w", err)
}
for _, component := range page.Value {
if component == nil || component.Tags == nil ||
component.Tags[cloud.TagUID] == nil || *component.Tags[cloud.TagUID] != uid {
continue
}
if component.Properties == nil || component.Properties.InstrumentationKey == nil {
return "", errors.New("unable to get instrumentation key")
}
return *component.Properties.InstrumentationKey, nil
}
}
return "", errors.New("could not find correctly tagged application insights")
}

View File

@ -0,0 +1,185 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
"github.com/Azure/go-autorest/autorest/to"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/stretchr/testify/assert"
)
func TestGetAppInsightsKey(t *testing.T) {
someErr := errors.New("failed")
goodAppInsights := armapplicationinsights.Component{
Tags: map[string]*string{
cloud.TagUID: to.StringPtr("uid"),
},
Properties: &armapplicationinsights.ComponentProperties{
InstrumentationKey: to.StringPtr("key"),
},
}
testCases := map[string]struct {
imds *stubIMDSAPI
appInsights *stubApplicationsInsightsAPI
wantKey string
wantErr bool
}{
"success": {
imds: &stubIMDSAPI{
resourceGroupVal: "resource-group",
uidVal: "uid",
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{list: []armapplicationinsights.Component{goodAppInsights}},
},
wantKey: "key",
},
"multiple apps": {
imds: &stubIMDSAPI{
resourceGroupVal: "resource-group",
uidVal: "uid",
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{list: []armapplicationinsights.Component{
{
Tags: map[string]*string{
cloud.TagUID: to.StringPtr("different-uid"),
},
Properties: &armapplicationinsights.ComponentProperties{
InstrumentationKey: to.StringPtr("different-key"),
},
},
goodAppInsights,
}},
},
wantKey: "key",
},
"missing properties": {
imds: &stubIMDSAPI{
resourceGroupVal: "resource-group",
uidVal: "uid",
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{list: []armapplicationinsights.Component{
{
Tags: map[string]*string{
cloud.TagUID: to.StringPtr("uid"),
},
},
}},
},
wantErr: true,
},
"no app with matching uid": {
imds: &stubIMDSAPI{
resourceGroupVal: "resource-group",
uidVal: "uid",
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{list: []armapplicationinsights.Component{
{
Tags: map[string]*string{
cloud.TagUID: to.StringPtr("different-uid"),
},
Properties: &armapplicationinsights.ComponentProperties{
InstrumentationKey: to.StringPtr("different-key"),
},
},
}},
},
wantErr: true,
},
"imds resource group error": {
imds: &stubIMDSAPI{
resourceGroupErr: someErr,
uidVal: "uid",
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{list: []armapplicationinsights.Component{goodAppInsights}},
},
wantErr: true,
},
"imds uid error": {
imds: &stubIMDSAPI{
resourceGroupVal: "resource-group",
uidErr: someErr,
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{list: []armapplicationinsights.Component{goodAppInsights}},
},
wantErr: true,
},
"app insights list error": {
imds: &stubIMDSAPI{
resourceGroupVal: "resource-group",
uidVal: "uid",
},
appInsights: &stubApplicationsInsightsAPI{
pager: &stubApplicationKeyPager{fetchErr: someErr},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
key, err := getAppInsightsKey(context.Background(), tc.imds, tc.appInsights)
if tc.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tc.wantKey, key)
}
})
}
}
type stubApplicationKeyPager struct {
list []armapplicationinsights.Component
fetchErr error
more bool
}
func (p *stubApplicationKeyPager) moreFunc() func(armapplicationinsights.ComponentsClientListByResourceGroupResponse) bool {
return func(armapplicationinsights.ComponentsClientListByResourceGroupResponse) bool {
return p.more
}
}
func (p *stubApplicationKeyPager) fetcherFunc() func(context.Context, *armapplicationinsights.ComponentsClientListByResourceGroupResponse,
) (armapplicationinsights.ComponentsClientListByResourceGroupResponse, error) {
return func(context.Context, *armapplicationinsights.ComponentsClientListByResourceGroupResponse) (armapplicationinsights.ComponentsClientListByResourceGroupResponse, error) {
page := make([]*armapplicationinsights.Component, len(p.list))
for i := range p.list {
page[i] = &p.list[i]
}
return armapplicationinsights.ComponentsClientListByResourceGroupResponse{
ComponentListResult: armapplicationinsights.ComponentListResult{
Value: page,
},
}, p.fetchErr
}
}
type stubApplicationsInsightsAPI struct {
pager *stubApplicationKeyPager
}
func (a *stubApplicationsInsightsAPI) NewListByResourceGroupPager(resourceGroupName string, options *armapplicationinsights.ComponentsClientListByResourceGroupOptions,
) *runtime.Pager[armapplicationinsights.ComponentsClientListByResourceGroupResponse] {
return runtime.NewPager(runtime.PagingHandler[armapplicationinsights.ComponentsClientListByResourceGroupResponse]{
More: a.pager.moreFunc(),
Fetcher: a.pager.fetcherFunc(),
})
}

View File

@ -1,382 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"encoding/json"
"fmt"
"net/http"
"regexp"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
)
var publicIPAddressRegexp = regexp.MustCompile(`/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft.Network/publicIPAddresses/(?P<IPname>[^/]+)`)
// Metadata implements azure metadata APIs.
type Metadata struct {
imdsAPI
virtualNetworksAPI
securityGroupsAPI
networkInterfacesAPI
publicIPAddressesAPI
scaleSetsAPI
loadBalancerAPI
virtualMachineScaleSetVMsAPI
tagsAPI
applicationInsightsAPI
}
// NewMetadata creates a new Metadata.
func NewMetadata(ctx context.Context) (*Metadata, error) {
cred, err := azidentity.NewDefaultAzureCredential(nil)
if err != nil {
return nil, err
}
// The default http client may use a system-wide proxy and it is recommended to disable the proxy explicitly:
// https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux#proxies
// See also: https://github.com/microsoft/azureimds/blob/master/imdssample.go#L10
imdsAPI := imdsClient{
client: &http.Client{Transport: &http.Transport{Proxy: nil}},
}
subscriptionID, err := imdsAPI.SubscriptionID(ctx)
if err != nil {
return nil, err
}
virtualNetworksAPI, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
networkInterfacesAPI, err := armnetwork.NewInterfacesClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
publicIPAddressesAPI, err := armnetwork.NewPublicIPAddressesClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
securityGroupsAPI, err := armnetwork.NewSecurityGroupsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
scaleSetsAPI, err := armcomputev2.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
loadBalancerAPI, err := armnetwork.NewLoadBalancersClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
virtualMachineScaleSetVMsAPI, err := armcomputev2.NewVirtualMachineScaleSetVMsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
tagsAPI, err := armresources.NewTagsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
applicationInsightsAPI, err := armapplicationinsights.NewComponentsClient(subscriptionID, cred, nil)
if err != nil {
return nil, err
}
return &Metadata{
imdsAPI: &imdsAPI,
networkInterfacesAPI: networkInterfacesAPI,
virtualNetworksAPI: virtualNetworksAPI,
securityGroupsAPI: securityGroupsAPI,
publicIPAddressesAPI: publicIPAddressesAPI,
loadBalancerAPI: loadBalancerAPI,
scaleSetsAPI: scaleSetsAPI,
virtualMachineScaleSetVMsAPI: virtualMachineScaleSetVMsAPI,
tagsAPI: tagsAPI,
applicationInsightsAPI: applicationInsightsAPI,
}, nil
}
// List retrieves all instances belonging to the current constellation.
func (m *Metadata) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return nil, err
}
scaleSetInstances, err := m.listScaleSetVMs(ctx, resourceGroup)
if err != nil {
return nil, err
}
return scaleSetInstances, nil
}
// Self retrieves the current instance.
func (m *Metadata) Self(ctx context.Context) (metadata.InstanceMetadata, error) {
providerID, err := m.providerID(ctx)
if err != nil {
return metadata.InstanceMetadata{}, err
}
return m.GetInstance(ctx, providerID)
}
// GetInstance retrieves an instance using its providerID.
func (m *Metadata) GetInstance(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
instance, scaleSetErr := m.getScaleSetVM(ctx, providerID)
if scaleSetErr == nil {
return instance, nil
}
return metadata.InstanceMetadata{}, fmt.Errorf("retrieving instance given providerID %v: %w", providerID, scaleSetErr)
}
// GetNetworkSecurityGroupName returns the security group name of the resource group.
func (m *Metadata) GetNetworkSecurityGroupName(ctx context.Context) (string, error) {
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return "", err
}
nsg, err := m.getNetworkSecurityGroup(ctx, resourceGroup)
if err != nil {
return "", err
}
if nsg == nil || nsg.Name == nil {
return "", fmt.Errorf("could not dereference network security group name")
}
return *nsg.Name, nil
}
// getSubnetworkCIDR retrieves the subnetwork CIDR from cloud provider metadata.
func (m *Metadata) getSubnetworkCIDR(ctx context.Context) (string, error) {
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return "", err
}
virtualNetwork, err := m.getVirtualNetwork(ctx, resourceGroup)
if err != nil {
return "", err
}
if virtualNetwork == nil || virtualNetwork.Properties == nil || len(virtualNetwork.Properties.Subnets) == 0 ||
virtualNetwork.Properties.Subnets[0].Properties == nil || virtualNetwork.Properties.Subnets[0].Properties.AddressPrefix == nil {
return "", fmt.Errorf("could not retrieve subnetwork CIDR from virtual network %v", virtualNetwork)
}
return *virtualNetwork.Properties.Subnets[0].Properties.AddressPrefix, nil
}
// UID retrieves the UID of the constellation.
func (m *Metadata) UID(ctx context.Context) (string, error) {
return m.imdsAPI.UID(ctx)
}
// getLoadBalancer retrieves the load balancer from cloud provider metadata.
func (m *Metadata) getLoadBalancer(ctx context.Context) (*armnetwork.LoadBalancer, error) {
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return nil, err
}
pager := m.loadBalancerAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving loadbalancer page: %w", err)
}
for _, lb := range page.Value {
if lb != nil && lb.Properties != nil {
return lb, nil
}
}
}
return nil, fmt.Errorf("could not get any load balancer")
}
// GetLoadBalancerName returns the load balancer name of the resource group.
func (m *Metadata) GetLoadBalancerName(ctx context.Context) (string, error) {
lb, err := m.getLoadBalancer(ctx)
if err != nil {
return "", err
}
if lb == nil || lb.Name == nil {
return "", fmt.Errorf("could not dereference load balancer name")
}
return *lb.Name, nil
}
// GetLoadBalancerEndpoint retrieves the first load balancer IP from cloud provider metadata.
//
// The returned string is an IP address without a port, but the method name needs to satisfy the
// metadata interface.
func (m *Metadata) GetLoadBalancerEndpoint(ctx context.Context) (string, error) {
lb, err := m.getLoadBalancer(ctx)
if err != nil {
return "", err
}
if lb == nil || lb.Properties == nil {
return "", fmt.Errorf("could not dereference load balancer IP configuration")
}
var pubIPID string
for _, fipConf := range lb.Properties.FrontendIPConfigurations {
if fipConf == nil || fipConf.Properties == nil || fipConf.Properties.PublicIPAddress == nil || fipConf.Properties.PublicIPAddress.ID == nil {
continue
}
pubIPID = *fipConf.Properties.PublicIPAddress.ID
break
}
if pubIPID == "" {
return "", fmt.Errorf("could not find public IP address reference in load balancer")
}
matches := publicIPAddressRegexp.FindStringSubmatch(pubIPID)
if len(matches) != 2 {
return "", fmt.Errorf("could not find public IP address name in load balancer: %v", pubIPID)
}
pubIPName := matches[1]
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return "", err
}
resp, err := m.publicIPAddressesAPI.Get(ctx, resourceGroup, pubIPName, nil)
if err != nil {
return "", fmt.Errorf("could not retrieve public IP address: %w", err)
}
if resp.Properties == nil || resp.Properties.IPAddress == nil {
return "", fmt.Errorf("could not resolve public IP address reference for load balancer")
}
return *resp.Properties.IPAddress, nil
}
// GetCCMConfig returns the configuration needed for the CCM on Azure.
func (m *Metadata) GetCCMConfig(ctx context.Context, providerID string, cloudServiceAccountURI string) ([]byte, error) {
subscriptionID, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
if err != nil {
return nil, err
}
creds, err := azureshared.ApplicationCredentialsFromURI(cloudServiceAccountURI)
if err != nil {
return nil, err
}
vmType := "standard"
if _, _, _, _, err := azureshared.ScaleSetInformationFromProviderID(providerID); err == nil {
vmType = "vmss"
}
securityGroupName, err := m.GetNetworkSecurityGroupName(ctx)
if err != nil {
return nil, err
}
loadBalancerName, err := m.GetLoadBalancerName(ctx)
if err != nil {
return nil, err
}
config := cloudConfig{
Cloud: "AzurePublicCloud",
TenantID: creds.TenantID,
SubscriptionID: subscriptionID,
ResourceGroup: resourceGroup,
LoadBalancerSku: "standard",
SecurityGroupName: securityGroupName,
LoadBalancerName: loadBalancerName,
UseInstanceMetadata: true,
VMType: vmType,
Location: creds.Location,
AADClientID: creds.AppClientID,
AADClientSecret: creds.ClientSecretValue,
}
return json.Marshal(config)
}
// providerID retrieves the current instances providerID.
func (m *Metadata) providerID(ctx context.Context) (string, error) {
providerID, err := m.imdsAPI.ProviderID(ctx)
if err != nil {
return "", err
}
return "azure://" + providerID, nil
}
func (m *Metadata) getAppInsights(ctx context.Context) (*armapplicationinsights.Component, error) {
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return nil, err
}
uid, err := m.UID(ctx)
if err != nil {
return nil, err
}
pager := m.applicationInsightsAPI.NewListByResourceGroupPager(resourceGroup, nil)
for pager.More() {
nextResult, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving application insights page: %w", err)
}
for _, component := range nextResult.Value {
if component == nil || component.Tags == nil {
continue
}
tag, ok := component.Tags[cloud.TagUID]
if !ok || tag == nil {
continue
}
if *tag == uid {
return component, nil
}
}
}
return nil, fmt.Errorf("could not find correctly tagged application insights")
}
// extractInstanceTags converts azure tags into metadata key-value pairs.
func extractInstanceTags(tags map[string]*string) map[string]string {
metadataMap := map[string]string{}
for key, value := range tags {
if value == nil {
continue
}
metadataMap[key] = *value
}
return metadataMap
}
type cloudConfig struct {
Cloud string `json:"cloud,omitempty"`
TenantID string `json:"tenantId,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
ResourceGroup string `json:"resourceGroup,omitempty"`
Location string `json:"location,omitempty"`
SubnetName string `json:"subnetName,omitempty"`
SecurityGroupName string `json:"securityGroupName,omitempty"`
SecurityGroupResourceGroup string `json:"securityGroupResourceGroup,omitempty"`
LoadBalancerName string `json:"loadBalancerName,omitempty"`
LoadBalancerSku string `json:"loadBalancerSku,omitempty"`
VNetName string `json:"vnetName,omitempty"`
VNetResourceGroup string `json:"vnetResourceGroup,omitempty"`
CloudProviderBackoff bool `json:"cloudProviderBackoff,omitempty"`
UseInstanceMetadata bool `json:"useInstanceMetadata,omitempty"`
VMType string `json:"vmType,omitempty"`
AADClientID string `json:"aadClientId,omitempty"`
AADClientSecret string `json:"aadClientSecret,omitempty"`
}

View File

@ -1,657 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/role"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestList(t *testing.T) {
wantInstances := []metadata.InstanceMetadata{
{
Name: "scale-set-name-instance-id",
ProviderID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
Role: role.Worker,
VPCIP: "192.0.2.0",
},
}
testCases := map[string]struct {
imdsAPI imdsAPI
networkInterfacesAPI networkInterfacesAPI
scaleSetsAPI scaleSetsAPI
virtualMachineScaleSetVMsAPI virtualMachineScaleSetVMsAPI
tagsAPI tagsAPI
wantErr bool
wantInstances []metadata.InstanceMetadata
}{
"List works": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
networkInterfacesAPI: newNetworkInterfacesStub(),
scaleSetsAPI: newScaleSetsStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
tagsAPI: newTagsStub(),
wantInstances: wantInstances,
},
"imds resource group fails": {
imdsAPI: &stubIMDSAPI{resourceGroupErr: errors.New("failed")},
wantErr: true,
},
"listScaleSetVMs fails": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
networkInterfacesAPI: newNetworkInterfacesStub(),
scaleSetsAPI: newScaleSetsStub(),
virtualMachineScaleSetVMsAPI: newFailingListsVirtualMachineScaleSetsVMsStub(),
tagsAPI: newTagsStub(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
azureMetadata := Metadata{
imdsAPI: tc.imdsAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
scaleSetsAPI: tc.scaleSetsAPI,
virtualMachineScaleSetVMsAPI: tc.virtualMachineScaleSetVMsAPI,
tagsAPI: tc.tagsAPI,
}
instances, err := azureMetadata.List(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.ElementsMatch(tc.wantInstances, instances)
})
}
}
func TestSelf(t *testing.T) {
wantScaleSetInstance := metadata.InstanceMetadata{
Name: "scale-set-name-instance-id",
ProviderID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
Role: role.Worker,
VPCIP: "192.0.2.0",
}
testCases := map[string]struct {
imdsAPI imdsAPI
networkInterfacesAPI networkInterfacesAPI
virtualMachineScaleSetVMsAPI virtualMachineScaleSetVMsAPI
wantErr bool
wantInstance metadata.InstanceMetadata
}{
"self for scale set instance works": {
imdsAPI: &stubIMDSAPI{providerID: "/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"},
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
wantInstance: wantScaleSetInstance,
},
"providerID cannot be retrieved": {
imdsAPI: &stubIMDSAPI{providerIDErr: errors.New("failed")},
wantErr: true,
},
"GetInstance fails": {
imdsAPI: &stubIMDSAPI{providerID: wantScaleSetInstance.ProviderID},
virtualMachineScaleSetVMsAPI: &stubVirtualMachineScaleSetVMsAPI{getErr: errors.New("failed")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachineScaleSetVMsAPI: tc.virtualMachineScaleSetVMsAPI,
}
instance, err := metadata.Self(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantInstance, instance)
})
}
}
func TestGetNetworkSecurityGroupName(t *testing.T) {
name := "network-security-group-name"
testCases := map[string]struct {
securityGroupsAPI securityGroupsAPI
imdsAPI imdsAPI
wantName string
wantErr bool
}{
"GetNetworkSecurityGroupName works": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
securityGroupsAPI: &stubSecurityGroupsAPI{
pager: &stubSecurityGroupsClientListPager{
list: []armnetwork.SecurityGroup{{Name: to.Ptr(name)}},
},
},
wantName: name,
},
"no security group": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
securityGroupsAPI: &stubSecurityGroupsAPI{
pager: &stubSecurityGroupsClientListPager{},
},
wantErr: true,
},
"missing name in security group struct": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
securityGroupsAPI: &stubSecurityGroupsAPI{
pager: &stubSecurityGroupsClientListPager{
list: []armnetwork.SecurityGroup{{}},
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
securityGroupsAPI: tc.securityGroupsAPI,
}
name, err := metadata.GetNetworkSecurityGroupName(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantName, name)
})
}
}
func TestGetSubnetworkCIDR(t *testing.T) {
subnetworkCIDR := "192.0.2.0/24"
name := "name"
testCases := map[string]struct {
virtualNetworksAPI virtualNetworksAPI
imdsAPI imdsAPI
wantNetworkCIDR string
wantErr bool
}{
"GetSubnetworkCIDR works": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
virtualNetworksAPI: &stubVirtualNetworksAPI{
pager: &stubVirtualNetworksClientListPager{
list: []armnetwork.VirtualNetwork{{
Name: to.Ptr(name),
Properties: &armnetwork.VirtualNetworkPropertiesFormat{
Subnets: []*armnetwork.Subnet{
{Properties: &armnetwork.SubnetPropertiesFormat{AddressPrefix: to.Ptr(subnetworkCIDR)}},
},
},
}},
},
},
wantNetworkCIDR: subnetworkCIDR,
},
"no virtual networks found": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
virtualNetworksAPI: &stubVirtualNetworksAPI{
pager: &stubVirtualNetworksClientListPager{},
},
wantErr: true,
wantNetworkCIDR: subnetworkCIDR,
},
"malformed network struct": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
virtualNetworksAPI: &stubVirtualNetworksAPI{
pager: &stubVirtualNetworksClientListPager{list: []armnetwork.VirtualNetwork{{}}},
},
wantErr: true,
wantNetworkCIDR: subnetworkCIDR,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
virtualNetworksAPI: tc.virtualNetworksAPI,
}
subnetworkCIDR, err := metadata.getSubnetworkCIDR(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantNetworkCIDR, subnetworkCIDR)
})
}
}
func TestGetLoadBalancerName(t *testing.T) {
loadBalancerName := "load-balancer-name"
testCases := map[string]struct {
loadBalancerAPI loadBalancerAPI
imdsAPI imdsAPI
wantName string
wantErr bool
}{
"GetLoadBalancerName works": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
Name: to.Ptr(loadBalancerName),
Properties: &armnetwork.LoadBalancerPropertiesFormat{},
}},
},
},
wantName: loadBalancerName,
},
"invalid load balancer struct": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{list: []armnetwork.LoadBalancer{{}}},
},
wantErr: true,
},
"invalid missing name": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{list: []armnetwork.LoadBalancer{{
Properties: &armnetwork.LoadBalancerPropertiesFormat{},
}}},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
loadBalancerAPI: tc.loadBalancerAPI,
}
loadbalancerName, err := metadata.GetLoadBalancerName(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantName, loadbalancerName)
})
}
}
func TestGetLoadBalancerEndpoint(t *testing.T) {
loadBalancerName := "load-balancer-name"
publicIP := "192.0.2.1"
correctPublicIPID := "/subscriptions/subscription/resourceGroups/resourceGroup/providers/Microsoft.Network/publicIPAddresses/pubIPName"
someErr := errors.New("some error")
testCases := map[string]struct {
loadBalancerAPI loadBalancerAPI
publicIPAddressesAPI publicIPAddressesAPI
imdsAPI imdsAPI
wantIP string
wantErr bool
}{
"GetLoadBalancerEndpoint works": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
Name: to.Ptr(loadBalancerName),
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{
{
Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{ID: &correctPublicIPID},
},
},
},
},
}},
},
},
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getResponse: armnetwork.PublicIPAddressesClientGetResponse{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: &publicIP,
},
},
}},
wantIP: publicIP,
},
"no load balancer": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{},
},
wantErr: true,
},
"load balancer missing public IP reference": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
Name: to.Ptr(loadBalancerName),
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{},
},
}},
},
},
wantErr: true,
},
"public IP reference has wrong format": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
Name: to.Ptr(loadBalancerName),
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{
{
Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.Ptr("wrong-format"),
},
},
},
},
},
}},
},
},
wantErr: true,
},
"no public IP address found": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
Name: to.Ptr(loadBalancerName),
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{
{
Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{ID: &correctPublicIPID},
},
},
},
},
}},
},
},
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getErr: someErr},
wantErr: true,
},
"found public IP has no address field": {
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
Name: to.Ptr(loadBalancerName),
Properties: &armnetwork.LoadBalancerPropertiesFormat{
FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{
{
Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{ID: &correctPublicIPID},
},
},
},
},
}},
},
},
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getResponse: armnetwork.PublicIPAddressesClientGetResponse{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{},
},
}},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
loadBalancerAPI: tc.loadBalancerAPI,
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
loadbalancerName, err := metadata.GetLoadBalancerEndpoint(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantIP, loadbalancerName)
})
}
}
func TestProviderID(t *testing.T) {
testCases := map[string]struct {
imdsAPI imdsAPI
wantErr bool
wantProviderID string
}{
"providerID for scale set instance works": {
imdsAPI: &stubIMDSAPI{providerID: "provider-id"},
wantProviderID: "azure://provider-id",
},
"imds providerID fails": {
imdsAPI: &stubIMDSAPI{providerIDErr: errors.New("failed")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
}
providerID, err := metadata.providerID(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantProviderID, providerID)
})
}
}
func TestUID(t *testing.T) {
testCases := map[string]struct {
imdsAPI imdsAPI
wantErr bool
wantUID string
}{
"success": {
imdsAPI: &stubIMDSAPI{uid: "uid"},
wantUID: "uid",
},
"imds uid error": {
imdsAPI: &stubIMDSAPI{uidErr: errors.New("failed")},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
}
uid, err := metadata.UID(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantUID, uid)
})
}
}
func TestExtractInstanceTags(t *testing.T) {
testCases := map[string]struct {
in map[string]*string
wantTags map[string]string
}{
"tags are extracted": {
in: map[string]*string{"key": to.Ptr("value")},
wantTags: map[string]string{"key": "value"},
},
"nil values are skipped": {
in: map[string]*string{"key": nil},
wantTags: map[string]string{},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
tags := extractInstanceTags(tc.in)
assert.Equal(tc.wantTags, tags)
})
}
}
func newNetworkInterfacesStub() *stubNetworkInterfacesAPI {
return &stubNetworkInterfacesAPI{
getInterface: armnetwork.Interface{
Name: to.Ptr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.Ptr("192.0.2.0"),
Primary: to.Ptr(true),
},
},
},
},
},
}
}
func newScaleSetsStub() *stubScaleSetsAPI {
return &stubScaleSetsAPI{
pager: &stubVirtualMachineScaleSetsClientListPager{
list: []armcomputev2.VirtualMachineScaleSet{{
Name: to.Ptr("scale-set-name"),
Tags: map[string]*string{
cloud.TagUID: to.Ptr("uid"),
cloud.TagRole: to.Ptr("worker"),
},
}},
},
}
}
func newVirtualMachineScaleSetsVMsStub() *stubVirtualMachineScaleSetVMsAPI {
return &stubVirtualMachineScaleSetVMsAPI{
getVM: armcomputev2.VirtualMachineScaleSetVM{
Name: to.Ptr("scale-set-name_instance-id"),
InstanceID: to.Ptr("instance-id"),
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"),
Properties: &armcomputev2.VirtualMachineScaleSetVMProperties{
NetworkProfile: &armcomputev2.NetworkProfile{
NetworkInterfaces: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id/networkInterfaces/interface-name"),
},
},
},
OSProfile: &armcomputev2.OSProfile{
ComputerName: to.Ptr("scale-set-name-instance-id"),
},
},
Tags: map[string]*string{
cloud.TagUID: to.Ptr("uid"),
cloud.TagRole: to.Ptr("worker"),
},
},
pager: &stubVirtualMachineScaleSetVMPager{
list: []armcomputev2.VirtualMachineScaleSetVM{
{
Name: to.Ptr("scale-set-name_instance-id"),
InstanceID: to.Ptr("instance-id"),
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"),
Properties: &armcomputev2.VirtualMachineScaleSetVMProperties{
NetworkProfile: &armcomputev2.NetworkProfile{
NetworkInterfaces: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id/networkInterfaces/interface-name"),
},
},
},
OSProfile: &armcomputev2.OSProfile{
ComputerName: to.Ptr("scale-set-name-instance-id"),
},
},
Tags: map[string]*string{
cloud.TagUID: to.Ptr("uid"),
cloud.TagRole: to.Ptr("worker"),
},
},
},
},
}
}
func newFailingListsVirtualMachineScaleSetsVMsStub() *stubVirtualMachineScaleSetVMsAPI {
return &stubVirtualMachineScaleSetVMsAPI{
pager: &stubVirtualMachineScaleSetVMPager{
list: []armcomputev2.VirtualMachineScaleSetVM{{
InstanceID: to.Ptr("invalid-instance-id"),
}},
},
}
}
func newTagsStub() *stubTagsAPI {
return &stubTagsAPI{}
}

View File

@ -1,122 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"fmt"
"strings"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
)
// getVMInterfaces retrieves all network interfaces referenced by a virtual machine.
func (m *Metadata) getVMInterfaces(ctx context.Context, vm armcomputev2.VirtualMachine, resourceGroup string) ([]armnetwork.Interface, error) {
if vm.Properties == nil || vm.Properties.NetworkProfile == nil {
return []armnetwork.Interface{}, nil
}
interfaceNames := extractInterfaceNamesFromInterfaceReferences(vm.Properties.NetworkProfile.NetworkInterfaces)
networkInterfaces := []armnetwork.Interface{}
for _, interfaceName := range interfaceNames {
networkInterfacesResp, err := m.networkInterfacesAPI.Get(ctx, resourceGroup, interfaceName, nil)
if err != nil {
return nil, fmt.Errorf("retrieving network interface %v: %w", interfaceName, err)
}
networkInterfaces = append(networkInterfaces, networkInterfacesResp.Interface)
}
return networkInterfaces, nil
}
// getScaleSetVMInterfaces retrieves all network interfaces referenced by a scale set virtual machine.
func (m *Metadata) getScaleSetVMInterfaces(ctx context.Context, vm armcomputev2.VirtualMachineScaleSetVM, resourceGroup, scaleSet, instanceID string) ([]armnetwork.Interface, error) {
if vm.Properties == nil || vm.Properties.NetworkProfile == nil {
return []armnetwork.Interface{}, nil
}
interfaceNames := extractInterfaceNamesFromInterfaceReferences(vm.Properties.NetworkProfile.NetworkInterfaces)
networkInterfaces := []armnetwork.Interface{}
for _, interfaceName := range interfaceNames {
networkInterfacesResp, err := m.networkInterfacesAPI.GetVirtualMachineScaleSetNetworkInterface(ctx, resourceGroup, scaleSet, instanceID, interfaceName, nil)
if err != nil {
return nil, fmt.Errorf("retrieving network interface %v: %w", interfaceName, err)
}
networkInterfaces = append(networkInterfaces, networkInterfacesResp.Interface)
}
return networkInterfaces, nil
}
// getScaleSetVMPublicIPAddress retrieves the primary public IP address from a network interface which is referenced by a scale set virtual machine.
func (m *Metadata) getScaleSetVMPublicIPAddress(ctx context.Context, resourceGroup, scaleSet, instanceID string,
networkInterfaces []armnetwork.Interface,
) (string, error) {
for _, networkInterface := range networkInterfaces {
if networkInterface.Properties == nil || networkInterface.Name == nil {
continue
}
for _, config := range networkInterface.Properties.IPConfigurations {
if config == nil || config.Name == nil || config.Properties == nil || config.Properties.PublicIPAddress == nil ||
config.Properties.Primary == nil || !*config.Properties.Primary {
continue
}
publicIPAddressName := *config.Properties.PublicIPAddress.ID
publicIPAddressNameParts := strings.Split(publicIPAddressName, "/")
publicIPAddressName = publicIPAddressNameParts[len(publicIPAddressNameParts)-1]
publicIPAddress, err := m.publicIPAddressesAPI.GetVirtualMachineScaleSetPublicIPAddress(ctx, resourceGroup, scaleSet, instanceID, *networkInterface.Name, *config.Name, publicIPAddressName, nil)
if err != nil {
return "", fmt.Errorf("failed to retrieve public ip address %v: %w", publicIPAddressName, err)
}
if publicIPAddress.Properties == nil || publicIPAddress.Properties.IPAddress == nil {
return "", errors.New("retrieved public ip address has invalid ip address")
}
return *publicIPAddress.Properties.IPAddress, nil
}
}
// instances may have no public IP, in that case we don't return an error.
return "", nil
}
// extractVPCIP extracts the primary VPC IP from a list of network interface IP configurations.
func extractVPCIP(networkInterfaces []armnetwork.Interface) string {
for _, networkInterface := range networkInterfaces {
if networkInterface.Properties == nil || len(networkInterface.Properties.IPConfigurations) == 0 {
continue
}
for _, config := range networkInterface.Properties.IPConfigurations {
if config == nil || config.Properties == nil || config.Properties.PrivateIPAddress == nil || config.Properties.Primary == nil {
continue
}
if *config.Properties.Primary {
return *config.Properties.PrivateIPAddress
}
}
}
return ""
}
// extractInterfaceNamesFromInterfaceReferences extracts the name of a network interface from a reference id.
// Format:
// - "/subscriptions/<subscription>/resourceGroups/<resource-group>/providers/Microsoft.Network/networkInterfaces/<interface-name>"
// - "/subscriptions/<subscription>/resourceGroups/<resource-group>/providers/Microsoft.Compute/virtualMachineScaleSets/<scale-set-name>/virtualMachines/<instanceID>/networkInterfaces/<interface-name>".
func extractInterfaceNamesFromInterfaceReferences(references []*armcomputev2.NetworkInterfaceReference) []string {
interfaceNames := []string{}
for _, interfaceReference := range references {
if interfaceReference == nil || interfaceReference.ID == nil {
continue
}
interfaceIDParts := strings.Split(*interfaceReference.ID, "/")
if len(interfaceIDParts) < 1 {
continue
}
interfaceName := interfaceIDParts[len(interfaceIDParts)-1]
interfaceNames = append(interfaceNames, interfaceName)
}
return interfaceNames
}

View File

@ -1,405 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"testing"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/go-autorest/autorest/to"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetVMInterfaces(t *testing.T) {
wantNetworkInterfaces := []armnetwork.Interface{
{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
},
},
},
}
vm := armcomputev2.VirtualMachine{
Properties: &armcomputev2.VirtualMachineProperties{
NetworkProfile: &armcomputev2.NetworkProfile{
NetworkInterfaces: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.StringPtr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Network/networkInterfaces/interface-name"),
},
},
},
},
}
testCases := map[string]struct {
vm armcomputev2.VirtualMachine
networkInterfacesAPI networkInterfacesAPI
wantErr bool
wantNetworkInterfaces []armnetwork.Interface
}{
"retrieval works": {
vm: vm,
networkInterfacesAPI: &stubNetworkInterfacesAPI{
getInterface: armnetwork.Interface{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
},
},
},
},
wantNetworkInterfaces: wantNetworkInterfaces,
},
"vm can have 0 interfaces": {
vm: armcomputev2.VirtualMachine{},
networkInterfacesAPI: &stubNetworkInterfacesAPI{
getInterface: armnetwork.Interface{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
},
},
},
},
wantNetworkInterfaces: []armnetwork.Interface{},
},
"interface retrieval fails": {
vm: vm,
networkInterfacesAPI: &stubNetworkInterfacesAPI{
getErr: errors.New("get err"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
networkInterfacesAPI: tc.networkInterfacesAPI,
}
vmNetworkInteraces, err := metadata.getVMInterfaces(context.Background(), tc.vm, "resource-group")
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantNetworkInterfaces, vmNetworkInteraces)
})
}
}
func TestGetScaleSetVMInterfaces(t *testing.T) {
wantNetworkInterfaces := []armnetwork.Interface{
{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
},
},
},
}
vm := armcomputev2.VirtualMachineScaleSetVM{
Properties: &armcomputev2.VirtualMachineScaleSetVMProperties{
NetworkProfile: &armcomputev2.NetworkProfile{
NetworkInterfaces: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.StringPtr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id/networkInterfaces/interface-name"),
},
},
},
},
}
testCases := map[string]struct {
vm armcomputev2.VirtualMachineScaleSetVM
networkInterfacesAPI networkInterfacesAPI
wantErr bool
wantNetworkInterfaces []armnetwork.Interface
}{
"retrieval works": {
vm: vm,
networkInterfacesAPI: &stubNetworkInterfacesAPI{
getInterface: armnetwork.Interface{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
},
},
},
},
wantNetworkInterfaces: wantNetworkInterfaces,
},
"vm can have 0 interfaces": {
vm: armcomputev2.VirtualMachineScaleSetVM{},
networkInterfacesAPI: &stubNetworkInterfacesAPI{
getInterface: armnetwork.Interface{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
},
},
},
},
wantNetworkInterfaces: []armnetwork.Interface{},
},
"interface retrieval fails": {
vm: vm,
networkInterfacesAPI: &stubNetworkInterfacesAPI{
getErr: errors.New("get err"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
networkInterfacesAPI: tc.networkInterfacesAPI,
}
configs, err := metadata.getScaleSetVMInterfaces(context.Background(), tc.vm, "resource-group", "scale-set-name", "instance-id")
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantNetworkInterfaces, configs)
})
}
}
func TestGetScaleSetVMPublicIPAddresses(t *testing.T) {
someErr := errors.New("some err")
newNetworkInterfaces := func() []armnetwork.Interface {
return []armnetwork.Interface{{
Name: to.StringPtr("interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Name: to.StringPtr("ip-config-name"),
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
Primary: to.BoolPtr(true),
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.StringPtr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Network/publicIPAddresses/public-ip-name"),
},
},
},
},
},
}, {
Name: to.StringPtr("interface-name2"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Name: to.StringPtr("ip-config-name2"),
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PublicIPAddress: &armnetwork.PublicIPAddress{
ID: to.StringPtr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Network/publicIPAddresses/public-ip-name2"),
},
},
},
},
},
}}
}
testCases := map[string]struct {
networkInterfacesMutator func(*[]armnetwork.Interface)
networkInterfaces []armnetwork.Interface
publicIPAddressesAPI publicIPAddressesAPI
wantIP string
wantErr bool
}{
"retrieval works": {
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getVirtualMachineScaleSetPublicIPAddressResponse: armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
}},
networkInterfaces: newNetworkInterfaces(),
wantIP: "192.0.2.1",
},
"retrieval works for no valid interfaces": {
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getVirtualMachineScaleSetPublicIPAddressResponse: armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse{
PublicIPAddress: armnetwork.PublicIPAddress{
Properties: &armnetwork.PublicIPAddressPropertiesFormat{
IPAddress: to.StringPtr("192.0.2.1"),
},
},
}},
networkInterfaces: newNetworkInterfaces(),
networkInterfacesMutator: func(nets *[]armnetwork.Interface) {
(*nets)[0].Properties.IPConfigurations = []*armnetwork.InterfaceIPConfiguration{nil}
(*nets)[1] = armnetwork.Interface{Name: nil}
},
},
"fail to get public IP": {
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getErr: someErr},
networkInterfaces: newNetworkInterfaces(),
wantErr: true,
},
"fail to parse IPv4 address of public IP": {
publicIPAddressesAPI: &stubPublicIPAddressesAPI{getVirtualMachineScaleSetPublicIPAddressResponse: armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse{
PublicIPAddress: armnetwork.PublicIPAddress{},
}},
networkInterfaces: newNetworkInterfaces(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
if tc.networkInterfacesMutator != nil {
tc.networkInterfacesMutator(&tc.networkInterfaces)
}
metadata := Metadata{
publicIPAddressesAPI: tc.publicIPAddressesAPI,
}
ips, err := metadata.getScaleSetVMPublicIPAddress(context.Background(), "resource-group", "scale-set-name", "instance-id", tc.networkInterfaces)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantIP, ips)
})
}
}
func TestExtractPrivateIPs(t *testing.T) {
testCases := map[string]struct {
networkInterfaces []armnetwork.Interface
wantIP string
}{
"extraction works": {
networkInterfaces: []armnetwork.Interface{
{
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
Primary: to.BoolPtr(true),
PrivateIPAddress: to.StringPtr("192.0.2.0"),
},
},
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
PrivateIPAddress: to.StringPtr("192.0.2.1"),
},
},
},
},
},
},
wantIP: "192.0.2.0",
},
"can be empty": {
networkInterfaces: []armnetwork.Interface{},
},
"invalid interface is skipped": {
networkInterfaces: []armnetwork.Interface{{}},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ip := extractVPCIP(tc.networkInterfaces)
assert.Equal(tc.wantIP, ip)
})
}
}
func TestExtractInterfaceNamesFromInterfaceReferences(t *testing.T) {
testCases := map[string]struct {
references []*armcomputev2.NetworkInterfaceReference
wantNames []string
}{
"extraction with individual interface reference works": {
references: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.StringPtr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Network/networkInterfaces/interface-name"),
},
},
wantNames: []string{"interface-name"},
},
"extraction with scale set interface reference works": {
references: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.StringPtr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id/networkInterfaces/interface-name"),
},
},
wantNames: []string{"interface-name"},
},
"can be empty": {
references: []*armcomputev2.NetworkInterfaceReference{},
},
"interface reference containing nil fields is skipped": {
references: []*armcomputev2.NetworkInterfaceReference{
{},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
names := extractInterfaceNamesFromInterfaceReferences(tc.references)
assert.ElementsMatch(tc.wantNames, names)
})
}
}

View File

@ -1,130 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"fmt"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/edgelesssys/constellation/v2/internal/cloud/azureshared"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/role"
)
// getScaleSetVM tries to get an azure vm belonging to a scale set.
func (m *Metadata) getScaleSetVM(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
_, resourceGroup, scaleSet, instanceID, err := azureshared.ScaleSetInformationFromProviderID(providerID)
if err != nil {
return metadata.InstanceMetadata{}, err
}
vmResp, err := m.virtualMachineScaleSetVMsAPI.Get(ctx, resourceGroup, scaleSet, instanceID, nil)
if err != nil {
return metadata.InstanceMetadata{}, err
}
networkInterfaces, err := m.getScaleSetVMInterfaces(ctx, vmResp.VirtualMachineScaleSetVM, resourceGroup, scaleSet, instanceID)
if err != nil {
return metadata.InstanceMetadata{}, err
}
return convertScaleSetVMToCoreInstance(vmResp.VirtualMachineScaleSetVM, networkInterfaces)
}
// listScaleSetVMs lists all scale set VMs in the current resource group.
func (m *Metadata) listScaleSetVMs(ctx context.Context, resourceGroup string) ([]metadata.InstanceMetadata, error) {
instances := []metadata.InstanceMetadata{}
scaleSetPager := m.scaleSetsAPI.NewListPager(resourceGroup, nil)
for scaleSetPager.More() {
page, err := scaleSetPager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving scale sets: %w", err)
}
for _, scaleSet := range page.Value {
if scaleSet == nil || scaleSet.Name == nil {
continue
}
vmPager := m.virtualMachineScaleSetVMsAPI.NewListPager(resourceGroup, *scaleSet.Name, nil)
for vmPager.More() {
vmPage, err := vmPager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving vms: %w", err)
}
for _, vm := range vmPage.Value {
if vm == nil || vm.InstanceID == nil {
continue
}
interfaces, err := m.getScaleSetVMInterfaces(ctx, *vm, resourceGroup, *scaleSet.Name, *vm.InstanceID)
if err != nil {
return nil, err
}
instance, err := convertScaleSetVMToCoreInstance(*vm, interfaces)
if err != nil {
return nil, err
}
instances = append(instances, instance)
}
}
}
}
return instances, nil
}
// convertScaleSetVMToCoreInstance converts an azure scale set virtual machine with interface configurations into a core.Instance.
func convertScaleSetVMToCoreInstance(vm armcomputev2.VirtualMachineScaleSetVM, networkInterfaces []armnetwork.Interface,
) (metadata.InstanceMetadata, error) {
if vm.ID == nil {
return metadata.InstanceMetadata{}, errors.New("retrieving instance from armcompute API client returned no instance ID")
}
if vm.Properties == nil || vm.Properties.OSProfile == nil || vm.Properties.OSProfile.ComputerName == nil {
return metadata.InstanceMetadata{}, errors.New("retrieving instance from armcompute API client returned no computer name")
}
if vm.Tags == nil {
return metadata.InstanceMetadata{}, errors.New("retrieving instance from armcompute API client returned no tags")
}
return metadata.InstanceMetadata{
Name: *vm.Properties.OSProfile.ComputerName,
ProviderID: "azure://" + *vm.ID,
Role: extractScaleSetVMRole(vm.Tags),
VPCIP: extractVPCIP(networkInterfaces),
}, nil
}
// extractScaleSetVMRole extracts the constellation role of a scale set using its name.
func extractScaleSetVMRole(tags map[string]*string) role.Role {
if tags == nil {
return role.Unknown
}
roleStr, ok := tags[cloud.TagRole]
if !ok {
return role.Unknown
}
if roleStr == nil {
return role.Unknown
}
return role.FromString(*roleStr)
}
// ImageReferenceFromImage sets the `ID` or `CommunityGalleryImageID` field
// of `ImageReference` depending on the provided `img`.
func ImageReferenceFromImage(img string) *armcomputev2.ImageReference {
ref := &armcomputev2.ImageReference{}
if strings.HasPrefix(img, "/CommunityGalleries") {
ref.CommunityGalleryImageID = to.Ptr(img)
} else {
ref.ID = to.Ptr(img)
}
return ref
}

View File

@ -1,356 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"errors"
"testing"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/edgelesssys/constellation/v2/internal/cloud"
"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
"github.com/edgelesssys/constellation/v2/internal/role"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetScaleSetVM(t *testing.T) {
wantInstance := metadata.InstanceMetadata{
Name: "scale-set-name-instance-id",
ProviderID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
Role: role.Worker,
VPCIP: "192.0.2.0",
}
testCases := map[string]struct {
providerID string
networkInterfacesAPI networkInterfacesAPI
virtualMachineScaleSetVMsAPI virtualMachineScaleSetVMsAPI
wantErr bool
wantInstance metadata.InstanceMetadata
}{
"getVM for scale set instance works": {
providerID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
wantInstance: wantInstance,
},
"getVM for individual instance must fail": {
providerID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachines/instance-name",
wantErr: true,
},
"Get fails": {
providerID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
virtualMachineScaleSetVMsAPI: newFailingGetScaleSetVirtualMachinesStub(),
wantErr: true,
},
"conversion fails": {
providerID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
virtualMachineScaleSetVMsAPI: newGetInvalidScaleSetVirtualMachinesStub(),
networkInterfacesAPI: newNetworkInterfacesStub(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachineScaleSetVMsAPI: tc.virtualMachineScaleSetVMsAPI,
}
instance, err := metadata.getScaleSetVM(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantInstance, instance)
})
}
}
func TestListScaleSetVMs(t *testing.T) {
wantInstances := []metadata.InstanceMetadata{
{
Name: "scale-set-name-instance-id",
ProviderID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
Role: role.Worker,
VPCIP: "192.0.2.0",
},
}
testCases := map[string]struct {
networkInterfacesAPI networkInterfacesAPI
virtualMachineScaleSetVMsAPI virtualMachineScaleSetVMsAPI
scaleSetsAPI scaleSetsAPI
wantErr bool
wantInstances []metadata.InstanceMetadata
}{
"listVMs works": {
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
scaleSetsAPI: newScaleSetsStub(),
wantInstances: wantInstances,
},
"invalid scale sets are skipped": {
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
scaleSetsAPI: newListContainingNilScaleSetStub(),
wantInstances: wantInstances,
},
"listVMs can return 0 VMs": {
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: &stubVirtualMachineScaleSetVMsAPI{pager: &stubVirtualMachineScaleSetVMPager{}},
scaleSetsAPI: newScaleSetsStub(),
wantInstances: []metadata.InstanceMetadata{},
},
"can skip nil in VM list": {
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newListContainingNilScaleSetVirtualMachinesStub(),
scaleSetsAPI: newScaleSetsStub(),
wantInstances: wantInstances,
},
"converting instance fails": {
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newListContainingInvalidScaleSetVirtualMachinesStub(),
scaleSetsAPI: newScaleSetsStub(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
metadata := Metadata{
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachineScaleSetVMsAPI: tc.virtualMachineScaleSetVMsAPI,
scaleSetsAPI: tc.scaleSetsAPI,
}
instances, err := metadata.listScaleSetVMs(context.Background(), "resource-group")
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.ElementsMatch(tc.wantInstances, instances)
})
}
}
func TestConvertScaleSetVMToCoreInstance(t *testing.T) {
testCases := map[string]struct {
inVM armcomputev2.VirtualMachineScaleSetVM
inInterface []armnetwork.Interface
wantErr bool
wantInstance metadata.InstanceMetadata
}{
"conversion works": {
inVM: armcomputev2.VirtualMachineScaleSetVM{
Name: to.Ptr("scale-set-name_instance-id"),
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"),
Tags: map[string]*string{"tag-key": to.Ptr("tag-value")},
Properties: &armcomputev2.VirtualMachineScaleSetVMProperties{
OSProfile: &armcomputev2.OSProfile{
ComputerName: to.Ptr("scale-set-name-instance-id"),
},
},
},
inInterface: []armnetwork.Interface{
{
Name: to.Ptr("scale-set-name_instance-id"),
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Network/networkInterfaces/interface-name"),
Properties: &armnetwork.InterfacePropertiesFormat{
IPConfigurations: []*armnetwork.InterfaceIPConfiguration{
{
Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{
Primary: to.Ptr(true),
PrivateIPAddress: to.Ptr("192.0.2.0"),
},
},
},
},
},
},
wantInstance: metadata.InstanceMetadata{
Name: "scale-set-name-instance-id",
ProviderID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
VPCIP: "192.0.2.0",
},
},
"invalid instance": {
inVM: armcomputev2.VirtualMachineScaleSetVM{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
instance, err := convertScaleSetVMToCoreInstance(tc.inVM, tc.inInterface)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantInstance, instance)
})
}
}
func TestExtractScaleSetVMRole(t *testing.T) {
testCases := map[string]struct {
tags map[string]*string
wantRole role.Role
}{
"control-plane role": {
tags: map[string]*string{cloud.TagRole: to.Ptr("control-plane")},
wantRole: role.ControlPlane,
},
"worker role": {
tags: map[string]*string{cloud.TagRole: to.Ptr("worker")},
wantRole: role.Worker,
},
"unknown role": {
tags: map[string]*string{cloud.TagRole: to.Ptr("foo")},
wantRole: role.Unknown,
},
"no role": {
tags: map[string]*string{},
wantRole: role.Unknown,
},
"nil role": {
tags: map[string]*string{cloud.TagRole: nil},
wantRole: role.Unknown,
},
"nil tags": {
tags: nil,
wantRole: role.Unknown,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
role := extractScaleSetVMRole(tc.tags)
assert.Equal(tc.wantRole, role)
})
}
}
func newFailingGetScaleSetVirtualMachinesStub() *stubVirtualMachineScaleSetVMsAPI {
return &stubVirtualMachineScaleSetVMsAPI{
getErr: errors.New("get err"),
}
}
func newGetInvalidScaleSetVirtualMachinesStub() *stubVirtualMachineScaleSetVMsAPI {
return &stubVirtualMachineScaleSetVMsAPI{
getVM: armcomputev2.VirtualMachineScaleSetVM{},
}
}
func newListContainingNilScaleSetVirtualMachinesStub() *stubVirtualMachineScaleSetVMsAPI {
return &stubVirtualMachineScaleSetVMsAPI{
pager: &stubVirtualMachineScaleSetVMPager{
list: []armcomputev2.VirtualMachineScaleSetVM{
{
Name: to.Ptr("scale-set-name_instance-id"),
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"),
InstanceID: to.Ptr("instance-id"),
Tags: map[string]*string{
cloud.TagRole: to.Ptr("worker"),
},
Properties: &armcomputev2.VirtualMachineScaleSetVMProperties{
NetworkProfile: &armcomputev2.NetworkProfile{
NetworkInterfaces: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id/networkInterfaces/interface-name"),
},
},
},
OSProfile: &armcomputev2.OSProfile{
ComputerName: to.Ptr("scale-set-name-instance-id"),
},
},
},
},
},
}
}
func newListContainingInvalidScaleSetVirtualMachinesStub() *stubVirtualMachineScaleSetVMsAPI {
return &stubVirtualMachineScaleSetVMsAPI{
pager: &stubVirtualMachineScaleSetVMPager{
list: []armcomputev2.VirtualMachineScaleSetVM{
{
InstanceID: to.Ptr("instance-id"),
Properties: &armcomputev2.VirtualMachineScaleSetVMProperties{
OSProfile: &armcomputev2.OSProfile{},
NetworkProfile: &armcomputev2.NetworkProfile{
NetworkInterfaces: []*armcomputev2.NetworkInterfaceReference{
{
ID: to.Ptr("/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id/networkInterfaces/interface-name"),
},
},
},
},
},
},
},
}
}
func newListContainingNilScaleSetStub() *stubScaleSetsAPI {
return &stubScaleSetsAPI{
pager: &stubVirtualMachineScaleSetsClientListPager{
list: []armcomputev2.VirtualMachineScaleSet{{Name: to.Ptr("scale-set-name")}},
},
}
}
func TestImageReferenceFromImage(t *testing.T) {
testCases := map[string]struct {
img string
wantID *string
wantCommunityID *string
}{
"ID": {
img: "/subscriptions/0d202bbb-4fa7-4af8-8125-58c269a05435/resourceGroups/constellation-images/providers/Microsoft.Compute/galleries/Constellation/images/constellation/versions/1.5.0",
wantID: to.Ptr("/subscriptions/0d202bbb-4fa7-4af8-8125-58c269a05435/resourceGroups/constellation-images/providers/Microsoft.Compute/galleries/Constellation/images/constellation/versions/1.5.0"),
wantCommunityID: nil,
},
"Community": {
img: "/CommunityGalleries/ConstellationCVM-728bd310-e898-4450-a1ed-21cf2fb0d735/Images/feat-azure-cvm-sharing/Versions/2022.0826.084922",
wantID: nil,
wantCommunityID: to.Ptr("/CommunityGalleries/ConstellationCVM-728bd310-e898-4450-a1ed-21cf2fb0d735/Images/feat-azure-cvm-sharing/Versions/2022.0826.084922"),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ref := ImageReferenceFromImage(tc.img)
assert.Equal(tc.wantID, ref.ID)
assert.Equal(tc.wantCommunityID, ref.CommunityGalleryImageID)
})
}
}

View File

@ -1,29 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
)
// getNetworkSecurityGroup retrieves the list of security groups for the given resource group.
func (m *Metadata) getNetworkSecurityGroup(ctx context.Context, resourceGroup string) (*armnetwork.SecurityGroup, error) {
pager := m.securityGroupsAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving security groups: %w", err)
}
for _, securityGroup := range page.Value {
return securityGroup, nil
}
}
return nil, fmt.Errorf("no security group found for resource group %q", resourceGroup)
}

View File

@ -1,31 +0,0 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package azure
import (
"context"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
)
// getVirtualNetwork return the first virtual network found in the resource group.
func (m *Metadata) getVirtualNetwork(ctx context.Context, resourceGroup string) (*armnetwork.VirtualNetwork, error) {
pager := m.virtualNetworksAPI.NewListPager(resourceGroup, nil)
for pager.More() {
page, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving virtual networks: %w", err)
}
for _, network := range page.Value {
if network != nil {
return network, nil
}
}
}
return nil, fmt.Errorf("no virtual network found in resource group %s", resourceGroup)
}

View File

@ -27,7 +27,6 @@ type InstanceMetadata struct {
// SecondaryIPRange is the VPC wide CIDR from which subnets are attached to VMs as AliasIPRanges. // SecondaryIPRange is the VPC wide CIDR from which subnets are attached to VMs as AliasIPRanges.
// May be empty on certain CSPs. // May be empty on certain CSPs.
SecondaryIPRange string SecondaryIPRange string
// AliasIPRanges is a list of IP ranges that are attached. // AliasIPRanges is a list of IP ranges that are attached.
// May be empty on certain CSPs. // May be empty on certain CSPs.
AliasIPRanges []string AliasIPRanges []string

View File

@ -123,7 +123,7 @@ func getVPCIP(ctx context.Context, provider string) (string, error) {
return "", err return "", err
} }
case cloudprovider.Azure: case cloudprovider.Azure:
metadata, err = azurecloud.NewMetadata(ctx) metadata, err = azurecloud.New(ctx)
if err != nil { if err != nil {
return "", err return "", err
} }