Remove passing context seperately to initialize

This commit is contained in:
Nils Hanke 2022-06-28 11:19:03 +02:00 committed by Nils Hanke
parent 0653c20792
commit e3f78a5bff
6 changed files with 18 additions and 27 deletions

View File

@ -1,7 +1,6 @@
package cmd package cmd
import ( import (
"context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -62,14 +61,12 @@ func runInitialize(cmd *cobra.Command, args []string) error {
protoClient := &proto.Client{} protoClient := &proto.Client{}
defer protoClient.Close() defer protoClient.Close()
// We have to parse the context separately, since cmd.Context() return initialize(cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
} }
// initialize initializes a Constellation. Coordinator instances are activated as contole-plane nodes and will // initialize initializes a Constellation. Coordinator instances are activated as contole-plane nodes and will
// themself activate the other peers as workers. // themself activate the other peers as workers.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator, func initialize(cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler, fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
) error { ) error {
flags, err := evalFlagArgs(cmd, fileHandler) flags, err := evalFlagArgs(cmd, fileHandler)
@ -107,7 +104,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
cmd.Print(validators.WarningsIncludeInit()) cmd.Print(validators.WarningsIncludeInit())
cmd.Println("Creating service account ...") cmd.Println("Creating service account ...")
serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config) serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config)
if err != nil { if err != nil {
return err return err
} }
@ -126,7 +123,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
if err := waiter.InitializeValidators(validators.V()); err != nil { if err := waiter.InitializeValidators(validators.V()); err != nil {
return err return err
} }
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil { if err := waiter.WaitForAll(cmd.Context(), endpoints, coordinatorstate.AcceptingInit); err != nil {
return fmt.Errorf("waiting for all peers status: %w", err) return fmt.Errorf("waiting for all peers status: %w", err)
} }
@ -145,7 +142,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
cloudServiceAccountURI: serviceAccount, cloudServiceAccountURI: serviceAccount,
sshUserKeys: ssh.ToProtoSlice(sshUsers), sshUserKeys: ssh.ToProtoSlice(sshUsers),
} }
result, err := activate(ctx, cmd, protCl, input, validators.V()) result, err := activate(cmd, protCl, input, validators.V())
if err != nil { if err != nil {
return err return err
} }
@ -173,7 +170,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return nil return nil
} }
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, func activate(cmd *cobra.Command, client protoClient, input activationInput,
validators []atls.Validator, validators []atls.Validator,
) (activationResult, error) { ) (activationResult, error) {
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, strconv.Itoa(constants.CoordinatorPort)), validators) err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, strconv.Itoa(constants.CoordinatorPort)), validators)
@ -181,7 +178,7 @@ func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input
return activationResult{}, err return activationResult{}, err
} }
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys) respCl, err := client.Activate(cmd.Context(), input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys)
if err != nil { if err != nil {
return activationResult{}, err return activationResult{}, err
} }

View File

@ -305,8 +305,9 @@ func TestInitialize(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second) ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel() defer cancel()
cmd.SetContext(ctx)
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler) err := initialize(cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
@ -608,9 +609,8 @@ func TestAutoscaleFlag(t *testing.T) {
require.NoError(cmd.Flags().Set("privatekey", "privK")) require.NoError(cmd.Flags().Set("privatekey", "privK"))
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag))) require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler)) require.NoError(initialize(cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
if tc.autoscaleFlag { if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1) assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else { } else {

View File

@ -1,7 +1,6 @@
package cmd package cmd
import ( import (
"context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
@ -43,10 +42,10 @@ func runRecover(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
recoveryClient := &proto.KeyClient{} recoveryClient := &proto.KeyClient{}
defer recoveryClient.Close() defer recoveryClient.Close()
return recover(cmd.Context(), cmd, fileHandler, recoveryClient) return recover(cmd, fileHandler, recoveryClient)
} }
func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error { func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error {
flags, err := parseRecoverFlags(cmd, fileHandler) flags, err := parseRecoverFlags(cmd, fileHandler)
if err != nil { if err != nil {
return err return err
@ -79,7 +78,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
return err return err
} }
if err := recoveryClient.PushStateDiskKey(ctx, diskKey); err != nil { if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey); err != nil {
return err return err
} }

View File

@ -2,7 +2,6 @@ package cmd
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"testing" "testing"
@ -182,8 +181,7 @@ func TestRecover(t *testing.T) {
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone)) require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
} }
ctx := context.Background() err := recover(cmd, fileHandler, tc.client)
err := recover(ctx, cmd, fileHandler, tc.client)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)

View File

@ -1,7 +1,6 @@
package cmd package cmd
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
@ -40,10 +39,10 @@ func runVerify(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs()) fileHandler := file.NewHandler(afero.NewOsFs())
protoClient := &proto.Client{} protoClient := &proto.Client{}
defer protoClient.Close() defer protoClient.Close()
return verify(cmd.Context(), cmd, provider, fileHandler, protoClient) return verify(cmd, provider, fileHandler, protoClient)
} }
func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error { func verify(cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error {
flags, err := parseVerifyFlags(cmd) flags, err := parseVerifyFlags(cmd)
if err != nil { if err != nil {
return err return err
@ -69,7 +68,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil { if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
return err return err
} }
if _, err := protoClient.GetState(ctx); err != nil { if _, err := protoClient.GetState(cmd.Context()); err != nil {
if err, ok := rpcStatus.FromError(err); ok { if err, ok := rpcStatus.FromError(err); ok {
return fmt.Errorf("verifying Constellation cluster: %s", err.Message()) return fmt.Errorf("verifying Constellation cluster: %s", err.Message())
} }

View File

@ -2,7 +2,6 @@ package cmd
import ( import (
"bytes" "bytes"
"context"
"encoding/base64" "encoding/base64"
"errors" "errors"
"testing" "testing"
@ -151,8 +150,7 @@ func TestVerify(t *testing.T) {
} }
fileHandler := file.NewHandler(tc.setupFs(require)) fileHandler := file.NewHandler(tc.setupFs(require))
ctx := context.Background() err := verify(cmd, tc.provider, fileHandler, tc.protoClient)
err := verify(ctx, cmd, tc.provider, fileHandler, tc.protoClient)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)