From 4496755c644c899104c8e2a7570278f4080876f7 Mon Sep 17 00:00:00 2001 From: katexochen <49727155+katexochen@users.noreply.github.com> Date: Wed, 13 Apr 2022 15:01:02 +0200 Subject: [PATCH] Fix PCR handling --- cli/cmd/init.go | 75 +++++++++--------- cli/cmd/init_test.go | 142 ++++++++++++----------------------- cli/cmd/protoclient.go | 2 +- cli/cmd/protoclient_test.go | 4 +- cli/cmd/statuswaiter.go | 1 + cli/cmd/statuswaiter_test.go | 13 +++- cli/proto/client.go | 31 +++----- cli/proto/client_test.go | 2 +- cli/status/status.go | 42 +++++++---- cli/status/status_test.go | 50 ++++++++---- util/pcr-reader/main.go | 3 +- 11 files changed, 182 insertions(+), 183 deletions(-) diff --git a/cli/cmd/init.go b/cli/cmd/init.go index cbd8a4d28..b04af7c0a 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -49,38 +49,33 @@ func newInitCmd() *cobra.Command { func runInitialize(cmd *cobra.Command, args []string) error { fileHandler := file.NewHandler(afero.NewOsFs()) vpnHandler := vpn.NewConfigHandler() - devConfigName, err := cmd.Flags().GetString("dev-config") - if err != nil { - return err - } - config, err := config.FromFile(fileHandler, devConfigName) - if err != nil { - return err - } - - protoClient := proto.NewClient(*config.Provider.GCP.PCRs) - defer protoClient.Close() - if err != nil { - return err - } - serviceAccountCreator := cloudcmd.NewServiceAccountCreator() + waiter := status.NewWaiter() + protoClient := &proto.Client{} + defer protoClient.Close() // We have to parse the context separately, since cmd.Context() // returns nil during the tests otherwise. - return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs), vpnHandler) + return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler) } // initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will // themself activate the other peers as nodes. func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator, - fileHandler file.Handler, config *config.Config, waiter statusWaiter, vpnHandler vpnHandler, + fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler, ) error { - flagArgs, err := evalFlagArgs(cmd, fileHandler) + flags, err := evalFlagArgs(cmd, fileHandler) if err != nil { return err } + config, err := config.FromFile(fileHandler, flags.devConfigPath) + if err != nil { + return err + } + + waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs) + var stat state.ConstellationState err = fileHandler.ReadJSON(constants.StateFilename, &stat) if errors.Is(err, fs.ErrNotExist) { @@ -121,14 +116,14 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser } var autoscalingNodeGroups []string - if flagArgs.autoscale { + if flags.autoscale { autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID) } input := activationInput{ coordinatorPubIP: coordinators.PublicIPs()[0], - pubKey: flagArgs.userPubKey, - masterSecret: flagArgs.masterSecret, + pubKey: flags.userPubKey, + masterSecret: flags.masterSecret, nodePrivIPs: nodes.PrivateIPs(), autoscalingNodeGroups: autoscalingNodeGroups, cloudServiceAccountURI: serviceAccount, @@ -143,7 +138,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser return err } - vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flagArgs.userPrivKey), result.clientVpnIP, wireguardAdminMTU) + vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, wireguardAdminMTU) if err != nil { return err } @@ -152,7 +147,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser return fmt.Errorf("write wg-quick file: %w", err) } - if flagArgs.autoconfigureWG { + if flags.autoconfigureWG { if err := vpnHandler.Apply(vpnConfig); err != nil { return err } @@ -162,7 +157,13 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser } func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) { - if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil { + err := client.Connect( + input.coordinatorPubIP, + *config.CoordinatorPort, + *config.Provider.GCP.PCRs, + *config.Provider.Azure.PCRs, + ) + if err != nil { return activationResult{}, err } @@ -265,33 +266,38 @@ func writeRow(wr io.Writer, col1 string, col2 string) { // evalFlagArgs gets the flag values and does preprocessing of these values like // reading the content from file path flags and deriving other values from flag combinations. -func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error) { +func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) { userPrivKeyPath, err := cmd.Flags().GetString("privatekey") if err != nil { - return flagArgs{}, err + return initFlags{}, err } userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath) if err != nil { - return flagArgs{}, err + return initFlags{}, err } autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig") if err != nil { - return flagArgs{}, err + return initFlags{}, err } masterSecretPath, err := cmd.Flags().GetString("master-secret") if err != nil { - return flagArgs{}, err + return initFlags{}, err } masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath) if err != nil { - return flagArgs{}, err + return initFlags{}, err } autoscale, err := cmd.Flags().GetBool("autoscale") if err != nil { - return flagArgs{}, err + return initFlags{}, err + } + devConfigPath, err := cmd.Flags().GetString("dev-config") + if err != nil { + return initFlags{}, err } - return flagArgs{ + return initFlags{ + devConfigPath: devConfigPath, userPrivKey: userPrivKey, userPubKey: userPubKey, autoconfigureWG: autoconfigureWG, @@ -300,8 +306,9 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error }, nil } -// flagArgs are the resulting values of flag preprocessing. -type flagArgs struct { +// initFlags are the resulting values of flag preprocessing. +type initFlags struct { + devConfigPath string userPrivKey []byte userPubKey []byte masterSecret []byte diff --git a/cli/cmd/init_test.go b/cli/cmd/init_test.go index 7fe991410..7490a5d9c 100644 --- a/cli/cmd/init_test.go +++ b/cli/cmd/init_test.go @@ -14,7 +14,6 @@ import ( "github.com/edgelesssys/constellation/cli/ec2" "github.com/edgelesssys/constellation/cli/file" "github.com/edgelesssys/constellation/cli/gcp" - "github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/constants" "github.com/edgelesssys/constellation/internal/state" wgquick "github.com/nmiculinic/wg-quick-go" @@ -35,7 +34,6 @@ func TestInitArgumentValidation(t *testing.T) { func TestInitialize(t *testing.T) { testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")) - config := config.Default() testEc2State := state.ConstellationState{ CloudProvider: "AWS", EC2Instances: ec2.Instances{ @@ -56,39 +54,21 @@ func TestInitialize(t *testing.T) { } testGcpState := state.ConstellationState{ GCPNodes: gcp.Instances{ - "id-0": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, - "id-1": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, + "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, GCPCoordinators: gcp.Instances{ - "id-c": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, } testAzureState := state.ConstellationState{ CloudProvider: "Azure", AzureNodes: azure.Instances{ - "id-0": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, - "id-1": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, + "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, AzureCoordinators: azure.Instances{ - "id-c": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, AzureResourceGroup: "test", } @@ -121,7 +101,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{}, privKey: testKey, }, @@ -130,7 +110,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{}, privKey: testKey, }, @@ -139,7 +119,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{}, privKey: testKey, }, @@ -148,7 +128,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{}, initVPN: true, privKey: testKey, @@ -158,7 +138,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{applyErr: someErr}, initVPN: true, privKey: testKey, @@ -169,7 +149,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{createErr: someErr}, initVPN: true, privKey: testKey, @@ -180,7 +160,7 @@ func TestInitialize(t *testing.T) { client: &fakeProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, vpnHandler: &stubVPNHandler{marshalErr: someErr}, initVPN: true, privKey: testKey, @@ -189,7 +169,7 @@ func TestInitialize(t *testing.T) { "no state exists": { existingState: state.ConstellationState{}, client: &stubProtoClient{}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -200,7 +180,7 @@ func TestInitialize(t *testing.T) { EC2SecurityGroup: "sg-test", }, client: &stubProtoClient{}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -211,7 +191,7 @@ func TestInitialize(t *testing.T) { EC2SecurityGroup: "sg-test", }, client: &stubProtoClient{}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -219,7 +199,7 @@ func TestInitialize(t *testing.T) { "public key to short": { existingState: testEc2State, client: &stubProtoClient{}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")), vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -227,7 +207,7 @@ func TestInitialize(t *testing.T) { "public key to long": { existingState: testEc2State, client: &stubProtoClient{}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")), vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -235,7 +215,7 @@ func TestInitialize(t *testing.T) { "public key not base64": { existingState: testEc2State, client: &stubProtoClient{}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: "this is not base64 encoded", vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -243,7 +223,7 @@ func TestInitialize(t *testing.T) { "fail Connect": { existingState: testEc2State, client: &stubProtoClient{connectErr: someErr}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -251,7 +231,7 @@ func TestInitialize(t *testing.T) { "fail Activate": { existingState: testEc2State, client: &stubProtoClient{activateErr: someErr}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -259,7 +239,7 @@ func TestInitialize(t *testing.T) { "fail respClient WriteLogStream": { existingState: testEc2State, client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -267,7 +247,7 @@ func TestInitialize(t *testing.T) { "fail respClient getKubeconfig": { existingState: testEc2State, client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -275,7 +255,7 @@ func TestInitialize(t *testing.T) { "fail respClient getCoordinatorVpnKey": { existingState: testEc2State, client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -283,7 +263,7 @@ func TestInitialize(t *testing.T) { "fail respClient getClientVpnIp": { existingState: testEc2State, client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -291,7 +271,7 @@ func TestInitialize(t *testing.T) { "fail respClient getOwnerID": { existingState: testEc2State, client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -299,7 +279,7 @@ func TestInitialize(t *testing.T) { "fail respClient getClusterID": { existingState: testEc2State, client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}}, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -307,7 +287,7 @@ func TestInitialize(t *testing.T) { "fail to wait for required status": { existingState: testGcpState, client: &stubProtoClient{}, - waiter: stubStatusWaiter{waitForAllErr: someErr}, + waiter: &stubStatusWaiter{waitForAllErr: someErr}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -318,7 +298,7 @@ func TestInitialize(t *testing.T) { serviceAccountCreator: stubServiceAccountCreator{ createErr: someErr, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, vpnHandler: &stubVPNHandler{}, errExpected: true, @@ -335,6 +315,7 @@ func TestInitialize(t *testing.T) { cmd.SetOut(&out) var errOut bytes.Buffer cmd.SetErr(&errOut) + cmd.Flags().String("dev-config", "", "") // register persisten flag manually fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone)) @@ -350,7 +331,7 @@ func TestInitialize(t *testing.T) { ctx, cancel := context.WithTimeout(ctx, 4*time.Second) defer cancel() - err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, tc.vpnHandler) + err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler) if tc.errExpected { assert.Error(err) @@ -551,58 +532,30 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) { func TestAutoscaleFlag(t *testing.T) { testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")) - config := config.Default() testEc2State := state.ConstellationState{ EC2Instances: ec2.Instances{ - "id-0": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.2", - }, - "id-1": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.2", - }, - "id-2": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.2", - }, + "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"}, + "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"}, + "id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"}, }, EC2SecurityGroup: "sg-test", } testGcpState := state.ConstellationState{ GCPNodes: gcp.Instances{ - "id-0": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, - "id-1": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, + "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, GCPCoordinators: gcp.Instances{ - "id-c": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, } testAzureState := state.ConstellationState{ AzureNodes: azure.Instances{ - "id-0": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, - "id-1": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, + "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, AzureCoordinators: azure.Instances{ - "id-c": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.1", - }, + "id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, }, AzureResourceGroup: "test", } @@ -633,7 +586,7 @@ func TestAutoscaleFlag(t *testing.T) { client: &stubProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, }, "initialize some gcp instances without autoscale flag": { @@ -642,7 +595,7 @@ func TestAutoscaleFlag(t *testing.T) { client: &stubProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, }, "initialize some azure instances without autoscale flag": { @@ -651,7 +604,7 @@ func TestAutoscaleFlag(t *testing.T) { client: &stubProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, }, "initialize some ec2 instances with autoscale flag": { @@ -660,7 +613,7 @@ func TestAutoscaleFlag(t *testing.T) { client: &stubProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, }, "initialize some gcp instances with autoscale flag": { @@ -669,7 +622,7 @@ func TestAutoscaleFlag(t *testing.T) { client: &stubProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, }, "initialize some azure instances with autoscale flag": { @@ -678,7 +631,7 @@ func TestAutoscaleFlag(t *testing.T) { client: &stubProtoClient{ respClient: &fakeActivationRespClient{responses: testActivationResps}, }, - waiter: stubStatusWaiter{}, + waiter: &stubStatusWaiter{}, privKey: testKey, }, } @@ -693,6 +646,7 @@ func TestAutoscaleFlag(t *testing.T) { cmd.SetOut(&out) var errOut bytes.Buffer cmd.SetErr(&errOut) + cmd.Flags().String("dev-config", "", "") // register persisten flag manually fs := afero.NewMemMapFs() fileHandler := file.NewHandler(fs) vpnHandler := stubVPNHandler{} @@ -705,7 +659,7 @@ func TestAutoscaleFlag(t *testing.T) { require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag))) ctx := context.Background() - require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, &vpnHandler)) + require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler)) if tc.autoscaleFlag { assert.Len(tc.client.activateAutoscalingNodeGroups, 1) } else { diff --git a/cli/cmd/protoclient.go b/cli/cmd/protoclient.go index 30c72f140..a91ed1332 100644 --- a/cli/cmd/protoclient.go +++ b/cli/cmd/protoclient.go @@ -7,7 +7,7 @@ import ( ) type protoClient interface { - Connect(ip string, port string) error + Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error Close() error Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) } diff --git a/cli/cmd/protoclient_test.go b/cli/cmd/protoclient_test.go index 23741dd41..3678e990e 100644 --- a/cli/cmd/protoclient_test.go +++ b/cli/cmd/protoclient_test.go @@ -23,7 +23,7 @@ type stubProtoClient struct { cloudServiceAccountURI string } -func (c *stubProtoClient) Connect(ip string, port string) error { +func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error { c.conn = true return c.connectErr } @@ -89,7 +89,7 @@ type fakeProtoClient struct { respClient proto.ActivationResponseClient } -func (c *fakeProtoClient) Connect(ip string, port string) error { +func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error { c.conn = true return nil } diff --git a/cli/cmd/statuswaiter.go b/cli/cmd/statuswaiter.go index 0e5fc8f01..3a63de53a 100644 --- a/cli/cmd/statuswaiter.go +++ b/cli/cmd/statuswaiter.go @@ -7,5 +7,6 @@ import ( ) type statusWaiter interface { + InitializePCRs(map[uint32][]byte, map[uint32][]byte) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error } diff --git a/cli/cmd/statuswaiter_test.go b/cli/cmd/statuswaiter_test.go index 43afb0ac0..bcba9da17 100644 --- a/cli/cmd/statuswaiter_test.go +++ b/cli/cmd/statuswaiter_test.go @@ -2,14 +2,23 @@ package cmd import ( "context" + "errors" "github.com/edgelesssys/constellation/coordinator/state" ) type stubStatusWaiter struct { + initialized bool waitForAllErr error } -func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error { - return w.waitForAllErr +func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) { + s.initialized = true +} + +func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error { + if !s.initialized { + return errors.New("waiter not initialized") + } + return s.waitForAllErr } diff --git a/cli/proto/client.go b/cli/proto/client.go index eb1f3f704..8132ed21a 100644 --- a/cli/proto/client.go +++ b/cli/proto/client.go @@ -7,7 +7,6 @@ import ( "net" "github.com/edgelesssys/constellation/coordinator/atls" - "github.com/edgelesssys/constellation/coordinator/attestation/aws" "github.com/edgelesssys/constellation/coordinator/attestation/azure" "github.com/edgelesssys/constellation/coordinator/attestation/gcp" "github.com/edgelesssys/constellation/coordinator/kms" @@ -21,31 +20,25 @@ import ( // The client offers a method to activate the connected // AVPNServer as Coordinator. type Client struct { - conn *grpc.ClientConn - avpn pubproto.APIClient - validators []atls.Validator + conn *grpc.ClientConn + avpn pubproto.APIClient } -// NewClient creates a Client without a connection. -func NewClient(gcpPCRs map[uint32][]byte) *Client { - return &Client{ - validators: []atls.Validator{ - aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson), - gcp.NewValidator(gcpPCRs), - gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms - azure.NewValidator(map[uint32][]byte{}), - }, - } -} - -// Connect connects the client to a given server. +// Connect connects the client to a given server, using the handed +// Validators for the attestation of the connection. // The connection must be closed using Close(). If connect is // called on a client that already has a connection, the old // connection is closed. -func (c *Client) Connect(ip string, port string) error { +func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error { addr := net.JoinHostPort(ip, port) - tlsConfig, err := atls.CreateAttestationClientTLSConfig(c.validators) + validators := []atls.Validator{ + gcp.NewValidator(gcpPCRs), + gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms + azure.NewValidator(map[uint32][]byte{}), + } + + tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) if err != nil { return err } diff --git a/cli/proto/client_test.go b/cli/proto/client_test.go index 47032079e..a5621c3aa 100644 --- a/cli/proto/client_test.go +++ b/cli/proto/client_test.go @@ -21,7 +21,7 @@ func TestClose(t *testing.T) { assert := assert.New(t) require := require.New(t) - client := NewClient(map[uint32][]byte{}) + client := Client{} // Create a connection. listener := bufconn.Listen(4) diff --git a/cli/status/status.go b/cli/status/status.go index 7b460e61b..4cac2cd42 100644 --- a/cli/status/status.go +++ b/cli/status/status.go @@ -2,11 +2,11 @@ package status import ( "context" + "errors" "io" "time" "github.com/edgelesssys/constellation/coordinator/atls" - "github.com/edgelesssys/constellation/coordinator/attestation/aws" "github.com/edgelesssys/constellation/coordinator/attestation/azure" "github.com/edgelesssys/constellation/coordinator/attestation/gcp" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" @@ -17,26 +17,37 @@ import ( grpcstatus "google.golang.org/grpc/status" ) -// Waiter waits for PeerStatusServer to reach a specific state. +// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs +// to be initialized before usage. type Waiter struct { - interval time.Duration - newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) - newClient func(cc grpc.ClientConnInterface) pubproto.APIClient + initialized bool + interval time.Duration + newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) + newClient func(cc grpc.ClientConnInterface) pubproto.APIClient } // NewWaiter returns a default Waiter with probing inteval of 10 seconds, // attested gRPC connection and PeerStatusClient. -func NewWaiter(gcpPCRs map[uint32][]byte) Waiter { - return Waiter{ +func NewWaiter() *Waiter { + return &Waiter{ interval: 10 * time.Second, - newConn: newAttestedConnGenerator(gcpPCRs), newClient: pubproto.NewAPIClient, } } +// InitializePCRs initializes the PCRs for the attestation validators. +func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) { + w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs) + w.initialized = true +} + // WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint // to reach the specified state. -func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error { +func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error { + if !w.initialized { + return errors.New("waiter not initialized") + } + ticker := time.NewTicker(w.interval) defer ticker.Stop() @@ -71,7 +82,7 @@ func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.St } // probe sends a PeerStatusCheck request to a PeerStatusServer and returns the response. -func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) { +func (w *Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) { conn, err := w.newConn(ctx, endpoint) if err != nil { return nil, err @@ -84,7 +95,11 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR // WaitForAll waits for a list of PeerStatusServers, which listen on the handed // endpoints, to reach the specified state. -func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error { +func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error { + if !w.initialized { + return errors.New("waiter not initialized") + } + for _, endpoint := range endpoints { if err := w.WaitFor(ctx, endpoint, status...); err != nil { return err @@ -94,13 +109,12 @@ func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...st } // newAttestedConnGenerator creates a function returning a default attested grpc connection. -func newAttestedConnGenerator(gcpPCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { +func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { validators := []atls.Validator{ - aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson), gcp.NewValidator(gcpPCRs), gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms - azure.NewValidator(map[uint32][]byte{}), + azure.NewValidator(azurePCRs), } tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) diff --git a/cli/status/status_test.go b/cli/status/status_test.go index 5179e9924..08ed3b049 100644 --- a/cli/status/status_test.go +++ b/cli/status/status_test.go @@ -12,6 +12,21 @@ import ( "google.golang.org/grpc" ) +func TestInitializeValidators(t *testing.T) { + assert := assert.New(t) + + waiter := Waiter{ + interval: time.Millisecond, + newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), + } + + // Uninitialized waiter fails. + assert.Error(waiter.WaitFor(context.Background(), "someIP", state.IsNode)) + + waiter.InitializeValidators(nil) + assert.NoError(waiter.WaitFor(context.Background(), "someIP", state.IsNode)) +} + func TestWaitForAndWaitForAll(t *testing.T) { var noErr error someErr := errors.New("failed") @@ -23,43 +38,48 @@ func TestWaitForAndWaitForAll(t *testing.T) { }{ "successful wait": { waiter: Waiter{ - interval: time.Millisecond, - newConn: stubNewConnFunc(noErr), - newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), + initialized: true, + interval: time.Millisecond, + newConn: stubNewConnFunc(noErr), + newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), }, waitForState: []state.State{state.IsNode}, }, "successful wait multi states": { waiter: Waiter{ - interval: time.Millisecond, - newConn: stubNewConnFunc(noErr), - newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), + initialized: true, + interval: time.Millisecond, + newConn: stubNewConnFunc(noErr), + newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}), }, waitForState: []state.State{state.IsNode, state.ActivatingNodes}, }, "expect timeout": { waiter: Waiter{ - interval: time.Millisecond, - newConn: stubNewConnFunc(noErr), - newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}), + initialized: true, + interval: time.Millisecond, + newConn: stubNewConnFunc(noErr), + newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}), }, waitForState: []state.State{state.IsNode}, wantErr: true, }, "fail to check call": { waiter: Waiter{ - interval: time.Millisecond, - newConn: stubNewConnFunc(noErr), - newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}), + initialized: true, + interval: time.Millisecond, + newConn: stubNewConnFunc(noErr), + newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}), }, waitForState: []state.State{state.IsNode}, wantErr: true, }, "fail to create conn": { waiter: Waiter{ - interval: time.Millisecond, - newConn: stubNewConnFunc(someErr), - newClient: stubNewClientFunc(&stubPeerStatusClient{}), + initialized: true, + interval: time.Millisecond, + newConn: stubNewConnFunc(someErr), + newClient: stubNewClientFunc(&stubPeerStatusClient{}), }, waitForState: []state.State{state.IsNode}, wantErr: true, diff --git a/util/pcr-reader/main.go b/util/pcr-reader/main.go index 85c888a21..cbbb48230 100644 --- a/util/pcr-reader/main.go +++ b/util/pcr-reader/main.go @@ -41,7 +41,8 @@ func main() { defer cancel() // wait for coordinator to come online - waiter := status.NewWaiter(map[uint32][]byte{}) + waiter := status.NewWaiter() + waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{}) if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil { log.Fatal(err) }