diff --git a/cli/internal/cmd/init.go b/cli/internal/cmd/init.go index d940624d7..8bf635eb1 100644 --- a/cli/internal/cmd/init.go +++ b/cli/internal/cmd/init.go @@ -105,7 +105,7 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator return err } - serviceAccURI, err := getMarschaledServiceAccountURI(provider, config, fileHandler) + serviceAccURI, err := getMarshaledServiceAccountURI(provider, config, fileHandler) if err != nil { return err } @@ -126,11 +126,16 @@ func initialize(cmd *cobra.Command, newDialer func(validator *cloudcmd.Validator return fmt.Errorf("loading Helm charts: %w", err) } + masterSecret, err := readOrGenerateMasterSecret(cmd.OutOrStdout(), fileHandler, flags.masterSecretPath) + if err != nil { + return fmt.Errorf("parsing or generating master secret from file %s: %w", flags.masterSecretPath, err) + } + cmd.Println("Initializing cluster ...") req := &initproto.InitRequest{ AutoscalingNodeGroups: autoscalingNodeGroups, - MasterSecret: flags.masterSecret.Key, - Salt: flags.masterSecret.Salt, + MasterSecret: masterSecret.Key, + Salt: masterSecret.Salt, KmsUri: kms.ClusterKMSURI, StorageUri: kms.NoStoreURI, KeyEncryptionKeyId: "", @@ -253,10 +258,6 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, erro if err != nil { return initFlags{}, fmt.Errorf("parsing master-secret path flag: %w", err) } - masterSecret, err := readOrGenerateMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath) - if err != nil { - return initFlags{}, fmt.Errorf("parsing or generating master mastersecret from file %s: %w", masterSecretPath, err) - } endpoint, err := cmd.Flags().GetString("endpoint") if err != nil { return initFlags{}, fmt.Errorf("parsing endpoint flag: %w", err) @@ -277,19 +278,19 @@ func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, erro } return initFlags{ - configPath: configPath, - endpoint: endpoint, - autoscale: autoscale, - masterSecret: masterSecret, + configPath: configPath, + endpoint: endpoint, + autoscale: autoscale, + masterSecretPath: masterSecretPath, }, nil } // initFlags are the resulting values of flag preprocessing. type initFlags struct { - configPath string - masterSecret masterSecret - endpoint string - autoscale bool + configPath string + masterSecretPath string + endpoint string + autoscale bool } // masterSecret holds the master key and salt for deriving keys. @@ -347,7 +348,7 @@ func readIPFromIDFile(fileHandler file.Handler) (string, error) { return idFile.IP, nil } -func getMarschaledServiceAccountURI(provider cloudprovider.Provider, config *config.Config, fileHandler file.Handler) (string, error) { +func getMarshaledServiceAccountURI(provider cloudprovider.Provider, config *config.Config, fileHandler file.Handler) (string, error) { switch provider { case cloudprovider.GCP: path := config.Provider.GCP.ServiceAccountKeyPath diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index 6f791283b..c9797d0bf 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -68,15 +68,16 @@ func TestInitialize(t *testing.T) { someErr := errors.New("failed") testCases := map[string]struct { - state *state.ConstellationState - idFile *clusterIDsFile - configMutator func(*config.Config) - serviceAccKey *gcpshared.ServiceAccountKey - helmLoader stubHelmLoader - initServerAPI *stubInitServer - endpointFlag string - setAutoscaleFlag bool - wantErr bool + state *state.ConstellationState + idFile *clusterIDsFile + configMutator func(*config.Config) + serviceAccKey *gcpshared.ServiceAccountKey + helmLoader stubHelmLoader + initServerAPI *stubInitServer + endpointFlag string + masterSecretShouldExist bool + setAutoscaleFlag bool + wantErr bool }{ "initialize some gcp instances": { state: testGcpState, @@ -139,12 +140,17 @@ func TestInitialize(t *testing.T) { c.Provider.Azure.ResourceGroup = "resourceGroup" c.Provider.Azure.UserAssignedIdentity = "userAssignedIdentity" }, - initServerAPI: &stubInitServer{}, - wantErr: true, + initServerAPI: &stubInitServer{}, + masterSecretShouldExist: true, + wantErr: true, }, - "fail to load helm charts": { - state: testGcpState, - helmLoader: stubHelmLoader{loadErr: someErr}, + "fail missing enforced PCR": { + state: testGcpState, + idFile: &clusterIDsFile{IP: "192.0.2.1"}, + configMutator: func(c *config.Config) { + c.Provider.GCP.EnforcedMeasurements = append(c.Provider.GCP.EnforcedMeasurements, 10) + }, + serviceAccKey: gcpServiceAccKey, initServerAPI: &stubInitServer{initResp: testInitResp}, wantErr: true, }, @@ -209,6 +215,10 @@ func TestInitialize(t *testing.T) { if tc.wantErr { assert.Error(err) + if !tc.masterSecretShouldExist { + _, err = fileHandler.Stat(constants.MasterSecretFilename) + assert.Error(err) + } return } require.NoError(err) @@ -219,6 +229,10 @@ func TestInitialize(t *testing.T) { } else { assert.Len(tc.initServerAPI.activateAutoscalingNodeGroups, 0) } + var secret masterSecret + assert.NoError(fileHandler.ReadJSON(constants.MasterSecretFilename, &secret)) + assert.NotEmpty(secret.Key) + assert.NotEmpty(secret.Salt) }) } } @@ -303,7 +317,7 @@ func TestInitCompletion(t *testing.T) { } } -func TestReadOrGeneratedMasterSecret(t *testing.T) { +func TestReadOrGenerateMasterSecret(t *testing.T) { testCases := map[string]struct { filename string createFileFunc func(handler file.Handler) error