Remove passing context seperately to initialize

This commit is contained in:
Nils Hanke 2022-06-28 11:19:03 +02:00 committed by Nils Hanke
parent 0653c20792
commit e3f78a5bff
6 changed files with 18 additions and 27 deletions

View File

@ -1,7 +1,6 @@
package cmd
import (
"context"
"encoding/base64"
"errors"
"fmt"
@ -62,14 +61,12 @@ func runInitialize(cmd *cobra.Command, args []string) error {
protoClient := &proto.Client{}
defer protoClient.Close()
// We have to parse the context separately, since cmd.Context()
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
return initialize(cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
}
// initialize initializes a Constellation. Coordinator instances are activated as contole-plane nodes and will
// themself activate the other peers as workers.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
func initialize(cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
) error {
flags, err := evalFlagArgs(cmd, fileHandler)
@ -107,7 +104,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
cmd.Print(validators.WarningsIncludeInit())
cmd.Println("Creating service account ...")
serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config)
serviceAccount, stat, err := serviceAccCreator.Create(cmd.Context(), stat, config)
if err != nil {
return err
}
@ -126,7 +123,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
if err := waiter.InitializeValidators(validators.V()); err != nil {
return err
}
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
if err := waiter.WaitForAll(cmd.Context(), endpoints, coordinatorstate.AcceptingInit); err != nil {
return fmt.Errorf("waiting for all peers status: %w", err)
}
@ -145,7 +142,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
cloudServiceAccountURI: serviceAccount,
sshUserKeys: ssh.ToProtoSlice(sshUsers),
}
result, err := activate(ctx, cmd, protCl, input, validators.V())
result, err := activate(cmd, protCl, input, validators.V())
if err != nil {
return err
}
@ -173,7 +170,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return nil
}
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
func activate(cmd *cobra.Command, client protoClient, input activationInput,
validators []atls.Validator,
) (activationResult, error) {
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, strconv.Itoa(constants.CoordinatorPort)), validators)
@ -181,7 +178,7 @@ func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input
return activationResult{}, err
}
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys)
respCl, err := client.Activate(cmd.Context(), input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys)
if err != nil {
return activationResult{}, err
}

View File

@ -305,8 +305,9 @@ func TestInitialize(t *testing.T) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
cmd.SetContext(ctx)
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
err := initialize(cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
if tc.wantErr {
assert.Error(err)
@ -608,9 +609,8 @@ func TestAutoscaleFlag(t *testing.T) {
require.NoError(cmd.Flags().Set("privatekey", "privK"))
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
require.NoError(initialize(cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else {

View File

@ -1,7 +1,6 @@
package cmd
import (
"context"
"encoding/base64"
"errors"
"fmt"
@ -43,10 +42,10 @@ func runRecover(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
recoveryClient := &proto.KeyClient{}
defer recoveryClient.Close()
return recover(cmd.Context(), cmd, fileHandler, recoveryClient)
return recover(cmd, fileHandler, recoveryClient)
}
func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error {
func recover(cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error {
flags, err := parseRecoverFlags(cmd, fileHandler)
if err != nil {
return err
@ -79,7 +78,7 @@ func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler,
return err
}
if err := recoveryClient.PushStateDiskKey(ctx, diskKey); err != nil {
if err := recoveryClient.PushStateDiskKey(cmd.Context(), diskKey); err != nil {
return err
}

View File

@ -2,7 +2,6 @@ package cmd
import (
"bytes"
"context"
"errors"
"testing"
@ -182,8 +181,7 @@ func TestRecover(t *testing.T) {
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
}
ctx := context.Background()
err := recover(ctx, cmd, fileHandler, tc.client)
err := recover(cmd, fileHandler, tc.client)
if tc.wantErr {
assert.Error(err)

View File

@ -1,7 +1,6 @@
package cmd
import (
"context"
"errors"
"fmt"
@ -40,10 +39,10 @@ func runVerify(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
protoClient := &proto.Client{}
defer protoClient.Close()
return verify(cmd.Context(), cmd, provider, fileHandler, protoClient)
return verify(cmd, provider, fileHandler, protoClient)
}
func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error {
func verify(cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error {
flags, err := parseVerifyFlags(cmd)
if err != nil {
return err
@ -69,7 +68,7 @@ func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Prov
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
return err
}
if _, err := protoClient.GetState(ctx); err != nil {
if _, err := protoClient.GetState(cmd.Context()); err != nil {
if err, ok := rpcStatus.FromError(err); ok {
return fmt.Errorf("verifying Constellation cluster: %s", err.Message())
}

View File

@ -2,7 +2,6 @@ package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"testing"
@ -151,8 +150,7 @@ func TestVerify(t *testing.T) {
}
fileHandler := file.NewHandler(tc.setupFs(require))
ctx := context.Background()
err := verify(ctx, cmd, tc.provider, fileHandler, tc.protoClient)
err := verify(cmd, tc.provider, fileHandler, tc.protoClient)
if tc.wantErr {
assert.Error(err)