mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-12-24 15:09:39 -05:00
Move validators to cloudcmd
This commit is contained in:
parent
dad9a97ee2
commit
4e29c38027
112
cli/cloud/cloudcmd/validators.go
Normal file
112
cli/cloud/cloudcmd/validators.go
Normal file
@ -0,0 +1,112 @@
|
||||
package cloudcmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/vtpm"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
)
|
||||
|
||||
type Validators struct {
|
||||
validators []atls.Validator
|
||||
pcrWarnings string
|
||||
pcrWarningsInit string
|
||||
}
|
||||
|
||||
func NewValidators(provider cloudprovider.Provider, config *config.Config) (Validators, error) {
|
||||
v := Validators{}
|
||||
switch provider {
|
||||
case cloudprovider.GCP:
|
||||
gcpPCRs := *config.Provider.GCP.PCRs
|
||||
if err := v.checkPCRs(gcpPCRs); err != nil {
|
||||
return Validators{}, err
|
||||
}
|
||||
v.setPCRWarnings(gcpPCRs)
|
||||
v.validators = []atls.Validator{
|
||||
gcp.NewValidator(gcpPCRs),
|
||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non CVMs.
|
||||
}
|
||||
case cloudprovider.Azure:
|
||||
azurePCRs := *config.Provider.Azure.PCRs
|
||||
if err := v.checkPCRs(azurePCRs); err != nil {
|
||||
return Validators{}, err
|
||||
}
|
||||
v.setPCRWarnings(azurePCRs)
|
||||
v.validators = []atls.Validator{
|
||||
azure.NewValidator(azurePCRs),
|
||||
}
|
||||
default:
|
||||
return Validators{}, errors.New("unsupported cloud provider")
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// V returns validators as list of atls.Validator.
|
||||
func (v *Validators) V() []atls.Validator {
|
||||
return v.validators
|
||||
}
|
||||
|
||||
// Warnings returns warnings for the specifc PCR values that are not verified.
|
||||
//
|
||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||
func (v *Validators) Warnings() string {
|
||||
return v.pcrWarnings
|
||||
}
|
||||
|
||||
// WarningsIncludeInit returns warnings for the specifc PCR values that are not verified.
|
||||
// Warnings regarding the initialization are included.
|
||||
//
|
||||
// PCR allocation inspired by https://link.springer.com/chapter/10.1007/978-1-4302-6584-9_12#Tab1
|
||||
func (v *Validators) WarningsIncludeInit() string {
|
||||
return v.pcrWarnings + v.pcrWarningsInit
|
||||
}
|
||||
|
||||
func (v *Validators) checkPCRs(pcrs map[uint32][]byte) error {
|
||||
for k, v := range pcrs {
|
||||
if len(v) != 32 {
|
||||
return fmt.Errorf("bad config: PCR[%d]: expected length: %d, but got: %d", k, 32, len(v))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *Validators) setPCRWarnings(pcrs map[uint32][]byte) {
|
||||
const warningStr = "Warning: not verifying the Constellation's %s measurements\n"
|
||||
sb := &strings.Builder{}
|
||||
|
||||
if pcrs[0] == nil || pcrs[1] == nil {
|
||||
writeFmt(sb, warningStr, "BIOS")
|
||||
}
|
||||
|
||||
if pcrs[2] == nil || pcrs[3] == nil {
|
||||
writeFmt(sb, warningStr, "OPROM")
|
||||
}
|
||||
|
||||
if pcrs[4] == nil || pcrs[5] == nil {
|
||||
writeFmt(sb, warningStr, "MBR")
|
||||
}
|
||||
|
||||
// GRUB measures kernel command line and initrd into pcrs 8 and 9
|
||||
if pcrs[8] == nil {
|
||||
writeFmt(sb, warningStr, "kernel command line")
|
||||
}
|
||||
if pcrs[9] == nil {
|
||||
writeFmt(sb, warningStr, "initrd")
|
||||
}
|
||||
v.pcrWarnings = sb.String()
|
||||
|
||||
// Write init warnings separate.
|
||||
if pcrs[uint32(vtpm.PCRIndexOwnerID)] == nil || pcrs[uint32(vtpm.PCRIndexClusterID)] == nil {
|
||||
v.pcrWarningsInit = fmt.Sprintf(warningStr, "initialization status")
|
||||
}
|
||||
}
|
||||
|
||||
func writeFmt(sb *strings.Builder, fmtStr string, args ...interface{}) {
|
||||
sb.WriteString(fmt.Sprintf(fmtStr, args...))
|
||||
}
|
200
cli/cloud/cloudcmd/validators_test.go
Normal file
200
cli/cloud/cloudcmd/validators_test.go
Normal file
@ -0,0 +1,200 @@
|
||||
package cloudcmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestWarnAboutPCRs(t *testing.T) {
|
||||
zero := []byte("00000000000000000000000000000000")
|
||||
|
||||
testCases := map[string]struct {
|
||||
pcrs map[uint32][]byte
|
||||
wantWarnings []string
|
||||
wantWInclude []string
|
||||
wantErr bool
|
||||
}{
|
||||
"no warnings": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
6: zero,
|
||||
7: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
10: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
},
|
||||
"no warnings for missing non critical values": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
},
|
||||
"warn for BIOS": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
wantWarnings: []string{"BIOS"},
|
||||
},
|
||||
"warn for OPROM": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
wantWarnings: []string{"OPROM"},
|
||||
},
|
||||
"warn for MBR": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
wantWarnings: []string{"MBR"},
|
||||
},
|
||||
"warn for kernel": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
wantWarnings: []string{"kernel"},
|
||||
},
|
||||
"warn for initrd": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
11: zero,
|
||||
12: zero,
|
||||
},
|
||||
wantWarnings: []string{"initrd"},
|
||||
},
|
||||
"warn for initialization": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: zero,
|
||||
1: zero,
|
||||
2: zero,
|
||||
3: zero,
|
||||
4: zero,
|
||||
5: zero,
|
||||
8: zero,
|
||||
9: zero,
|
||||
11: zero,
|
||||
},
|
||||
wantWInclude: []string{"initialization"},
|
||||
},
|
||||
"multi warning": {
|
||||
pcrs: map[uint32][]byte{},
|
||||
wantWarnings: []string{
|
||||
"BIOS",
|
||||
"OPROM",
|
||||
"MBR",
|
||||
"initrd",
|
||||
"kernel",
|
||||
},
|
||||
wantWInclude: []string{"initialization"},
|
||||
},
|
||||
"bad config": {
|
||||
pcrs: map[uint32][]byte{
|
||||
0: []byte("000"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, provider := range []string{"gcp", "azure", "unknown"} {
|
||||
t.Run(provider, func(t *testing.T) {
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
config := &config.Config{
|
||||
Provider: &config.ProviderConfig{
|
||||
Azure: &config.AzureConfig{PCRs: &tc.pcrs},
|
||||
GCP: &config.GCPConfig{PCRs: &tc.pcrs},
|
||||
},
|
||||
}
|
||||
|
||||
validators, err := NewValidators(cloudprovider.FromString(provider), config)
|
||||
|
||||
v := validators.V()
|
||||
warnings := validators.Warnings()
|
||||
warningsInclueInit := validators.WarningsIncludeInit()
|
||||
|
||||
if tc.wantErr || provider == "unknown" {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
if len(tc.wantWarnings) == 0 {
|
||||
assert.Empty(warnings)
|
||||
}
|
||||
for _, w := range tc.wantWarnings {
|
||||
assert.Contains(warnings, w)
|
||||
}
|
||||
for _, w := range tc.wantWarnings {
|
||||
assert.Contains(warningsInclueInit, w)
|
||||
}
|
||||
if len(tc.wantWInclude) == 0 {
|
||||
assert.Equal(len(warnings), len(warningsInclueInit))
|
||||
} else {
|
||||
assert.Greater(len(warningsInclueInit), len(warnings))
|
||||
}
|
||||
for _, w := range tc.wantWInclude {
|
||||
assert.Contains(warningsInclueInit, w)
|
||||
}
|
||||
assert.NotEmpty(v)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -12,11 +12,13 @@ import (
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/azure"
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/cli/proto"
|
||||
"github.com/edgelesssys/constellation/cli/status"
|
||||
"github.com/edgelesssys/constellation/cli/vpn"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
@ -74,8 +76,6 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
return err
|
||||
}
|
||||
|
||||
waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs)
|
||||
|
||||
var stat state.ConstellationState
|
||||
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@ -84,16 +84,11 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
return err
|
||||
}
|
||||
|
||||
switch stat.CloudProvider {
|
||||
case "GCP":
|
||||
if err := warnAboutPCRs(cmd, *config.Provider.GCP.PCRs, true); err != nil {
|
||||
return err
|
||||
}
|
||||
case "Azure":
|
||||
if err := warnAboutPCRs(cmd, *config.Provider.Azure.PCRs, true); err != nil {
|
||||
return err
|
||||
}
|
||||
validators, err := cloudcmd.NewValidators(cloudprovider.FromString(stat.CloudProvider), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(validators.WarningsIncludeInit())
|
||||
|
||||
cmd.Println("Creating service account ...")
|
||||
serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config)
|
||||
@ -110,7 +105,9 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
}
|
||||
|
||||
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), *config.CoordinatorPort)
|
||||
|
||||
cmd.Println("Waiting for cloud provider to finish resource creation ...")
|
||||
waiter.InitializeValidators(validators.V())
|
||||
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
|
||||
return fmt.Errorf("failed to wait for peer status: %w", err)
|
||||
}
|
||||
@ -128,7 +125,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
autoscalingNodeGroups: autoscalingNodeGroups,
|
||||
cloudServiceAccountURI: serviceAccount,
|
||||
}
|
||||
result, err := activate(ctx, cmd, protCl, input, config)
|
||||
result, err := activate(ctx, cmd, protCl, input, config, validators.V())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -156,13 +153,10 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
return nil
|
||||
}
|
||||
|
||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
|
||||
err := client.Connect(
|
||||
input.coordinatorPubIP,
|
||||
*config.CoordinatorPort,
|
||||
*config.Provider.GCP.PCRs,
|
||||
*config.Provider.Azure.PCRs,
|
||||
)
|
||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
|
||||
config *config.Config, validators []atls.Validator,
|
||||
) (activationResult, error) {
|
||||
err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort, validators)
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
@ -34,25 +34,8 @@ func TestInitArgumentValidation(t *testing.T) {
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
testEc2State := state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-2": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
testGcpState := state.ConstellationState{
|
||||
CloudProvider: "GCP",
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
@ -96,15 +79,6 @@ func TestInitialize(t *testing.T) {
|
||||
initVPN bool
|
||||
errExpected bool
|
||||
}{
|
||||
"initialize some ec2 instances": {
|
||||
existingState: testEc2State,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances": {
|
||||
existingState: testGcpState,
|
||||
client: &fakeProtoClient{
|
||||
@ -185,19 +159,8 @@ func TestInitialize(t *testing.T) {
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
},
|
||||
"only one instance": {
|
||||
existingState: state.ConstellationState{
|
||||
EC2Instances: ec2.Instances{"id-1": {}},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
},
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
},
|
||||
"public key to short": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
|
||||
@ -205,7 +168,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"public key to long": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
|
||||
@ -213,7 +176,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"public key not base64": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: "this is not base64 encoded",
|
||||
@ -221,7 +184,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail Connect": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{connectErr: someErr},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -229,7 +192,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail Activate": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{activateErr: someErr},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -237,7 +200,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient WriteLogStream": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -245,7 +208,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getKubeconfig": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -253,7 +216,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getCoordinatorVpnKey": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -261,7 +224,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getClientVpnIp": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -269,7 +232,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getOwnerID": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -277,7 +240,7 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail respClient getClusterID": {
|
||||
existingState: testEc2State,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
@ -293,15 +256,13 @@ func TestInitialize(t *testing.T) {
|
||||
errExpected: true,
|
||||
},
|
||||
"fail to create service account": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
serviceAccountCreator: stubServiceAccountCreator{
|
||||
createErr: someErr,
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
serviceAccountCreator: stubServiceAccountCreator{createErr: someErr},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
},
|
||||
}
|
||||
|
||||
@ -532,15 +493,8 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
||||
|
||||
func TestAutoscaleFlag(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
testEc2State := state.ConstellationState{
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
testGcpState := state.ConstellationState{
|
||||
CloudProvider: "gcp",
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
@ -550,6 +504,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
CloudProvider: "azure",
|
||||
AzureNodes: azure.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
@ -580,15 +535,6 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
waiter statusWaiter
|
||||
privKey string
|
||||
}{
|
||||
"initialize some ec2 instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testGcpState,
|
||||
@ -607,15 +553,6 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some ec2 instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testGcpState,
|
||||
|
@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/proto"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
)
|
||||
|
||||
type protoClient interface {
|
||||
Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error
|
||||
Connect(ip, port string, validators []atls.Validator) error
|
||||
Close() error
|
||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/proto"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
)
|
||||
|
||||
type stubProtoClient struct {
|
||||
@ -23,7 +24,7 @@ type stubProtoClient struct {
|
||||
cloudServiceAccountURI string
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
|
||||
func (c *stubProtoClient) Connect(_, _ string, _ []atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
@ -89,7 +90,13 @@ type fakeProtoClient struct {
|
||||
respClient proto.ActivationResponseClient
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
|
||||
func (c *fakeProtoClient) Connect(ip, port string, validators []atls.Validator) error {
|
||||
if ip == "" || port == "" {
|
||||
return errors.New("ip or port is empty")
|
||||
}
|
||||
if len(validators) == 0 {
|
||||
return errors.New("validators is empty")
|
||||
}
|
||||
c.conn = true
|
||||
return nil
|
||||
}
|
||||
|
@ -3,10 +3,11 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type statusWaiter interface {
|
||||
InitializePCRs(map[uint32][]byte, map[uint32][]byte)
|
||||
InitializeValidators([]atls.Validator)
|
||||
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
@ -12,7 +13,7 @@ type stubStatusWaiter struct {
|
||||
waitForAllErr error
|
||||
}
|
||||
|
||||
func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
|
||||
func (s *stubStatusWaiter) InitializeValidators([]atls.Validator) {
|
||||
s.initialized = true
|
||||
}
|
||||
|
||||
|
@ -7,8 +7,6 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
@ -29,15 +27,9 @@ type Client struct {
|
||||
// The connection must be closed using Close(). If connect is
|
||||
// called on a client that already has a connection, the old
|
||||
// connection is closed.
|
||||
func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error {
|
||||
func (c *Client) Connect(ip, port string, validators []atls.Validator) error {
|
||||
addr := net.JoinHostPort(ip, port)
|
||||
|
||||
validators := []atls.Validator{
|
||||
gcp.NewValidator(gcpPCRs),
|
||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||
azure.NewValidator(map[uint32][]byte{}),
|
||||
}
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -7,8 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
"google.golang.org/grpc"
|
||||
@ -35,9 +33,9 @@ func NewWaiter() *Waiter {
|
||||
}
|
||||
}
|
||||
|
||||
// InitializePCRs initializes the PCRs for the attestation validators.
|
||||
func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
|
||||
w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs)
|
||||
// InitializeValidators initializes the validators for the attestation.
|
||||
func (w *Waiter) InitializeValidators(validators []atls.Validator) {
|
||||
w.newConn = newAttestedConnGenerator(validators)
|
||||
w.initialized = true
|
||||
}
|
||||
|
||||
@ -109,14 +107,8 @@ func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...s
|
||||
}
|
||||
|
||||
// newAttestedConnGenerator creates a function returning a default attested grpc connection.
|
||||
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||
func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||
validators := []atls.Validator{
|
||||
gcp.NewValidator(gcpPCRs),
|
||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||
azure.NewValidator(azurePCRs),
|
||||
}
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -42,7 +42,7 @@ func main() {
|
||||
|
||||
// wait for coordinator to come online
|
||||
waiter := status.NewWaiter()
|
||||
waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{})
|
||||
waiter.InitializeValidators(nil)
|
||||
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user