mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-06-15 17:59:50 -04:00
Fix PCR handling
This commit is contained in:
parent
de52bf14da
commit
4496755c64
11 changed files with 182 additions and 183 deletions
|
@ -49,38 +49,33 @@ func newInitCmd() *cobra.Command {
|
||||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||||
vpnHandler := vpn.NewConfigHandler()
|
vpnHandler := vpn.NewConfigHandler()
|
||||||
devConfigName, err := cmd.Flags().GetString("dev-config")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
config, err := config.FromFile(fileHandler, devConfigName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
protoClient := proto.NewClient(*config.Provider.GCP.PCRs)
|
|
||||||
defer protoClient.Close()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
||||||
|
waiter := status.NewWaiter()
|
||||||
|
protoClient := &proto.Client{}
|
||||||
|
defer protoClient.Close()
|
||||||
|
|
||||||
// We have to parse the context separately, since cmd.Context()
|
// We have to parse the context separately, since cmd.Context()
|
||||||
// returns nil during the tests otherwise.
|
// returns nil during the tests otherwise.
|
||||||
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, config, status.NewWaiter(*config.Provider.GCP.PCRs), vpnHandler)
|
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
|
// initialize initializes a Constellation. Coordinator instances are activated as Coordinators and will
|
||||||
// themself activate the other peers as nodes.
|
// themself activate the other peers as nodes.
|
||||||
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
|
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
|
||||||
fileHandler file.Handler, config *config.Config, waiter statusWaiter, vpnHandler vpnHandler,
|
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
|
||||||
) error {
|
) error {
|
||||||
flagArgs, err := evalFlagArgs(cmd, fileHandler)
|
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
config, err := config.FromFile(fileHandler, flags.devConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
waiter.InitializePCRs(*config.Provider.GCP.PCRs, *config.Provider.Azure.PCRs)
|
||||||
|
|
||||||
var stat state.ConstellationState
|
var stat state.ConstellationState
|
||||||
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
|
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
|
||||||
if errors.Is(err, fs.ErrNotExist) {
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
@ -121,14 +116,14 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||||
}
|
}
|
||||||
|
|
||||||
var autoscalingNodeGroups []string
|
var autoscalingNodeGroups []string
|
||||||
if flagArgs.autoscale {
|
if flags.autoscale {
|
||||||
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
|
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
input := activationInput{
|
input := activationInput{
|
||||||
coordinatorPubIP: coordinators.PublicIPs()[0],
|
coordinatorPubIP: coordinators.PublicIPs()[0],
|
||||||
pubKey: flagArgs.userPubKey,
|
pubKey: flags.userPubKey,
|
||||||
masterSecret: flagArgs.masterSecret,
|
masterSecret: flags.masterSecret,
|
||||||
nodePrivIPs: nodes.PrivateIPs(),
|
nodePrivIPs: nodes.PrivateIPs(),
|
||||||
autoscalingNodeGroups: autoscalingNodeGroups,
|
autoscalingNodeGroups: autoscalingNodeGroups,
|
||||||
cloudServiceAccountURI: serviceAccount,
|
cloudServiceAccountURI: serviceAccount,
|
||||||
|
@ -143,7 +138,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flagArgs.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
|
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, wireguardAdminMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -152,7 +147,7 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||||
return fmt.Errorf("write wg-quick file: %w", err)
|
return fmt.Errorf("write wg-quick file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if flagArgs.autoconfigureWG {
|
if flags.autoconfigureWG {
|
||||||
if err := vpnHandler.Apply(vpnConfig); err != nil {
|
if err := vpnHandler.Apply(vpnConfig); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -162,7 +157,13 @@ func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, ser
|
||||||
}
|
}
|
||||||
|
|
||||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
|
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput, config *config.Config) (activationResult, error) {
|
||||||
if err := client.Connect(input.coordinatorPubIP, *config.CoordinatorPort); err != nil {
|
err := client.Connect(
|
||||||
|
input.coordinatorPubIP,
|
||||||
|
*config.CoordinatorPort,
|
||||||
|
*config.Provider.GCP.PCRs,
|
||||||
|
*config.Provider.Azure.PCRs,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
return activationResult{}, err
|
return activationResult{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -265,33 +266,38 @@ func writeRow(wr io.Writer, col1 string, col2 string) {
|
||||||
|
|
||||||
// evalFlagArgs gets the flag values and does preprocessing of these values like
|
// evalFlagArgs gets the flag values and does preprocessing of these values like
|
||||||
// reading the content from file path flags and deriving other values from flag combinations.
|
// reading the content from file path flags and deriving other values from flag combinations.
|
||||||
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error) {
|
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) {
|
||||||
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
|
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return flagArgs{}, err
|
return initFlags{}, err
|
||||||
}
|
}
|
||||||
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
|
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return flagArgs{}, err
|
return initFlags{}, err
|
||||||
}
|
}
|
||||||
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
|
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return flagArgs{}, err
|
return initFlags{}, err
|
||||||
}
|
}
|
||||||
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return flagArgs{}, err
|
return initFlags{}, err
|
||||||
}
|
}
|
||||||
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
|
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return flagArgs{}, err
|
return initFlags{}, err
|
||||||
}
|
}
|
||||||
autoscale, err := cmd.Flags().GetBool("autoscale")
|
autoscale, err := cmd.Flags().GetBool("autoscale")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return flagArgs{}, err
|
return initFlags{}, err
|
||||||
|
}
|
||||||
|
devConfigPath, err := cmd.Flags().GetString("dev-config")
|
||||||
|
if err != nil {
|
||||||
|
return initFlags{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return flagArgs{
|
return initFlags{
|
||||||
|
devConfigPath: devConfigPath,
|
||||||
userPrivKey: userPrivKey,
|
userPrivKey: userPrivKey,
|
||||||
userPubKey: userPubKey,
|
userPubKey: userPubKey,
|
||||||
autoconfigureWG: autoconfigureWG,
|
autoconfigureWG: autoconfigureWG,
|
||||||
|
@ -300,8 +306,9 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (flagArgs, error
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// flagArgs are the resulting values of flag preprocessing.
|
// initFlags are the resulting values of flag preprocessing.
|
||||||
type flagArgs struct {
|
type initFlags struct {
|
||||||
|
devConfigPath string
|
||||||
userPrivKey []byte
|
userPrivKey []byte
|
||||||
userPubKey []byte
|
userPubKey []byte
|
||||||
masterSecret []byte
|
masterSecret []byte
|
||||||
|
|
|
@ -14,7 +14,6 @@ import (
|
||||||
"github.com/edgelesssys/constellation/cli/ec2"
|
"github.com/edgelesssys/constellation/cli/ec2"
|
||||||
"github.com/edgelesssys/constellation/cli/file"
|
"github.com/edgelesssys/constellation/cli/file"
|
||||||
"github.com/edgelesssys/constellation/cli/gcp"
|
"github.com/edgelesssys/constellation/cli/gcp"
|
||||||
"github.com/edgelesssys/constellation/internal/config"
|
|
||||||
"github.com/edgelesssys/constellation/internal/constants"
|
"github.com/edgelesssys/constellation/internal/constants"
|
||||||
"github.com/edgelesssys/constellation/internal/state"
|
"github.com/edgelesssys/constellation/internal/state"
|
||||||
wgquick "github.com/nmiculinic/wg-quick-go"
|
wgquick "github.com/nmiculinic/wg-quick-go"
|
||||||
|
@ -35,7 +34,6 @@ func TestInitArgumentValidation(t *testing.T) {
|
||||||
|
|
||||||
func TestInitialize(t *testing.T) {
|
func TestInitialize(t *testing.T) {
|
||||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||||
config := config.Default()
|
|
||||||
testEc2State := state.ConstellationState{
|
testEc2State := state.ConstellationState{
|
||||||
CloudProvider: "AWS",
|
CloudProvider: "AWS",
|
||||||
EC2Instances: ec2.Instances{
|
EC2Instances: ec2.Instances{
|
||||||
|
@ -56,39 +54,21 @@ func TestInitialize(t *testing.T) {
|
||||||
}
|
}
|
||||||
testGcpState := state.ConstellationState{
|
testGcpState := state.ConstellationState{
|
||||||
GCPNodes: gcp.Instances{
|
GCPNodes: gcp.Instances{
|
||||||
"id-0": {
|
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
"id-1": {
|
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
GCPCoordinators: gcp.Instances{
|
GCPCoordinators: gcp.Instances{
|
||||||
"id-c": {
|
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testAzureState := state.ConstellationState{
|
testAzureState := state.ConstellationState{
|
||||||
CloudProvider: "Azure",
|
CloudProvider: "Azure",
|
||||||
AzureNodes: azure.Instances{
|
AzureNodes: azure.Instances{
|
||||||
"id-0": {
|
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
"id-1": {
|
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
AzureCoordinators: azure.Instances{
|
AzureCoordinators: azure.Instances{
|
||||||
"id-c": {
|
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
AzureResourceGroup: "test",
|
AzureResourceGroup: "test",
|
||||||
}
|
}
|
||||||
|
@ -121,7 +101,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
|
@ -130,7 +110,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
|
@ -139,7 +119,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
|
@ -148,7 +128,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
initVPN: true,
|
initVPN: true,
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
|
@ -158,7 +138,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{applyErr: someErr},
|
vpnHandler: &stubVPNHandler{applyErr: someErr},
|
||||||
initVPN: true,
|
initVPN: true,
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
|
@ -169,7 +149,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{createErr: someErr},
|
vpnHandler: &stubVPNHandler{createErr: someErr},
|
||||||
initVPN: true,
|
initVPN: true,
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
|
@ -180,7 +160,7 @@ func TestInitialize(t *testing.T) {
|
||||||
client: &fakeProtoClient{
|
client: &fakeProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
vpnHandler: &stubVPNHandler{marshalErr: someErr},
|
vpnHandler: &stubVPNHandler{marshalErr: someErr},
|
||||||
initVPN: true,
|
initVPN: true,
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
|
@ -189,7 +169,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"no state exists": {
|
"no state exists": {
|
||||||
existingState: state.ConstellationState{},
|
existingState: state.ConstellationState{},
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -200,7 +180,7 @@ func TestInitialize(t *testing.T) {
|
||||||
EC2SecurityGroup: "sg-test",
|
EC2SecurityGroup: "sg-test",
|
||||||
},
|
},
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -211,7 +191,7 @@ func TestInitialize(t *testing.T) {
|
||||||
EC2SecurityGroup: "sg-test",
|
EC2SecurityGroup: "sg-test",
|
||||||
},
|
},
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -219,7 +199,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"public key to short": {
|
"public key to short": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
|
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -227,7 +207,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"public key to long": {
|
"public key to long": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
|
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -235,7 +215,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"public key not base64": {
|
"public key not base64": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: "this is not base64 encoded",
|
privKey: "this is not base64 encoded",
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -243,7 +223,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail Connect": {
|
"fail Connect": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{connectErr: someErr},
|
client: &stubProtoClient{connectErr: someErr},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -251,7 +231,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail Activate": {
|
"fail Activate": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{activateErr: someErr},
|
client: &stubProtoClient{activateErr: someErr},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -259,7 +239,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail respClient WriteLogStream": {
|
"fail respClient WriteLogStream": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
|
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -267,7 +247,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail respClient getKubeconfig": {
|
"fail respClient getKubeconfig": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
|
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -275,7 +255,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail respClient getCoordinatorVpnKey": {
|
"fail respClient getCoordinatorVpnKey": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
|
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -283,7 +263,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail respClient getClientVpnIp": {
|
"fail respClient getClientVpnIp": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
|
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -291,7 +271,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail respClient getOwnerID": {
|
"fail respClient getOwnerID": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
|
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -299,7 +279,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail respClient getClusterID": {
|
"fail respClient getClusterID": {
|
||||||
existingState: testEc2State,
|
existingState: testEc2State,
|
||||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
|
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -307,7 +287,7 @@ func TestInitialize(t *testing.T) {
|
||||||
"fail to wait for required status": {
|
"fail to wait for required status": {
|
||||||
existingState: testGcpState,
|
existingState: testGcpState,
|
||||||
client: &stubProtoClient{},
|
client: &stubProtoClient{},
|
||||||
waiter: stubStatusWaiter{waitForAllErr: someErr},
|
waiter: &stubStatusWaiter{waitForAllErr: someErr},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -318,7 +298,7 @@ func TestInitialize(t *testing.T) {
|
||||||
serviceAccountCreator: stubServiceAccountCreator{
|
serviceAccountCreator: stubServiceAccountCreator{
|
||||||
createErr: someErr,
|
createErr: someErr,
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
vpnHandler: &stubVPNHandler{},
|
vpnHandler: &stubVPNHandler{},
|
||||||
errExpected: true,
|
errExpected: true,
|
||||||
|
@ -335,6 +315,7 @@ func TestInitialize(t *testing.T) {
|
||||||
cmd.SetOut(&out)
|
cmd.SetOut(&out)
|
||||||
var errOut bytes.Buffer
|
var errOut bytes.Buffer
|
||||||
cmd.SetErr(&errOut)
|
cmd.SetErr(&errOut)
|
||||||
|
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
fileHandler := file.NewHandler(fs)
|
fileHandler := file.NewHandler(fs)
|
||||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||||
|
@ -350,7 +331,7 @@ func TestInitialize(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, tc.vpnHandler)
|
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
|
||||||
|
|
||||||
if tc.errExpected {
|
if tc.errExpected {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
|
@ -551,58 +532,30 @@ func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
||||||
|
|
||||||
func TestAutoscaleFlag(t *testing.T) {
|
func TestAutoscaleFlag(t *testing.T) {
|
||||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||||
config := config.Default()
|
|
||||||
testEc2State := state.ConstellationState{
|
testEc2State := state.ConstellationState{
|
||||||
EC2Instances: ec2.Instances{
|
EC2Instances: ec2.Instances{
|
||||||
"id-0": {
|
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||||
PrivateIP: "192.0.2.1",
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||||
PublicIP: "192.0.2.2",
|
"id-2": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.2"},
|
||||||
},
|
|
||||||
"id-1": {
|
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.2",
|
|
||||||
},
|
|
||||||
"id-2": {
|
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.2",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
EC2SecurityGroup: "sg-test",
|
EC2SecurityGroup: "sg-test",
|
||||||
}
|
}
|
||||||
testGcpState := state.ConstellationState{
|
testGcpState := state.ConstellationState{
|
||||||
GCPNodes: gcp.Instances{
|
GCPNodes: gcp.Instances{
|
||||||
"id-0": {
|
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
"id-1": {
|
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
GCPCoordinators: gcp.Instances{
|
GCPCoordinators: gcp.Instances{
|
||||||
"id-c": {
|
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
testAzureState := state.ConstellationState{
|
testAzureState := state.ConstellationState{
|
||||||
AzureNodes: azure.Instances{
|
AzureNodes: azure.Instances{
|
||||||
"id-0": {
|
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
"id-1": {
|
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
AzureCoordinators: azure.Instances{
|
AzureCoordinators: azure.Instances{
|
||||||
"id-c": {
|
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||||
PrivateIP: "192.0.2.1",
|
|
||||||
PublicIP: "192.0.2.1",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
AzureResourceGroup: "test",
|
AzureResourceGroup: "test",
|
||||||
}
|
}
|
||||||
|
@ -633,7 +586,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
client: &stubProtoClient{
|
client: &stubProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
"initialize some gcp instances without autoscale flag": {
|
"initialize some gcp instances without autoscale flag": {
|
||||||
|
@ -642,7 +595,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
client: &stubProtoClient{
|
client: &stubProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
"initialize some azure instances without autoscale flag": {
|
"initialize some azure instances without autoscale flag": {
|
||||||
|
@ -651,7 +604,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
client: &stubProtoClient{
|
client: &stubProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
"initialize some ec2 instances with autoscale flag": {
|
"initialize some ec2 instances with autoscale flag": {
|
||||||
|
@ -660,7 +613,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
client: &stubProtoClient{
|
client: &stubProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
"initialize some gcp instances with autoscale flag": {
|
"initialize some gcp instances with autoscale flag": {
|
||||||
|
@ -669,7 +622,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
client: &stubProtoClient{
|
client: &stubProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
"initialize some azure instances with autoscale flag": {
|
"initialize some azure instances with autoscale flag": {
|
||||||
|
@ -678,7 +631,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
client: &stubProtoClient{
|
client: &stubProtoClient{
|
||||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||||
},
|
},
|
||||||
waiter: stubStatusWaiter{},
|
waiter: &stubStatusWaiter{},
|
||||||
privKey: testKey,
|
privKey: testKey,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -693,6 +646,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
cmd.SetOut(&out)
|
cmd.SetOut(&out)
|
||||||
var errOut bytes.Buffer
|
var errOut bytes.Buffer
|
||||||
cmd.SetErr(&errOut)
|
cmd.SetErr(&errOut)
|
||||||
|
cmd.Flags().String("dev-config", "", "") // register persisten flag manually
|
||||||
fs := afero.NewMemMapFs()
|
fs := afero.NewMemMapFs()
|
||||||
fileHandler := file.NewHandler(fs)
|
fileHandler := file.NewHandler(fs)
|
||||||
vpnHandler := stubVPNHandler{}
|
vpnHandler := stubVPNHandler{}
|
||||||
|
@ -705,7 +659,7 @@ func TestAutoscaleFlag(t *testing.T) {
|
||||||
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
|
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, config, tc.waiter, &vpnHandler))
|
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
|
||||||
if tc.autoscaleFlag {
|
if tc.autoscaleFlag {
|
||||||
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
|
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type protoClient interface {
|
type protoClient interface {
|
||||||
Connect(ip string, port string) error
|
Connect(ip, port string, gcpPCRs, azurePCRs map[uint32][]byte) error
|
||||||
Close() error
|
Close() error
|
||||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
Activate(ctx context.Context, userPublicKey, masterSecret []byte, endpoints, autoscalingNodeGroups []string, cloudServiceAccountURI string) (proto.ActivationResponseClient, error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,7 +23,7 @@ type stubProtoClient struct {
|
||||||
cloudServiceAccountURI string
|
cloudServiceAccountURI string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *stubProtoClient) Connect(ip string, port string) error {
|
func (c *stubProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
|
||||||
c.conn = true
|
c.conn = true
|
||||||
return c.connectErr
|
return c.connectErr
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,7 @@ type fakeProtoClient struct {
|
||||||
respClient proto.ActivationResponseClient
|
respClient proto.ActivationResponseClient
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *fakeProtoClient) Connect(ip string, port string) error {
|
func (c *fakeProtoClient) Connect(_, _ string, _, _ map[uint32][]byte) error {
|
||||||
c.conn = true
|
c.conn = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,5 +7,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type statusWaiter interface {
|
type statusWaiter interface {
|
||||||
|
InitializePCRs(map[uint32][]byte, map[uint32][]byte)
|
||||||
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,14 +2,23 @@ package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/state"
|
"github.com/edgelesssys/constellation/coordinator/state"
|
||||||
)
|
)
|
||||||
|
|
||||||
type stubStatusWaiter struct {
|
type stubStatusWaiter struct {
|
||||||
|
initialized bool
|
||||||
waitForAllErr error
|
waitForAllErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
func (s *stubStatusWaiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
|
||||||
return w.waitForAllErr
|
s.initialized = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||||
|
if !s.initialized {
|
||||||
|
return errors.New("waiter not initialized")
|
||||||
|
}
|
||||||
|
return s.waitForAllErr
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||||
"github.com/edgelesssys/constellation/coordinator/kms"
|
"github.com/edgelesssys/constellation/coordinator/kms"
|
||||||
|
@ -23,29 +22,23 @@ import (
|
||||||
type Client struct {
|
type Client struct {
|
||||||
conn *grpc.ClientConn
|
conn *grpc.ClientConn
|
||||||
avpn pubproto.APIClient
|
avpn pubproto.APIClient
|
||||||
validators []atls.Validator
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a Client without a connection.
|
// Connect connects the client to a given server, using the handed
|
||||||
func NewClient(gcpPCRs map[uint32][]byte) *Client {
|
// Validators for the attestation of the connection.
|
||||||
return &Client{
|
|
||||||
validators: []atls.Validator{
|
|
||||||
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
|
|
||||||
gcp.NewValidator(gcpPCRs),
|
|
||||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
|
||||||
azure.NewValidator(map[uint32][]byte{}),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Connect connects the client to a given server.
|
|
||||||
// The connection must be closed using Close(). If connect is
|
// The connection must be closed using Close(). If connect is
|
||||||
// called on a client that already has a connection, the old
|
// called on a client that already has a connection, the old
|
||||||
// connection is closed.
|
// connection is closed.
|
||||||
func (c *Client) Connect(ip string, port string) error {
|
func (c *Client) Connect(ip, port string, gcpPCRs, AzurePCRs map[uint32][]byte) error {
|
||||||
addr := net.JoinHostPort(ip, port)
|
addr := net.JoinHostPort(ip, port)
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(c.validators)
|
validators := []atls.Validator{
|
||||||
|
gcp.NewValidator(gcpPCRs),
|
||||||
|
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||||
|
azure.NewValidator(map[uint32][]byte{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,7 @@ func TestClose(t *testing.T) {
|
||||||
assert := assert.New(t)
|
assert := assert.New(t)
|
||||||
require := require.New(t)
|
require := require.New(t)
|
||||||
|
|
||||||
client := NewClient(map[uint32][]byte{})
|
client := Client{}
|
||||||
|
|
||||||
// Create a connection.
|
// Create a connection.
|
||||||
listener := bufconn.Listen(4)
|
listener := bufconn.Listen(4)
|
||||||
|
|
|
@ -2,11 +2,11 @@ package status
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/aws"
|
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
"github.com/edgelesssys/constellation/coordinator/attestation/azure"
|
||||||
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
"github.com/edgelesssys/constellation/coordinator/attestation/gcp"
|
||||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||||
|
@ -17,8 +17,10 @@ import (
|
||||||
grpcstatus "google.golang.org/grpc/status"
|
grpcstatus "google.golang.org/grpc/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Waiter waits for PeerStatusServer to reach a specific state.
|
// Waiter waits for PeerStatusServer to reach a specific state. The waiter needs
|
||||||
|
// to be initialized before usage.
|
||||||
type Waiter struct {
|
type Waiter struct {
|
||||||
|
initialized bool
|
||||||
interval time.Duration
|
interval time.Duration
|
||||||
newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error)
|
newConn func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error)
|
||||||
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
|
newClient func(cc grpc.ClientConnInterface) pubproto.APIClient
|
||||||
|
@ -26,17 +28,26 @@ type Waiter struct {
|
||||||
|
|
||||||
// NewWaiter returns a default Waiter with probing inteval of 10 seconds,
|
// NewWaiter returns a default Waiter with probing inteval of 10 seconds,
|
||||||
// attested gRPC connection and PeerStatusClient.
|
// attested gRPC connection and PeerStatusClient.
|
||||||
func NewWaiter(gcpPCRs map[uint32][]byte) Waiter {
|
func NewWaiter() *Waiter {
|
||||||
return Waiter{
|
return &Waiter{
|
||||||
interval: 10 * time.Second,
|
interval: 10 * time.Second,
|
||||||
newConn: newAttestedConnGenerator(gcpPCRs),
|
|
||||||
newClient: pubproto.NewAPIClient,
|
newClient: pubproto.NewAPIClient,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializePCRs initializes the PCRs for the attestation validators.
|
||||||
|
func (w *Waiter) InitializePCRs(gcpPCRs, azurePCRs map[uint32][]byte) {
|
||||||
|
w.newConn = newAttestedConnGenerator(gcpPCRs, azurePCRs)
|
||||||
|
w.initialized = true
|
||||||
|
}
|
||||||
|
|
||||||
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
|
// WaitFor waits for a PeerStatusServer, which is reachable under the given endpoint
|
||||||
// to reach the specified state.
|
// to reach the specified state.
|
||||||
func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
|
func (w *Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.State) error {
|
||||||
|
if !w.initialized {
|
||||||
|
return errors.New("waiter not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
ticker := time.NewTicker(w.interval)
|
ticker := time.NewTicker(w.interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
@ -71,7 +82,7 @@ func (w Waiter) WaitFor(ctx context.Context, endpoint string, status ...state.St
|
||||||
}
|
}
|
||||||
|
|
||||||
// probe sends a PeerStatusCheck request to a PeerStatusServer and returns the response.
|
// probe sends a PeerStatusCheck request to a PeerStatusServer and returns the response.
|
||||||
func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
|
func (w *Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateResponse, error) {
|
||||||
conn, err := w.newConn(ctx, endpoint)
|
conn, err := w.newConn(ctx, endpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -84,7 +95,11 @@ func (w Waiter) probe(ctx context.Context, endpoint string) (*pubproto.GetStateR
|
||||||
|
|
||||||
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
|
// WaitForAll waits for a list of PeerStatusServers, which listen on the handed
|
||||||
// endpoints, to reach the specified state.
|
// endpoints, to reach the specified state.
|
||||||
func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||||
|
if !w.initialized {
|
||||||
|
return errors.New("waiter not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
for _, endpoint := range endpoints {
|
for _, endpoint := range endpoints {
|
||||||
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
|
if err := w.WaitFor(ctx, endpoint, status...); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -94,13 +109,12 @@ func (w Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...st
|
||||||
}
|
}
|
||||||
|
|
||||||
// newAttestedConnGenerator creates a function returning a default attested grpc connection.
|
// newAttestedConnGenerator creates a function returning a default attested grpc connection.
|
||||||
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
func newAttestedConnGenerator(gcpPCRs map[uint32][]byte, azurePCRs map[uint32][]byte) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||||
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) {
|
||||||
validators := []atls.Validator{
|
validators := []atls.Validator{
|
||||||
aws.NewValidator(aws.NaAdGetVerifiedPayloadAsJson),
|
|
||||||
gcp.NewValidator(gcpPCRs),
|
gcp.NewValidator(gcpPCRs),
|
||||||
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
gcp.NewNonCVMValidator(map[uint32][]byte{}), // TODO: Remove once we no longer use non cvms
|
||||||
azure.NewValidator(map[uint32][]byte{}),
|
azure.NewValidator(azurePCRs),
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators)
|
||||||
|
|
|
@ -12,6 +12,21 @@ import (
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestInitializeValidators(t *testing.T) {
|
||||||
|
assert := assert.New(t)
|
||||||
|
|
||||||
|
waiter := Waiter{
|
||||||
|
interval: time.Millisecond,
|
||||||
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uninitialized waiter fails.
|
||||||
|
assert.Error(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
|
||||||
|
|
||||||
|
waiter.InitializeValidators(nil)
|
||||||
|
assert.NoError(waiter.WaitFor(context.Background(), "someIP", state.IsNode))
|
||||||
|
}
|
||||||
|
|
||||||
func TestWaitForAndWaitForAll(t *testing.T) {
|
func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
var noErr error
|
var noErr error
|
||||||
someErr := errors.New("failed")
|
someErr := errors.New("failed")
|
||||||
|
@ -23,6 +38,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
"successful wait": {
|
"successful wait": {
|
||||||
waiter: Waiter{
|
waiter: Waiter{
|
||||||
|
initialized: true,
|
||||||
interval: time.Millisecond,
|
interval: time.Millisecond,
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||||
|
@ -31,6 +47,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
},
|
},
|
||||||
"successful wait multi states": {
|
"successful wait multi states": {
|
||||||
waiter: Waiter{
|
waiter: Waiter{
|
||||||
|
initialized: true,
|
||||||
interval: time.Millisecond,
|
interval: time.Millisecond,
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.IsNode}),
|
||||||
|
@ -39,6 +56,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
},
|
},
|
||||||
"expect timeout": {
|
"expect timeout": {
|
||||||
waiter: Waiter{
|
waiter: Waiter{
|
||||||
|
initialized: true,
|
||||||
interval: time.Millisecond,
|
interval: time.Millisecond,
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{state: state.AcceptingInit}),
|
||||||
|
@ -48,6 +66,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail to check call": {
|
"fail to check call": {
|
||||||
waiter: Waiter{
|
waiter: Waiter{
|
||||||
|
initialized: true,
|
||||||
interval: time.Millisecond,
|
interval: time.Millisecond,
|
||||||
newConn: stubNewConnFunc(noErr),
|
newConn: stubNewConnFunc(noErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{checkErr: someErr}),
|
||||||
|
@ -57,6 +76,7 @@ func TestWaitForAndWaitForAll(t *testing.T) {
|
||||||
},
|
},
|
||||||
"fail to create conn": {
|
"fail to create conn": {
|
||||||
waiter: Waiter{
|
waiter: Waiter{
|
||||||
|
initialized: true,
|
||||||
interval: time.Millisecond,
|
interval: time.Millisecond,
|
||||||
newConn: stubNewConnFunc(someErr),
|
newConn: stubNewConnFunc(someErr),
|
||||||
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
|
newClient: stubNewClientFunc(&stubPeerStatusClient{}),
|
||||||
|
|
|
@ -41,7 +41,8 @@ func main() {
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// wait for coordinator to come online
|
// wait for coordinator to come online
|
||||||
waiter := status.NewWaiter(map[uint32][]byte{})
|
waiter := status.NewWaiter()
|
||||||
|
waiter.InitializePCRs(map[uint32][]byte{}, map[uint32][]byte{})
|
||||||
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
|
if err := waiter.WaitFor(ctx, addr, state.AcceptingInit, state.ActivatingNodes, state.IsNode, state.NodeWaitingForClusterJoin); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue