mirror of
https://github.com/edgelesssys/constellation.git
synced 2024-10-01 01:36:09 -04:00
Fix PCR handling
This commit is contained in:
parent
de52bf14da
commit
4496755c64
@ -49,38 +49,33 @@ func newInitCmd() *cobra.Command {
|
||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
vpnHandler := vpn.NewConfigHandler()
|
||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
config, err := config.FromFile(fileHandler, devConfigName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
|
||||
defer protoClient.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
||||
waiter := status.NewWaiter()
|
||||
protoClient := &proto.Client{}
|
||||
defer protoClient.Close()
|
||||
|
||||
// We have to parse the context separately, since cmd.Context()
|
||||
// returns nil during the tests otherwise.
|
||||
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs), vpnHandler)
|
||||
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
|
||||
}
|
||||
|
||||
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
|
||||
// themself activate the other peers as nodes.
|
||||
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
|
||||
fileHandler file.Handler, config *config.Config, waiter statusWaiter, vpnHandler vpnHandler,
|
||||
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
|
||||
) error {
|
||||
flagArgs, err := evalFlagArgs(cmd, fileHandler)
|
||||
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := config.FromFile(fileHandler, flags.devConfigPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs)
|
||||
|
||||
var stat state.ConstellationState
|
||||
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
@ -121,14 +116,14 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
}
|
||||
|
||||
var autoscalingNodeGroups []string
|
||||
if flagArgs.autoscale {
|
||||
if flags.autoscale {
|
||||
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
|
||||
}
|
||||
|
||||
input := activationInput{
|
||||
coordinatorPubIP: coordinators.PublicIPs()[0],
|
||||
pubKey: flagArgs.userPubKey,
|
||||
masterSecret: flagArgs.masterSecret,
|
||||
pubKey: flags.userPubKey,
|
||||
masterSecret: flags.masterSecret,
|
||||
nodePrivIPs: nodes.PrivateIPs(),
|
||||
autoscalingNodeGroups: autoscalingNodeGroups,
|
||||
cloudServiceAccountURI: serviceAccount,
|
||||
@ -143,7 +138,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
return err
|
||||
}
|
||||
|
||||
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flagArgs.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
|
||||
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -152,7 +147,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
return fmt.Errorf("write wg-quick file: %w", err)
|
||||
}
|
||||
|
||||
if flagArgs.autoconfigureWG {
|
||||
if flags.autoconfigureWG {
|
||||
if err := vpnHandler.Apply(vpnConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -162,7 +157,13 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||
}
|
||||
|
||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
|
||||
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
|
||||
err := client.Connect(
|
||||
input.coordinatorPubIP,
|
||||
*config.CoordinatorPort,
|
||||
*config.Provider.GCP.PCRs,
|
||||
*config.Provider.Azure.PCRs,
|
||||
)
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
@ -265,33 +266,38 @@ func writeRow(wr io.Writer, col1 string, col2 string) {
|
||||
|
||||
// evalFlagArgs gets the flag values and does preprocessing of these values like
|
||||
// reading the content from file path flags and deriving other values from flag combinations.
|
||||
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error) {
|
||||
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) {
|
||||
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
return initFlags{}, err
|
||||
}
|
||||
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
return initFlags{}, err
|
||||
}
|
||||
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
return initFlags{}, err
|
||||
}
|
||||
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
return initFlags{}, err
|
||||
}
|
||||
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
return initFlags{}, err
|
||||
}
|
||||
autoscale, err := cmd.Flags().GetBool("autoscale")
|
||||
if err != nil {
|
||||
return flagArgs{}, err
|
||||
return initFlags{}, err
|
||||
}
|
||||
devConfigPath, err := cmd.Flags().GetString("dev-config")
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
|
||||
return flagArgs{
|
||||
return initFlags{
|
||||
devConfigPath: devConfigPath,
|
||||
userPrivKey: userPrivKey,
|
||||
userPubKey: userPubKey,
|
||||
autoconfigureWG: autoconfigureWG,
|
||||
@ -300,8 +306,9 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error
|
||||
}, nil
|
||||
}
|
||||
|
||||
// flagArgs are the resulting values of flag preprocessing.
|
||||
type flagArgs struct {
|
||||
// initFlags are the resulting values of flag preprocessing.
|
||||
type initFlags struct {
|
||||
devConfigPath string
|
||||
userPrivKey []byte
|
||||
userPubKey []byte
|
||||
masterSecret []byte
|
||||
|
@ -14,7 +14,6 @@ import (
|
||||
"github.com/edgelesssys/constellation/cli/ec2"
|
||||
"github.com/edgelesssys/constellation/cli/file"
|
||||
"github.com/edgelesssys/constellation/cli/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
wgquick "github.com/nmiculinic/wg-quick-go"
|
||||
@ -35,7 +34,6 @@ func TestInitArgumentValidation(t *testing.T) {
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
config := config.Default()
|
||||
testEc2State := state.ConstellationState{
|
||||
CloudProvider: "AWS",
|
||||
EC2Instances: ec2.Instances{
|
||||
@ -56,39 +54,21 @@ func TestInitialize(t *testing.T) {
|
||||
}
|
||||
testGcpState := state.ConstellationState{
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
GCPCoordinators: gcp.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
CloudProvider: "Azure",
|
||||
AzureNodes: azure.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
@ -121,7 +101,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
@ -130,7 +110,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
@ -139,7 +119,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
@ -148,7 +128,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
@ -158,7 +138,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{applyErr: someErr},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
@ -169,7 +149,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{createErr: someErr},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
@ -180,7 +160,7 @@ func TestInitialize(t *testing.T) {
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{marshalErr: someErr},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
@ -189,7 +169,7 @@ func TestInitialize(t *testing.T) {
|
||||
"no state exists": {
|
||||
existingState: state.ConstellationState{},
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -200,7 +180,7 @@ func TestInitialize(t *testing.T) {
|
||||
EC2SecurityGroup: "sg-test",
|
||||
},
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -211,7 +191,7 @@ func TestInitialize(t *testing.T) {
|
||||
EC2SecurityGroup: "sg-test",
|
||||
},
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -219,7 +199,7 @@ func TestInitialize(t *testing.T) {
|
||||
"public key to short": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -227,7 +207,7 @@ func TestInitialize(t *testing.T) {
|
||||
"public key to long": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -235,7 +215,7 @@ func TestInitialize(t *testing.T) {
|
||||
"public key not base64": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: "this is not base64 encoded",
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -243,7 +223,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail Connect": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{connectErr: someErr},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -251,7 +231,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail Activate": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{activateErr: someErr},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -259,7 +239,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail respClient WriteLogStream": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -267,7 +247,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail respClient getKubeconfig": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -275,7 +255,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail respClient getCoordinatorVpnKey": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -283,7 +263,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail respClient getClientVpnIp": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -291,7 +271,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail respClient getOwnerID": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -299,7 +279,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail respClient getClusterID": {
|
||||
existingState: testEc2State,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -307,7 +287,7 @@ func TestInitialize(t *testing.T) {
|
||||
"fail to wait for required status": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: stubStatusWaiter{waitForAllErr: someErr},
|
||||
waiter: &stubStatusWaiter{waitForAllErr: someErr},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -318,7 +298,7 @@ func TestInitialize(t *testing.T) {
|
||||
serviceAccountCreator: stubServiceAccountCreator{
|
||||
createErr: someErr,
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
errExpected: true,
|
||||
@ -335,6 +315,7 @@ func TestInitialize(t *testing.T) {
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||
@ -350,7 +331,7 @@ func TestInitialize(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, tc.vpnHandler)
|
||||
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
|
||||
|
||||
if tc.errExpected {
|
||||
assert.Error(err)
|
||||
@ -551,58 +532,30 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
||||
|
||||
func TestAutoscaleFlag(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
config := config.Default()
|
||||
testEc2State := state.ConstellationState{
|
||||
EC2Instances: ec2.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-2": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.2",
|
||||
},
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||
},
|
||||
EC2SecurityGroup: "sg-test",
|
||||
}
|
||||
testGcpState := state.ConstellationState{
|
||||
GCPNodes: gcp.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
GCPCoordinators: gcp.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
AzureNodes: azure.Instances{
|
||||
"id-0": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-1": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureCoordinators: azure.Instances{
|
||||
"id-c": {
|
||||
PrivateIP: "192.0.2.1",
|
||||
PublicIP: "192.0.2.1",
|
||||
},
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
@ -633,7 +586,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances without autoscale flag": {
|
||||
@ -642,7 +595,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some azure instances without autoscale flag": {
|
||||
@ -651,7 +604,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some ec2 instances with autoscale flag": {
|
||||
@ -660,7 +613,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances with autoscale flag": {
|
||||
@ -669,7 +622,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some azure instances with autoscale flag": {
|
||||
@ -678,7 +631,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: stubStatusWaiter{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
}
|
||||
@ -693,6 +646,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
vpnHandler := stubVPNHandler{}
|
||||
@ -705,7 +659,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, &vpnHandler))
|
||||
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
|
||||
if tc.autoscaleFlag {
|
||||
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
|
||||
} else {
|
||||
|
@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type protoClient interface {
|
||||
Connect(ip string, port string) error
|
||||
Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error
|
||||
Close() error
|
||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ type stubProtoClient struct {
|
||||
cloudServiceAccountURI string
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Connect(ip string, port string) error {
|
||||
func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
@ -89,7 +89,7 @@ type fakeProtoClient struct {
|
||||
respClient proto.ActivationResponseClient
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Connect(ip string, port string) error {
|
||||
func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
|
||||
c.conn = true
|
||||
return nil
|
||||
}
|
||||
|
@ -7,5 +7,6 @@ import (
|
||||
)
|
||||
|
||||
type statusWaiter interface {
|
||||
InitializePCRs(map[uint32][]byte, map[uint32][]byte)
|
||||
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
||||
}
|
||||
|
@ -2,14 +2,23 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type stubStatusWaiter struct {
|
||||
initialized bool
|
||||
waitForAllErr error
|
||||
}
|
||||
|
||||
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
return w.waitForAllErr
|
||||
func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
|
||||
s.initialized = true
|
||||
}
|
||||
|
||||
func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
if !s.initialized {
|
||||
return errors.New("waiter not initialized")
|
||||
}
|
||||
return s.waitForAllErr
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"net"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||
@ -23,29 +22,23 @@ import (
|
||||
type Client struct {
|
||||
conn *grpc.ClientConn
|
||||
avpn pubproto.APIClient
|
||||
validators []atls.Validator
|
||||
}
|
||||
|
||||
// NewClient creates a Client without a connection.
|
||||
func NewClient(gcpPCRs map[uint32][]byte) *Client {
|
||||
return &Client{
|
||||
validators: []atls.Validator{
|
||||
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
|
||||
gcp.NewValidator(gcpPCRs),
|
||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||
azure.NewValidator(map[uint32][]byte{}),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Connect connects the client to a given server.
|
||||
// Connect connects the client to a given server, using the handed
|
||||
// Validators for the attestation of the connection.
|
||||
// The connection must be closed using Close(). If connect is
|
||||
// called on a client that already has a connection, the old
|
||||
// connection is closed.
|
||||
func (c *Client) Connect(ip string, port string) error {
|
||||
func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error {
|
||||
addr := net.JoinHostPort(ip, port)
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(c.validators)
|
||||
validators := []atls.Validator{
|
||||
gcp.NewValidator(gcpPCRs),
|
||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||
azure.NewValidator(map[uint32][]byte{}),
|
||||
}
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ func TestClose(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
client := NewClient(map[uint32][]byte{})
|
||||
client := Client{}
|
||||
|
||||
// Create a connection.
|
||||
listener := bufconn.Listen(4)
|
||||
|
@ -2,11 +2,11 @@ package status
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
@ -17,8 +17,10 @@ import (
|
||||
grpcstatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// Waiter waits for PeerStatusServer to reach a specific state.
|
||||
// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs
|
||||
// to be initialized before usage.
|
||||
type Waiter struct {
|
||||
initialized bool
|
||||
interval time.Duration
|
||||
newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error)
|
||||
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
|
||||
@ -26,17 +28,26 @@ type Waiter struct {
|
||||
|
||||
// NewWaiter returns a default Waiter with probing inteval of 10 seconds,
|
||||
// attested gRPC connection and PeerStatusClient.
|
||||
func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
|
||||
return Waiter{
|
||||
func NewWaiter() *Waiter {
|
||||
return &Waiter{
|
||||
interval: 10 * time.Second,
|
||||
newConn: newAttestedConnGenerator(gcpPCRs),
|
||||
newClient: pubproto.NewAPIClient,
|
||||
}
|
||||
}
|
||||
|
||||
// InitializePCRs initializes the PCRs for the attestation validators.
|
||||
func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
|
||||
w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs)
|
||||
w.initialized = true
|
||||
}
|
||||
|
||||
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
|
||||
// to reach the specified state.
|
||||
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
|
||||
func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
|
||||
if !w.initialized {
|
||||
return errors.New("waiter not initialized")
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(w.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@ -71,7 +82,7 @@ func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.St
|
||||
}
|
||||
|
||||
// probe sends a PeerStatusCheck request to a PeerStatusServer and returns the response.
|
||||
func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
|
||||
func (w *Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
|
||||
conn, err := w.newConn(ctx, endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -84,7 +95,11 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
|
||||
|
||||
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
|
||||
// endpoints, to reach the specified state.
|
||||
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
if !w.initialized {
|
||||
return errors.New("waiter not initialized")
|
||||
}
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
|
||||
return err
|
||||
@ -94,13 +109,12 @@ func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...st
|
||||
}
|
||||
|
||||
// newAttestedConnGenerator creates a function returning a default attested grpc connection.
|
||||
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||
validators := []atls.Validator{
|
||||
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
|
||||
gcp.NewValidator(gcpPCRs),
|
||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||
azure.NewValidator(map[uint32][]byte{}),
|
||||
azure.NewValidator(azurePCRs),
|
||||
}
|
||||
|
||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||
|
@ -12,6 +12,21 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestInitializeValidators(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
waiter := Waiter{
|
||||
interval: time.Millisecond,
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||
}
|
||||
|
||||
// Uninitialized waiter fails.
|
||||
assert.Error(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
|
||||
|
||||
waiter.InitializeValidators(nil)
|
||||
assert.NoError(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
|
||||
}
|
||||
|
||||
func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
var noErr error
|
||||
someErr := errors.New("failed")
|
||||
@ -23,6 +38,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
}{
|
||||
"successful wait": {
|
||||
waiter: Waiter{
|
||||
initialized: true,
|
||||
interval: time.Millisecond,
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||
@ -31,6 +47,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
},
|
||||
"successful wait multi states": {
|
||||
waiter: Waiter{
|
||||
initialized: true,
|
||||
interval: time.Millisecond,
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||
@ -39,6 +56,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
},
|
||||
"expect timeout": {
|
||||
waiter: Waiter{
|
||||
initialized: true,
|
||||
interval: time.Millisecond,
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
|
||||
@ -48,6 +66,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
},
|
||||
"fail to check call": {
|
||||
waiter: Waiter{
|
||||
initialized: true,
|
||||
interval: time.Millisecond,
|
||||
newConn: stubNewConnFunc(noErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
|
||||
@ -57,6 +76,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||
},
|
||||
"fail to create conn": {
|
||||
waiter: Waiter{
|
||||
initialized: true,
|
||||
interval: time.Millisecond,
|
||||
newConn: stubNewConnFunc(someErr),
|
||||
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
|
||||
|
@ -41,7 +41,8 @@ func main() {
|
||||
defer cancel()
|
||||
|
||||
// wait for coordinator to come online
|
||||
waiter := status.NewWaiter(map[uint32][]byte{})
|
||||
waiter := status.NewWaiter()
|
||||
waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{})
|
||||
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user