diff --git a/cli/internal/cmd/BUILD.bazel b/cli/internal/cmd/BUILD.bazel index 7b29a66fa..cf22bc7b6 100644 --- a/cli/internal/cmd/BUILD.bazel +++ b/cli/internal/cmd/BUILD.bazel @@ -145,6 +145,7 @@ go_test( "maapatch_test.go", "recover_test.go", "spinner_test.go", + "ssh_test.go", "status_test.go", "terminate_test.go", "upgradeapply_test.go", diff --git a/cli/internal/cmd/ssh.go b/cli/internal/cmd/ssh.go index c9c78bd0a..079e8ca90 100644 --- a/cli/internal/cmd/ssh.go +++ b/cli/internal/cmd/ssh.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only package cmd import ( + "context" "crypto/ed25519" "crypto/rand" "fmt" @@ -52,7 +53,16 @@ func runSSH(cmd *cobra.Command, _ []string) error { return err } - _, err = fh.Stat(constants.TerraformWorkingDir) + keyPath, err := cmd.Flags().GetString("key") + if err != nil { + return fmt.Errorf("retrieving path to public key from flags: %s", err) + } + + return generateKey(cmd.Context(), keyPath, fh, debugLogger) +} + +func generateKey(ctx context.Context, keyPath string, fh file.Handler, debugLogger debugLog) error { + _, err := fh.Stat(constants.TerraformWorkingDir) if os.IsNotExist(err) { return fmt.Errorf("directory %q does not exist", constants.TerraformWorkingDir) } @@ -67,11 +77,11 @@ func runSSH(cmd *cobra.Command, _ []string) error { } mastersecretURI := uri.MasterSecret{Key: mastersecret.Key, Salt: mastersecret.Salt} - kms, err := setup.KMS(cmd.Context(), uri.NoStoreURI, mastersecretURI.EncodeToURI()) + kms, err := setup.KMS(ctx, uri.NoStoreURI, mastersecretURI.EncodeToURI()) if err != nil { return fmt.Errorf("setting up KMS: %s", err) } - key, err := kms.GetDEK(cmd.Context(), crypto.DEKPrefix+constants.SSHCAKeySuffix, ed25519.SeedSize) + key, err := kms.GetDEK(ctx, crypto.DEKPrefix+constants.SSHCAKeySuffix, ed25519.SeedSize) if err != nil { return fmt.Errorf("retrieving key from KMS: %s", err) } @@ -83,11 +93,6 @@ func runSSH(cmd *cobra.Command, _ []string) error { debugLogger.Debug("SSH CA KEY generated", "public-key", string(ssh.MarshalAuthorizedKey(ca.PublicKey()))) - keyPath, err := cmd.Flags().GetString("key") - if err != nil { - return fmt.Errorf("retrieving path to public key from flags: %s", err) - } - keyBuffer, err := fh.Read(keyPath) if err != nil { return fmt.Errorf("reading public key %q: %s", keyPath, err) diff --git a/cli/internal/cmd/ssh_test.go b/cli/internal/cmd/ssh_test.go new file mode 100644 index 000000000..18a96d8d1 --- /dev/null +++ b/cli/internal/cmd/ssh_test.go @@ -0,0 +1,97 @@ +package cmd + +import ( + "context" + "testing" + + "github.com/edgelesssys/constellation/v2/internal/constants" + "github.com/edgelesssys/constellation/v2/internal/file" + "github.com/edgelesssys/constellation/v2/internal/logger" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSSH(t *testing.T) { + require := require.New(t) + + someSSHPubKey := "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBDA1yYg1PIJNjAGjyuv66r8AJtpfBDFLdp3u9lVwkgbVKv1AzcaeTF/NEw+nhNJOjuCZ61LTPj12LZ8Wy/oSm0A= motte@lolcatghost" + someSSHPubKeyPath := "some-key.pub" + someMasterSecret := ` + { + "key": "MDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAK", + "salt": "MDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAwMDAK" + } + ` + + newFsWithDirectory := func() file.Handler { + fh := file.NewHandler(afero.NewMemMapFs()) + require.NoError(fh.MkdirAll(constants.TerraformWorkingDir)) + return fh + } + newFsNoDirectory := func() file.Handler { + fh := file.NewHandler(afero.NewMemMapFs()) + return fh + } + + testCases := map[string]struct { + fh file.Handler + pubKey string + masterSecret string + wantErr bool + }{ + "everything exists": { + fh: newFsWithDirectory(), + pubKey: someSSHPubKey, + masterSecret: someMasterSecret, + }, + "no public key": { + fh: newFsWithDirectory(), + masterSecret: someMasterSecret, + wantErr: true, + }, + "no master secret": { + fh: newFsWithDirectory(), + pubKey: someSSHPubKey, + wantErr: true, + }, + "malformatted public key": { + fh: newFsWithDirectory(), + pubKey: "asdf", + masterSecret: someMasterSecret, + wantErr: true, + }, + "malformatted master secret": { + fh: newFsWithDirectory(), + masterSecret: "asdf", + pubKey: someSSHPubKey, + wantErr: true, + }, + "directory does not exist": { + fh: newFsNoDirectory(), + pubKey: someSSHPubKey, + masterSecret: someMasterSecret, + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + if tc.pubKey != "" { + tc.fh.Write(someSSHPubKeyPath, []byte(tc.pubKey)) + } + if tc.masterSecret != "" { + tc.fh.Write(constants.MasterSecretFilename, []byte(tc.masterSecret)) + } + + err := generateKey(context.Background(), someSSHPubKeyPath, tc.fh, logger.NewTest(t)) + if tc.wantErr { + assert.Error(err) + } else { + assert.NoError(err) + } + }) + } +}