Add autoscaling and cluster upgrade support for AWS (#1758)

* aws: autoscaling and upgrades

* docs: update scaling and upgrades for AWS

* deps: pin vuln check against release
This commit is contained in:
3u13r 2023-05-19 13:57:31 +02:00 committed by GitHub
parent 12ccfea543
commit 964775c4c2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 1720 additions and 44 deletions

View file

@ -0,0 +1,46 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//bazel/go:go_test.bzl", "go_test")
go_library(
name = "client",
srcs = [
"api.go",
"autoscaler.go",
"client.go",
"nodeimage.go",
"pendingnode.go",
"scalinggroup.go",
],
importpath = "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/internal/cloud/aws/client",
visibility = ["//operators/constellation-node-operator:__subpackages__"],
deps = [
"//operators/constellation-node-operator/api/v1alpha1",
"@com_github_aws_aws_sdk_go_v2_config//:config",
"@com_github_aws_aws_sdk_go_v2_feature_ec2_imds//:imds",
"@com_github_aws_aws_sdk_go_v2_service_autoscaling//:autoscaling",
"@com_github_aws_aws_sdk_go_v2_service_autoscaling//types",
"@com_github_aws_aws_sdk_go_v2_service_ec2//:ec2",
"@com_github_aws_aws_sdk_go_v2_service_ec2//types",
"@io_k8s_sigs_controller_runtime//pkg/log",
],
)
go_test(
name = "client_test",
srcs = [
"client_test.go",
"nodeimage_test.go",
"pendingnode_test.go",
"scalinggroup_test.go",
],
embed = [":client"],
deps = [
"//operators/constellation-node-operator/api/v1alpha1",
"@com_github_aws_aws_sdk_go_v2_service_autoscaling//:autoscaling",
"@com_github_aws_aws_sdk_go_v2_service_autoscaling//types",
"@com_github_aws_aws_sdk_go_v2_service_ec2//:ec2",
"@com_github_aws_aws_sdk_go_v2_service_ec2//types",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],
)

View file

@ -0,0 +1,28 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
"github.com/aws/aws-sdk-go-v2/service/ec2"
)
type ec2API interface {
DescribeInstances(ctx context.Context, params *ec2.DescribeInstancesInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error)
DescribeInstanceStatus(ctx context.Context, params *ec2.DescribeInstanceStatusInput, optFns ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error)
CreateLaunchTemplateVersion(ctx context.Context, params *ec2.CreateLaunchTemplateVersionInput, optFns ...func(*ec2.Options)) (*ec2.CreateLaunchTemplateVersionOutput, error)
ModifyLaunchTemplate(ctx context.Context, params *ec2.ModifyLaunchTemplateInput, optFns ...func(*ec2.Options)) (*ec2.ModifyLaunchTemplateOutput, error)
DescribeLaunchTemplateVersions(ctx context.Context, params *ec2.DescribeLaunchTemplateVersionsInput, optFns ...func(*ec2.Options)) (*ec2.DescribeLaunchTemplateVersionsOutput, error)
}
type scalingAPI interface {
DescribeAutoScalingGroups(ctx context.Context, params *autoscaling.DescribeAutoScalingGroupsInput, optFns ...func(*autoscaling.Options)) (*autoscaling.DescribeAutoScalingGroupsOutput, error)
SetDesiredCapacity(ctx context.Context, params *autoscaling.SetDesiredCapacityInput, optFns ...func(*autoscaling.Options)) (*autoscaling.SetDesiredCapacityOutput, error)
TerminateInstanceInAutoScalingGroup(ctx context.Context, params *autoscaling.TerminateInstanceInAutoScalingGroupInput, optFns ...func(*autoscaling.Options)) (*autoscaling.TerminateInstanceInAutoScalingGroupOutput, error)
}

View file

@ -0,0 +1,12 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
// AutoscalingCloudProvider returns the cloud-provider name as used by k8s cluster-autoscaler.
func (c *Client) AutoscalingCloudProvider() string {
return "aws"
}

View file

@ -0,0 +1,65 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"fmt"
"strings"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
"github.com/aws/aws-sdk-go-v2/service/ec2"
)
// Client is a client for the AWS Cloud.
type Client struct {
ec2Client ec2API
scalingClient scalingAPI
}
// New creates a client with initialized clients.
func New(ctx context.Context) (*Client, error) {
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load aws config: %w", err)
}
// get region from ec2metadata
imdsClient := imds.NewFromConfig(cfg)
regionOut, err := imdsClient.GetRegion(ctx, &imds.GetRegionInput{})
if err != nil {
return nil, fmt.Errorf("failed to get region from ec2metadata: %w", err)
}
return NewWithRegion(ctx, regionOut.Region)
}
// NewWithRegion creates a client with initialized clients and a given region.
func NewWithRegion(ctx context.Context, region string) (*Client, error) {
cfg, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, fmt.Errorf("failed to load aws config: %w", err)
}
cfg.Region = region
ec2Client := ec2.NewFromConfig(cfg)
scalingClient := autoscaling.NewFromConfig(cfg)
return &Client{
ec2Client: ec2Client,
scalingClient: scalingClient,
}, nil
}
func getInstanceNameFromProviderID(providerID string) (string, error) {
// aws:///us-east-2a/i-06888991e7138ed4e
providerIDParts := strings.Split(providerID, "/")
if len(providerIDParts) != 5 {
return "", fmt.Errorf("invalid providerID: %s", providerID)
}
return providerIDParts[4], nil
}

View file

@ -0,0 +1,50 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetInstanceNameFromProviderID(t *testing.T) {
testCases := map[string]struct {
providerID string
want string
wantErr bool
}{
"valid": {
providerID: "aws:///us-east-2a/i-06888991e7138ed4e",
want: "i-06888991e7138ed4e",
},
"too many parts": {
providerID: "aws:///us-east-2a/i-06888991e7138ed4e/invalid",
wantErr: true,
},
"too few parts": {
providerID: "aws:///us-east-2a",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
got, err := getInstanceNameFromProviderID(tc.providerID)
if tc.wantErr {
require.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.want, got)
})
}
}

View file

@ -0,0 +1,219 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"fmt"
"time"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
"github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/aws/aws-sdk-go-v2/service/ec2"
)
// GetNodeImage returns the image name of the node.
func (c *Client) GetNodeImage(ctx context.Context, providerID string) (string, error) {
instanceName, err := getInstanceNameFromProviderID(providerID)
if err != nil {
return "", fmt.Errorf("failed to get instance name from providerID: %w", err)
}
params := &ec2.DescribeInstancesInput{
InstanceIds: []string{
instanceName,
},
}
resp, err := c.ec2Client.DescribeInstances(ctx, params)
if err != nil {
return "", fmt.Errorf("failed to describe instances: %w", err)
}
if len(resp.Reservations) == 0 {
return "", fmt.Errorf("no reservations for instance %q", instanceName)
}
if len(resp.Reservations[0].Instances) == 0 {
return "", fmt.Errorf("no instances for instance %q", instanceName)
}
if resp.Reservations[0].Instances[0].ImageId == nil {
return "", fmt.Errorf("no image for instance %q", instanceName)
}
return *resp.Reservations[0].Instances[0].ImageId, nil
}
// GetScalingGroupID returns the scaling group ID of the node.
func (c *Client) GetScalingGroupID(ctx context.Context, providerID string) (string, error) {
instanceName, err := getInstanceNameFromProviderID(providerID)
if err != nil {
return "", fmt.Errorf("failed to get instance name from providerID: %w", err)
}
params := &ec2.DescribeInstancesInput{
InstanceIds: []string{
instanceName,
},
}
resp, err := c.ec2Client.DescribeInstances(ctx, params)
if err != nil {
return "", fmt.Errorf("failed to describe instances: %w", err)
}
if len(resp.Reservations) == 0 {
return "", fmt.Errorf("no reservations for instance %q", instanceName)
}
if len(resp.Reservations[0].Instances) == 0 {
return "", fmt.Errorf("no instances for instance %q", instanceName)
}
if resp.Reservations[0].Instances[0].Tags == nil {
return "", fmt.Errorf("no tags for instance %q", instanceName)
}
for _, tag := range resp.Reservations[0].Instances[0].Tags {
if tag.Key == nil || tag.Value == nil {
continue
}
if *tag.Key == "aws:autoscaling:groupName" {
return *tag.Value, nil
}
}
return "", fmt.Errorf("node %q does not have valid tags", providerID)
}
// CreateNode creates a node in the specified scaling group.
func (c *Client) CreateNode(ctx context.Context, scalingGroupID string) (nodeName, providerID string, err error) {
containsInstance := func(instances []types.Instance, target types.Instance) bool {
for _, i := range instances {
if i.InstanceId == nil || target.InstanceId == nil {
continue
}
if *i.InstanceId == *target.InstanceId {
return true
}
}
return false
}
// get current capacity
groups, err := c.scalingClient.DescribeAutoScalingGroups(
ctx,
&autoscaling.DescribeAutoScalingGroupsInput{
AutoScalingGroupNames: []string{scalingGroupID},
},
)
if err != nil {
return "", "", fmt.Errorf("failed to describe autoscaling group: %w", err)
}
if len(groups.AutoScalingGroups) != 1 {
return "", "", fmt.Errorf("expected exactly one autoscaling group, got %d", len(groups.AutoScalingGroups))
}
if groups.AutoScalingGroups[0].DesiredCapacity == nil {
return "", "", fmt.Errorf("desired capacity is nil")
}
currentCapacity := int(*groups.AutoScalingGroups[0].DesiredCapacity)
// check for int32 overflow
if currentCapacity >= int(^uint32(0)>>1) {
return "", "", fmt.Errorf("current capacity is at maximum")
}
// get current list of instances
previousInstances := groups.AutoScalingGroups[0].Instances
// create new instance by increasing capacity by 1
_, err = c.scalingClient.SetDesiredCapacity(
ctx,
&autoscaling.SetDesiredCapacityInput{
AutoScalingGroupName: &scalingGroupID,
DesiredCapacity: toPtr(int32(currentCapacity + 1)),
},
)
if err != nil {
return "", "", fmt.Errorf("failed to set desired capacity: %w", err)
}
// poll until new instance is created with 30 second timeout
newInstance := types.Instance{}
for i := 0; i < 30; i++ {
groups, err := c.scalingClient.DescribeAutoScalingGroups(
ctx,
&autoscaling.DescribeAutoScalingGroupsInput{
AutoScalingGroupNames: []string{scalingGroupID},
},
)
if err != nil {
return "", "", fmt.Errorf("failed to describe autoscaling group: %w", err)
}
if len(groups.AutoScalingGroups) != 1 {
return "", "", fmt.Errorf("expected exactly one autoscaling group, got %d", len(groups.AutoScalingGroups))
}
for _, instance := range groups.AutoScalingGroups[0].Instances {
if !containsInstance(previousInstances, instance) {
newInstance = instance
break
}
}
// break if new instance is found
if newInstance.InstanceId != nil {
break
}
// wait 1 second
select {
case <-ctx.Done():
return "", "", fmt.Errorf("context cancelled")
case <-time.After(1 * time.Second):
}
}
if newInstance.InstanceId == nil {
return "", "", fmt.Errorf("timed out waiting for new instance")
}
if newInstance.AvailabilityZone == nil {
return "", "", fmt.Errorf("new instance %s does not have availability zone", *newInstance.InstanceId)
}
// return new instance
return *newInstance.InstanceId, fmt.Sprintf("aws:///%s/%s", *newInstance.AvailabilityZone, *newInstance.InstanceId), nil
}
// DeleteNode deletes a node from the specified scaling group.
func (c *Client) DeleteNode(ctx context.Context, providerID string) error {
instanceID, err := getInstanceNameFromProviderID(providerID)
if err != nil {
return fmt.Errorf("failed to get instance name from providerID: %w", err)
}
_, err = c.scalingClient.TerminateInstanceInAutoScalingGroup(
ctx,
&autoscaling.TerminateInstanceInAutoScalingGroupInput{
InstanceId: &instanceID,
ShouldDecrementDesiredCapacity: toPtr(true),
},
)
if err != nil {
return fmt.Errorf("failed to terminate instance: %w", err)
}
return nil
}
func toPtr[T any](v T) *T {
return &v
}

View file

@ -0,0 +1,459 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"testing"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
autoscalingtypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetNodeImage(t *testing.T) {
ami := "ami-00000000000000000"
testCases := map[string]struct {
providerID string
describeInstancesErr error
describeInstancesOut *ec2.DescribeInstancesOutput
wantImage string
wantErr bool
}{
"getting node image works": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{
ImageId: &ami,
},
},
},
},
},
wantImage: ami,
},
"no reservations": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{},
},
wantErr: true,
},
"no instances": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{},
},
},
},
wantErr: true,
},
"no image": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{},
},
},
},
},
wantErr: true,
},
"error describing instances": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesErr: assert.AnError,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
ec2Client: &stubEC2API{
describeInstancesOut: tc.describeInstancesOut,
describeInstancesErr: tc.describeInstancesErr,
},
}
gotImage, err := client.GetNodeImage(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantImage, gotImage)
})
}
}
func TestGetScalingGroupID(t *testing.T) {
asgName := "my-asg"
testCases := map[string]struct {
providerID string
describeInstancesErr error
describeInstancesOut *ec2.DescribeInstancesOutput
wantASGID string
wantErr bool
}{
"getting node's tag works": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{
Tags: []ec2types.Tag{
{
Key: toPtr("aws:autoscaling:groupName"),
Value: &asgName,
},
},
},
},
},
},
},
wantASGID: asgName,
},
"no valid tags": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{
Tags: []ec2types.Tag{
{
Key: toPtr("foo"),
Value: toPtr("bar"),
},
},
},
},
},
},
},
wantErr: true,
},
"no reservations": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{},
},
wantErr: true,
},
"no instances": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{},
},
},
},
wantErr: true,
},
"no image": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesOut: &ec2.DescribeInstancesOutput{
Reservations: []ec2types.Reservation{
{
Instances: []ec2types.Instance{
{},
},
},
},
},
wantErr: true,
},
"error describing instances": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstancesErr: assert.AnError,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
ec2Client: &stubEC2API{
describeInstancesOut: tc.describeInstancesOut,
describeInstancesErr: tc.describeInstancesErr,
},
}
gotScalingID, err := client.GetScalingGroupID(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantASGID, gotScalingID)
})
}
}
func TestCreateNode(t *testing.T) {
testCases := map[string]struct {
providerID string
describeAutoscalingOutFirst *autoscaling.DescribeAutoScalingGroupsOutput
describeAutoscalingFirstErr error
describeAutoscalingOutSecond *autoscaling.DescribeAutoScalingGroupsOutput
describeAutoscalingSecondErr error
setDesiredCapacityErr error
wantNodeName string
wantProviderID string
wantErr bool
}{
"creating a new node works": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoscalingOutFirst: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("my-asg"),
Instances: []autoscalingtypes.Instance{
{
InstanceId: toPtr("i-00000000000000000"),
},
},
DesiredCapacity: toPtr(int32(1)),
},
},
},
describeAutoscalingOutSecond: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("my-asg"),
Instances: []autoscalingtypes.Instance{
{
InstanceId: toPtr("i-00000000000000000"),
},
{
InstanceId: toPtr("i-00000000000000001"),
AvailabilityZone: toPtr("us-east-2a"),
},
},
DesiredCapacity: toPtr(int32(2)),
},
},
},
wantNodeName: "i-00000000000000001",
wantProviderID: "aws:///us-east-2a/i-00000000000000001",
},
"creating a new node fails when describing the auto scaling group the first time": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoscalingFirstErr: assert.AnError,
wantErr: true,
},
"creating a new node fails when describing the auto scaling group the second time": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoscalingOutFirst: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("my-asg"),
Instances: []autoscalingtypes.Instance{
{
InstanceId: toPtr("i-00000000000000000"),
},
},
DesiredCapacity: toPtr(int32(1)),
},
},
},
describeAutoscalingSecondErr: assert.AnError,
wantErr: true,
},
"creating a new node fails when the auto scaling group is not found": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoscalingOutFirst: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{},
},
wantErr: true,
},
"creating a new node fails when set desired capacity fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoscalingOutFirst: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("my-asg"),
Instances: []autoscalingtypes.Instance{
{
InstanceId: toPtr("i-00000000000000000"),
},
},
DesiredCapacity: toPtr(int32(1)),
},
},
},
setDesiredCapacityErr: assert.AnError,
wantErr: true,
},
"creating a new node fails when the found vm does not contain an availability zone": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoscalingOutFirst: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("my-asg"),
Instances: []autoscalingtypes.Instance{
{
InstanceId: toPtr("i-00000000000000000"),
},
},
DesiredCapacity: toPtr(int32(1)),
},
},
},
describeAutoscalingOutSecond: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []autoscalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("my-asg"),
Instances: []autoscalingtypes.Instance{
{
InstanceId: toPtr("i-00000000000000000"),
},
{
InstanceId: toPtr("i-00000000000000001"),
},
},
DesiredCapacity: toPtr(int32(2)),
},
},
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
scalingClient: &stubAutoscalingAPI{
describeAutoScalingGroupsOut: []*autoscaling.DescribeAutoScalingGroupsOutput{
tc.describeAutoscalingOutFirst,
tc.describeAutoscalingOutSecond,
},
describeAutoScalingGroupsErr: []error{
tc.describeAutoscalingFirstErr,
tc.describeAutoscalingSecondErr,
},
setDesiredCapacityErr: tc.setDesiredCapacityErr,
},
}
nodeName, providerID, err := client.CreateNode(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantNodeName, nodeName)
assert.Equal(tc.wantProviderID, providerID)
})
}
}
func TestDeleteNode(t *testing.T) {
testCases := map[string]struct {
providerID string
terminateInstanceErr error
wantErr bool
}{
"deleting node works": {
providerID: "aws:///us-east-2a/i-00000000000000000",
},
"deleting node fails when terminating the instance fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
terminateInstanceErr: assert.AnError,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
scalingClient: &stubAutoscalingAPI{
terminateInstanceErr: tc.terminateInstanceErr,
},
}
err := client.DeleteNode(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
type stubEC2API struct {
describeInstancesOut *ec2.DescribeInstancesOutput
describeInstancesErr error
describeInstanceStatusOut *ec2.DescribeInstanceStatusOutput
describeInstanceStatusErr error
describeLaunchTemplateVersionsOut *ec2.DescribeLaunchTemplateVersionsOutput
describeLaunchTemplateVersionsErr error
createLaunchTemplateVersionOut *ec2.CreateLaunchTemplateVersionOutput
createLaunchTemplateVersionErr error
modifyLaunchTemplateErr error
}
func (a *stubEC2API) DescribeInstances(_ context.Context, _ *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) {
return a.describeInstancesOut, a.describeInstancesErr
}
func (a *stubEC2API) DescribeInstanceStatus(_ context.Context, _ *ec2.DescribeInstanceStatusInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstanceStatusOutput, error) {
return a.describeInstanceStatusOut, a.describeInstanceStatusErr
}
func (a *stubEC2API) CreateLaunchTemplateVersion(_ context.Context, _ *ec2.CreateLaunchTemplateVersionInput, _ ...func(*ec2.Options)) (*ec2.CreateLaunchTemplateVersionOutput, error) {
return a.createLaunchTemplateVersionOut, a.createLaunchTemplateVersionErr
}
func (a *stubEC2API) ModifyLaunchTemplate(_ context.Context, _ *ec2.ModifyLaunchTemplateInput, _ ...func(*ec2.Options)) (*ec2.ModifyLaunchTemplateOutput, error) {
return nil, a.modifyLaunchTemplateErr
}
func (a *stubEC2API) DescribeLaunchTemplateVersions(_ context.Context, _ *ec2.DescribeLaunchTemplateVersionsInput, _ ...func(*ec2.Options)) (*ec2.DescribeLaunchTemplateVersionsOutput, error) {
return a.describeLaunchTemplateVersionsOut, a.describeLaunchTemplateVersionsErr
}
type stubAutoscalingAPI struct {
describeAutoScalingGroupsOut []*autoscaling.DescribeAutoScalingGroupsOutput
describeAutoScalingGroupsErr []error
describeCounter int
setDesiredCapacityErr error
terminateInstanceErr error
}
func (a *stubAutoscalingAPI) DescribeAutoScalingGroups(_ context.Context, _ *autoscaling.DescribeAutoScalingGroupsInput, _ ...func(*autoscaling.Options)) (*autoscaling.DescribeAutoScalingGroupsOutput, error) {
out := a.describeAutoScalingGroupsOut[a.describeCounter]
err := a.describeAutoScalingGroupsErr[a.describeCounter]
a.describeCounter++
return out, err
}
func (a *stubAutoscalingAPI) SetDesiredCapacity(_ context.Context, _ *autoscaling.SetDesiredCapacityInput, _ ...func(*autoscaling.Options)) (*autoscaling.SetDesiredCapacityOutput, error) {
return nil, a.setDesiredCapacityErr
}
func (a *stubAutoscalingAPI) TerminateInstanceInAutoScalingGroup(_ context.Context, _ *autoscaling.TerminateInstanceInAutoScalingGroupInput, _ ...func(*autoscaling.Options)) (*autoscaling.TerminateInstanceInAutoScalingGroupOutput, error) {
return nil, a.terminateInstanceErr
}

View file

@ -0,0 +1,68 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"fmt"
"strings"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
updatev1alpha1 "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api/v1alpha1"
"sigs.k8s.io/controller-runtime/pkg/log"
)
// GetNodeState returns the state of the node.
func (c *Client) GetNodeState(ctx context.Context, providerID string) (updatev1alpha1.CSPNodeState, error) {
logr := log.FromContext(ctx)
logr.Info("GetNodeState", "providerID", providerID)
instanceName, err := getInstanceNameFromProviderID(providerID)
if err != nil {
return updatev1alpha1.NodeStateUnknown, fmt.Errorf("failed to get instance name from providerID: %w", err)
}
statusOut, err := c.ec2Client.DescribeInstanceStatus(ctx, &ec2.DescribeInstanceStatusInput{
InstanceIds: []string{instanceName},
IncludeAllInstances: toPtr(true),
})
if err != nil {
if strings.Contains(err.Error(), "InvalidInstanceID.NotFound") {
return updatev1alpha1.NodeStateTerminated, nil
}
return updatev1alpha1.NodeStateUnknown, err
}
if len(statusOut.InstanceStatuses) != 1 {
return updatev1alpha1.NodeStateUnknown, fmt.Errorf("expected 1 instance status, got %d", len(statusOut.InstanceStatuses))
}
if statusOut.InstanceStatuses[0].InstanceState == nil {
return updatev1alpha1.NodeStateUnknown, errors.New("instance state is nil")
}
// Translate AWS instance state to node state.
switch statusOut.InstanceStatuses[0].InstanceState.Name {
case ec2types.InstanceStateNameRunning:
return updatev1alpha1.NodeStateReady, nil
case ec2types.InstanceStateNameTerminated:
return updatev1alpha1.NodeStateTerminated, nil
case ec2types.InstanceStateNameShuttingDown:
return updatev1alpha1.NodeStateTerminating, nil
case ec2types.InstanceStateNameStopped:
return updatev1alpha1.NodeStateStopped, nil
// For "Stopping" we can only know the next state in the state machine
// so we preemptively set it to "Stopped".
case ec2types.InstanceStateNameStopping:
return updatev1alpha1.NodeStateStopped, nil
case ec2types.InstanceStateNamePending:
return updatev1alpha1.NodeStateCreating, nil
default:
return updatev1alpha1.NodeStateUnknown, fmt.Errorf("unknown instance state %q", statusOut.InstanceStatuses[0].InstanceState.Name)
}
}

View file

@ -0,0 +1,173 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"errors"
"testing"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
updatev1alpha1 "github.com/edgelesssys/constellation/v2/operators/constellation-node-operator/v2/api/v1alpha1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetNodeState(t *testing.T) {
testCases := map[string]struct {
providerID string
describeInstanceStatusOut *ec2.DescribeInstanceStatusOutput
describeInstanceStatusErr error
wantState updatev1alpha1.CSPNodeState
wantErr bool
}{
"getting node state works for running VM": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
},
},
},
},
wantState: updatev1alpha1.NodeStateReady,
},
"getting node state works for terminated VM": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameTerminated,
},
},
},
},
wantState: updatev1alpha1.NodeStateTerminated,
},
"getting node state works for stopping VM": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameStopping,
},
},
},
},
wantState: updatev1alpha1.NodeStateStopped,
},
"getting node state works for stopped VM": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameStopped,
},
},
},
},
wantState: updatev1alpha1.NodeStateStopped,
},
"getting node state works for pending VM": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: ec2types.InstanceStateNamePending,
},
},
},
},
wantState: updatev1alpha1.NodeStateCreating,
},
"getting node state works for shutting-down VM": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameShuttingDown,
},
},
},
},
wantState: updatev1alpha1.NodeStateTerminating,
},
"getting node state fails when the state is unknown": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: &ec2types.InstanceState{
Name: "unknown",
},
},
},
},
wantState: updatev1alpha1.NodeStateUnknown,
wantErr: true,
},
"cannot find instance": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusErr: errors.New("InvalidInstanceID.NotFound"),
wantState: updatev1alpha1.NodeStateTerminated,
},
"unknown error when describing the instance error": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusErr: assert.AnError,
wantState: updatev1alpha1.NodeStateUnknown,
wantErr: true,
},
"fails when getting no instances": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{},
},
wantState: updatev1alpha1.NodeStateUnknown,
wantErr: true,
},
"fails when the instance state is nil": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeInstanceStatusOut: &ec2.DescribeInstanceStatusOutput{
InstanceStatuses: []ec2types.InstanceStatus{
{
InstanceState: nil,
},
},
},
wantState: updatev1alpha1.NodeStateUnknown,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
ec2Client: &stubEC2API{
describeInstanceStatusOut: tc.describeInstanceStatusOut,
describeInstanceStatusErr: tc.describeInstanceStatusErr,
},
}
nodeState, err := client.GetNodeState(context.Background(), tc.providerID)
assert.Equal(tc.wantState, nodeState)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}

View file

@ -0,0 +1,174 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"fmt"
"strings"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
scalingtypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
)
// GetScalingGroupImage returns the image URI of the scaling group.
func (c *Client) GetScalingGroupImage(ctx context.Context, scalingGroupID string) (string, error) {
launchTemplate, err := c.getScalingGroupTemplate(ctx, scalingGroupID)
if err != nil {
return "", err
}
if launchTemplate.LaunchTemplateData == nil {
return "", fmt.Errorf("launch template data is nil for scaling group %q", scalingGroupID)
}
if launchTemplate.LaunchTemplateData.ImageId == nil {
return "", fmt.Errorf("image ID is nil for scaling group %q", scalingGroupID)
}
return *launchTemplate.LaunchTemplateData.ImageId, nil
}
// SetScalingGroupImage sets the image URI of the scaling group.
func (c *Client) SetScalingGroupImage(ctx context.Context, scalingGroupID, imageURI string) error {
launchTemplate, err := c.getScalingGroupTemplate(ctx, scalingGroupID)
if err != nil {
return fmt.Errorf("failed to get launch template for scaling group %q: %w", scalingGroupID, err)
}
if launchTemplate.VersionNumber == nil {
return fmt.Errorf("version number is nil for scaling group %q", scalingGroupID)
}
createLaunchTemplateOut, err := c.ec2Client.CreateLaunchTemplateVersion(
ctx,
&ec2.CreateLaunchTemplateVersionInput{
LaunchTemplateData: &ec2types.RequestLaunchTemplateData{
ImageId: &imageURI,
},
LaunchTemplateId: launchTemplate.LaunchTemplateId,
SourceVersion: toPtr(fmt.Sprintf("%d", *launchTemplate.VersionNumber)),
},
)
if err != nil {
return fmt.Errorf("failed to create launch template version: %w", err)
}
if createLaunchTemplateOut == nil {
return fmt.Errorf("create launch template version output is nil")
}
if createLaunchTemplateOut.LaunchTemplateVersion == nil {
return fmt.Errorf("created launch template version is nil")
}
if createLaunchTemplateOut.LaunchTemplateVersion.VersionNumber == nil {
return fmt.Errorf("created launch template version number is nil")
}
// set created version as default
_, err = c.ec2Client.ModifyLaunchTemplate(
ctx,
&ec2.ModifyLaunchTemplateInput{
LaunchTemplateId: launchTemplate.LaunchTemplateId,
DefaultVersion: toPtr(fmt.Sprintf("%d", createLaunchTemplateOut.LaunchTemplateVersion.VersionNumber)),
},
)
if err != nil {
return fmt.Errorf("failed to modify launch template: %w", err)
}
return nil
}
func (c *Client) getScalingGroupTemplate(ctx context.Context, scalingGroupID string) (ec2types.LaunchTemplateVersion, error) {
groupOutput, err := c.scalingClient.DescribeAutoScalingGroups(
ctx,
&autoscaling.DescribeAutoScalingGroupsInput{
AutoScalingGroupNames: []string{scalingGroupID},
},
)
if err != nil {
return ec2types.LaunchTemplateVersion{}, fmt.Errorf("failed to describe scaling group %q: %w", scalingGroupID, err)
}
if len(groupOutput.AutoScalingGroups) != 1 {
return ec2types.LaunchTemplateVersion{}, fmt.Errorf("expected exactly one scaling group, got %d", len(groupOutput.AutoScalingGroups))
}
if groupOutput.AutoScalingGroups[0].LaunchTemplate == nil {
return ec2types.LaunchTemplateVersion{}, fmt.Errorf("launch template is nil for scaling group %q", scalingGroupID)
}
if groupOutput.AutoScalingGroups[0].LaunchTemplate.LaunchTemplateId == nil {
return ec2types.LaunchTemplateVersion{}, fmt.Errorf("launch template ID is nil for scaling group %q", scalingGroupID)
}
launchTemplateID := groupOutput.AutoScalingGroups[0].LaunchTemplate.LaunchTemplateId
launchTemplateOutput, err := c.ec2Client.DescribeLaunchTemplateVersions(
ctx,
&ec2.DescribeLaunchTemplateVersionsInput{
LaunchTemplateId: launchTemplateID,
Versions: []string{"$Latest"},
},
)
if err != nil {
return ec2types.LaunchTemplateVersion{}, fmt.Errorf("failed to describe launch template %q: %w", *launchTemplateID, err)
}
if len(launchTemplateOutput.LaunchTemplateVersions) != 1 {
return ec2types.LaunchTemplateVersion{}, fmt.Errorf("expected exactly one launch template, got %d", len(launchTemplateOutput.LaunchTemplateVersions))
}
return launchTemplateOutput.LaunchTemplateVersions[0], nil
}
// GetScalingGroupName retrieves the name of a scaling group.
// This keeps the casing of the original name, but Kubernetes requires the name to be lowercase,
// so use strings.ToLower() on the result if using the name in a Kubernetes context.
func (c *Client) GetScalingGroupName(scalingGroupID string) (string, error) {
return strings.ToLower(scalingGroupID), nil
}
// GetAutoscalingGroupName retrieves the name of a scaling group as needed by the cluster-autoscaler.
func (c *Client) GetAutoscalingGroupName(scalingGroupID string) (string, error) {
return scalingGroupID, nil
}
// ListScalingGroups retrieves a list of scaling groups for the cluster.
func (c *Client) ListScalingGroups(ctx context.Context, uid string) (controlPlaneGroupIDs []string, workerGroupIDs []string, err error) {
output, err := c.scalingClient.DescribeAutoScalingGroups(
ctx,
&autoscaling.DescribeAutoScalingGroupsInput{
Filters: []scalingtypes.Filter{
{
Name: toPtr("tag:constellation-uid"),
Values: []string{uid},
},
},
},
)
if err != nil {
return nil, nil, fmt.Errorf("failed to describe scaling groups: %w", err)
}
for _, group := range output.AutoScalingGroups {
if group.Tags == nil {
continue
}
for _, tag := range group.Tags {
if *tag.Key == "constellation-role" {
if *tag.Value == "control-plane" {
controlPlaneGroupIDs = append(controlPlaneGroupIDs, *group.AutoScalingGroupName)
} else if *tag.Value == "worker" {
workerGroupIDs = append(workerGroupIDs, *group.AutoScalingGroupName)
}
}
}
}
return controlPlaneGroupIDs, workerGroupIDs, nil
}

View file

@ -0,0 +1,306 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package client
import (
"context"
"testing"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
scalingtypes "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGetScalingGroupImage(t *testing.T) {
testCases := map[string]struct {
providerID string
describeAutoScalingGroupsOut *autoscaling.DescribeAutoScalingGroupsOutput
describeAutoScalingGroupsErr error
describeLaunchTemplateVersionsOut *ec2.DescribeLaunchTemplateVersionsOutput
describeLaunchTemplateVersionsErr error
wantImage string
wantErr bool
}{
"getting scaling group image works": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []scalingtypes.AutoScalingGroup{
{
LaunchTemplate: &scalingtypes.LaunchTemplateSpecification{
LaunchTemplateId: toPtr("lt-00000000000000000"),
},
},
},
},
describeLaunchTemplateVersionsOut: &ec2.DescribeLaunchTemplateVersionsOutput{
LaunchTemplateVersions: []ec2types.LaunchTemplateVersion{
{
LaunchTemplateData: &ec2types.ResponseLaunchTemplateData{
ImageId: toPtr("ami-00000000000000000"),
},
},
},
},
wantImage: "ami-00000000000000000",
},
"fails when describing autoscaling group fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsErr: assert.AnError,
wantErr: true,
},
"fails when describing launch template versions fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []scalingtypes.AutoScalingGroup{
{
LaunchTemplate: &scalingtypes.LaunchTemplateSpecification{
LaunchTemplateId: toPtr("lt-00000000000000000"),
},
},
},
},
describeLaunchTemplateVersionsErr: assert.AnError,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
ec2Client: &stubEC2API{
describeLaunchTemplateVersionsOut: tc.describeLaunchTemplateVersionsOut,
describeLaunchTemplateVersionsErr: tc.describeLaunchTemplateVersionsErr,
},
scalingClient: &stubAutoscalingAPI{
describeAutoScalingGroupsOut: []*autoscaling.DescribeAutoScalingGroupsOutput{
tc.describeAutoScalingGroupsOut,
},
describeAutoScalingGroupsErr: []error{
tc.describeAutoScalingGroupsErr,
},
},
}
scalingGroupImage, err := client.GetScalingGroupImage(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantImage, scalingGroupImage)
})
}
}
func TestSetScalingGroupImage(t *testing.T) {
testCases := map[string]struct {
providerID string
describeAutoScalingGroupsOut *autoscaling.DescribeAutoScalingGroupsOutput
describeAutoScalingGroupsErr error
describeLaunchTemplateVersionsOut *ec2.DescribeLaunchTemplateVersionsOutput
describeLaunchTemplateVersionsErr error
createLaunchTemplateVersionOut *ec2.CreateLaunchTemplateVersionOutput
createLaunchTemplateVersionErr error
modifyLaunchTemplateErr error
imageURI string
wantErr bool
}{
"getting scaling group image works": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []scalingtypes.AutoScalingGroup{
{
LaunchTemplate: &scalingtypes.LaunchTemplateSpecification{
LaunchTemplateId: toPtr("lt-00000000000000000"),
},
},
},
},
describeLaunchTemplateVersionsOut: &ec2.DescribeLaunchTemplateVersionsOutput{
LaunchTemplateVersions: []ec2types.LaunchTemplateVersion{
{
LaunchTemplateData: &ec2types.ResponseLaunchTemplateData{
ImageId: toPtr("ami-00000000000000000"),
},
VersionNumber: toPtr(int64(1)),
},
},
},
createLaunchTemplateVersionOut: &ec2.CreateLaunchTemplateVersionOutput{
LaunchTemplateVersion: &ec2types.LaunchTemplateVersion{
VersionNumber: toPtr(int64(2)),
},
},
imageURI: "ami-00000000000000000",
},
"fails when creating launch template version fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []scalingtypes.AutoScalingGroup{
{
LaunchTemplate: &scalingtypes.LaunchTemplateSpecification{
LaunchTemplateId: toPtr("lt-00000000000000000"),
},
},
},
},
describeLaunchTemplateVersionsOut: &ec2.DescribeLaunchTemplateVersionsOutput{
LaunchTemplateVersions: []ec2types.LaunchTemplateVersion{
{
LaunchTemplateData: &ec2types.ResponseLaunchTemplateData{
ImageId: toPtr("ami-00000000000000000"),
},
VersionNumber: toPtr(int64(1)),
},
},
},
imageURI: "ami-00000000000000000",
createLaunchTemplateVersionErr: assert.AnError,
wantErr: true,
},
"fails when modifying launch template fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: &autoscaling.DescribeAutoScalingGroupsOutput{
AutoScalingGroups: []scalingtypes.AutoScalingGroup{
{
LaunchTemplate: &scalingtypes.LaunchTemplateSpecification{
LaunchTemplateId: toPtr("lt-00000000000000000"),
},
},
},
},
describeLaunchTemplateVersionsOut: &ec2.DescribeLaunchTemplateVersionsOutput{
LaunchTemplateVersions: []ec2types.LaunchTemplateVersion{
{
LaunchTemplateData: &ec2types.ResponseLaunchTemplateData{
ImageId: toPtr("ami-00000000000000000"),
},
VersionNumber: toPtr(int64(1)),
},
},
},
imageURI: "ami-00000000000000000",
modifyLaunchTemplateErr: assert.AnError,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
ec2Client: &stubEC2API{
describeLaunchTemplateVersionsOut: tc.describeLaunchTemplateVersionsOut,
describeLaunchTemplateVersionsErr: tc.describeLaunchTemplateVersionsErr,
createLaunchTemplateVersionOut: tc.createLaunchTemplateVersionOut,
createLaunchTemplateVersionErr: tc.createLaunchTemplateVersionErr,
modifyLaunchTemplateErr: tc.modifyLaunchTemplateErr,
},
scalingClient: &stubAutoscalingAPI{
describeAutoScalingGroupsOut: []*autoscaling.DescribeAutoScalingGroupsOutput{
tc.describeAutoScalingGroupsOut,
},
describeAutoScalingGroupsErr: []error{
tc.describeAutoScalingGroupsErr,
},
},
}
err := client.SetScalingGroupImage(context.Background(), tc.providerID, tc.imageURI)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
func TestListScalingGroups(t *testing.T) {
testCases := map[string]struct {
providerID string
describeAutoScalingGroupsOut []*autoscaling.DescribeAutoScalingGroupsOutput
describeAutoScalingGroupsErr []error
wantControlPlaneGroupIDs []string
wantWorkerGroupIDs []string
wantErr bool
}{
"listing scaling groups work": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: []*autoscaling.DescribeAutoScalingGroupsOutput{
{
AutoScalingGroups: []scalingtypes.AutoScalingGroup{
{
AutoScalingGroupName: toPtr("control-plane-asg"),
Tags: []scalingtypes.TagDescription{
{
Key: toPtr("constellation-role"),
Value: toPtr("control-plane"),
},
},
},
{
AutoScalingGroupName: toPtr("worker-asg"),
Tags: []scalingtypes.TagDescription{
{
Key: toPtr("constellation-role"),
Value: toPtr("worker"),
},
},
},
{
AutoScalingGroupName: toPtr("worker-asg-2"),
Tags: []scalingtypes.TagDescription{
{
Key: toPtr("constellation-role"),
Value: toPtr("worker"),
},
},
},
{
AutoScalingGroupName: toPtr("other-asg"),
},
},
},
},
describeAutoScalingGroupsErr: []error{nil},
wantControlPlaneGroupIDs: []string{"control-plane-asg"},
wantWorkerGroupIDs: []string{"worker-asg", "worker-asg-2"},
},
"fails when describing scaling groups fails": {
providerID: "aws:///us-east-2a/i-00000000000000000",
describeAutoScalingGroupsOut: []*autoscaling.DescribeAutoScalingGroupsOutput{nil},
describeAutoScalingGroupsErr: []error{assert.AnError},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := Client{
scalingClient: &stubAutoscalingAPI{
describeAutoScalingGroupsOut: tc.describeAutoScalingGroupsOut,
describeAutoScalingGroupsErr: tc.describeAutoScalingGroupsErr,
},
}
controlPlaneGroupIDs, workerGroupIDs, err := client.ListScalingGroups(context.Background(), tc.providerID)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantControlPlaneGroupIDs, controlPlaneGroupIDs)
assert.Equal(tc.wantWorkerGroupIDs, workerGroupIDs)
})
}
}