Refactor Azure IMDS client and metadata

This commit is contained in:
katexochen 2022-08-29 11:54:30 +02:00 committed by Paul Meyer
parent f15605cb45
commit 69abe17c96
10 changed files with 312 additions and 250 deletions

View File

@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"regexp"
"strings"
)
var azureVMSSProviderIDRegexp = regexp.MustCompile(`^azure:///subscriptions/([^/]+)/resourceGroups/([^/]+)/providers/Microsoft.Compute/virtualMachineScaleSets/([^/]+)/virtualMachines/([^/]+)$`)
@ -18,19 +17,6 @@ func BasicsFromProviderID(providerID string) (subscriptionID, resourceGroup stri
return "", "", fmt.Errorf("providerID %v is malformatted", providerID)
}
// UIDFromProviderID extracts our own generated unique ID, which is the
// suffix at the resource group, e.g., resource-group-J18dB
// J18dB is the UID.
func UIDFromProviderID(providerID string) (string, error) {
_, resourceGroup, _, _, err := ScaleSetInformationFromProviderID(providerID)
if err != nil {
return "", err
}
parts := strings.Split(resourceGroup, "-")
return parts[len(parts)-1], nil
}
// ScaleSetInformationFromProviderID splits a provider's id belonging to an azure scaleset into core components.
// A providerID for scale set VMs is build after the following schema:
// - 'azure:///subscriptions/<subscription-id>/resourceGroups/<resource-group>/providers/Microsoft.Compute/virtualMachineScaleSets/<scale-set-name>/virtualMachines/<instance-id>'

View File

@ -43,39 +43,6 @@ func TestBasicsFromProviderID(t *testing.T) {
}
}
func TestUIDFromProviderID(t *testing.T) {
testCases := map[string]struct {
providerID string
wantUID string
wantErr bool
}{
"UID from virtual machine works": {
providerID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group-ABC123/providers/Microsoft.Compute/virtualMachineScaleSets/scaleset/virtualMachines/instance-name",
wantUID: "ABC123",
},
"providerID is malformed": {
providerID: "malformed-provider-id",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
uid, err := UIDFromProviderID(tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantUID, uid)
})
}
}
func TestScaleSetInformationFromProviderID(t *testing.T) {
testCases := map[string]struct {
providerID string

View File

@ -11,15 +11,22 @@ import (
)
type imdsAPI interface {
Retrieve(ctx context.Context) (metadataResponse, error)
ProviderID(ctx context.Context) (string, error)
SubscriptionID(ctx context.Context) (string, error)
ResourceGroup(ctx context.Context) (string, error)
UID(ctx context.Context) (string, error)
}
type virtualNetworksAPI interface {
NewListPager(resourceGroupName string, options *armnetwork.VirtualNetworksClientListOptions) *runtime.Pager[armnetwork.VirtualNetworksClientListResponse]
NewListPager(resourceGroupName string,
options *armnetwork.VirtualNetworksClientListOptions,
) *runtime.Pager[armnetwork.VirtualNetworksClientListResponse]
}
type securityGroupsAPI interface {
NewListPager(resourceGroupName string, options *armnetwork.SecurityGroupsClientListOptions) *runtime.Pager[armnetwork.SecurityGroupsClientListResponse]
NewListPager(resourceGroupName string,
options *armnetwork.SecurityGroupsClientListOptions,
) *runtime.Pager[armnetwork.SecurityGroupsClientListResponse]
}
type networkInterfacesAPI interface {
@ -38,7 +45,8 @@ type publicIPAddressesAPI interface {
options *armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressOptions,
) (armnetwork.PublicIPAddressesClientGetVirtualMachineScaleSetPublicIPAddressResponse, error)
Get(ctx context.Context, resourceGroupName string, publicIPAddressName string,
options *armnetwork.PublicIPAddressesClientGetOptions) (armnetwork.PublicIPAddressesClientGetResponse, error)
options *armnetwork.PublicIPAddressesClientGetOptions,
) (armnetwork.PublicIPAddressesClientGetResponse, error)
}
type virtualMachineScaleSetVMsAPI interface {
@ -61,10 +69,16 @@ type loadBalancerAPI interface {
}
type tagsAPI interface {
CreateOrUpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsResource, options *armresources.TagsClientCreateOrUpdateAtScopeOptions) (armresources.TagsClientCreateOrUpdateAtScopeResponse, error)
UpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsPatchResource, options *armresources.TagsClientUpdateAtScopeOptions) (armresources.TagsClientUpdateAtScopeResponse, error)
CreateOrUpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsResource,
options *armresources.TagsClientCreateOrUpdateAtScopeOptions,
) (armresources.TagsClientCreateOrUpdateAtScopeResponse, error)
UpdateAtScope(ctx context.Context, scope string, parameters armresources.TagsPatchResource,
options *armresources.TagsClientUpdateAtScopeOptions,
) (armresources.TagsClientUpdateAtScopeResponse, error)
}
type applicationInsightsAPI interface {
Get(ctx context.Context, resourceGroupName string, resourceName string, options *armapplicationinsights.ComponentsClientGetOptions) (armapplicationinsights.ComponentsClientGetResponse, error)
NewListByResourceGroupPager(resourceGroupName string,
options *armapplicationinsights.ComponentsClientListByResourceGroupOptions,
) *runtime.Pager[armapplicationinsights.ComponentsClientListByResourceGroupResponse]
}

View File

@ -19,12 +19,30 @@ func TestMain(m *testing.M) {
}
type stubIMDSAPI struct {
res metadataResponse
retrieveErr error
providerIDErr error
providerID string
subscriptionIDErr error
subscriptionID string
resourceGroupErr error
resourceGroup string
uidErr error
uid string
}
func (a *stubIMDSAPI) Retrieve(ctx context.Context) (metadataResponse, error) {
return a.res, a.retrieveErr
func (a *stubIMDSAPI) ProviderID(ctx context.Context) (string, error) {
return a.providerID, a.providerIDErr
}
func (a *stubIMDSAPI) SubscriptionID(ctx context.Context) (string, error) {
return a.subscriptionID, a.subscriptionIDErr
}
func (a *stubIMDSAPI) ResourceGroup(ctx context.Context) (string, error) {
return a.resourceGroup, a.resourceGroupErr
}
func (a *stubIMDSAPI) UID(ctx context.Context) (string, error) {
return a.uid, a.uidErr
}
type stubNetworkInterfacesAPI struct {

View File

@ -3,8 +3,10 @@ package azure
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"time"
)
// subset of azure imds API: https://docs.microsoft.com/en-us/azure/virtual-machines/windows/instance-metadata-service?tabs=linux
@ -13,17 +15,95 @@ import (
const (
imdsURL = "http://169.254.169.254/metadata/instance"
imdsAPIVersion = "2021-02-01"
maxCacheAge = 12 * time.Hour
)
type imdsClient struct {
client *http.Client
cache metadataResponse
cacheTime time.Time
}
// Retrieve retrieves instance metadata from the azure imds API.
func (c *imdsClient) Retrieve(ctx context.Context) (metadataResponse, error) {
req, err := http.NewRequestWithContext(ctx, "GET", imdsURL, http.NoBody)
// ProviderID returns the provider ID of the instance the function is called from.
func (c *imdsClient) ProviderID(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.ResourceID == "" {
if err := c.update(ctx); err != nil {
return "", err
}
}
if c.cache.Compute.ResourceID == "" {
return "", errors.New("unable to get provider id")
}
return c.cache.Compute.ResourceID, nil
}
// SubscriptionID returns the subscription ID of the instance the function
// is called from.
func (c *imdsClient) SubscriptionID(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.SubscriptionID == "" {
if err := c.update(ctx); err != nil {
return "", err
}
}
if c.cache.Compute.SubscriptionID == "" {
return "", errors.New("unable to get subscription id")
}
return c.cache.Compute.SubscriptionID, nil
}
// ResourceGroup returns the resource group of the instance the function
// is called from.
func (c *imdsClient) ResourceGroup(ctx context.Context) (string, error) {
if c.timeForUpdate() || c.cache.Compute.ResourceGroup == "" {
if err := c.update(ctx); err != nil {
return "", err
}
}
if c.cache.Compute.ResourceGroup == "" {
return "", errors.New("unable to get resource group")
}
return c.cache.Compute.ResourceGroup, nil
}
// UID returns the UID of the cluster, based on the tags on the instance
// the function is calles from, which are inherited from the scale set.
func (c *imdsClient) UID(ctx context.Context) (string, error) {
if c.timeForUpdate() || len(c.cache.Compute.Tags) == 0 {
if err := c.update(ctx); err != nil {
return "", err
}
}
if len(c.cache.Compute.Tags) == 0 {
return "", errors.New("unable to get uid")
}
for _, tag := range c.cache.Compute.Tags {
if tag.Name == "uid" {
return tag.Value, nil
}
}
return "", errors.New("unable to get uid from metadata tags")
}
// timeForUpdate checks whether an update is needed due to cache age.
func (c *imdsClient) timeForUpdate() bool {
return time.Since(c.cacheTime) > maxCacheAge
}
// update updates instance metadata from the azure imds API.
func (c *imdsClient) update(ctx context.Context) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, imdsURL, http.NoBody)
if err != nil {
return metadataResponse{}, err
return err
}
req.Header.Add("Metadata", "True")
query := req.URL.Query()
@ -32,23 +112,36 @@ func (c *imdsClient) Retrieve(ctx context.Context) (metadataResponse, error) {
req.URL.RawQuery = query.Encode()
resp, err := c.client.Do(req)
if err != nil {
return metadataResponse{}, err
return err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return metadataResponse{}, err
return err
}
var res metadataResponse
if err := json.Unmarshal(body, &res); err != nil {
return metadataResponse{}, err
return err
}
return res, nil
c.cache = res
c.cacheTime = time.Now()
return nil
}
// metadataResponse contains metadataResponse with only the required values.
type metadataResponse struct {
Compute struct {
ResourceID string `json:"resourceId,omitempty"`
} `json:"compute,omitempty"`
Compute metadataResponseCompute `json:"compute,omitempty"`
}
type metadataResponseCompute struct {
ResourceID string `json:"resourceId,omitempty"`
SubscriptionID string `json:"subscriptionId,omitempty"`
ResourceGroup string `json:"resourceGroupName,omitempty"`
Tags []metadataTag `json:"tagsList,omitempty"`
}
type metadataTag struct {
Name string `json:"name,omitempty"`
Value string `json:"value,omitempty"`
}

View File

@ -10,39 +10,83 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/test/bufconn"
)
func TestRetrieve(t *testing.T) {
func TestIMDSClient(t *testing.T) {
uidTags := []metadataTag{{Name: "uid", Value: "uid"}}
response := metadataResponse{
Compute: struct {
ResourceID string `json:"resourceId,omitempty"`
}{
ResourceID: "resource-id",
Compute: metadataResponseCompute{
ResourceID: "resource-id",
ResourceGroup: "resource-group",
Tags: uidTags,
},
}
responseWithoutID := metadataResponse{
Compute: metadataResponseCompute{
ResourceGroup: "resource-group",
Tags: uidTags,
},
}
responseWithoutGroup := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
Tags: uidTags,
},
}
responseWithoutUID := metadataResponse{
Compute: metadataResponseCompute{
ResourceID: "resource-id",
ResourceGroup: "resource-group",
},
}
testCases := map[string]struct {
server httpBufconnServer
wantErr bool
wantResponse metadataResponse
server httpBufconnServer
wantProviderIDErr bool
wantProviderID string
wantResourceGroupErr bool
wantResourceGroup string
wantUIDErr bool
wantUID string
}{
"metadata response parsed": {
server: newHTTPBufconnServerWithMetadataResponse(response),
wantResponse: response,
server: newHTTPBufconnServerWithMetadataResponse(response),
wantProviderID: "resource-id",
wantResourceGroup: "resource-group",
wantUID: "uid",
},
"metadata response without resource ID": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutID),
wantProviderIDErr: true,
wantResourceGroup: "resource-group",
wantUID: "uid",
},
"metadata response without UID tag": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutUID),
wantProviderID: "resource-id",
wantResourceGroup: "resource-group",
wantUIDErr: true,
},
"metadata response without resource group": {
server: newHTTPBufconnServerWithMetadataResponse(responseWithoutGroup),
wantProviderID: "resource-id",
wantResourceGroupErr: true,
wantUID: "uid",
},
"invalid imds response detected": {
server: newHTTPBufconnServer(func(writer http.ResponseWriter, request *http.Request) {
fmt.Fprintln(writer, "invalid-result")
}),
wantErr: true,
wantProviderIDErr: true,
wantResourceGroupErr: true,
wantUIDErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
defer tc.server.Close()
@ -54,17 +98,33 @@ func TestRetrieve(t *testing.T) {
DialTLS: tc.server.Dial,
},
}
iClient := imdsClient{
client: &hClient,
}
resp, err := iClient.Retrieve(context.Background())
iClient := imdsClient{client: &hClient}
if tc.wantErr {
ctx := context.Background()
id, err := iClient.ProviderID(ctx)
if tc.wantProviderIDErr {
assert.Error(err)
return
} else {
assert.NoError(err)
assert.Equal(tc.wantProviderID, id)
}
group, err := iClient.ResourceGroup(ctx)
if tc.wantResourceGroupErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantResourceGroup, group)
}
uid, err := iClient.UID(ctx)
if tc.wantUIDErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantUID, uid)
}
require.NoError(err)
assert.Equal(tc.wantResponse, resp)
})
}
}

View File

@ -3,9 +3,8 @@ package azure
import (
"context"
"errors"
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/applicationinsights/armapplicationinsights"
"github.com/edgelesssys/constellation/internal/azureshared"
"github.com/microsoft/ApplicationInsights-Go/appinsights"
)
@ -16,40 +15,24 @@ type Logger struct {
// NewLogger creates a new client to store information in Azure Application Insights
// https://github.com/Microsoft/ApplicationInsights-go
func NewLogger(ctx context.Context, metadata *Metadata) (*Logger, error) {
providerID, err := metadata.providerID(ctx)
component, err := metadata.getAppInsights(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("getting app insights: %w", err)
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
if err != nil {
return nil, err
}
uid, err := azureshared.UIDFromProviderID(providerID)
if err != nil {
return nil, err
}
resourceName := "constellation-insights-" + uid
resp, err := metadata.applicationInsightsAPI.Get(ctx, resourceGroup, resourceName, &armapplicationinsights.ComponentsClientGetOptions{})
if err != nil {
return nil, err
}
if resp.Properties == nil || resp.Properties.InstrumentationKey == nil {
if component.Properties == nil || component.Properties.InstrumentationKey == nil {
return nil, errors.New("unable to get instrumentation key")
}
client := appinsights.NewTelemetryClient(*resp.Properties.InstrumentationKey)
instance, err := metadata.GetInstance(ctx, providerID)
client := appinsights.NewTelemetryClient(*component.Properties.InstrumentationKey)
self, err := metadata.Self(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("getting self: %w", err)
}
client.Context().CommonProperties["instance-name"] = instance.Name
client.Context().CommonProperties["instance-name"] = self.Name
return &Logger{
client: client,
}, nil
return &Logger{client: client}, nil
}
// Disclose stores log information in Azure Application Insights!

View File

@ -2,7 +2,6 @@ package azure
import (
"context"
"errors"
"fmt"
"net/http"
"regexp"
@ -12,14 +11,12 @@ import (
armcomputev2 "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v2"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/edgelesssys/constellation/internal/azureshared"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
)
var (
publicIPAddressRegexp = regexp.MustCompile(`/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft.Network/publicIPAddresses/(?P<IPname>[^/]+)`)
keyPathRegexp = regexp.MustCompile(`^\/home\/([^\/]+)\/\.ssh\/authorized_keys$`)
resourceGroupNameRegexp = regexp.MustCompile(`^(.*)-([^-]+)$`)
publicIPAddressRegexp = regexp.MustCompile(`/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft.Network/publicIPAddresses/(?P<IPname>[^/]+)`)
keyPathRegexp = regexp.MustCompile(`^\/home\/([^\/]+)\/\.ssh\/authorized_keys$`)
)
// Metadata implements azure metadata APIs.
@ -48,11 +45,7 @@ func NewMetadata(ctx context.Context) (*Metadata, error) {
imdsAPI := imdsClient{
client: &http.Client{Transport: &http.Transport{Proxy: nil}},
}
instanceMetadata, err := imdsAPI.Retrieve(ctx)
if err != nil {
return nil, err
}
subscriptionID, _, err := azureshared.BasicsFromProviderID("azure://" + instanceMetadata.Compute.ResourceID)
subscriptionID, err := imdsAPI.SubscriptionID(ctx)
if err != nil {
return nil, err
}
@ -109,11 +102,7 @@ func NewMetadata(ctx context.Context) (*Metadata, error) {
// List retrieves all instances belonging to the current constellation.
func (m *Metadata) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
providerID, err := m.providerID(ctx)
if err != nil {
return nil, err
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return nil, err
}
@ -144,11 +133,7 @@ func (m *Metadata) GetInstance(ctx context.Context, providerID string) (metadata
// GetNetworkSecurityGroupName returns the security group name of the resource group.
func (m *Metadata) GetNetworkSecurityGroupName(ctx context.Context) (string, error) {
providerID, err := m.providerID(ctx)
if err != nil {
return "", err
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return "", err
}
@ -165,11 +150,7 @@ func (m *Metadata) GetNetworkSecurityGroupName(ctx context.Context) (string, err
// GetSubnetworkCIDR retrieves the subnetwork CIDR from cloud provider metadata.
func (m *Metadata) GetSubnetworkCIDR(ctx context.Context) (string, error) {
providerID, err := m.providerID(ctx)
if err != nil {
return "", err
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return "", err
}
@ -187,33 +168,16 @@ func (m *Metadata) GetSubnetworkCIDR(ctx context.Context) (string, error) {
// UID retrieves the UID of the constellation.
func (m *Metadata) UID(ctx context.Context) (string, error) {
providerID, err := m.providerID(ctx)
if err != nil {
return "", err
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
if err != nil {
return "", err
}
uid, err := getUIDFromResourceGroup(resourceGroup)
if err != nil {
return "", err
}
return uid, nil
return m.imdsAPI.UID(ctx)
}
// getLoadBalancer retrieves the load balancer from cloud provider metadata.
func (m *Metadata) getLoadBalancer(ctx context.Context) (*armnetwork.LoadBalancer, error) {
providerID, err := m.providerID(ctx)
if err != nil {
return nil, err
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return nil, err
}
pager := m.loadBalancerAPI.NewListPager(resourceGroup, nil)
for pager.More() {
@ -279,14 +243,11 @@ func (m *Metadata) GetLoadBalancerEndpoint(ctx context.Context) (string, error)
}
pubIPName := matches[1]
providerID, err := m.providerID(ctx)
if err != nil {
return "", err
}
_, resourceGroup, err := azureshared.BasicsFromProviderID(providerID)
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return "", err
}
resp, err := m.publicIPAddressesAPI.Get(ctx, resourceGroup, pubIPName, nil)
if err != nil {
return "", fmt.Errorf("could not retrieve public IP address: %w", err)
@ -304,11 +265,42 @@ func (m *Metadata) Supported() bool {
// providerID retrieves the current instances providerID.
func (m *Metadata) providerID(ctx context.Context) (string, error) {
instanceMetadata, err := m.imdsAPI.Retrieve(ctx)
providerID, err := m.imdsAPI.ProviderID(ctx)
if err != nil {
return "", err
}
return "azure://" + instanceMetadata.Compute.ResourceID, nil
return "azure://" + providerID, nil
}
func (m *Metadata) getAppInsights(ctx context.Context) (*armapplicationinsights.Component, error) {
resourceGroup, err := m.imdsAPI.ResourceGroup(ctx)
if err != nil {
return nil, err
}
uid, err := m.UID(ctx)
if err != nil {
return nil, err
}
pager := m.applicationInsightsAPI.NewListByResourceGroupPager(resourceGroup, nil)
for pager.More() {
nextResult, err := pager.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving application insights page: %w", err)
}
for _, component := range nextResult.Value {
if component == nil || component.Tags == nil {
continue
}
if *component.Tags["uid"] == uid {
return component, nil
}
}
}
return nil, fmt.Errorf("could not find correctly tagged application insights")
}
// extractInstanceTags converts azure tags into metadata key-value pairs.
@ -338,11 +330,3 @@ func extractSSHKeys(sshConfig armcomputev2.SSHConfiguration) map[string][]string
}
return sshKeys
}
func getUIDFromResourceGroup(resourceGroup string) (string, error) {
matches := resourceGroupNameRegexp.FindStringSubmatch(resourceGroup)
if len(matches) != 3 {
return "", errors.New("error splitting resource group name")
}
return matches[2], nil
}

View File

@ -32,23 +32,19 @@ func TestList(t *testing.T) {
wantInstances []metadata.InstanceMetadata
}{
"List works": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
networkInterfacesAPI: newNetworkInterfacesStub(),
scaleSetsAPI: newScaleSetsStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
tagsAPI: newTagsStub(),
wantInstances: wantInstances,
},
"providerID cannot be retrieved": {
imdsAPI: &stubIMDSAPI{retrieveErr: errors.New("imds err")},
wantErr: true,
},
"providerID cannot be parsed": {
imdsAPI: newInvalidIMDSStub(),
"imds resource group fails": {
imdsAPI: &stubIMDSAPI{resourceGroupErr: errors.New("failed")},
wantErr: true,
},
"listScaleSetVMs fails": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
networkInterfacesAPI: newNetworkInterfacesStub(),
scaleSetsAPI: newScaleSetsStub(),
virtualMachineScaleSetVMsAPI: newFailingListsVirtualMachineScaleSetsVMsStub(),
@ -96,17 +92,17 @@ func TestSelf(t *testing.T) {
wantInstance metadata.InstanceMetadata
}{
"self for scale set instance works": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{providerID: "/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"},
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
wantInstance: wantScaleSetInstance,
},
"providerID cannot be retrieved": {
imdsAPI: &stubIMDSAPI{retrieveErr: errors.New("imds err")},
imdsAPI: &stubIMDSAPI{providerIDErr: errors.New("failed")},
wantErr: true,
},
"GetInstance fails": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{providerID: wantScaleSetInstance.ProviderID},
virtualMachineScaleSetVMsAPI: &stubVirtualMachineScaleSetVMsAPI{getErr: errors.New("failed")},
wantErr: true,
},
@ -143,7 +139,7 @@ func TestGetNetworkSecurityGroupName(t *testing.T) {
wantErr bool
}{
"GetNetworkSecurityGroupName works": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
securityGroupsAPI: &stubSecurityGroupsAPI{
pager: &stubSecurityGroupsClientListPager{
list: []armnetwork.SecurityGroup{{Name: to.Ptr(name)}},
@ -152,14 +148,14 @@ func TestGetNetworkSecurityGroupName(t *testing.T) {
wantName: name,
},
"no security group": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
securityGroupsAPI: &stubSecurityGroupsAPI{
pager: &stubSecurityGroupsClientListPager{},
},
wantErr: true,
},
"missing name in security group struct": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
securityGroupsAPI: &stubSecurityGroupsAPI{
pager: &stubSecurityGroupsClientListPager{
list: []armnetwork.SecurityGroup{{}},
@ -198,7 +194,7 @@ func TestGetSubnetworkCIDR(t *testing.T) {
wantErr bool
}{
"GetSubnetworkCIDR works": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
virtualNetworksAPI: &stubVirtualNetworksAPI{
pager: &stubVirtualNetworksClientListPager{
list: []armnetwork.VirtualNetwork{{
@ -214,7 +210,7 @@ func TestGetSubnetworkCIDR(t *testing.T) {
wantNetworkCIDR: subnetworkCIDR,
},
"no virtual networks found": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
virtualNetworksAPI: &stubVirtualNetworksAPI{
pager: &stubVirtualNetworksClientListPager{},
},
@ -222,7 +218,7 @@ func TestGetSubnetworkCIDR(t *testing.T) {
wantNetworkCIDR: subnetworkCIDR,
},
"malformed network struct": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
virtualNetworksAPI: &stubVirtualNetworksAPI{
pager: &stubVirtualNetworksClientListPager{list: []armnetwork.VirtualNetwork{{}}},
},
@ -259,7 +255,7 @@ func TestGetLoadBalancerName(t *testing.T) {
wantErr bool
}{
"GetLoadBalancerName works": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
@ -271,14 +267,14 @@ func TestGetLoadBalancerName(t *testing.T) {
wantName: loadBalancerName,
},
"invalid load balancer struct": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{list: []armnetwork.LoadBalancer{{}}},
},
wantErr: true,
},
"invalid missing name": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{list: []armnetwork.LoadBalancer{{
Properties: &armnetwork.LoadBalancerPropertiesFormat{},
@ -320,7 +316,7 @@ func TestGetLoadBalancerEndpoint(t *testing.T) {
wantErr bool
}{
"GetLoadBalancerEndpoint works": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
@ -347,14 +343,14 @@ func TestGetLoadBalancerEndpoint(t *testing.T) {
wantIP: publicIP,
},
"no load balancer": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{},
},
wantErr: true,
},
"load balancer missing public IP reference": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
@ -368,7 +364,7 @@ func TestGetLoadBalancerEndpoint(t *testing.T) {
wantErr: true,
},
"public IP reference has wrong format": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
@ -390,7 +386,7 @@ func TestGetLoadBalancerEndpoint(t *testing.T) {
wantErr: true,
},
"no public IP address found": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
@ -411,7 +407,7 @@ func TestGetLoadBalancerEndpoint(t *testing.T) {
wantErr: true,
},
"found public IP has no address field": {
imdsAPI: newScaleSetIMDSStub(),
imdsAPI: &stubIMDSAPI{resourceGroup: "resourceGroup"},
loadBalancerAPI: &stubLoadBalancersAPI{
pager: &stubLoadBalancersClientListPager{
list: []armnetwork.LoadBalancer{{
@ -470,11 +466,11 @@ func TestProviderID(t *testing.T) {
wantProviderID string
}{
"providerID for scale set instance works": {
imdsAPI: newScaleSetIMDSStub(),
wantProviderID: "azure:///subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id",
imdsAPI: &stubIMDSAPI{providerID: "provider-id"},
wantProviderID: "azure://provider-id",
},
"imds retrieval fails": {
imdsAPI: &stubIMDSAPI{retrieveErr: errors.New("imds err")},
"imds providerID fails": {
imdsAPI: &stubIMDSAPI{providerIDErr: errors.New("failed")},
wantErr: true,
},
}
@ -505,28 +501,12 @@ func TestUID(t *testing.T) {
wantErr bool
wantUID string
}{
"uid extraction from providerID works": {
imdsAPI: &stubIMDSAPI{
res: metadataResponse{Compute: struct {
ResourceID string `json:"resourceId,omitempty"`
}{"/subscriptions/subscription-id/resourceGroups/basename-uid/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"}},
},
"success": {
imdsAPI: &stubIMDSAPI{uid: "uid"},
wantUID: "uid",
},
"providerID does not contain uid": {
imdsAPI: &stubIMDSAPI{
res: metadataResponse{Compute: struct {
ResourceID string `json:"resourceId,omitempty"`
}{"/subscriptions/subscription-id/resourceGroups/invalid/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"}},
},
wantErr: true,
},
"providerID is invalid": {
imdsAPI: newInvalidIMDSStub(),
wantErr: true,
},
"imds retrieval fails": {
imdsAPI: &stubIMDSAPI{retrieveErr: errors.New("imds err")},
"imds uid error": {
imdsAPI: &stubIMDSAPI{uidErr: errors.New("failed")},
wantErr: true,
},
}
@ -641,22 +621,6 @@ func TestExtractSSHKeys(t *testing.T) {
}
}
func newScaleSetIMDSStub() *stubIMDSAPI {
return &stubIMDSAPI{
res: metadataResponse{Compute: struct {
ResourceID string `json:"resourceId,omitempty"`
}{"/subscriptions/subscription-id/resourceGroups/resource-group/providers/Microsoft.Compute/virtualMachineScaleSets/scale-set-name/virtualMachines/instance-id"}},
}
}
func newInvalidIMDSStub() *stubIMDSAPI {
return &stubIMDSAPI{
res: metadataResponse{Compute: struct {
ResourceID string `json:"resourceId,omitempty"`
}{"invalid-resource-id"}},
}
}
func newNetworkInterfacesStub() *stubNetworkInterfacesAPI {
return &stubNetworkInterfacesAPI{
getInterface: armnetwork.Interface{

View File

@ -82,7 +82,6 @@ func TestListScaleSetVMs(t *testing.T) {
},
}
testCases := map[string]struct {
imdsAPI imdsAPI
networkInterfacesAPI networkInterfacesAPI
virtualMachineScaleSetVMsAPI virtualMachineScaleSetVMsAPI
scaleSetsAPI scaleSetsAPI
@ -90,35 +89,30 @@ func TestListScaleSetVMs(t *testing.T) {
wantInstances []metadata.InstanceMetadata
}{
"listVMs works": {
imdsAPI: newScaleSetIMDSStub(),
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
scaleSetsAPI: newScaleSetsStub(),
wantInstances: wantInstances,
},
"invalid scale sets are skipped": {
imdsAPI: newScaleSetIMDSStub(),
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newVirtualMachineScaleSetsVMsStub(),
scaleSetsAPI: newListContainingNilScaleSetStub(),
wantInstances: wantInstances,
},
"listVMs can return 0 VMs": {
imdsAPI: newScaleSetIMDSStub(),
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: &stubVirtualMachineScaleSetVMsAPI{pager: &stubVirtualMachineScaleSetVMPager{}},
scaleSetsAPI: newScaleSetsStub(),
wantInstances: []metadata.InstanceMetadata{},
},
"can skip nil in VM list": {
imdsAPI: newScaleSetIMDSStub(),
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newListContainingNilScaleSetVirtualMachinesStub(),
scaleSetsAPI: newScaleSetsStub(),
wantInstances: wantInstances,
},
"converting instance fails": {
imdsAPI: newScaleSetIMDSStub(),
networkInterfacesAPI: newNetworkInterfacesStub(),
virtualMachineScaleSetVMsAPI: newListContainingInvalidScaleSetVirtualMachinesStub(),
scaleSetsAPI: newScaleSetsStub(),
@ -132,7 +126,6 @@ func TestListScaleSetVMs(t *testing.T) {
require := require.New(t)
metadata := Metadata{
imdsAPI: tc.imdsAPI,
networkInterfacesAPI: tc.networkInterfacesAPI,
virtualMachineScaleSetVMsAPI: tc.virtualMachineScaleSetVMsAPI,
scaleSetsAPI: tc.scaleSetsAPI,