constellation/cli/ec2/client/instances.go
Leonard Cohnen 2d8fcd9bf4 monorepo
Co-authored-by: Malte Poll <mp@edgeless.systems>
Co-authored-by: katexochen <katexochen@users.noreply.github.com>
Co-authored-by: Daniel Weiße <dw@edgeless.systems>
Co-authored-by: Thomas Tendyck <tt@edgeless.systems>
Co-authored-by: Benedict Schlueter <bs@edgeless.systems>
Co-authored-by: leongross <leon.gross@rub.de>
Co-authored-by: Moritz Eckert <m1gh7ym0@gmail.com>
2022-03-22 16:09:39 +01:00

200 lines
5.8 KiB
Go

package client
import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/edgelesssys/constellation/cli/ec2"
)
// CreateInstances creates the instances defined in input.
//
// An existing security group is needed to create instances.
func (c *Client) CreateInstances(ctx context.Context, input CreateInput) error {
if c.securityGroup == "" {
return errors.New("no security group set")
}
input.securityGroupIds = []string{c.securityGroup}
if err := c.createDryRun(ctx, input); err != nil {
return err
}
resp, err := c.api.RunInstances(ctx, input.AWS())
if err != nil {
return fmt.Errorf("failed to create instances: %w", err)
}
for _, instance := range resp.Instances {
id := instance.InstanceId
if id == nil {
return errors.New("instanceId is nil pointer")
}
c.instances[*id] = ec2.Instance{}
}
if err := c.waitStateRunning(ctx); err != nil {
return err
}
if err := c.tagInstances(ctx, input.Tags); err != nil {
return err
}
if err := c.getInstanceIPs(ctx); err != nil {
return err
}
return nil
}
// TerminateInstances terminates all instances of a Client.
func (c *Client) TerminateInstances(ctx context.Context) error {
if len(c.instances) == 0 {
return nil
}
input := &awsec2.TerminateInstancesInput{
InstanceIds: c.instances.IDs(),
}
if err := c.terminateDryRun(ctx, *input); err != nil {
return err
}
if _, err := c.api.TerminateInstances(ctx, input); err != nil {
return err
}
if err := c.waitStateTerminated(ctx); err != nil {
return err
}
c.instances = ec2.Instances{}
return nil
}
// waitStateRunning waits until all the client's instances reached the running state.
//
// A set of instances is also considered to be running if at least one of the
// instances' state is 'running' and the other instances have a nil state.
func (c *Client) waitStateRunning(ctx context.Context) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
waiter := awsec2.NewInstanceRunningWaiter(c.api)
return waiter.Wait(ctx, describeInput, c.timeout)
}
// waitStateTerminated waits until all the client's instances reached the terminated state.
//
// A set of instances is also considered to be terminated if at least one of the
// instances' state is 'terminated' and the other instances have a nil state.
func (c *Client) waitStateTerminated(ctx context.Context) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
waiter := awsec2.NewInstanceTerminatedWaiter(c.api)
return waiter.Wait(ctx, describeInput, c.timeout)
}
// tagInstances tags all instances of a client with a given set of tags.
func (c *Client) tagInstances(ctx context.Context, tags ec2.Tags) error {
if len(c.instances) == 0 {
return errors.New("client has no instances")
}
tagInput := &awsec2.CreateTagsInput{
Resources: c.instances.IDs(),
Tags: tags.AWS(),
}
if _, err := c.api.CreateTags(ctx, tagInput); err != nil {
return fmt.Errorf("failed to tag instances: %w", err)
}
return nil
}
// createDryRun checks if user has the privilege to create the instances
// which were defined in input.
func (c *Client) createDryRun(ctx context.Context, input CreateInput) error {
runInput := input.AWS()
runInput.DryRun = aws.Bool(true)
_, err := c.api.RunInstances(ctx, runInput)
return checkDryRunError(err)
}
// terminateDryRun checks if user has the privilege to terminate the instances
// which were defined in input.
func (c *Client) terminateDryRun(ctx context.Context, input awsec2.TerminateInstancesInput) error {
input.DryRun = aws.Bool(true)
_, err := c.api.TerminateInstances(ctx, &input)
return checkDryRunError(err)
}
// getInstanceIPs queries the private and public IP addresses
// and adds the information to each instance.
//
// The instances must be in 'running' state.
func (c *Client) getInstanceIPs(ctx context.Context) error {
describeInput := &awsec2.DescribeInstancesInput{
InstanceIds: c.instances.IDs(),
}
paginator := awsec2.NewDescribeInstancesPaginator(c.api, describeInput)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
return err
}
for _, reservation := range output.Reservations {
for _, instanceDescription := range reservation.Instances {
if instanceDescription.InstanceId == nil {
return errors.New("instanceId is nil pointer")
}
if instanceDescription.PublicIpAddress == nil {
return errors.New("publicIpAddress is nil pointer")
}
if instanceDescription.PrivateIpAddress == nil {
return errors.New("privateIpAddress is nil pointer")
}
instance, ok := c.instances[*instanceDescription.InstanceId]
if !ok {
return errors.New("got an instance description to an unknown instanceId")
}
instance.PublicIP = *instanceDescription.PublicIpAddress
instance.PrivateIP = *instanceDescription.PrivateIpAddress
c.instances[*instanceDescription.InstanceId] = instance
}
}
}
return nil
}
// CreateInput defines the propertis of the instances to create.
type CreateInput struct {
ImageId string
InstanceType string
Count int
Tags ec2.Tags
securityGroupIds []string
}
// AWS creates a AWS ec2.RunInstancesInput from an CreateInput.
func (ci *CreateInput) AWS() *awsec2.RunInstancesInput {
return &awsec2.RunInstancesInput{
ImageId: aws.String(ci.ImageId),
InstanceType: ec2.InstanceTypes[ci.InstanceType],
MaxCount: aws.Int32(int32(ci.Count)),
MinCount: aws.Int32(int32(ci.Count)),
EnclaveOptions: &types.EnclaveOptionsRequest{Enabled: aws.Bool(true)},
SecurityGroupIds: ci.securityGroupIds,
}
}