diff --git a/cli/internal/azure/client/api.go b/cli/internal/azure/client/api.go index f5fdf6de4..97b99154a 100644 --- a/cli/internal/azure/client/api.go +++ b/cli/internal/azure/client/api.go @@ -36,6 +36,9 @@ type loadBalancersAPI interface { } type scaleSetsAPI interface { + Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, + options *armcomputev2.VirtualMachineScaleSetsClientGetOptions, + ) (armcomputev2.VirtualMachineScaleSetsClientGetResponse, error) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, vmScaleSetName string, parameters armcomputev2.VirtualMachineScaleSet, options *armcomputev2.VirtualMachineScaleSetsClientBeginCreateOrUpdateOptions) ( diff --git a/cli/internal/azure/client/api_test.go b/cli/internal/azure/client/api_test.go index 88f5c4374..9ed03c7fa 100644 --- a/cli/internal/azure/client/api_test.go +++ b/cli/internal/azure/client/api_test.go @@ -136,6 +136,16 @@ type stubScaleSetsAPI struct { createErr error stubResponse armcomputev2.VirtualMachineScaleSetsClientCreateOrUpdateResponse pollErr error + getResponse armcomputev2.VirtualMachineScaleSet + getErr error +} + +func (a stubScaleSetsAPI) Get(ctx context.Context, resourceGroupName string, vmScaleSetName string, + options *armcomputev2.VirtualMachineScaleSetsClientGetOptions, +) (armcomputev2.VirtualMachineScaleSetsClientGetResponse, error) { + return armcomputev2.VirtualMachineScaleSetsClientGetResponse{ + VirtualMachineScaleSet: a.getResponse, + }, a.getErr } func (a stubScaleSetsAPI) BeginCreateOrUpdate(ctx context.Context, resourceGroupName string, diff --git a/cli/internal/azure/client/compute.go b/cli/internal/azure/client/compute.go index 428dd58bb..91cd67536 100644 --- a/cli/internal/azure/client/compute.go +++ b/cli/internal/azure/client/compute.go @@ -4,15 +4,22 @@ import ( "context" "errors" "fmt" + "net/http" "strconv" "sync" + "time" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources" "github.com/edgelesssys/constellation/cli/internal/azure" + "github.com/edgelesssys/constellation/cli/internal/azure/internal/poller" "github.com/edgelesssys/constellation/internal/cloud/cloudtypes" ) +// scaleSetCreateTimeout maximum timeout to wait for scale set creation. +const scaleSetCreateTimeout = 5 * time.Minute + func (c *Client) CreateInstances(ctx context.Context, input CreateInstancesInput) error { // Create worker scale set createWorkerInput := CreateScaleSetInput{ @@ -211,7 +218,7 @@ func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) LoadBalancerBackendAddressPool: input.LoadBalancerBackendAddressPool, }.Azure() - poller, err := c.scaleSetsAPI.BeginCreateOrUpdate( + _, err = c.scaleSetsAPI.BeginCreateOrUpdate( ctx, c.resourceGroup, input.Name, scaleSet, nil, @@ -220,14 +227,18 @@ func (c *Client) createScaleSet(ctx context.Context, input CreateScaleSetInput) return err } - _, err = poller.PollUntilDone(ctx, &runtime.PollUntilDoneOptions{ - Frequency: c.pollFrequency, + // use custom poller to wait for resource creation but skip waiting for OS provisioning. + // OS provisioning does not work reliably without the azure guest agent installed. + poller := poller.New[bool](&scaleSetCreationPollingHandler{ + resourceGroup: c.resourceGroup, + scaleSet: input.Name, + scaleSetsAPI: c.scaleSetsAPI, }) - if err != nil { - return err - } - return nil + pollCtx, cancel := context.WithTimeout(ctx, scaleSetCreateTimeout) + defer cancel() + _, err = poller.PollUntilDone(pollCtx, nil) + return err } func (c *Client) getInstanceIPs(ctx context.Context, scaleSet string, count int) (cloudtypes.Instances, error) { @@ -324,3 +335,41 @@ func (c *Client) TerminateResourceGroup(ctx context.Context) error { c.controlPlaneScaleSet = "" return nil } + +// scaleSetCreationPollingHandler is a custom poller used to check if a scale set was created successfully. +type scaleSetCreationPollingHandler struct { + done bool + resourceGroup string + scaleSet string + scaleSetsAPI scaleSetsAPI +} + +// Done returns true if the condition is met. +func (h *scaleSetCreationPollingHandler) Done() bool { + return h.done +} + +// Poll checks if the scale set resource was created successfully. +func (h *scaleSetCreationPollingHandler) Poll(ctx context.Context) error { + _, err := h.scaleSetsAPI.Get(ctx, h.resourceGroup, h.scaleSet, nil) + if err == nil { + h.done = true + return nil + } + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + // resource does not exist yet - retry later + return nil + } + return err +} + +// Result returns the result of the poller if the condition is met. +// If the condition is not met, an error is returned. +func (h *scaleSetCreationPollingHandler) Result(ctx context.Context, out *bool) error { + if !h.done { + return fmt.Errorf("failed to create scale set") + } + *out = h.done + return nil +} diff --git a/cli/internal/azure/client/compute_test.go b/cli/internal/azure/client/compute_test.go index 3cfcbdf33..bfee72753 100644 --- a/cli/internal/azure/client/compute_test.go +++ b/cli/internal/azure/client/compute_test.go @@ -167,7 +167,7 @@ func TestCreateInstances(t *testing.T) { "error when polling create scale set response": { publicIPAddressesAPI: stubPublicIPAddressesAPI{}, networkInterfacesAPI: stubNetworkInterfacesAPI{}, - scaleSetsAPI: stubScaleSetsAPI{pollErr: someErr}, + scaleSetsAPI: stubScaleSetsAPI{getErr: someErr}, resourceGroupAPI: newSuccessfulResourceGroupStub(), roleAssignmentsAPI: &stubRoleAssignmentsAPI{}, createInstancesInput: CreateInstancesInput{ diff --git a/cli/internal/azure/internal/poller/poller.go b/cli/internal/azure/internal/poller/poller.go new file mode 100644 index 000000000..87ddcae0d --- /dev/null +++ b/cli/internal/azure/internal/poller/poller.go @@ -0,0 +1,121 @@ +// Package poller implements a poller that can be used to wait for a condition to be met. +// The poller is designed to be a replacement for the azure-sdk-for-go poller +// with exponential backoff and an injectable clock. +// reference: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azcore@v1.1.1/runtime#Poller . +package poller + +import ( + "context" + "errors" + "time" + + "k8s.io/utils/clock" +) + +// PollUntilDoneOptions provides options for the Poller. +// Used to specify backoff and clock options. +type PollUntilDoneOptions struct { + StartingBackoff time.Duration + MaxBackoff time.Duration + clock.Clock +} + +// NewPollUntilDoneOptions creates a new PollUntilDoneOptions with the default values and a real clock. +func NewPollUntilDoneOptions() *PollUntilDoneOptions { + return &PollUntilDoneOptions{ + Clock: clock.RealClock{}, + } +} + +// Poller is a poller that can be used to wait for a condition to be met. +// The poller is designed to be a replacement for the azure-sdk-for-go poller +// with exponential backoff and an injectable clock. +// reference: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azcore@v1.1.1/runtime#Poller . +type Poller[T any] struct { + handler PollingHandler[T] + err error + result *T + done bool +} + +// New creates a new Poller. +func New[T any](handler PollingHandler[T]) *Poller[T] { + return &Poller[T]{ + handler: handler, + result: new(T), + } +} + +// PollUntilDone polls the handler until the condition is met or the context is cancelled. +func (p *Poller[T]) PollUntilDone(ctx context.Context, options *PollUntilDoneOptions) (T, error) { + if options == nil { + options = NewPollUntilDoneOptions() + } + if options.MaxBackoff == 0 { + options.MaxBackoff = time.Minute + } + if options.StartingBackoff < time.Second { + options.StartingBackoff = time.Second + } + backoff := options.StartingBackoff + for { + timer := options.Clock.NewTimer(backoff) + err := p.Poll(ctx) + if err != nil { + return *new(T), err + } + if p.Done() { + return p.Result(ctx) + } + select { + case <-ctx.Done(): + return *new(T), ctx.Err() + case <-timer.C(): + } + if backoff >= options.MaxBackoff/2 { + backoff = options.MaxBackoff + } else { + backoff *= 2 + } + } +} + +// Poll polls the handler. +func (p *Poller[T]) Poll(ctx context.Context) error { + return p.handler.Poll(ctx) +} + +// Done returns true if the condition is met. +func (p *Poller[T]) Done() bool { + return p.handler.Done() +} + +// Result returns the result of the poller if the condition is met. +// If the condition is not met, an error is returned. +func (p *Poller[T]) Result(ctx context.Context) (T, error) { + if !p.Done() { + return *new(T), errors.New("poller is in a non-terminal state") + } + if p.done { + // the result has already been retrieved, return the cached value + if p.err != nil { + return *new(T), p.err + } + return *p.result, nil + } + err := p.handler.Result(ctx, p.result) + p.done = true + if err != nil { + p.err = err + return *new(T), p.err + } + + return *p.result, nil +} + +// PollingHandler is a handler that can be used to poll for a condition to be met. +type PollingHandler[T any] interface { + Done() bool + Poll(ctx context.Context) error + Result(ctx context.Context, out *T) error +} diff --git a/cli/internal/azure/internal/poller/poller_test.go b/cli/internal/azure/internal/poller/poller_test.go new file mode 100644 index 000000000..96f686688 --- /dev/null +++ b/cli/internal/azure/internal/poller/poller_test.go @@ -0,0 +1,205 @@ +package poller + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/stretchr/testify/assert" + testclock "k8s.io/utils/clock/testing" +) + +func TestResult(t *testing.T) { + testCases := map[string]struct { + done bool + pollErr error + resultErr error + wantErr bool + wantResult int + }{ + "result called before poller is done": { + wantErr: true, + }, + "result returns error": { + done: true, + resultErr: errors.New("result error"), + wantErr: true, + }, + "result succeeds": { + done: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + poller := New[int](&stubPoller[int]{ + result: &tc.wantResult, + done: tc.done, + pollErr: tc.pollErr, + resultErr: tc.resultErr, + }) + _, firstErr := poller.Result(context.Background()) + if tc.wantErr { + assert.Error(firstErr) + // calling Result again should return the same error + _, secondErr := poller.Result(context.Background()) + assert.Equal(firstErr, secondErr) + return + } + assert.NoError(firstErr) + // calling Result again should still not return an error + _, secondErr := poller.Result(context.Background()) + assert.NoError(secondErr) + }) + } +} + +func TestPollUntilDone(t *testing.T) { + testCases := map[string]struct { + messages []message + maxBackoff time.Duration + resultErr error + wantErr bool + wantResult int + }{ + "poll succeeds on first try": { + messages: []message{ + {pollErr: to.Ptr[error](nil), done: to.Ptr(true)}, + {done: to.Ptr(true)}, // Result() will call Done() after the last poll + }, + wantResult: 1, + }, + "poll succeeds on fourth try": { + messages: []message{ + {pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second}, + {pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: 2 * time.Second}, + {pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: 4 * time.Second}, + {pollErr: to.Ptr[error](nil), done: to.Ptr(true)}, + {done: to.Ptr(true)}, // Result() will call Done() after the last poll + }, + wantResult: 1, + }, + "max backoff reached": { + messages: []message{ + {pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second}, + {pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second}, + {pollErr: to.Ptr[error](nil), done: to.Ptr(false), backoff: time.Second}, + {pollErr: to.Ptr[error](nil), done: to.Ptr(true)}, + {done: to.Ptr(true)}, // Result() will call Done() after the last poll + }, + maxBackoff: time.Second, + wantResult: 1, + }, + "poll errors": { + messages: []message{ + {pollErr: to.Ptr(errors.New("poll error"))}, + }, + wantErr: true, + }, + "result errors": { + messages: []message{ + {pollErr: to.Ptr[error](nil), done: to.Ptr(true)}, + {done: to.Ptr(true)}, // Result() will call Done() after the last poll + }, + resultErr: errors.New("result error"), + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + doneC := make(chan bool) + pollC := make(chan error) + poller := New[int](&fakePoller[int]{ + result: &tc.wantResult, + resultErr: tc.resultErr, + doneC: doneC, + pollC: pollC, + }) + clock := testclock.NewFakeClock(time.Now()) + var gotResult int + var gotErr error + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + gotResult, gotErr = poller.PollUntilDone(context.Background(), &PollUntilDoneOptions{ + MaxBackoff: tc.maxBackoff, + Clock: clock, + }) + }() + + for _, msg := range tc.messages { + if msg.pollErr != nil { + pollC <- *msg.pollErr + } + if msg.done != nil { + doneC <- *msg.done + } + clock.Step(msg.backoff) + } + + wg.Wait() + if tc.wantErr { + assert.Error(gotErr) + return + } + assert.NoError(gotErr) + assert.Equal(tc.wantResult, gotResult) + }) + } +} + +type stubPoller[T any] struct { + result *T + done bool + pollErr error + resultErr error +} + +func (s *stubPoller[T]) Poll(ctx context.Context) error { + return s.pollErr +} + +func (s *stubPoller[T]) Done() bool { + return s.done +} + +func (s *stubPoller[T]) Result(ctx context.Context, out *T) error { + *out = *s.result + return s.resultErr +} + +type message struct { + pollErr *error + done *bool + backoff time.Duration +} + +type fakePoller[T any] struct { + result *T + resultErr error + + doneC chan bool + pollC chan error +} + +func (s *fakePoller[T]) Poll(ctx context.Context) error { + return <-s.pollC +} + +func (s *fakePoller[T]) Done() bool { + return <-s.doneC +} + +func (s *fakePoller[T]) Result(ctx context.Context, out *T) error { + *out = *s.result + return s.resultErr +}