Fix PCR handling

This commit is contained in:
katexochen 2022-04-13 15:01:02 +02:00 committed by Paul Meyer
parent de52bf14da
commit 4496755c64
11 changed files with 182 additions and 183 deletions

View File

@ -49,38 +49,33 @@ func newInitCmd() *cobra.Command {
func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
vpnHandler := vpn.NewConfigHandler()
devConfigName, err := cmd.Flags().GetString("dev-config")
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, devConfigName)
if err != nil {
return err
}
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
defer protoClient.Close()
if err != nil {
return err
}
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
waiter := status.NewWaiter()
protoClient := &proto.Client{}
defer protoClient.Close()
// We have to parse the context separately, since cmd.Context()
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs), vpnHandler)
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
}
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
// themself activate the other peers as nodes.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
fileHandler file.Handler, config *config.Config, waiter statusWaiter, vpnHandler vpnHandler,
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
) error {
flagArgs, err := evalFlagArgs(cmd, fileHandler)
flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil {
return err
}
config, err := config.FromFile(fileHandler, flags.devConfigPath)
if err != nil {
return err
}
waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs)
var stat state.ConstellationState
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
if errors.Is(err, fs.ErrNotExist) {
@ -121,14 +116,14 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
}
var autoscalingNodeGroups []string
if flagArgs.autoscale {
if flags.autoscale {
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
}
input := activationInput{
coordinatorPubIP: coordinators.PublicIPs()[0],
pubKey: flagArgs.userPubKey,
masterSecret: flagArgs.masterSecret,
pubKey: flags.userPubKey,
masterSecret: flags.masterSecret,
nodePrivIPs: nodes.PrivateIPs(),
autoscalingNodeGroups: autoscalingNodeGroups,
cloudServiceAccountURI: serviceAccount,
@ -143,7 +138,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return err
}
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flagArgs.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
if err != nil {
return err
}
@ -152,7 +147,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
return fmt.Errorf("write wg-quick file: %w", err)
}
if flagArgs.autoconfigureWG {
if flags.autoconfigureWG {
if err := vpnHandler.Apply(vpnConfig); err != nil {
return err
}
@ -162,7 +157,13 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
}
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
err := client.Connect(
input.coordinatorPubIP,
*config.CoordinatorPort,
*config.Provider.GCP.PCRs,
*config.Provider.Azure.PCRs,
)
if err != nil {
return activationResult{}, err
}
@ -265,33 +266,38 @@ func writeRow(wr io.Writer, col1 string, col2 string) {
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error) {
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) {
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
autoscale, err := cmd.Flags().GetBool("autoscale")
if err != nil {
return flagArgs{}, err
return initFlags{}, err
}
devConfigPath, err := cmd.Flags().GetString("dev-config")
if err != nil {
return initFlags{}, err
}
return flagArgs{
return initFlags{
devConfigPath: devConfigPath,
userPrivKey: userPrivKey,
userPubKey: userPubKey,
autoconfigureWG: autoconfigureWG,
@ -300,8 +306,9 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error
}, nil
}
// flagArgs are the resulting values of flag preprocessing.
type flagArgs struct {
// initFlags are the resulting values of flag preprocessing.
type initFlags struct {
devConfigPath string
userPrivKey []byte
userPubKey []byte
masterSecret []byte

View File

@ -14,7 +14,6 @@ import (
"github.com/edgelesssys/constellation/cli/ec2"
"github.com/edgelesssys/constellation/cli/file"
"github.com/edgelesssys/constellation/cli/gcp"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/state"
wgquick "github.com/nmiculinic/wg-quick-go"
@ -35,7 +34,6 @@ func TestInitArgumentValidation(t *testing.T) {
func TestInitialize(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
CloudProvider: "AWS",
EC2Instances: ec2.Instances{
@ -56,39 +54,21 @@ func TestInitialize(t *testing.T) {
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testAzureState := state.ConstellationState{
CloudProvider: "Azure",
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "test",
}
@ -121,7 +101,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
@ -130,7 +110,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
@ -139,7 +119,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
@ -148,7 +128,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
initVPN: true,
privKey: testKey,
@ -158,7 +138,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{applyErr: someErr},
initVPN: true,
privKey: testKey,
@ -169,7 +149,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{createErr: someErr},
initVPN: true,
privKey: testKey,
@ -180,7 +160,7 @@ func TestInitialize(t *testing.T) {
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{marshalErr: someErr},
initVPN: true,
privKey: testKey,
@ -189,7 +169,7 @@ func TestInitialize(t *testing.T) {
"no state exists": {
existingState: state.ConstellationState{},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -200,7 +180,7 @@ func TestInitialize(t *testing.T) {
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -211,7 +191,7 @@ func TestInitialize(t *testing.T) {
EC2SecurityGroup: "sg-test",
},
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -219,7 +199,7 @@ func TestInitialize(t *testing.T) {
"public key to short": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -227,7 +207,7 @@ func TestInitialize(t *testing.T) {
"public key to long": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -235,7 +215,7 @@ func TestInitialize(t *testing.T) {
"public key not base64": {
existingState: testEc2State,
client: &stubProtoClient{},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: "this is not base64 encoded",
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -243,7 +223,7 @@ func TestInitialize(t *testing.T) {
"fail Connect": {
existingState: testEc2State,
client: &stubProtoClient{connectErr: someErr},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -251,7 +231,7 @@ func TestInitialize(t *testing.T) {
"fail Activate": {
existingState: testEc2State,
client: &stubProtoClient{activateErr: someErr},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -259,7 +239,7 @@ func TestInitialize(t *testing.T) {
"fail respClient WriteLogStream": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -267,7 +247,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getKubeconfig": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -275,7 +255,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getCoordinatorVpnKey": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -283,7 +263,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getClientVpnIp": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -291,7 +271,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getOwnerID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -299,7 +279,7 @@ func TestInitialize(t *testing.T) {
"fail respClient getClusterID": {
existingState: testEc2State,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -307,7 +287,7 @@ func TestInitialize(t *testing.T) {
"fail to wait for required status": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: stubStatusWaiter{waitForAllErr: someErr},
waiter: &stubStatusWaiter{waitForAllErr: someErr},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -318,7 +298,7 @@ func TestInitialize(t *testing.T) {
serviceAccountCreator: stubServiceAccountCreator{
createErr: someErr,
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
errExpected: true,
@ -335,6 +315,7 @@ func TestInitialize(t *testing.T) {
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
@ -350,7 +331,7 @@ func TestInitialize(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, tc.vpnHandler)
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
if tc.errExpected {
assert.Error(err)
@ -551,58 +532,30 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
func TestAutoscaleFlag(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
config := config.Default()
testEc2State := state.ConstellationState{
EC2Instances: ec2.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-2": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.2",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
},
EC2SecurityGroup: "sg-test",
}
testGcpState := state.ConstellationState{
GCPNodes: gcp.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
GCPCoordinators: gcp.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testAzureState := state.ConstellationState{
AzureNodes: azure.Instances{
"id-0": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-1": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureCoordinators: azure.Instances{
"id-c": {
PrivateIP: "192.0.2.1",
PublicIP: "192.0.2.1",
},
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "test",
}
@ -633,7 +586,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances without autoscale flag": {
@ -642,7 +595,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some azure instances without autoscale flag": {
@ -651,7 +604,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some ec2 instances with autoscale flag": {
@ -660,7 +613,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances with autoscale flag": {
@ -669,7 +622,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some azure instances with autoscale flag": {
@ -678,7 +631,7 @@ func TestAutoscaleFlag(t *testing.T) {
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: stubStatusWaiter{},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
}
@ -693,6 +646,7 @@ func TestAutoscaleFlag(t *testing.T) {
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
vpnHandler := stubVPNHandler{}
@ -705,7 +659,7 @@ func TestAutoscaleFlag(t *testing.T) {
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, &vpnHandler))
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else {

View File

@ -7,7 +7,7 @@ import (
)
type protoClient interface {
Connect(ip string, port string) error
Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error
Close() error
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
}

View File

@ -23,7 +23,7 @@ type stubProtoClient struct {
cloudServiceAccountURI string
}
func (c *stubProtoClient) Connect(ip string, port string) error {
func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
c.conn = true
return c.connectErr
}
@ -89,7 +89,7 @@ type fakeProtoClient struct {
respClient proto.ActivationResponseClient
}
func (c *fakeProtoClient) Connect(ip string, port string) error {
func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
c.conn = true
return nil
}

View File

@ -7,5 +7,6 @@ import (
)
type statusWaiter interface {
InitializePCRs(map[uint32][]byte, map[uint32][]byte)
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
}

View File

@ -2,14 +2,23 @@ package cmd
import (
"context"
"errors"
"github.com/edgelesssys/constellation/coordinator/state"
)
type stubStatusWaiter struct {
initialized bool
waitForAllErr error
}
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
return w.waitForAllErr
func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
s.initialized = true
}
func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
if !s.initialized {
return errors.New("waiter not initialized")
}
return s.waitForAllErr
}

View File

@ -7,7 +7,6 @@ import (
"net"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/kms"
@ -23,29 +22,23 @@ import (
type Client struct {
conn *grpc.ClientConn
avpn pubproto.APIClient
validators []atls.Validator
}
// NewClient creates a Client without a connection.
func NewClient(gcpPCRs map[uint32][]byte) *Client {
return &Client{
validators: []atls.Validator{
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
},
}
}
// Connect connects the client to a given server.
// Connect connects the client to a given server, using the handed
// Validators for the attestation of the connection.
// The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old
// connection is closed.
func (c *Client) Connect(ip string, port string) error {
func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error {
addr := net.JoinHostPort(ip, port)
tlsConfig, err := atls.CreateAttestationClientTLSConfig(c.validators)
validators := []atls.Validator{
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
}
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
if err != nil {
return err
}

View File

@ -21,7 +21,7 @@ func TestClose(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
client := NewClient(map[uint32][]byte{})
client := Client{}
// Create a connection.
listener := bufconn.Listen(4)

View File

@ -2,11 +2,11 @@ package status
import (
"context"
"errors"
"io"
"time"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
@ -17,8 +17,10 @@ import (
grpcstatus "google.golang.org/grpc/status"
)
// Waiter waits for PeerStatusServer to reach a specific state.
// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs
// to be initialized before usage.
type Waiter struct {
initialized bool
interval time.Duration
newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error)
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
@ -26,17 +28,26 @@ type Waiter struct {
// NewWaiter returns a default Waiter with probing inteval of 10 seconds,
// attested gRPC connection and PeerStatusClient.
func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
return Waiter{
func NewWaiter() *Waiter {
return &Waiter{
interval: 10 * time.Second,
newConn: newAttestedConnGenerator(gcpPCRs),
newClient: pubproto.NewAPIClient,
}
}
// InitializePCRs initializes the PCRs for the attestation validators.
func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs)
w.initialized = true
}
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
// to reach the specified state.
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
if !w.initialized {
return errors.New("waiter not initialized")
}
ticker := time.NewTicker(w.interval)
defer ticker.Stop()
@ -71,7 +82,7 @@ func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.St
}
// probe sends a PeerStatusCheck request to a PeerStatusServer and returns the response.
func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
func (w *Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
conn, err := w.newConn(ctx, endpoint)
if err != nil {
return nil, err
@ -84,7 +95,11 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
// endpoints, to reach the specified state.
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
if !w.initialized {
return errors.New("waiter not initialized")
}
for _, endpoint := range endpoints {
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
return err
@ -94,13 +109,12 @@ func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...st
}
// newAttestedConnGenerator creates a function returning a default attested grpc connection.
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
validators := []atls.Validator{
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
gcp.NewValidator(gcpPCRs),
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
azure.NewValidator(map[uint32][]byte{}),
azure.NewValidator(azurePCRs),
}
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)

View File

@ -12,6 +12,21 @@ import (
"google.golang.org/grpc"
)
func TestInitializeValidators(t *testing.T) {
assert := assert.New(t)
waiter := Waiter{
interval: time.Millisecond,
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
}
// Uninitialized waiter fails.
assert.Error(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
waiter.InitializeValidators(nil)
assert.NoError(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
}
func TestWaitForAndWaitForAll(t *testing.T) {
var noErr error
someErr := errors.New("failed")
@ -23,6 +38,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
}{
"successful wait": {
waiter: Waiter{
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
@ -31,6 +47,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
},
"successful wait multi states": {
waiter: Waiter{
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
@ -39,6 +56,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
},
"expect timeout": {
waiter: Waiter{
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
@ -48,6 +66,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
},
"fail to check call": {
waiter: Waiter{
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(noErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
@ -57,6 +76,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
},
"fail to create conn": {
waiter: Waiter{
initialized: true,
interval: time.Millisecond,
newConn: stubNewConnFunc(someErr),
newClient: stubNewClientFunc(&stubPeerStatusClient{}),

View File

@ -41,7 +41,8 @@ func main() {
defer cancel()
// wait for coordinator to come online
waiter := status.NewWaiter(map[uint32][]byte{})
waiter := status.NewWaiter()
waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{})
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
log.Fatal(err)
}