From 4e29c38027ad44fae340c9fa70e6b53edc759000 Mon Sep 17 00:00:00 2001 From: katexochen <49727155+katexochen@users.noreply.github.com> Date: Tue, 19 Apr 2022 17:02:02 +0200 Subject: [PATCH] Move validators to cloudcmd --- cli/cloud/cloudcmd/validators.go | 112 +++++++++++++++ cli/cloud/cloudcmd/validators_test.go | 200 ++++++++++++++++++++++++++ cli/cmd/init.go | 32 ++--- cli/cmd/init_test.go | 105 +++----------- cli/cmd/protoclient.go | 3 +- cli/cmd/protoclient_test.go | 11 +- cli/cmd/statuswaiter.go | 3 +- cli/cmd/statuswaiter_test.go | 3 +- cli/proto/client.go | 10 +- cli/status/status.go | 16 +-- util/pcr-reader/main.go | 2 +- 11 files changed, 367 insertions(+), 130 deletions(-) create mode 100644 cli/cloud/cloudcmd/validators.go create mode 100644 cli/cloud/cloudcmd/validators_test.go diff --git a/cli/cloud/cloudcmd/validators.go b/cli/cloud/cloudcmd/validators.go new file mode 100644 index 000000000..4bffc5794 --- /dev/null +++ b/cli/cloud/cloudcmd/validators.go @@ -0,0 +1,112 @@ +package cloudcmd + +import ( + "errors" + "fmt" + "strings" + + "github.com/edgelesssys/constellation/cli/cloudprovider" + "github.com/edgelesssys/constellation/coordinator/atls" + "github.com/edgelesssys/constellation/coordinator/attestation/azure" + "github.com/edgelesssys/constellation/coordinator/attestation/gcp" + "github.com/edgelesssys/constellation/coordinator/attestation/vtpm" + "github.com/edgelesssys/constellation/internal/config" +) + +type Validators struct { + validators []atls.Validator + pcrWarnings string + pcrWarningsInit string +} + +func NewValidators(provider cloudprovider.Provider, config *config.Config) (Validators, error) { + v := Validators{} + switch provider { + case cloudprovider.GCP: + gcpPCRs := *config.Provider.GCP.PCRs + if err := v.checkPCRs(gcpPCRs); err != nil { + return Validators{}, err + } + v.setPCRWarnings(gcpPCRs) + v.validators = []atls.Validator{ + gcp.NewValidator(gcpPCRs), + gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs. + } + case cloudprovider.Azure: + azurePCRs := *config.Provider.Azure.PCRs + if err := v.checkPCRs(azurePCRs); err != nil { + return Validators{}, err + } + v.setPCRWarnings(azurePCRs) + v.validators = []atls.Validator{ + azure.NewValidator(azurePCRs), + } + default: + return Validators{}, errors.New("unsupported cloud provider") + } + return v, nil +} + +// V returns validators as list of atls.Validator. +func (v *Validators) V() []atls.Validator { + return v.validators +} + +// Warnings returns warnings for the specifc PCR values that are not verified. +// +// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1 +func (v *Validators) Warnings() string { + return v.pcrWarnings +} + +// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified. +// Warnings regarding the initialization are included. +// +// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1 +func (v *Validators) WarningsIncludeInit() string { + return v.pcrWarnings + v.pcrWarningsInit +} + +func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error { + for k, v := range pcrs { + if len(v) != 32 { + return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v)) + } + } + return nil +} + +func (v *Validators) setPCRWarnings(pcrs map[uint32][]byte) { + const warningStr = "Warning: not verifying the Constellation's %s measurements\n" + sb := &strings.Builder{} + + if pcrs[0] == nil || pcrs[1] == nil { + writeFmt(sb, warningStr, "BIOS") + } + + if pcrs[2] == nil || pcrs[3] == nil { + writeFmt(sb, warningStr, "OPROM") + } + + if pcrs[4] == nil || pcrs[5] == nil { + writeFmt(sb, warningStr, "MBR") + } + + // GRUB measures kernel command line and initrd into pcrs 8 and 9 + if pcrs[8] == nil { + writeFmt(sb, warningStr, "kernel command line") + } + if pcrs[9] == nil { + writeFmt(sb, warningStr, "initrd") + } + v.pcrWarnings = sb.String() + + // Write init warnings separate. + if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil { + v.pcrWarningsInit = fmt.Sprintf(warningStr, "initialization status") + } +} + +func writeFmt(sb *strings.Builder, fmtStr string, args ...interface{}) { + sb.WriteString(fmt.Sprintf(fmtStr, args...)) +} diff --git a/cli/cloud/cloudcmd/validators_test.go b/cli/cloud/cloudcmd/validators_test.go new file mode 100644 index 000000000..962de66d9 --- /dev/null +++ b/cli/cloud/cloudcmd/validators_test.go @@ -0,0 +1,200 @@ +package cloudcmd + +import ( + "testing" + + "github.com/edgelesssys/constellation/cli/cloudprovider" + "github.com/edgelesssys/constellation/internal/config" + "github.com/stretchr/testify/assert" +) + +func TestWarnAboutPCRs(t *testing.T) { + zero := []byte("00000000000000000000000000000000") + + testCases := map[string]struct { + pcrs map[uint32][]byte + wantWarnings []string + wantWInclude []string + wantErr bool + }{ + "no warnings": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 2: zero, + 3: zero, + 4: zero, + 5: zero, + 6: zero, + 7: zero, + 8: zero, + 9: zero, + 10: zero, + 11: zero, + 12: zero, + }, + }, + "no warnings for missing non critical values": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 2: zero, + 3: zero, + 4: zero, + 5: zero, + 8: zero, + 9: zero, + 11: zero, + 12: zero, + }, + }, + "warn for BIOS": { + pcrs: map[uint32][]byte{ + 0: zero, + 2: zero, + 3: zero, + 4: zero, + 5: zero, + 8: zero, + 9: zero, + 11: zero, + 12: zero, + }, + wantWarnings: []string{"BIOS"}, + }, + "warn for OPROM": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 3: zero, + 4: zero, + 5: zero, + 8: zero, + 9: zero, + 11: zero, + 12: zero, + }, + wantWarnings: []string{"OPROM"}, + }, + "warn for MBR": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 2: zero, + 3: zero, + 5: zero, + 8: zero, + 9: zero, + 11: zero, + 12: zero, + }, + wantWarnings: []string{"MBR"}, + }, + "warn for kernel": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 2: zero, + 3: zero, + 4: zero, + 5: zero, + 9: zero, + 11: zero, + 12: zero, + }, + wantWarnings: []string{"kernel"}, + }, + "warn for initrd": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 2: zero, + 3: zero, + 4: zero, + 5: zero, + 8: zero, + 11: zero, + 12: zero, + }, + wantWarnings: []string{"initrd"}, + }, + "warn for initialization": { + pcrs: map[uint32][]byte{ + 0: zero, + 1: zero, + 2: zero, + 3: zero, + 4: zero, + 5: zero, + 8: zero, + 9: zero, + 11: zero, + }, + wantWInclude: []string{"initialization"}, + }, + "multi warning": { + pcrs: map[uint32][]byte{}, + wantWarnings: []string{ + "BIOS", + "OPROM", + "MBR", + "initrd", + "kernel", + }, + wantWInclude: []string{"initialization"}, + }, + "bad config": { + pcrs: map[uint32][]byte{ + 0: []byte("000"), + }, + wantErr: true, + }, + } + + for _, provider := range []string{"gcp", "azure", "unknown"} { + t.Run(provider, func(t *testing.T) { + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + config := &config.Config{ + Provider: &config.ProviderConfig{ + Azure: &config.AzureConfig{PCRs: &tc.pcrs}, + GCP: &config.GCPConfig{PCRs: &tc.pcrs}, + }, + } + + validators, err := NewValidators(cloudprovider.FromString(provider), config) + + v := validators.V() + warnings := validators.Warnings() + warningsInclueInit := validators.WarningsIncludeInit() + + if tc.wantErr || provider == "unknown" { + assert.Error(err) + } else { + assert.NoError(err) + if len(tc.wantWarnings) == 0 { + assert.Empty(warnings) + } + for _, w := range tc.wantWarnings { + assert.Contains(warnings, w) + } + for _, w := range tc.wantWarnings { + assert.Contains(warningsInclueInit, w) + } + if len(tc.wantWInclude) == 0 { + assert.Equal(len(warnings), len(warningsInclueInit)) + } else { + assert.Greater(len(warningsInclueInit), len(warnings)) + } + for _, w := range tc.wantWInclude { + assert.Contains(warningsInclueInit, w) + } + assert.NotEmpty(v) + } + }) + } + }) + } +} diff --git a/cli/cmd/init.go b/cli/cmd/init.go index b04af7c0a..d69e4705c 100644 --- a/cli/cmd/init.go +++ b/cli/cmd/init.go @@ -12,11 +12,13 @@ import ( "github.com/edgelesssys/constellation/cli/azure" "github.com/edgelesssys/constellation/cli/cloud/cloudcmd" + "github.com/edgelesssys/constellation/cli/cloudprovider" "github.com/edgelesssys/constellation/cli/file" "github.com/edgelesssys/constellation/cli/gcp" "github.com/edgelesssys/constellation/cli/proto" "github.com/edgelesssys/constellation/cli/status" "github.com/edgelesssys/constellation/cli/vpn" + "github.com/edgelesssys/constellation/coordinator/atls" coordinatorstate "github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/util" "github.com/edgelesssys/constellation/internal/config" @@ -74,8 +76,6 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser return err } - waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs) - var stat state.ConstellationState err = fileHandler.ReadJSON(constants.StateFilename, &stat) if errors.Is(err, fs.ErrNotExist) { @@ -84,16 +84,11 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser return err } - switch stat.CloudProvider { - case "GCP": - if err := warnAboutPCRs(cmd, *config.Provider.GCP.PCRs, true); err != nil { - return err - } - case "Azure": - if err := warnAboutPCRs(cmd, *config.Provider.Azure.PCRs, true); err != nil { - return err - } + validators, err := cloudcmd.NewValidators(cloudprovider.FromString(stat.CloudProvider), config) + if err != nil { + return err } + cmd.Print(validators.WarningsIncludeInit()) cmd.Println("Creating service account ...") serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config) @@ -110,7 +105,9 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser } endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort) + cmd.Println("Waiting for cloud provider to finish resource creation ...") + waiter.InitializeValidators(validators.V()) if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil { return fmt.Errorf("failed to wait for peer status: %w", err) } @@ -128,7 +125,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser autoscalingNodeGroups: autoscalingNodeGroups, cloudServiceAccountURI: serviceAccount, } - result, err := activate(ctx, cmd, protCl, input, config) + result, err := activate(ctx, cmd, protCl, input, config, validators.V()) if err != nil { return err } @@ -156,13 +153,10 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser return nil } -func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) { - err := client.Connect( - input.coordinatorPubIP, - *config.CoordinatorPort, - *config.Provider.GCP.PCRs, - *config.Provider.Azure.PCRs, - ) +func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, + config *config.Config, validators []atls.Validator, +) (activationResult, error) { + err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators) if err != nil { return activationResult{}, err } diff --git a/cli/cmd/init_test.go b/cli/cmd/init_test.go index 7490a5d9c..b1a36d584 100644 --- a/cli/cmd/init_test.go +++ b/cli/cmd/init_test.go @@ -34,25 +34,8 @@ func TestInitArgumentValidation(t *testing.T) { func TestInitialize(t *testing.T) { testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")) - testEc2State := state.ConstellationState{ - CloudProvider: "AWS", - EC2Instances: ec2.Instances{ - "id-0": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.2", - }, - "id-1": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.2", - }, - "id-2": { - PrivateIP: "192.0.2.1", - PublicIP: "192.0.2.2", - }, - }, - EC2SecurityGroup: "sg-test", - } testGcpState := state.ConstellationState{ + CloudProvider: "GCP", GCPNodes: gcp.Instances{ "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, @@ -96,15 +79,6 @@ func TestInitialize(t *testing.T) { initVPN bool errExpected bool }{ - "initialize some ec2 instances": { - existingState: testEc2State, - client: &fakeProtoClient{ - respClient: &fakeActivationRespClient{responses: testActivationResps}, - }, - waiter: &stubStatusWaiter{}, - vpnHandler: &stubVPNHandler{}, - privKey: testKey, - }, "initialize some gcp instances": { existingState: testGcpState, client: &fakeProtoClient{ @@ -185,19 +159,8 @@ func TestInitialize(t *testing.T) { vpnHandler: &stubVPNHandler{}, errExpected: true, }, - "only one instance": { - existingState: state.ConstellationState{ - EC2Instances: ec2.Instances{"id-1": {}}, - EC2SecurityGroup: "sg-test", - }, - client: &stubProtoClient{}, - waiter: &stubStatusWaiter{}, - privKey: testKey, - vpnHandler: &stubVPNHandler{}, - errExpected: true, - }, "public key to short": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{}, waiter: &stubStatusWaiter{}, privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")), @@ -205,7 +168,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "public key to long": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{}, waiter: &stubStatusWaiter{}, privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")), @@ -213,7 +176,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "public key not base64": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{}, waiter: &stubStatusWaiter{}, privKey: "this is not base64 encoded", @@ -221,7 +184,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail Connect": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{connectErr: someErr}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -229,7 +192,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail Activate": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{activateErr: someErr}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -237,7 +200,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail respClient WriteLogStream": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -245,7 +208,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail respClient getKubeconfig": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -253,7 +216,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail respClient getCoordinatorVpnKey": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -261,7 +224,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail respClient getClientVpnIp": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -269,7 +232,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail respClient getOwnerID": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -277,7 +240,7 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail respClient getClusterID": { - existingState: testEc2State, + existingState: testGcpState, client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}}, waiter: &stubStatusWaiter{}, privKey: testKey, @@ -293,15 +256,13 @@ func TestInitialize(t *testing.T) { errExpected: true, }, "fail to create service account": { - existingState: testGcpState, - client: &stubProtoClient{}, - serviceAccountCreator: stubServiceAccountCreator{ - createErr: someErr, - }, - waiter: &stubStatusWaiter{}, - privKey: testKey, - vpnHandler: &stubVPNHandler{}, - errExpected: true, + existingState: testGcpState, + client: &stubProtoClient{}, + serviceAccountCreator: stubServiceAccountCreator{createErr: someErr}, + waiter: &stubStatusWaiter{}, + privKey: testKey, + vpnHandler: &stubVPNHandler{}, + errExpected: true, }, } @@ -532,15 +493,8 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) { func TestAutoscaleFlag(t *testing.T) { testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")) - testEc2State := state.ConstellationState{ - EC2Instances: ec2.Instances{ - "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"}, - "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"}, - "id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"}, - }, - EC2SecurityGroup: "sg-test", - } testGcpState := state.ConstellationState{ + CloudProvider: "gcp", GCPNodes: gcp.Instances{ "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, @@ -550,6 +504,7 @@ func TestAutoscaleFlag(t *testing.T) { }, } testAzureState := state.ConstellationState{ + CloudProvider: "azure", AzureNodes: azure.Instances{ "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, @@ -580,15 +535,6 @@ func TestAutoscaleFlag(t *testing.T) { waiter statusWaiter privKey string }{ - "initialize some ec2 instances without autoscale flag": { - autoscaleFlag: false, - existingState: testEc2State, - client: &stubProtoClient{ - respClient: &fakeActivationRespClient{responses: testActivationResps}, - }, - waiter: &stubStatusWaiter{}, - privKey: testKey, - }, "initialize some gcp instances without autoscale flag": { autoscaleFlag: false, existingState: testGcpState, @@ -607,15 +553,6 @@ func TestAutoscaleFlag(t *testing.T) { waiter: &stubStatusWaiter{}, privKey: testKey, }, - "initialize some ec2 instances with autoscale flag": { - autoscaleFlag: true, - existingState: testEc2State, - client: &stubProtoClient{ - respClient: &fakeActivationRespClient{responses: testActivationResps}, - }, - waiter: &stubStatusWaiter{}, - privKey: testKey, - }, "initialize some gcp instances with autoscale flag": { autoscaleFlag: true, existingState: testGcpState, diff --git a/cli/cmd/protoclient.go b/cli/cmd/protoclient.go index a91ed1332..d01a1295a 100644 --- a/cli/cmd/protoclient.go +++ b/cli/cmd/protoclient.go @@ -4,10 +4,11 @@ import ( "context" "github.com/edgelesssys/constellation/cli/proto" + "github.com/edgelesssys/constellation/coordinator/atls" ) type protoClient interface { - Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error + Connect(ip, port string, validators []atls.Validator) error Close() error Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) } diff --git a/cli/cmd/protoclient_test.go b/cli/cmd/protoclient_test.go index 3678e990e..70669ba31 100644 --- a/cli/cmd/protoclient_test.go +++ b/cli/cmd/protoclient_test.go @@ -7,6 +7,7 @@ import ( "io" "github.com/edgelesssys/constellation/cli/proto" + "github.com/edgelesssys/constellation/coordinator/atls" ) type stubProtoClient struct { @@ -23,7 +24,7 @@ type stubProtoClient struct { cloudServiceAccountURI string } -func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error { +func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error { c.conn = true return c.connectErr } @@ -89,7 +90,13 @@ type fakeProtoClient struct { respClient proto.ActivationResponseClient } -func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error { +func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error { + if ip == "" || port == "" { + return errors.New("ip or port is empty") + } + if len(validators) == 0 { + return errors.New("validators is empty") + } c.conn = true return nil } diff --git a/cli/cmd/statuswaiter.go b/cli/cmd/statuswaiter.go index 3a63de53a..d87724102 100644 --- a/cli/cmd/statuswaiter.go +++ b/cli/cmd/statuswaiter.go @@ -3,10 +3,11 @@ package cmd import ( "context" + "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/state" ) type statusWaiter interface { - InitializePCRs(map[uint32][]byte, map[uint32][]byte) + InitializeValidators([]atls.Validator) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error } diff --git a/cli/cmd/statuswaiter_test.go b/cli/cmd/statuswaiter_test.go index bcba9da17..17361e8c0 100644 --- a/cli/cmd/statuswaiter_test.go +++ b/cli/cmd/statuswaiter_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" + "github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/state" ) @@ -12,7 +13,7 @@ type stubStatusWaiter struct { waitForAllErr error } -func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) { +func (s *stubStatusWaiter) InitializeValidators([]atls.Validator) { s.initialized = true } diff --git a/cli/proto/client.go b/cli/proto/client.go index 8132ed21a..eeb12bcbd 100644 --- a/cli/proto/client.go +++ b/cli/proto/client.go @@ -7,8 +7,6 @@ import ( "net" "github.com/edgelesssys/constellation/coordinator/atls" - "github.com/edgelesssys/constellation/coordinator/attestation/azure" - "github.com/edgelesssys/constellation/coordinator/attestation/gcp" "github.com/edgelesssys/constellation/coordinator/kms" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -29,15 +27,9 @@ type Client struct { // The connection must be closed using Close(). If connect is // called on a client that already has a connection, the old // connection is closed. -func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error { +func (c *Client) Connect(ip, port string, validators []atls.Validator) error { addr := net.JoinHostPort(ip, port) - validators := []atls.Validator{ - gcp.NewValidator(gcpPCRs), - gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms - azure.NewValidator(map[uint32][]byte{}), - } - tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) if err != nil { return err diff --git a/cli/status/status.go b/cli/status/status.go index 4cac2cd42..99ad157d4 100644 --- a/cli/status/status.go +++ b/cli/status/status.go @@ -7,8 +7,6 @@ import ( "time" "github.com/edgelesssys/constellation/coordinator/atls" - "github.com/edgelesssys/constellation/coordinator/attestation/azure" - "github.com/edgelesssys/constellation/coordinator/attestation/gcp" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/state" "google.golang.org/grpc" @@ -35,9 +33,9 @@ func NewWaiter() *Waiter { } } -// InitializePCRs initializes the PCRs for the attestation validators. -func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) { - w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs) +// InitializeValidators initializes the validators for the attestation. +func (w *Waiter) InitializeValidators(validators []atls.Validator) { + w.newConn = newAttestedConnGenerator(validators) w.initialized = true } @@ -109,14 +107,8 @@ func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...s } // newAttestedConnGenerator creates a function returning a default attested grpc connection. -func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { +func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { - validators := []atls.Validator{ - gcp.NewValidator(gcpPCRs), - gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms - azure.NewValidator(azurePCRs), - } - tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) if err != nil { return nil, err diff --git a/util/pcr-reader/main.go b/util/pcr-reader/main.go index cbbb48230..b6137881c 100644 --- a/util/pcr-reader/main.go +++ b/util/pcr-reader/main.go @@ -42,7 +42,7 @@ func main() { // wait for coordinator to come online waiter := status.NewWaiter() - waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{}) + waiter.InitializeValidators(nil) if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil { log.Fatal(err) }