Move validators to cloudcmd

This commit is contained in:
katexochen 2022-04-19 17:02:02 +02:00 committed by Paul Meyer
parent dad9a97ee2
commit 4e29c38027
11 changed files with 367 additions and 130 deletions

View File

@ -0,0 +1,112 @@
package cloudcmd
import (
"errors"
"fmt"
"strings"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
"github.com/edgelesssys/constellation/internal/config"
)
type Validators struct {
validators []atls.Validator
pcrWarnings string
pcrWarningsInit string
}
func NewValidators(provider cloudprovider.Provider, config *config.Config) (Validators, error) {
v := Validators{}
switch provider {
case cloudprovider.GCP:
gcpPCRs := *config.Provider.GCP.PCRs
if err := v.checkPCRs(gcpPCRs); err != nil {
return Validators{}, err
}
v.setPCRWarnings(gcpPCRs)
v.validators = []atls.Validator{
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs.
}
case cloudprovider.Azure:
azurePCRs := *config.Provider.Azure.PCRs
if err := v.checkPCRs(azurePCRs); err != nil {
return Validators{}, err
}
v.setPCRWarnings(azurePCRs)
v.validators = []atls.Validator{
azure.NewValidator(azurePCRs),
}
default:
return Validators{}, errors.New("unsupported cloud provider")
}
return v, nil
}
// V returns validators as list of atls.Validator.
func (v *Validators) V() []atls.Validator {
return v.validators
}
// Warnings returns warnings for the specifc PCR values that are not verified.
//
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func (v *Validators) Warnings() string {
return v.pcrWarnings
}
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified.
// Warnings regarding the initialization are included.
//
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
func (v *Validators) WarningsIncludeInit() string {
return v.pcrWarnings + v.pcrWarningsInit
}
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
for k, v := range pcrs {
if len(v) != 32 {
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
}
}
return nil
}
func (v *Validators) setPCRWarnings(pcrs map[uint32][]byte) {
const warningStr = "Warning: not verifying the Constellation's %s measurements\n"
sb := &strings.Builder{}
if pcrs[0] == nil || pcrs[1] == nil {
writeFmt(sb, warningStr, "BIOS")
}
if pcrs[2] == nil || pcrs[3] == nil {
writeFmt(sb, warningStr, "OPROM")
}
if pcrs[4] == nil || pcrs[5] == nil {
writeFmt(sb, warningStr, "MBR")
}
// GRUB measures kernel command line and initrd into pcrs 8 and 9
if pcrs[8] == nil {
writeFmt(sb, warningStr, "kernel command line")
}
if pcrs[9] == nil {
writeFmt(sb, warningStr, "initrd")
}
v.pcrWarnings = sb.String()
// Write init warnings separate.
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
v.pcrWarningsInit = fmt.Sprintf(warningStr, "initialization status")
}
}
func writeFmt(sb *strings.Builder, fmtStr string, args ...interface{}) {
sb.WriteString(fmt.Sprintf(fmtStr, args...))
}

View File

@ -0,0 +1,200 @@
package cloudcmd
import (
"testing"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/stretchr/testify/assert"
)
func TestWarnAboutPCRs(t *testing.T) {
zero := []byte("00000000000000000000000000000000")
testCases := map[string]struct {
pcrs map[uint32][]byte
wantWarnings []string
wantWInclude []string
wantErr bool
}{
"no warnings": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
6: zero,
7: zero,
8: zero,
9: zero,
10: zero,
11: zero,
12: zero,
},
},
"no warnings for missing non critical values": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
},
"warn for BIOS": {
pcrs: map[uint32][]byte{
0: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
wantWarnings: []string{"BIOS"},
},
"warn for OPROM": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
wantWarnings: []string{"OPROM"},
},
"warn for MBR": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
5: zero,
8: zero,
9: zero,
11: zero,
12: zero,
},
wantWarnings: []string{"MBR"},
},
"warn for kernel": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
9: zero,
11: zero,
12: zero,
},
wantWarnings: []string{"kernel"},
},
"warn for initrd": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
11: zero,
12: zero,
},
wantWarnings: []string{"initrd"},
},
"warn for initialization": {
pcrs: map[uint32][]byte{
0: zero,
1: zero,
2: zero,
3: zero,
4: zero,
5: zero,
8: zero,
9: zero,
11: zero,
},
wantWInclude: []string{"initialization"},
},
"multi warning": {
pcrs: map[uint32][]byte{},
wantWarnings: []string{
"BIOS",
"OPROM",
"MBR",
"initrd",
"kernel",
},
wantWInclude: []string{"initialization"},
},
"bad config": {
pcrs: map[uint32][]byte{
0: []byte("000"),
},
wantErr: true,
},
}
for _, provider := range []string{"gcp", "azure", "unknown"} {
t.Run(provider, func(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
config := &config.Config{
Provider: &config.ProviderConfig{
Azure: &config.AzureConfig{PCRs: &tc.pcrs},
GCP: &config.GCPConfig{PCRs: &tc.pcrs},
},
}
validators, err := NewValidators(cloudprovider.FromString(provider), config)
v := validators.V()
warnings := validators.Warnings()
warningsInclueInit := validators.WarningsIncludeInit()
if tc.wantErr || provider == "unknown" {
assert.Error(err)
} else {
assert.NoError(err)
if len(tc.wantWarnings) == 0 {
assert.Empty(warnings)
}
for _, w := range tc.wantWarnings {
assert.Contains(warnings, w)
}
for _, w := range tc.wantWarnings {
assert.Contains(warningsInclueInit, w)
}
if len(tc.wantWInclude) == 0 {
assert.Equal(len(warnings), len(warningsInclueInit))
} else {
assert.Greater(len(warningsInclueInit), len(warnings))
}
for _, w := range tc.wantWInclude {
assert.Contains(warningsInclueInit, w)
}
assert.NotEmpty(v)
}
})
}
})
}
}

View File

@ -12,11 +12,13 @@ import (
"github.com/edgelesssys/constellation/cli/azure" "github.com/edgelesssys/constellation/cli/azure"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd" "github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/cloudprovider"
"github.com/edgelesssys/constellation/cli/file" "github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp" "github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/cli/proto" "github.com/edgelesssys/constellation/cli/proto"
"github.com/edgelesssys/constellation/cli/status" "github.com/edgelesssys/constellation/cli/status"
"github.com/edgelesssys/constellation/cli/vpn" "github.com/edgelesssys/constellation/cli/vpn"
"github.com/edgelesssys/constellation/coordinator/atls"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state" coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util" "github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/internal/config" "github.com/edgelesssys/constellation/internal/config"
@ -74,8 +76,6 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return err return err
} }
waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs)
var stat state.ConstellationState var stat state.ConstellationState
err = fileHandler.ReadJSON(constants.StateFilename, &stat) err = fileHandler.ReadJSON(constants.StateFilename, &stat)
if errors.Is(err, fs.ErrNotExist) { if errors.Is(err, fs.ErrNotExist) {
@ -84,16 +84,11 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return err return err
} }
switch stat.CloudProvider { validators, err := cloudcmd.NewValidators(cloudprovider.FromString(stat.CloudProvider), config)
case "GCP": if err != nil {
if err := warnAboutPCRs(cmd, *config.Provider.GCP.PCRs, true); err != nil {
return err return err
} }
case "Azure": cmd.Print(validators.WarningsIncludeInit())
if err := warnAboutPCRs(cmd, *config.Provider.Azure.PCRs, true); err != nil {
return err
}
}
cmd.Println("Creating service account ...") cmd.Println("Creating service account ...")
serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config) serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config)
@ -110,7 +105,9 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
} }
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort) endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
cmd.Println("Waiting for cloud provider to finish resource creation ...") cmd.Println("Waiting for cloud provider to finish resource creation ...")
waiter.InitializeValidators(validators.V())
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil { if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
return fmt.Errorf("failed to wait for peer status: %w", err) return fmt.Errorf("failed to wait for peer status: %w", err)
} }
@ -128,7 +125,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
autoscalingNodeGroups: autoscalingNodeGroups, autoscalingNodeGroups: autoscalingNodeGroups,
cloudServiceAccountURI: serviceAccount, cloudServiceAccountURI: serviceAccount,
} }
result, err := activate(ctx, cmd, protCl, input, config) result, err := activate(ctx, cmd, protCl, input, config, validators.V())
if err != nil { if err != nil {
return err return err
} }
@ -156,13 +153,10 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return nil return nil
} }
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) { func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
err := client.Connect( config *config.Config, validators []atls.Validator,
input.coordinatorPubIP, ) (activationResult, error) {
*config.CoordinatorPort, err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators)
*config.Provider.GCP.PCRs,
*config.Provider.Azure.PCRs,
)
if err != nil { if err != nil {
return activationResult{}, err return activationResult{}, err
} }

View File

@ -34,25 +34,8 @@ func TestInitArgumentValidation(t *testing.T) {
func TestInitialize(t *testing.T) { func TestInitialize(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")) testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
testEc2State := state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{ testGcpState := state.ConstellationState{
CloudProvider: "GCP",
GCPNodes: gcp.Instances{ GCPNodes: gcp.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
@ -96,15 +79,6 @@ func TestInitialize(t *testing.T) {
initVPN bool initVPN bool
errExpected bool errExpected bool
}{ }{
"initialize some ec2 instances": {
existingState: testEc2State,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
"initialize some gcp instances": { "initialize some gcp instances": {
existingState: testGcpState, existingState: testGcpState,
client: &fakeProtoClient{ client: &fakeProtoClient{
@ -185,19 +159,8 @@ func TestInitialize(t *testing.T) {
vpnHandler: &stubVPNHandler{}, vpnHandler: &stubVPNHandler{},
errExpected: true, errExpected: true,
}, },
"only one instance": {
existingState: state.ConstellationState{
EC2Instances: ec2.Instances{"id-1": {}},
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
},
"public key to short": { "public key to short": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")), privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
@ -205,7 +168,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"public key to long": { "public key to long": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")), privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
@ -213,7 +176,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"public key not base64": { "public key not base64": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{}, client: &stubProtoClient{},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: "this is not base64 encoded", privKey: "this is not base64 encoded",
@ -221,7 +184,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail Connect": { "fail Connect": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{connectErr: someErr}, client: &stubProtoClient{connectErr: someErr},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -229,7 +192,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail Activate": { "fail Activate": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{activateErr: someErr}, client: &stubProtoClient{activateErr: someErr},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -237,7 +200,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail respClient WriteLogStream": { "fail respClient WriteLogStream": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -245,7 +208,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail respClient getKubeconfig": { "fail respClient getKubeconfig": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -253,7 +216,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail respClient getCoordinatorVpnKey": { "fail respClient getCoordinatorVpnKey": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -261,7 +224,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail respClient getClientVpnIp": { "fail respClient getClientVpnIp": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -269,7 +232,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail respClient getOwnerID": { "fail respClient getOwnerID": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -277,7 +240,7 @@ func TestInitialize(t *testing.T) {
errExpected: true, errExpected: true,
}, },
"fail respClient getClusterID": { "fail respClient getClusterID": {
existingState: testEc2State, existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}}, client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
@ -295,9 +258,7 @@ func TestInitialize(t *testing.T) {
"fail to create service account": { "fail to create service account": {
existingState: testGcpState, existingState: testGcpState,
client: &stubProtoClient{}, client: &stubProtoClient{},
serviceAccountCreator: stubServiceAccountCreator{ serviceAccountCreator: stubServiceAccountCreator{createErr: someErr},
createErr: someErr,
},
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
vpnHandler: &stubVPNHandler{}, vpnHandler: &stubVPNHandler{},
@ -532,15 +493,8 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
func TestAutoscaleFlag(t *testing.T) { func TestAutoscaleFlag(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")) testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
testEc2State := state.ConstellationState{
EC2Instances: ec2.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{ testGcpState := state.ConstellationState{
CloudProvider: "gcp",
GCPNodes: gcp.Instances{ GCPNodes: gcp.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
@ -550,6 +504,7 @@ func TestAutoscaleFlag(t *testing.T) {
}, },
} }
testAzureState := state.ConstellationState{ testAzureState := state.ConstellationState{
CloudProvider: "azure",
AzureNodes: azure.Instances{ AzureNodes: azure.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"}, "id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
@ -580,15 +535,6 @@ func TestAutoscaleFlag(t *testing.T) {
waiter statusWaiter waiter statusWaiter
privKey string privKey string
}{ }{
"initialize some ec2 instances without autoscale flag": {
autoscaleFlag: false,
existingState: testEc2State,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances without autoscale flag": { "initialize some gcp instances without autoscale flag": {
autoscaleFlag: false, autoscaleFlag: false,
existingState: testGcpState, existingState: testGcpState,
@ -607,15 +553,6 @@ func TestAutoscaleFlag(t *testing.T) {
waiter: &stubStatusWaiter{}, waiter: &stubStatusWaiter{},
privKey: testKey, privKey: testKey,
}, },
"initialize some ec2 instances with autoscale flag": {
autoscaleFlag: true,
existingState: testEc2State,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances with autoscale flag": { "initialize some gcp instances with autoscale flag": {
autoscaleFlag: true, autoscaleFlag: true,
existingState: testGcpState, existingState: testGcpState,

View File

@ -4,10 +4,11 @@ import (
"context" "context"
"github.com/edgelesssys/constellation/cli/proto" "github.com/edgelesssys/constellation/cli/proto"
"github.com/edgelesssys/constellation/coordinator/atls"
) )
type protoClient interface { type protoClient interface {
Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error Connect(ip, port string, validators []atls.Validator) error
Close() error Close() error
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error) Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
} }

View File

@ -7,6 +7,7 @@ import (
"io" "io"
"github.com/edgelesssys/constellation/cli/proto" "github.com/edgelesssys/constellation/cli/proto"
"github.com/edgelesssys/constellation/coordinator/atls"
) )
type stubProtoClient struct { type stubProtoClient struct {
@ -23,7 +24,7 @@ type stubProtoClient struct {
cloudServiceAccountURI string cloudServiceAccountURI string
} }
func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error { func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error {
c.conn = true c.conn = true
return c.connectErr return c.connectErr
} }
@ -89,7 +90,13 @@ type fakeProtoClient struct {
respClient proto.ActivationResponseClient respClient proto.ActivationResponseClient
} }
func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error { func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error {
if ip == "" || port == "" {
return errors.New("ip or port is empty")
}
if len(validators) == 0 {
return errors.New("validators is empty")
}
c.conn = true c.conn = true
return nil return nil
} }

View File

@ -3,10 +3,11 @@ package cmd
import ( import (
"context" "context"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
) )
type statusWaiter interface { type statusWaiter interface {
InitializePCRs(map[uint32][]byte, map[uint32][]byte) InitializeValidators([]atls.Validator)
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
) )
@ -12,7 +13,7 @@ type stubStatusWaiter struct {
waitForAllErr error waitForAllErr error
} }
func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) { func (s *stubStatusWaiter) InitializeValidators([]atls.Validator) {
s.initialized = true s.initialized = true
} }

View File

@ -7,8 +7,6 @@ import (
"net" "net"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/kms" "github.com/edgelesssys/constellation/coordinator/kms"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@ -29,15 +27,9 @@ type Client struct {
// The connection must be closed using Close(). If connect is // The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old // called on a client that already has a connection, the old
// connection is closed. // connection is closed.
func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error { func (c *Client) Connect(ip, port string, validators []atls.Validator) error {
addr := net.JoinHostPort(ip, port) addr := net.JoinHostPort(ip, port)
validators := []atls.Validator{
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
}
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil { if err != nil {
return err return err

View File

@ -7,8 +7,6 @@ import (
"time" "time"
"github.com/edgelesssys/constellation/coordinator/atls" "github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto" "github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state" "github.com/edgelesssys/constellation/coordinator/state"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -35,9 +33,9 @@ func NewWaiter() *Waiter {
} }
} }
// InitializePCRs initializes the PCRs for the attestation validators. // InitializeValidators initializes the validators for the attestation.
func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) { func (w *Waiter) InitializeValidators(validators []atls.Validator) {
w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs) w.newConn = newAttestedConnGenerator(validators)
w.initialized = true w.initialized = true
} }
@ -109,14 +107,8 @@ func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...s
} }
// newAttestedConnGenerator creates a function returning a default attested grpc connection. // newAttestedConnGenerator creates a function returning a default attested grpc connection.
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
validators := []atls.Validator{
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(azurePCRs),
}
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -42,7 +42,7 @@ func main() {
// wait for coordinator to come online // wait for coordinator to come online
waiter := status.NewWaiter() waiter := status.NewWaiter()
waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{}) waiter.InitializeValidators(nil)
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil { if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
log.Fatal(err) log.Fatal(err)
} }