diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index ac2944880..d81f43500 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "encoding/base64" "errors" "fmt" @@ -62,14 +61,12 @@ func runInitialize(cmd *cobra.Command, args []string) error { protoClient := &proto.Client{} defer protoClient.Close() - // We have to parse the context separately, since cmd.Context() - // returns nil during the tests otherwise. - return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler) + return initialize(cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler) } // initialize initializes a Constellation. Coordinator instances are activated as contole-plane nodes and will // themself activate the other peers as workers. -func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator, +func initialize(cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator, fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler, ) error { flags, err := evalFlagArgs(cmd, fileHandler) @@ -107,7 +104,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser cmd.Print(validators.WarningsIncludeInit()) cmd.Println("Creating service account ...") - serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config) + serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config) if err != nil { return err } @@ -126,7 +123,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser if err := waiter.InitializeValidators(validators.V()); err != nil { return err } - if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil { + if err := waiter.WaitForAll(cmd.Context(), endpoints, coordinatorstate.AcceptingInit); err != nil { return fmt.Errorf("waiting for all peers status: %w", err) } @@ -145,7 +142,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser cloudServiceAccountURI: serviceAccount, sshUserKeys: ssh.ToProtoSlice(sshUsers), } - result, err := activate(ctx, cmd, protCl, input, validators.V()) + result, err := activate(cmd, protCl, input, validators.V()) if err != nil { return err } @@ -173,7 +170,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser return nil } -func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, +func activate(cmd *cobra.Command, client protoClient, input activationInput, validators []atls.Validator, ) (activationResult, error) { err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, strconv.Itoa(constants.CoordinatorPort)), validators) @@ -181,7 +178,7 @@ func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input return activationResult{}, err } - respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys) + respCl, err := client.Activate(cmd.Context(), input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys) if err != nil { return activationResult{}, err } diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 0172aac7b..9152324e5 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -305,8 +305,9 @@ func TestInitialize(t *testing.T) { ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() + cmd.SetContext(ctx) - err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler) + err := initialize(cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler) if tc.wantErr { assert.Error(err) @@ -608,9 +609,8 @@ func TestAutoscaleFlag(t *testing.T) { require.NoError(cmd.Flags().Set("privatekey", "privK")) require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag))) - ctx := context.Background() - require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler)) + require.NoError(initialize(cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler)) if tc.autoscaleFlag { assert.Len(tc.client.activateAutoscalingNodeGroups, 1) } else { diff --git a/cli/internal/cmd/recover.go b/cli/internal/cmd/recover.go index 8e20192f2..2d7738ab0 100644 --- a/cli/internal/cmd/recover.go +++ b/cli/internal/cmd/recover.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "encoding/base64" "errors" "fmt" @@ -43,10 +42,10 @@ func runRecover(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) recoveryClient := &proto.KeyClient{} defer recoveryClient.Close() - return recover(cmd.Context(), cmd, fileHandler, recoveryClient) + return recover(cmd, fileHandler, recoveryClient) } -func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error { +func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error { flags, err := parseRecoverFlags(cmd, fileHandler) if err != nil { return err @@ -79,7 +78,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, return err } - if err := recoveryClient.PushStateDiskKey(ctx, diskKey); err != nil { + if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey); err != nil { return err } diff --git a/cli/internal/cmd/recover_test.go b/cli/internal/cmd/recover_test.go index 50ce3706e..29df499a1 100644 --- a/cli/internal/cmd/recover_test.go +++ b/cli/internal/cmd/recover_test.go @@ -2,7 +2,6 @@ package cmd import ( "bytes" - "context" "errors" "testing" @@ -182,8 +181,7 @@ func TestRecover(t *testing.T) { require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone)) } - ctx := context.Background() - err := recover(ctx, cmd, fileHandler, tc.client) + err := recover(cmd, fileHandler, tc.client) if tc.wantErr { assert.Error(err) diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 6d5fee217..bc779b5b7 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -1,7 +1,6 @@ package cmd import ( - "context" "errors" "fmt" @@ -40,10 +39,10 @@ func runVerify(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) protoClient := &proto.Client{} defer protoClient.Close() - return verify(cmd.Context(), cmd, provider, fileHandler, protoClient) + return verify(cmd, provider, fileHandler, protoClient) } -func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error { +func verify(cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error { flags, err := parseVerifyFlags(cmd) if err != nil { return err @@ -69,7 +68,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil { return err } - if _, err := protoClient.GetState(ctx); err != nil { + if _, err := protoClient.GetState(cmd.Context()); err != nil { if err, ok := rpcStatus.FromError(err); ok { return fmt.Errorf("verifying Constellation cluster: %s", err.Message()) } diff --git a/cli/internal/cmd/verify_test.go b/cli/internal/cmd/verify_test.go index 418e6640f..02ef083f5 100644 --- a/cli/internal/cmd/verify_test.go +++ b/cli/internal/cmd/verify_test.go @@ -2,7 +2,6 @@ package cmd import ( "bytes" - "context" "encoding/base64" "errors" "testing" @@ -151,8 +150,7 @@ func TestVerify(t *testing.T) { } fileHandler := file.NewHandler(tc.setupFs(require)) - ctx := context.Background() - err := verify(ctx, cmd, tc.provider, fileHandler, tc.protoClient) + err := verify(cmd, tc.provider, fileHandler, tc.protoClient) if tc.wantErr { assert.Error(err)