Fix PCR handling

This commit is contained in:
katexochen 2022-04-13 15:01:02 +02:00 committed by Paul Meyer
parent de52bf14da
commit 4496755c64
11 changed files with 182 additions and 183 deletions

View File

@ -49,38 +49,33 @@ func newInitCmd() *cobra.Command {
func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
vpnHandler := vpn.NewConfigHandler()
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
defer protoClient.Close()
if err != nil {
return err
}
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
waiter := status.NewWaiter()
protoClient := &proto.Client{}
defer protoClient.Close()
// We have to parse the context separately, since cmd.Context()
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs), vpnHandler)
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
}
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
// themself activate the other peers as nodes.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
fileHandler file.Handler, config *config.Config, waiter statusWaiter, vpnHandler vpnHandler,
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
) error {
flagArgs, err := evalFlagArgs(cmd, fileHandler)
flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, flags.devConfigPath)
if err != nil {
return err
}
waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs)
var stat state.ConstellationState
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
if errors.Is(err, fs.ErrNotExist) {
@ -121,14 +116,14 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
}
var autoscalingNodeGroups []string
if flagArgs.autoscale {
if flags.autoscale {
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
}
input := activationInput{
coordinatorPubIP: coordinators.PublicIPs()[0],
pubKey: flagArgs.userPubKey,
masterSecret: flagArgs.masterSecret,
pubKey: flags.userPubKey,
masterSecret: flags.masterSecret,
nodePrivIPs: nodes.PrivateIPs(),
autoscalingNodeGroups: autoscalingNodeGroups,
cloudServiceAccountURI: serviceAccount,
@ -143,7 +138,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return err
}
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flagArgs.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
if err != nil {
return err
}
@ -152,7 +147,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return fmt.Errorf("write wg-quick file: %w", err)
}
if flagArgs.autoconfigureWG {
if flags.autoconfigureWG {
if err := vpnHandler.Apply(vpnConfig); err != nil {
return err
}
@ -162,7 +157,13 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
}
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
err := client.Connect(
input.coordinatorPubIP,
*config.CoordinatorPort,
*config.Provider.GCP.PCRs,
*config.Provider.Azure.PCRs,
)
if err != nil {
return activationResult{}, err
}
@ -265,33 +266,38 @@ func writeRow(wr io.Writer, col1 string, col2 string) {
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error) {
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) {
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
autoscale, err := cmd.Flags().GetBool("autoscale")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
devConfigPath, err := cmd.Flags().GetString("dev-config")
if err != nil {
return initFlags{}, err
}
return flagArgs{
return initFlags{
devConfigPath: devConfigPath,
userPrivKey: userPrivKey,
userPubKey: userPubKey,
autoconfigureWG: autoconfigureWG,
@ -300,8 +306,9 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error
}, nil
}
// flagArgs are the resulting values of flag preprocessing.
type flagArgs struct {
// initFlags are the resulting values of flag preprocessing.
type initFlags struct {
devConfigPath string
userPrivKey []byte
userPubKey []byte
masterSecret []byte

View File

@ -14,7 +14,6 @@ import (
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/state"
wgquick "github.com/nmiculinic/wg-quick-go"
@ -35,7 +34,6 @@ func TestInitArgumentValidation(t *testing.T) {
func TestInitialize(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{
@ -56,39 +54,21 @@ func TestInitialize(t *testing.T) {
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testAzureState := state.ConstellationState{
CloudProvider: "Azure",
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "test",
}
@ -121,7 +101,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
@ -130,7 +110,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
@ -139,7 +119,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
@ -148,7 +128,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
initVPN: true,
privKey: testKey,
@ -158,7 +138,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{applyErr: someErr},
initVPN: true,
privKey: testKey,
@ -169,7 +149,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{createErr: someErr},
initVPN: true,
privKey: testKey,
@ -180,7 +160,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{marshalErr: someErr},
initVPN: true,
privKey: testKey,
@ -189,7 +169,7 @@ func TestInitialize(t *testing.T) {
"no state exists": {
existingState: state.ConstellationState{},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -200,7 +180,7 @@ func TestInitialize(t *testing.T) {
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -211,7 +191,7 @@ func TestInitialize(t *testing.T) {
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -219,7 +199,7 @@ func TestInitialize(t *testing.T) {
"public key to short": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -227,7 +207,7 @@ func TestInitialize(t *testing.T) {
"public key to long": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -235,7 +215,7 @@ func TestInitialize(t *testing.T) {
"public key not base64": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: "this is not base64 encoded",
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -243,7 +223,7 @@ func TestInitialize(t *testing.T) {
"fail Connect": {
existingState: testEc2State,
client: &stubProtoClient{connectErr: someErr},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -251,7 +231,7 @@ func TestInitialize(t *testing.T) {
"fail Activate": {
existingState: testEc2State,
client: &stubProtoClient{activateErr: someErr},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -259,7 +239,7 @@ func TestInitialize(t *testing.T) {
"fail respClient WriteLogStream": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -267,7 +247,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getKubeconfig": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -275,7 +255,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getCoordinatorVpnKey": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -283,7 +263,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getClientVpnIp": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -291,7 +271,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getOwnerID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -299,7 +279,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getClusterID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -307,7 +287,7 @@ func TestInitialize(t *testing.T) {
"fail to wait for required status": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: stubStatusWaiter{waitForAllErr: someErr},
waiter: &stubStatusWaiter{waitForAllErr: someErr},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -318,7 +298,7 @@ func TestInitialize(t *testing.T) {
serviceAccountCreator: stubServiceAccountCreator{
createErr: someErr,
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -335,6 +315,7 @@ func TestInitialize(t *testing.T) {
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
@ -350,7 +331,7 @@ func TestInitialize(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, tc.vpnHandler)
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
if tc.errExpected {
assert.Error(err)
@ -551,58 +532,30 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
func TestAutoscaleFlag(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testAzureState := state.ConstellationState{
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "test",
}
@ -633,7 +586,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances without autoscale flag": {
@ -642,7 +595,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some azure instances without autoscale flag": {
@ -651,7 +604,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some ec2 instances with autoscale flag": {
@ -660,7 +613,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances with autoscale flag": {
@ -669,7 +622,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some azure instances with autoscale flag": {
@ -678,7 +631,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
}
@ -693,6 +646,7 @@ func TestAutoscaleFlag(t *testing.T) {
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
vpnHandler := stubVPNHandler{}
@ -705,7 +659,7 @@ func TestAutoscaleFlag(t *testing.T) {
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, &vpnHandler))
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else {

View File

@ -7,7 +7,7 @@ import (
)
type protoClient interface {
Connect(ip string, port string) error
Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error
Close() error
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
}

View File

@ -23,7 +23,7 @@ type stubProtoClient struct {
cloudServiceAccountURI string
}
func (c *stubProtoClient) Connect(ip string, port string) error {
func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
c.conn = true
return c.connectErr
}
@ -89,7 +89,7 @@ type fakeProtoClient struct {
respClient proto.ActivationResponseClient
}
func (c *fakeProtoClient) Connect(ip string, port string) error {
func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
c.conn = true
return nil
}

View File

@ -7,5 +7,6 @@ import (
)
type statusWaiter interface {
InitializePCRs(map[uint32][]byte, map[uint32][]byte)
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
}

View File

@ -2,14 +2,23 @@ package cmd
import (
"context"
"errors"
"github.com/edgelesssys/constellation/coordinator/state"
)
type stubStatusWaiter struct {
initialized bool
waitForAllErr error
}
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
return w.waitForAllErr
func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
s.initialized = true
}
func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
if !s.initialized {
return errors.New("waiter not initialized")
}
return s.waitForAllErr
}

View File

@ -7,7 +7,6 @@ import (
"net"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/kms"
@ -21,31 +20,25 @@ import (
// The client offers a method to activate the connected
// AVPNServer as Coordinator.
type Client struct {
conn *grpc.ClientConn
avpn pubproto.APIClient
validators []atls.Validator
conn *grpc.ClientConn
avpn pubproto.APIClient
}
// NewClient creates a Client without a connection.
func NewClient(gcpPCRs map[uint32][]byte) *Client {
return &Client{
validators: []atls.Validator{
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
},
}
}
// Connect connects the client to a given server.
// Connect connects the client to a given server, using the handed
// Validators for the attestation of the connection.
// The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old
// connection is closed.
func (c *Client) Connect(ip string, port string) error {
func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error {
addr := net.JoinHostPort(ip, port)
tlsConfig, err := atls.CreateAttestationClientTLSConfig(c.validators)
validators := []atls.Validator{
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
}
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil {
return err
}

View File

@ -21,7 +21,7 @@ func TestClose(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := NewClient(map[uint32][]byte{})
client := Client{}
// Create a connection.
listener := bufconn.Listen(4)

View File

@ -2,11 +2,11 @@ package status
import (
"context"
"errors"
"io"
"time"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
@ -17,26 +17,37 @@ import (
grpcstatus "google.golang.org/grpc/status"
)
// Waiter waits for PeerStatusServer to reach a specific state.
// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs
// to be initialized before usage.
type Waiter struct {
interval time.Duration
newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error)
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
initialized bool
interval time.Duration
newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error)
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
}
// NewWaiter returns a default Waiter with probing inteval of 10 seconds,
// attested gRPC connection and PeerStatusClient.
func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
return Waiter{
func NewWaiter() *Waiter {
return &Waiter{
interval: 10 * time.Second,
newConn: newAttestedConnGenerator(gcpPCRs),
newClient: pubproto.NewAPIClient,
}
}
// InitializePCRs initializes the PCRs for the attestation validators.
func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs)
w.initialized = true
}
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
// to reach the specified state.
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
if !w.initialized {
return errors.New("waiter not initialized")
}
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
@ -71,7 +82,7 @@ func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.St
}
// probe sends a PeerStatusCheck request to a PeerStatusServer and returns the response.
func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
func (w *Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
conn, err := w.newConn(ctx, endpoint)
if err != nil {
return nil, err
@ -84,7 +95,11 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
// endpoints, to reach the specified state.
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
if !w.initialized {
return errors.New("waiter not initialized")
}
for _, endpoint := range endpoints {
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
return err
@ -94,13 +109,12 @@ func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...st
}
// newAttestedConnGenerator creates a function returning a default attested grpc connection.
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
validators := []atls.Validator{
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
azure.NewValidator(azurePCRs),
}
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)

View File

@ -12,6 +12,21 @@ import (
"google.golang.org/grpc"
)
func TestInitializeValidators(t *testing.T) {
assert := assert.New(t)
waiter := Waiter{
interval: time.Millisecond,
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
}
// Uninitialized waiter fails.
assert.Error(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
waiter.InitializeValidators(nil)
assert.NoError(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
}
func TestWaitForAndWaitForAll(t *testing.T) {
var noErr error
someErr := errors.New("failed")
@ -23,43 +38,48 @@ func TestWaitForAndWaitForAll(t *testing.T) {
}{
"successful wait": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
},
waitForState: []state.State{state.IsNode},
},
"successful wait multi states": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
},
waitForState: []state.State{state.IsNode, state.ActivatingNodes},
},
"expect timeout": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
},
waitForState: []state.State{state.IsNode},
wantErr: true,
},
"fail to check call": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
},
waitForState: []state.State{state.IsNode},
wantErr: true,
},
"fail to create conn": {
waiter: Waiter{
interval: time.Millisecond,
newConn: stubNewConnFunc(someErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(someErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
},
waitForState: []state.State{state.IsNode},
wantErr: true,

View File

@ -41,7 +41,8 @@ func main() {
defer cancel()
// wait for coordinator to come online
waiter := status.NewWaiter(map[uint32][]byte{})
waiter := status.NewWaiter()
waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{})
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
log.Fatal(err)
}