mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-07-19 21:38:44 -04:00
Move cli/cmd into cli/internal
This commit is contained in:
parent
d71e97a940
commit
c3ebd3d3cd
34 changed files with 45 additions and 32 deletions
|
@ -1,28 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type cloudCreator interface {
|
||||
Create(
|
||||
ctx context.Context,
|
||||
provider cloudprovider.Provider,
|
||||
config *config.Config,
|
||||
name, insType string,
|
||||
coordCount, nodeCount int,
|
||||
) (state.ConstellationState, error)
|
||||
}
|
||||
|
||||
type cloudTerminator interface {
|
||||
Terminate(context.Context, state.ConstellationState) error
|
||||
}
|
||||
|
||||
type serviceAccountCreator interface {
|
||||
Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
|
||||
) (string, state.ConstellationState, error)
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
type stubCloudCreator struct {
|
||||
createCalled bool
|
||||
state state.ConstellationState
|
||||
createErr error
|
||||
}
|
||||
|
||||
func (c *stubCloudCreator) Create(
|
||||
ctx context.Context,
|
||||
provider cloudprovider.Provider,
|
||||
config *config.Config,
|
||||
name, insType string,
|
||||
coordCount, nodeCount int,
|
||||
) (state.ConstellationState, error) {
|
||||
c.createCalled = true
|
||||
return c.state, c.createErr
|
||||
}
|
||||
|
||||
type stubCloudTerminator struct {
|
||||
called bool
|
||||
terminateErr error
|
||||
}
|
||||
|
||||
func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationState) error {
|
||||
c.called = true
|
||||
return c.terminateErr
|
||||
}
|
||||
|
||||
func (c *stubCloudTerminator) Called() bool {
|
||||
return c.called
|
||||
}
|
||||
|
||||
type stubServiceAccountCreator struct {
|
||||
cloudServiceAccountURI string
|
||||
createErr error
|
||||
}
|
||||
|
||||
func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
|
||||
return c.cloudServiceAccountURI, stat, c.createErr
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newConfigCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "config",
|
||||
Short: "Work with the Constellation configuration file",
|
||||
Long: "Generate a configuration file for Constellation.",
|
||||
Args: cobra.ExactArgs(0),
|
||||
}
|
||||
|
||||
cmd.AddCommand(newConfigGenerateCmd())
|
||||
|
||||
return cmd
|
||||
}
|
|
@ -1,81 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/talos-systems/talos/pkg/machinery/config/encoder"
|
||||
)
|
||||
|
||||
func newConfigGenerateCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "generate {aws|azure|gcp}",
|
||||
Short: "Generate a default configuration file",
|
||||
Long: "Generate a default configuration file for your selected cloud provider.",
|
||||
Args: cobra.MatchAll(
|
||||
cobra.ExactArgs(1),
|
||||
isCloudProvider(0),
|
||||
warnAWS(0),
|
||||
),
|
||||
ValidArgsFunction: generateCompletion,
|
||||
RunE: runConfigGenerate,
|
||||
}
|
||||
cmd.Flags().StringP("file", "f", constants.ConfigFilename, "path to output file, or '-' for stdout")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
type generateFlags struct {
|
||||
file string
|
||||
}
|
||||
|
||||
func runConfigGenerate(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
provider := cloudprovider.FromString(args[0])
|
||||
return configGenerate(cmd, fileHandler, provider)
|
||||
}
|
||||
|
||||
func configGenerate(cmd *cobra.Command, fileHandler file.Handler, provider cloudprovider.Provider) error {
|
||||
flags, err := parseGenerateFlags(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conf := config.Default()
|
||||
conf.RemoveProviderExcept(provider)
|
||||
|
||||
if flags.file == "-" {
|
||||
content, err := encoder.NewEncoder(conf).Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = cmd.OutOrStdout().Write(content)
|
||||
return err
|
||||
}
|
||||
|
||||
return fileHandler.WriteYAML(flags.file, conf, 0o644)
|
||||
}
|
||||
|
||||
func parseGenerateFlags(cmd *cobra.Command) (generateFlags, error) {
|
||||
file, err := cmd.Flags().GetString("file")
|
||||
if err != nil {
|
||||
return generateFlags{}, err
|
||||
}
|
||||
return generateFlags{
|
||||
file: file,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createCompletion handles the completion of the create command. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func generateCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return []string{"aws", "gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
func TestConfigGenerateDefault(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
cmd := newConfigGenerateCmd()
|
||||
|
||||
require.NoError(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
|
||||
|
||||
var readConfig config.Config
|
||||
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
|
||||
assert.NoError(err)
|
||||
assert.Equal(*config.Default(), readConfig)
|
||||
}
|
||||
|
||||
func TestConfigGenerateDefaultGCPSpecific(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
wantConf := config.Default()
|
||||
wantConf.RemoveProviderExcept(cloudprovider.GCP)
|
||||
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
cmd := newConfigGenerateCmd()
|
||||
|
||||
require.NoError(configGenerate(cmd, fileHandler, cloudprovider.GCP))
|
||||
|
||||
var readConfig config.Config
|
||||
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
|
||||
assert.NoError(err)
|
||||
assert.Equal(*wantConf, readConfig)
|
||||
}
|
||||
|
||||
func TestConfigGenerateDefaultExists(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(fileHandler.Write(constants.ConfigFilename, []byte("foobar"), file.OptNone))
|
||||
cmd := newConfigGenerateCmd()
|
||||
|
||||
require.Error(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
|
||||
}
|
||||
|
||||
func TestConfigGenerateFileFlagRemoved(t *testing.T) {
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
cmd := newConfigGenerateCmd()
|
||||
cmd.ResetFlags()
|
||||
|
||||
require.Error(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
|
||||
}
|
||||
|
||||
func TestConfigGenerateStdOut(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
|
||||
var outBuffer bytes.Buffer
|
||||
cmd := newConfigGenerateCmd()
|
||||
cmd.SetOut(&outBuffer)
|
||||
require.NoError(cmd.Flags().Set("file", "-"))
|
||||
|
||||
require.NoError(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
|
||||
|
||||
var readConfig config.Config
|
||||
require.NoError(yaml.NewDecoder(&outBuffer).Decode(&readConfig))
|
||||
|
||||
assert.Equal(*config.Default(), readConfig)
|
||||
}
|
|
@ -1,223 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||
"github.com/edgelesssys/constellation/cli/internal/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newCreateCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "create {aws|azure|gcp}",
|
||||
Short: "Create instances on a cloud platform for your Constellation cluster",
|
||||
Long: "Create instances on a cloud platform for your Constellation cluster.",
|
||||
Args: cobra.MatchAll(
|
||||
cobra.ExactArgs(1),
|
||||
isCloudProvider(0),
|
||||
warnAWS(0),
|
||||
),
|
||||
ValidArgsFunction: createCompletion,
|
||||
RunE: runCreate,
|
||||
}
|
||||
cmd.Flags().String("name", "constell", "create the cluster with the specified name")
|
||||
cmd.Flags().BoolP("yes", "y", false, "create the cluster without further confirmation")
|
||||
cmd.Flags().IntP("control-plane-nodes", "c", 0, "number of control-plane nodes (required)")
|
||||
must(cobra.MarkFlagRequired(cmd.Flags(), "control-plane-nodes"))
|
||||
cmd.Flags().IntP("worker-nodes", "w", 0, "number of worker nodes (required)")
|
||||
must(cobra.MarkFlagRequired(cmd.Flags(), "worker-nodes"))
|
||||
cmd.Flags().StringP("instance-type", "t", "", "instance type of cluster nodes")
|
||||
must(cmd.RegisterFlagCompletionFunc("instance-type", instanceTypeCompletion))
|
||||
|
||||
cmd.SetHelpTemplate(cmd.HelpTemplate() + fmt.Sprintf(`
|
||||
Azure instance types:
|
||||
%v
|
||||
|
||||
GCP instance types:
|
||||
%v
|
||||
`, formatInstanceTypes(azure.InstanceTypes), formatInstanceTypes(gcp.InstanceTypes)))
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runCreate(cmd *cobra.Command, args []string) error {
|
||||
provider := cloudprovider.FromString(args[0])
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
creator := cloudcmd.NewCreator(cmd.OutOrStdout())
|
||||
|
||||
return create(cmd, creator, fileHandler, provider)
|
||||
}
|
||||
|
||||
func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, provider cloudprovider.Provider,
|
||||
) (retErr error) {
|
||||
flags, err := parseCreateFlags(cmd, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := checkDirClean(fileHandler); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !flags.yes {
|
||||
// Ask user to confirm action.
|
||||
cmd.Printf("The following Constellation cluster will be created:\n")
|
||||
cmd.Printf("%d control-planes nodes of type %s will be created.\n", flags.controllerCount, flags.insType)
|
||||
cmd.Printf("%d worker nodes of type %s will be created.\n", flags.workerCount, flags.insType)
|
||||
ok, err := askToConfirm(cmd, "Do you want to create this cluster?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
cmd.Println("The creation of the cluster was aborted.")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
state, err := creator.Create(cmd.Context(), provider, config, flags.name, flags.insType, flags.controllerCount, flags.workerCount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fileHandler.WriteJSON(constants.StateFilename, state, file.OptNone); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Your Constellation cluster was created successfully.")
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseCreateFlags parses the flags of the create command.
|
||||
func parseCreateFlags(cmd *cobra.Command, provider cloudprovider.Provider) (createFlags, error) {
|
||||
controllerCount, err := cmd.Flags().GetInt("control-plane-nodes")
|
||||
if err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
if controllerCount < constants.MinControllerCount {
|
||||
return createFlags{}, fmt.Errorf("number of control-plane nodes must be at least %d", constants.MinControllerCount)
|
||||
}
|
||||
|
||||
workerCount, err := cmd.Flags().GetInt("worker-nodes")
|
||||
if err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
if workerCount < constants.MinWorkerCount {
|
||||
return createFlags{}, fmt.Errorf("number of worker nodes must be at least %d", constants.MinWorkerCount)
|
||||
}
|
||||
|
||||
insType, err := cmd.Flags().GetString("instance-type")
|
||||
if err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
if insType == "" {
|
||||
insType = defaultInstanceType(provider)
|
||||
}
|
||||
if err := validInstanceTypeForProvider(cmd, insType, provider); err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
|
||||
name, err := cmd.Flags().GetString("name")
|
||||
if err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
if len(name) > constants.ConstellationNameLength {
|
||||
return createFlags{}, fmt.Errorf(
|
||||
"name for Constellation cluster too long, maximum length is %d, got %d: %s",
|
||||
constants.ConstellationNameLength, len(name), name,
|
||||
)
|
||||
}
|
||||
|
||||
yes, err := cmd.Flags().GetBool("yes")
|
||||
if err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
|
||||
configPath, err := cmd.Flags().GetString("config")
|
||||
if err != nil {
|
||||
return createFlags{}, err
|
||||
}
|
||||
|
||||
return createFlags{
|
||||
controllerCount: controllerCount,
|
||||
workerCount: workerCount,
|
||||
insType: insType,
|
||||
name: name,
|
||||
configPath: configPath,
|
||||
yes: yes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createFlags contains the parsed flags of the create command.
|
||||
type createFlags struct {
|
||||
controllerCount int
|
||||
workerCount int
|
||||
insType string
|
||||
name string
|
||||
configPath string
|
||||
yes bool
|
||||
}
|
||||
|
||||
// defaultInstanceType returns the default instance type for the given provider.
|
||||
func defaultInstanceType(provider cloudprovider.Provider) string {
|
||||
switch provider {
|
||||
case cloudprovider.GCP:
|
||||
return gcp.InstanceTypes[0]
|
||||
case cloudprovider.Azure:
|
||||
return azure.InstanceTypes[0]
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// checkDirClean checks if files of a previous Constellation are left in the current working dir.
|
||||
func checkDirClean(fileHandler file.Handler) error {
|
||||
if _, err := fileHandler.Stat(constants.StateFilename); !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", constants.StateFilename)
|
||||
}
|
||||
if _, err := fileHandler.Stat(constants.AdminConfFilename); !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", constants.AdminConfFilename)
|
||||
}
|
||||
if _, err := fileHandler.Stat(constants.MasterSecretFilename); !errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("file '%s' already exists in working directory, clean it up first", constants.MasterSecretFilename)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createCompletion handles the completion of the create command. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func createCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return []string{"aws", "gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
||||
|
||||
func instanceTypeCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
if len(args) != 1 {
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
switch args[0] {
|
||||
case "gcp":
|
||||
return gcp.InstanceTypes, cobra.ShellCompDirectiveNoFileComp
|
||||
case "azure":
|
||||
return azure.InstanceTypes, cobra.ShellCompDirectiveNoFileComp
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
|
@ -1,397 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||
"github.com/edgelesssys/constellation/cli/internal/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantErr bool
|
||||
}{
|
||||
"gcp": {[]string{"gcp"}, false},
|
||||
"azure": {[]string{"azure"}, false},
|
||||
"aws waring": {[]string{"aws"}, true},
|
||||
"too many args": {[]string{"gcp", "1", "2"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
err := newCreateCmd().ValidateArgs(tc.args)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreate(t *testing.T) {
|
||||
testState := state.ConstellationState{Name: "test"}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
setupFs func(*require.Assertions) afero.Fs
|
||||
creator *stubCloudCreator
|
||||
provider cloudprovider.Provider
|
||||
yesFlag bool
|
||||
controllerCountFlag *int
|
||||
workerCountFlag *int
|
||||
insTypeFlag string
|
||||
configFlag string
|
||||
nameFlag string
|
||||
stdin string
|
||||
wantErr bool
|
||||
wantAbbort bool
|
||||
}{
|
||||
"create": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{state: testState},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(2),
|
||||
yesFlag: true,
|
||||
},
|
||||
"interactive": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{state: testState},
|
||||
provider: cloudprovider.Azure,
|
||||
controllerCountFlag: intPtr(2),
|
||||
workerCountFlag: intPtr(1),
|
||||
stdin: "yes\n",
|
||||
},
|
||||
"interactive abort": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
stdin: "no\n",
|
||||
wantAbbort: true,
|
||||
},
|
||||
"interactive error": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
stdin: "foo\nfoo\nfoo\n",
|
||||
wantErr: true,
|
||||
},
|
||||
"flag name to long": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
nameFlag: strings.Repeat("a", constants.ConstellationNameLength+1),
|
||||
wantErr: true,
|
||||
},
|
||||
"flag control-plane-count invalid": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(0),
|
||||
workerCountFlag: intPtr(3),
|
||||
wantErr: true,
|
||||
},
|
||||
"flag worker-count invalid": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(3),
|
||||
workerCountFlag: intPtr(0),
|
||||
wantErr: true,
|
||||
},
|
||||
"flag control-plane-count missing": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
workerCountFlag: intPtr(3),
|
||||
wantErr: true,
|
||||
},
|
||||
"flag worker-count missing": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(3),
|
||||
wantErr: true,
|
||||
},
|
||||
"flag invalid instance-type": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
insTypeFlag: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
"old state in directory": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.Write(constants.StateFilename, []byte{1}, file.OptNone))
|
||||
return fs
|
||||
},
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
yesFlag: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"old adminConf in directory": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.Write(constants.AdminConfFilename, []byte{1}, file.OptNone))
|
||||
return fs
|
||||
},
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
yesFlag: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"old masterSecret in directory": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.Write(constants.MasterSecretFilename, []byte{1}, file.OptNone))
|
||||
return fs
|
||||
},
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
yesFlag: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"config does not exist": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
yesFlag: true,
|
||||
configFlag: constants.ConfigFilename,
|
||||
wantErr: true,
|
||||
},
|
||||
"create error": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
creator: &stubCloudCreator{createErr: someErr},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
yesFlag: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"write state error": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
return afero.NewReadOnlyFs(fs)
|
||||
},
|
||||
creator: &stubCloudCreator{},
|
||||
provider: cloudprovider.GCP,
|
||||
controllerCountFlag: intPtr(1),
|
||||
workerCountFlag: intPtr(1),
|
||||
yesFlag: true,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newCreateCmd()
|
||||
cmd.SetOut(&bytes.Buffer{})
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
cmd.SetIn(bytes.NewBufferString(tc.stdin))
|
||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
||||
if tc.yesFlag {
|
||||
require.NoError(cmd.Flags().Set("yes", "true"))
|
||||
}
|
||||
if tc.nameFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("name", tc.nameFlag))
|
||||
}
|
||||
if tc.configFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("config", tc.configFlag))
|
||||
}
|
||||
if tc.controllerCountFlag != nil {
|
||||
require.NoError(cmd.Flags().Set("control-plane-nodes", strconv.Itoa(*tc.controllerCountFlag)))
|
||||
}
|
||||
if tc.workerCountFlag != nil {
|
||||
require.NoError(cmd.Flags().Set("worker-nodes", strconv.Itoa(*tc.workerCountFlag)))
|
||||
}
|
||||
if tc.insTypeFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("instance-type", tc.insTypeFlag))
|
||||
}
|
||||
|
||||
fileHandler := file.NewHandler(tc.setupFs(require))
|
||||
|
||||
err := create(cmd, tc.creator, fileHandler, tc.provider)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
if tc.wantAbbort {
|
||||
assert.False(tc.creator.createCalled)
|
||||
} else {
|
||||
assert.True(tc.creator.createCalled)
|
||||
var state state.ConstellationState
|
||||
require.NoError(fileHandler.ReadJSON(constants.StateFilename, &state))
|
||||
assert.Equal(state, testState)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckDirClean(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
fileHandler file.Handler
|
||||
existingFiles []string
|
||||
wantErr bool
|
||||
}{
|
||||
"no file exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
},
|
||||
"adminconf exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{constants.AdminConfFilename},
|
||||
wantErr: true,
|
||||
},
|
||||
"master secret exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{constants.MasterSecretFilename},
|
||||
wantErr: true,
|
||||
},
|
||||
"state file exists": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{constants.StateFilename},
|
||||
wantErr: true,
|
||||
},
|
||||
"multiple exist": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
existingFiles: []string{constants.AdminConfFilename, constants.MasterSecretFilename, constants.StateFilename},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
for _, f := range tc.existingFiles {
|
||||
require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, file.OptNone))
|
||||
}
|
||||
|
||||
err := checkDirClean(tc.fileHandler)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantResult []string
|
||||
wantShellCD cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
wantResult: []string{"aws", "gcp", "azure"},
|
||||
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"second arg": {
|
||||
args: []string{"gcp", "foo"},
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := createCompletion(cmd, tc.args, "")
|
||||
assert.Equal(tc.wantResult, result)
|
||||
assert.Equal(tc.wantShellCD, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstanceTypeCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantResult []string
|
||||
wantShellCD cobra.ShellCompDirective
|
||||
}{
|
||||
"azure": {
|
||||
args: []string{"azure"},
|
||||
wantResult: azure.InstanceTypes,
|
||||
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"gcp": {
|
||||
args: []string{"gcp"},
|
||||
wantResult: gcp.InstanceTypes,
|
||||
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"empty args": {
|
||||
args: []string{},
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
"unknown provider": {
|
||||
args: []string{"foo"},
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := instanceTypeCompletion(cmd, tc.args, "")
|
||||
|
||||
assert.Equal(tc.wantResult, result)
|
||||
assert.Equal(tc.wantShellCD, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func intPtr(i int) *int {
|
||||
return &i
|
||||
}
|
490
cli/cmd/init.go
490
cli/cmd/init.go
|
@ -1,490 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"strconv"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||
"github.com/edgelesssys/constellation/cli/internal/gcp"
|
||||
"github.com/edgelesssys/constellation/cli/internal/proto"
|
||||
"github.com/edgelesssys/constellation/cli/internal/status"
|
||||
"github.com/edgelesssys/constellation/cli/internal/vpn"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/kr/text"
|
||||
wgquick "github.com/nmiculinic/wg-quick-go"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
func newInitCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "init",
|
||||
Short: "Initialize the Constellation cluster",
|
||||
Long: "Initialize the Constellation cluster. Start your confidential Kubernetes.",
|
||||
ValidArgsFunction: initCompletion,
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: runInitialize,
|
||||
}
|
||||
cmd.Flags().String("privatekey", "", "path to your private key")
|
||||
cmd.Flags().String("master-secret", "", "path to base64-encoded master secret")
|
||||
cmd.Flags().Bool("wg-autoconfig", false, "enable automatic configuration of WireGuard interface")
|
||||
must(cmd.Flags().MarkHidden("wg-autoconfig"))
|
||||
cmd.Flags().Bool("autoscale", false, "enable Kubernetes cluster-autoscaler")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runInitialize runs the initialize command.
|
||||
func runInitialize(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
vpnHandler := vpn.NewConfigHandler()
|
||||
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
|
||||
waiter := status.NewWaiter()
|
||||
protoClient := &proto.Client{}
|
||||
defer protoClient.Close()
|
||||
|
||||
// We have to parse the context separately, since cmd.Context()
|
||||
// returns nil during the tests otherwise.
|
||||
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
|
||||
}
|
||||
|
||||
// initialize initializes a Constellation. Coordinator instances are activated as contole-plane nodes and will
|
||||
// themself activate the other peers as workers.
|
||||
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
|
||||
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
|
||||
) error {
|
||||
flags, err := evalFlagArgs(cmd, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var stat state.ConstellationState
|
||||
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return fmt.Errorf("nothing to initialize: %w", err)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider := cloudprovider.FromString(stat.CloudProvider)
|
||||
|
||||
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var sshUsers []*ssh.UserKey
|
||||
for _, user := range config.SSHUsers {
|
||||
sshUsers = append(sshUsers, &ssh.UserKey{
|
||||
Username: user.Username,
|
||||
PublicKey: user.PublicKey,
|
||||
})
|
||||
}
|
||||
|
||||
validators, err := cloudcmd.NewValidators(provider, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(validators.WarningsIncludeInit())
|
||||
|
||||
cmd.Println("Creating service account ...")
|
||||
serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
coordinators, nodes, err := getScalingGroupsFromConfig(stat, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), strconv.Itoa(constants.CoordinatorPort))
|
||||
|
||||
cmd.Println("Waiting for cloud provider resource creation and boot ...")
|
||||
if err := waiter.InitializeValidators(validators.V()); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
|
||||
return fmt.Errorf("failed to wait for peer status: %w", err)
|
||||
}
|
||||
|
||||
var autoscalingNodeGroups []string
|
||||
if flags.autoscale {
|
||||
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
|
||||
}
|
||||
|
||||
input := activationInput{
|
||||
coordinatorPubIP: coordinators.PublicIPs()[0],
|
||||
pubKey: flags.userPubKey,
|
||||
masterSecret: flags.masterSecret,
|
||||
nodePrivIPs: nodes.PrivateIPs(),
|
||||
coordinatorPrivIPs: coordinators.PrivateIPs()[1:],
|
||||
autoscalingNodeGroups: autoscalingNodeGroups,
|
||||
cloudServiceAccountURI: serviceAccount,
|
||||
sshUserKeys: ssh.ToProtoSlice(sshUsers),
|
||||
}
|
||||
result, err := activate(ctx, cmd, protCl, input, validators.V())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = result.writeOutput(cmd.OutOrStdout(), fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, constants.WireguardAdminMTU)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := writeWGQuickFile(fileHandler, vpnHandler, vpnConfig); err != nil {
|
||||
return fmt.Errorf("write wg-quick file: %w", err)
|
||||
}
|
||||
|
||||
if flags.autoconfigureWG {
|
||||
if err := vpnHandler.Apply(vpnConfig); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
|
||||
validators []atls.Validator,
|
||||
) (activationResult, error) {
|
||||
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, strconv.Itoa(constants.CoordinatorPort)), validators)
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys)
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
indentOut := text.NewIndentWriter(cmd.OutOrStdout(), []byte{'\t'})
|
||||
cmd.Println("Activating the cluster ...")
|
||||
if err := respCl.WriteLogStream(indentOut); err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
clientVpnIp, err := respCl.GetClientVpnIp()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
coordinatorPubKey, err := respCl.GetCoordinatorVpnKey()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
kubeconfig, err := respCl.GetKubeconfig()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
ownerID, err := respCl.GetOwnerID()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
clusterID, err := respCl.GetClusterID()
|
||||
if err != nil {
|
||||
return activationResult{}, err
|
||||
}
|
||||
|
||||
return activationResult{
|
||||
clientVpnIP: clientVpnIp,
|
||||
coordinatorPubKey: coordinatorPubKey,
|
||||
coordinatorPubIP: input.coordinatorPubIP,
|
||||
kubeconfig: kubeconfig,
|
||||
ownerID: ownerID,
|
||||
clusterID: clusterID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type activationInput struct {
|
||||
coordinatorPubIP string
|
||||
pubKey []byte
|
||||
masterSecret []byte
|
||||
nodePrivIPs []string
|
||||
coordinatorPrivIPs []string
|
||||
autoscalingNodeGroups []string
|
||||
cloudServiceAccountURI string
|
||||
sshUserKeys []*pubproto.SSHUserKey
|
||||
}
|
||||
|
||||
type activationResult struct {
|
||||
clientVpnIP string
|
||||
coordinatorPubKey string
|
||||
coordinatorPubIP string
|
||||
kubeconfig string
|
||||
ownerID string
|
||||
clusterID string
|
||||
}
|
||||
|
||||
// writeWGQuickFile writes the wg-quick file to the default path.
|
||||
func writeWGQuickFile(fileHandler file.Handler, vpnHandler vpnHandler, vpnConfig *wgquick.Config) error {
|
||||
data, err := vpnHandler.Marshal(vpnConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return fileHandler.Write(constants.WGQuickConfigFilename, data, file.OptNone)
|
||||
}
|
||||
|
||||
func (r activationResult) writeOutput(wr io.Writer, fileHandler file.Handler) error {
|
||||
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
|
||||
|
||||
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
|
||||
writeRow(tw, "Your WireGuard IP", r.clientVpnIP)
|
||||
writeRow(tw, "Control plane's public IP", r.coordinatorPubIP)
|
||||
writeRow(tw, "Control plane's public key", r.coordinatorPubKey)
|
||||
writeRow(tw, "Constellation cluster's owner identifier", r.ownerID)
|
||||
writeRow(tw, "Constellation cluster's unique identifier", r.clusterID)
|
||||
writeRow(tw, "WireGuard configuration file", constants.WGQuickConfigFilename)
|
||||
writeRow(tw, "Kubernetes configuration", constants.AdminConfFilename)
|
||||
tw.Flush()
|
||||
fmt.Fprintln(wr)
|
||||
|
||||
if err := fileHandler.Write(constants.AdminConfFilename, []byte(r.kubeconfig), file.OptNone); err != nil {
|
||||
return fmt.Errorf("write kubeconfig: %w", err)
|
||||
}
|
||||
|
||||
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
|
||||
fmt.Fprintf(wr, "\twg-quick up ./%s\n", constants.WGQuickConfigFilename)
|
||||
fmt.Fprintf(wr, "\texport KUBECONFIG=\"$PWD/%s\"\n", constants.AdminConfFilename)
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeRow(wr io.Writer, col1 string, col2 string) {
|
||||
fmt.Fprint(wr, col1, "\t", col2, "\n")
|
||||
}
|
||||
|
||||
// evalFlagArgs gets the flag values and does preprocessing of these values like
|
||||
// reading the content from file path flags and deriving other values from flag combinations.
|
||||
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) {
|
||||
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
autoscale, err := cmd.Flags().GetBool("autoscale")
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
configPath, err := cmd.Flags().GetString("config")
|
||||
if err != nil {
|
||||
return initFlags{}, err
|
||||
}
|
||||
|
||||
return initFlags{
|
||||
configPath: configPath,
|
||||
userPrivKey: userPrivKey,
|
||||
userPubKey: userPubKey,
|
||||
autoconfigureWG: autoconfigureWG,
|
||||
autoscale: autoscale,
|
||||
masterSecret: masterSecret,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// initFlags are the resulting values of flag preprocessing.
|
||||
type initFlags struct {
|
||||
configPath string
|
||||
userPrivKey []byte
|
||||
userPubKey []byte
|
||||
masterSecret []byte
|
||||
autoconfigureWG bool
|
||||
autoscale bool
|
||||
}
|
||||
|
||||
func readOrGenerateVPNKey(fileHandler file.Handler, privKeyPath string) (privKey, pubKey []byte, err error) {
|
||||
var privKeyParsed wgtypes.Key
|
||||
if privKeyPath == "" {
|
||||
privKeyParsed, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
privKey = []byte(privKeyParsed.String())
|
||||
} else {
|
||||
privKey, err = fileHandler.Read(privKeyPath)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
privKeyParsed, err = wgtypes.ParseKey(string(privKey))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
pubKey = []byte(privKeyParsed.PublicKey().String())
|
||||
|
||||
return privKey, pubKey, nil
|
||||
}
|
||||
|
||||
func ipsToEndpoints(ips []string, port string) []string {
|
||||
var endpoints []string
|
||||
for _, ip := range ips {
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
endpoints = append(endpoints, net.JoinHostPort(ip, port))
|
||||
}
|
||||
return endpoints
|
||||
}
|
||||
|
||||
// readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
|
||||
func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string) ([]byte, error) {
|
||||
if filename != "" {
|
||||
// Try to read the base64 secret from file
|
||||
encodedSecret, err := fileHandler.Read(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(decoded) < constants.MasterSecretLengthMin {
|
||||
return nil, errors.New("provided master secret is smaller than the required minimum of 16 Bytes")
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// No file given, generate a new secret, and save it to disk
|
||||
masterSecret, err := util.GenerateRandomBytes(constants.MasterSecretLengthDefault)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := fileHandler.Write(constants.MasterSecretFilename, []byte(base64.StdEncoding.EncodeToString(masterSecret)), file.OptNone); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Fprintf(w, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
|
||||
return masterSecret, nil
|
||||
}
|
||||
|
||||
func getScalingGroupsFromConfig(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
|
||||
switch {
|
||||
case len(stat.GCPCoordinators) != 0:
|
||||
return getGCPInstances(stat, config)
|
||||
case len(stat.AzureCoordinators) != 0:
|
||||
return getAzureInstances(stat, config)
|
||||
case len(stat.QEMUCoordinators) != 0:
|
||||
return getQEMUInstances(stat, config)
|
||||
default:
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no instances to initialize")
|
||||
}
|
||||
}
|
||||
|
||||
func getGCPInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
|
||||
if len(stat.GCPCoordinators) == 0 {
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no control-plane nodes available, can't create Constellation without any instance")
|
||||
}
|
||||
|
||||
// GroupID of coordinators is empty, since they currently do not scale.
|
||||
coordinators = cloudtypes.ScalingGroup{
|
||||
Instances: stat.GCPCoordinators,
|
||||
GroupID: "",
|
||||
}
|
||||
|
||||
if len(stat.GCPNodes) == 0 {
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no worker nodes available, can't create Constellation with one instance")
|
||||
}
|
||||
|
||||
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
|
||||
nodes = cloudtypes.ScalingGroup{
|
||||
Instances: stat.GCPNodes,
|
||||
GroupID: gcp.AutoscalingNodeGroup(stat.GCPProject, stat.GCPZone, stat.GCPNodeInstanceGroup, config.AutoscalingNodeGroupMin, config.AutoscalingNodeGroupMax),
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getAzureInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
|
||||
if len(stat.AzureCoordinators) == 0 {
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no control-plane nodes available, can't create Constellation cluster without any instance")
|
||||
}
|
||||
|
||||
// GroupID of coordinators is empty, since they currently do not scale.
|
||||
coordinators = cloudtypes.ScalingGroup{
|
||||
Instances: stat.AzureCoordinators,
|
||||
GroupID: "",
|
||||
}
|
||||
|
||||
if len(stat.AzureNodes) == 0 {
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no worker nodes available, can't create Constellation cluster with one instance")
|
||||
}
|
||||
|
||||
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
|
||||
nodes = cloudtypes.ScalingGroup{
|
||||
Instances: stat.AzureNodes,
|
||||
GroupID: azure.AutoscalingNodeGroup(stat.AzureNodesScaleSet, config.AutoscalingNodeGroupMin, config.AutoscalingNodeGroupMax),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getQEMUInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
|
||||
coordinatorMap := stat.QEMUCoordinators
|
||||
if len(coordinatorMap) == 0 {
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no coordinators available, can't create Constellation without any instance")
|
||||
}
|
||||
|
||||
// QEMU does not support autoscaling
|
||||
coordinators = cloudtypes.ScalingGroup{
|
||||
Instances: stat.QEMUCoordinators,
|
||||
GroupID: "",
|
||||
}
|
||||
|
||||
if len(stat.QEMUNodes) == 0 {
|
||||
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
|
||||
}
|
||||
|
||||
// QEMU does not support autoscaling
|
||||
nodes = cloudtypes.ScalingGroup{
|
||||
Instances: stat.QEMUNodes,
|
||||
GroupID: "",
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// initCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
if len(args) != 0 {
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
return []string{}, cobra.ShellCompDirectiveDefault
|
||||
}
|
|
@ -1,662 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
wgquick "github.com/nmiculinic/wg-quick-go"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestInitArgumentValidation(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
assert.NoError(cmd.ValidateArgs(nil))
|
||||
assert.Error(cmd.ValidateArgs([]string{"something"}))
|
||||
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
|
||||
}
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
testGcpState := state.ConstellationState{
|
||||
CloudProvider: "GCP",
|
||||
GCPNodes: cloudtypes.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
GCPCoordinators: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
CloudProvider: "Azure",
|
||||
AzureNodes: cloudtypes.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureCoordinators: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
testQemuState := state.ConstellationState{
|
||||
CloudProvider: "QEMU",
|
||||
QEMUNodes: cloudtypes.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
QEMUCoordinators: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
}
|
||||
testActivationResps := []fakeActivationRespMessage{
|
||||
{log: "testlog1"},
|
||||
{log: "testlog2"},
|
||||
{
|
||||
kubeconfig: "kubeconfig",
|
||||
clientVpnIp: "192.0.2.2",
|
||||
coordinatorVpnKey: testKey,
|
||||
ownerID: "ownerID",
|
||||
clusterID: "clusterID",
|
||||
},
|
||||
{log: "testlog3"},
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
existingState state.ConstellationState
|
||||
client protoClient
|
||||
serviceAccountCreator stubServiceAccountCreator
|
||||
waiter statusWaiter
|
||||
privKey string
|
||||
vpnHandler vpnHandler
|
||||
initVPN bool
|
||||
wantErr bool
|
||||
}{
|
||||
"initialize some gcp instances": {
|
||||
existingState: testGcpState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some azure instances": {
|
||||
existingState: testAzureState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some qemu instances": {
|
||||
existingState: testQemuState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize vpn": {
|
||||
existingState: testAzureState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
},
|
||||
"invalid initialize vpn": {
|
||||
existingState: testAzureState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{applyErr: someErr},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid create vpn config": {
|
||||
existingState: testAzureState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{createErr: someErr},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid write vpn config": {
|
||||
existingState: testAzureState,
|
||||
client: &fakeProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
vpnHandler: &stubVPNHandler{marshalErr: someErr},
|
||||
initVPN: true,
|
||||
privKey: testKey,
|
||||
wantErr: true,
|
||||
},
|
||||
"no state exists": {
|
||||
existingState: state.ConstellationState{},
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"no instances to pick one": {
|
||||
existingState: state.ConstellationState{GCPNodes: cloudtypes.Instances{}},
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"public key to short": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"public key to long": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"public key not base64": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: "this is not base64 encoded",
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail Connect": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{connectErr: someErr},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail Activate": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{activateErr: someErr},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail respClient WriteLogStream": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail respClient getKubeconfig": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail respClient getCoordinatorVpnKey": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail respClient getClientVpnIp": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail respClient getOwnerID": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail respClient getClusterID": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail to wait for required status": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
waiter: &stubStatusWaiter{waitForAllErr: someErr},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
"fail to create service account": {
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{},
|
||||
serviceAccountCreator: stubServiceAccountCreator{createErr: someErr},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
vpnHandler: &stubVPNHandler{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||
|
||||
// Write key file to filesystem and set path in flag.
|
||||
require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600))
|
||||
require.NoError(cmd.Flags().Set("privatekey", "privK"))
|
||||
if tc.initVPN {
|
||||
require.NoError(cmd.Flags().Set("wg-autoconfig", "true"))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.initVPN, tc.vpnHandler.(*stubVPNHandler).configured)
|
||||
assert.Contains(out.String(), "192.0.2.2")
|
||||
assert.Contains(out.String(), "ownerID")
|
||||
assert.Contains(out.String(), "clusterID")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteOutput(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
result := activationResult{
|
||||
clientVpnIP: "foo-qq",
|
||||
coordinatorPubKey: "bar-qq",
|
||||
coordinatorPubIP: "baz-qq",
|
||||
kubeconfig: "foo-bar-baz-qq",
|
||||
}
|
||||
var out bytes.Buffer
|
||||
testFs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(testFs)
|
||||
|
||||
err := result.writeOutput(&out, fileHandler)
|
||||
assert.NoError(err)
|
||||
assert.Contains(out.String(), result.clientVpnIP)
|
||||
assert.Contains(out.String(), result.coordinatorPubIP)
|
||||
assert.Contains(out.String(), result.coordinatorPubKey)
|
||||
|
||||
afs := afero.Afero{Fs: testFs}
|
||||
adminConf, err := afs.ReadFile(constants.AdminConfFilename)
|
||||
assert.NoError(err)
|
||||
assert.Equal(result.kubeconfig, string(adminConf))
|
||||
}
|
||||
|
||||
func TestIpsToEndpoints(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ips := []string{"192.0.2.1", "192.0.2.2", "", "192.0.2.3"}
|
||||
port := "8080"
|
||||
endpoints := ipsToEndpoints(ips, port)
|
||||
assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints)
|
||||
}
|
||||
|
||||
func TestInitCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
wantResult []string
|
||||
wantShellCD cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "hello",
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveDefault,
|
||||
},
|
||||
"secnod arg": {
|
||||
args: []string{"23"},
|
||||
toComplete: "/test/h",
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
"third arg": {
|
||||
args: []string{"./file", "sth"},
|
||||
toComplete: "./file",
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := initCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.wantResult, result)
|
||||
assert.Equal(tc.wantShellCD, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadOrGenerateVPNKey(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
testKey := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
|
||||
fileHandler := file.NewHandler(afero.NewMemMapFs())
|
||||
require.NoError(fileHandler.Write("testKey", testKey, file.OptNone))
|
||||
|
||||
privK, pubK, err := readOrGenerateVPNKey(fileHandler, "testKey")
|
||||
assert.NoError(err)
|
||||
assert.Equal(testKey, privK)
|
||||
assert.NotEmpty(pubK)
|
||||
|
||||
// no path provided
|
||||
privK, pubK, err = readOrGenerateVPNKey(fileHandler, "")
|
||||
assert.NoError(err)
|
||||
assert.NotEmpty(privK)
|
||||
assert.NotEmpty(pubK)
|
||||
}
|
||||
|
||||
func TestReadOrGeneratedMasterSecret(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
filename string
|
||||
filecontent string
|
||||
createFile bool
|
||||
fs func() afero.Fs
|
||||
wantErr bool
|
||||
}{
|
||||
"file with secret exists": {
|
||||
filename: "someSecret",
|
||||
filecontent: base64.StdEncoding.EncodeToString([]byte("ConstellationSecret")),
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
wantErr: false,
|
||||
},
|
||||
"no file given": {
|
||||
filename: "",
|
||||
filecontent: "",
|
||||
fs: afero.NewMemMapFs,
|
||||
wantErr: false,
|
||||
},
|
||||
"file does not exist": {
|
||||
filename: "nonExistingSecret",
|
||||
filecontent: "",
|
||||
createFile: false,
|
||||
fs: afero.NewMemMapFs,
|
||||
wantErr: true,
|
||||
},
|
||||
"file is empty": {
|
||||
filename: "emptySecret",
|
||||
filecontent: "",
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
wantErr: true,
|
||||
},
|
||||
"secret too short": {
|
||||
filename: "shortSecret",
|
||||
filecontent: base64.StdEncoding.EncodeToString([]byte("short")),
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
wantErr: true,
|
||||
},
|
||||
"secret not encoded": {
|
||||
filename: "unencodedSecret",
|
||||
filecontent: "Constellation",
|
||||
createFile: true,
|
||||
fs: afero.NewMemMapFs,
|
||||
wantErr: true,
|
||||
},
|
||||
"file not writeable": {
|
||||
filename: "",
|
||||
filecontent: "",
|
||||
createFile: false,
|
||||
fs: func() afero.Fs { return afero.NewReadOnlyFs(afero.NewMemMapFs()) },
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fileHandler := file.NewHandler(tc.fs())
|
||||
|
||||
if tc.createFile {
|
||||
require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), file.OptNone))
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
secret, err := readOrGeneratedMasterSecret(&out, fileHandler, tc.filename)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
|
||||
if tc.filename == "" {
|
||||
require.Contains(out.String(), constants.MasterSecretFilename)
|
||||
filename := strings.Split(out.String(), "./")
|
||||
tc.filename = strings.Trim(filename[1], "\n")
|
||||
}
|
||||
|
||||
content, err := fileHandler.Read(tc.filename)
|
||||
require.NoError(err)
|
||||
assert.Equal(content, []byte(base64.StdEncoding.EncodeToString(secret)))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAutoscaleFlag(t *testing.T) {
|
||||
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
|
||||
testGcpState := state.ConstellationState{
|
||||
CloudProvider: "gcp",
|
||||
GCPNodes: cloudtypes.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
GCPCoordinators: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
}
|
||||
testAzureState := state.ConstellationState{
|
||||
CloudProvider: "azure",
|
||||
AzureNodes: cloudtypes.Instances{
|
||||
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureCoordinators: cloudtypes.Instances{
|
||||
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
|
||||
},
|
||||
AzureResourceGroup: "test",
|
||||
}
|
||||
testActivationResps := []fakeActivationRespMessage{
|
||||
{log: "testlog1"},
|
||||
{log: "testlog2"},
|
||||
{
|
||||
kubeconfig: "kubeconfig",
|
||||
clientVpnIp: "192.0.2.2",
|
||||
coordinatorVpnKey: testKey,
|
||||
ownerID: "ownerID",
|
||||
clusterID: "clusterID",
|
||||
},
|
||||
{log: "testlog3"},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
autoscaleFlag bool
|
||||
existingState state.ConstellationState
|
||||
client *stubProtoClient
|
||||
serviceAccountCreator stubServiceAccountCreator
|
||||
waiter statusWaiter
|
||||
privKey string
|
||||
}{
|
||||
"initialize some gcp instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some azure instances without autoscale flag": {
|
||||
autoscaleFlag: false,
|
||||
existingState: testAzureState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some gcp instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testGcpState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
"initialize some azure instances with autoscale flag": {
|
||||
autoscaleFlag: true,
|
||||
existingState: testAzureState,
|
||||
client: &stubProtoClient{
|
||||
respClient: &fakeActivationRespClient{responses: testActivationResps},
|
||||
},
|
||||
waiter: &stubStatusWaiter{},
|
||||
privKey: testKey,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newInitCmd()
|
||||
var out bytes.Buffer
|
||||
cmd.SetOut(&out)
|
||||
var errOut bytes.Buffer
|
||||
cmd.SetErr(&errOut)
|
||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
vpnHandler := stubVPNHandler{}
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||
|
||||
// Write key file to filesystem and set path in flag.
|
||||
require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600))
|
||||
require.NoError(cmd.Flags().Set("privatekey", "privK"))
|
||||
|
||||
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
|
||||
ctx := context.Background()
|
||||
|
||||
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
|
||||
if tc.autoscaleFlag {
|
||||
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
|
||||
} else {
|
||||
assert.Len(tc.client.activateAutoscalingNodeGroups, 0)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteWGQuickFile(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
fileHandler file.Handler
|
||||
vpnHandler *stubVPNHandler
|
||||
vpnConfig *wgquick.Config
|
||||
wantErr bool
|
||||
}{
|
||||
"write wg quick file": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
vpnHandler: &stubVPNHandler{marshalRes: "config"},
|
||||
},
|
||||
"marshal failed": {
|
||||
fileHandler: file.NewHandler(afero.NewMemMapFs()),
|
||||
vpnHandler: &stubVPNHandler{marshalErr: errors.New("some err")},
|
||||
wantErr: true,
|
||||
},
|
||||
"write fails": {
|
||||
fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
|
||||
vpnHandler: &stubVPNHandler{marshalRes: "config"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
err := writeWGQuickFile(tc.fileHandler, tc.vpnHandler, tc.vpnConfig)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
file, err := tc.fileHandler.Read(constants.WGQuickConfigFilename)
|
||||
assert.NoError(err)
|
||||
assert.Contains(string(file), tc.vpnHandler.marshalRes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import "strings"
|
||||
|
||||
func formatInstanceTypes(types []string) string {
|
||||
return " " + strings.Join(types, "\n ")
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/proto"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type protoClient interface {
|
||||
Connect(endpoint string, validators []atls.Validator) error
|
||||
Close() error
|
||||
GetState(ctx context.Context) (state.State, error)
|
||||
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string, sshUsers []*pubproto.SSHUserKey) (proto.ActivationResponseClient, error)
|
||||
}
|
|
@ -1,225 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/proto"
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type stubProtoClient struct {
|
||||
conn bool
|
||||
respClient proto.ActivationResponseClient
|
||||
connectErr error
|
||||
closeErr error
|
||||
getStateErr error
|
||||
activateErr error
|
||||
|
||||
getStateState state.State
|
||||
activateUserPublicKey []byte
|
||||
activateMasterSecret []byte
|
||||
activateNodeIPs []string
|
||||
activateCoordinatorIPs []string
|
||||
activateAutoscalingNodeGroups []string
|
||||
cloudServiceAccountURI string
|
||||
sshUserKeys []*pubproto.SSHUserKey
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Close() error {
|
||||
c.conn = false
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) GetState(_ context.Context) (state.State, error) {
|
||||
return c.getStateState, c.getStateErr
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs []string, autoscalingNodeGroups []string, cloudServiceAccountURI string, sshUserKeys []*pubproto.SSHUserKey) (proto.ActivationResponseClient, error) {
|
||||
c.activateUserPublicKey = userPublicKey
|
||||
c.activateMasterSecret = masterSecret
|
||||
c.activateNodeIPs = nodeIPs
|
||||
c.activateCoordinatorIPs = coordinatorIPs
|
||||
c.activateAutoscalingNodeGroups = autoscalingNodeGroups
|
||||
c.cloudServiceAccountURI = cloudServiceAccountURI
|
||||
c.sshUserKeys = sshUserKeys
|
||||
|
||||
return c.respClient, c.activateErr
|
||||
}
|
||||
|
||||
func (c *stubProtoClient) ActivateAdditionalCoordinators(ctx context.Context, ips []string) error {
|
||||
return c.activateErr
|
||||
}
|
||||
|
||||
type stubActivationRespClient struct {
|
||||
nextLogErr *error
|
||||
getKubeconfigErr error
|
||||
getCoordinatorVpnKeyErr error
|
||||
getClientVpnIpErr error
|
||||
getOwnerIDErr error
|
||||
getClusterIDErr error
|
||||
writeLogStreamErr error
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) NextLog() (string, error) {
|
||||
if s.nextLogErr == nil {
|
||||
return "", io.EOF
|
||||
}
|
||||
return "", *s.nextLogErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) WriteLogStream(io.Writer) error {
|
||||
return s.writeLogStreamErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetKubeconfig() (string, error) {
|
||||
return "", s.getKubeconfigErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetCoordinatorVpnKey() (string, error) {
|
||||
return "", s.getCoordinatorVpnKeyErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetClientVpnIp() (string, error) {
|
||||
return "", s.getClientVpnIpErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetOwnerID() (string, error) {
|
||||
return "", s.getOwnerIDErr
|
||||
}
|
||||
|
||||
func (s *stubActivationRespClient) GetClusterID() (string, error) {
|
||||
return "", s.getClusterIDErr
|
||||
}
|
||||
|
||||
type fakeProtoClient struct {
|
||||
conn bool
|
||||
respClient proto.ActivationResponseClient
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error {
|
||||
if endpoint == "" {
|
||||
return errors.New("endpoint is empty")
|
||||
}
|
||||
if len(validators) == 0 {
|
||||
return errors.New("validators is empty")
|
||||
}
|
||||
c.conn = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Close() error {
|
||||
c.conn = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) GetState(_ context.Context) (state.State, error) {
|
||||
if !c.conn {
|
||||
return state.Uninitialized, errors.New("client is not connected")
|
||||
}
|
||||
return state.IsNode, nil
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string, sshUserKeys []*pubproto.SSHUserKey) (proto.ActivationResponseClient, error) {
|
||||
if !c.conn {
|
||||
return nil, errors.New("client is not connected")
|
||||
}
|
||||
return c.respClient, nil
|
||||
}
|
||||
|
||||
func (c *fakeProtoClient) ActivateAdditionalCoordinators(ctx context.Context, ips []string) error {
|
||||
if !c.conn {
|
||||
return errors.New("client is not connected")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeActivationRespClient struct {
|
||||
responses []fakeActivationRespMessage
|
||||
kubeconfig string
|
||||
coordinatorVpnKey string
|
||||
clientVpnIp string
|
||||
ownerID string
|
||||
clusterID string
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) NextLog() (string, error) {
|
||||
for len(c.responses) > 0 {
|
||||
resp := c.responses[0]
|
||||
c.responses = c.responses[1:]
|
||||
if len(resp.log) > 0 {
|
||||
return resp.log, nil
|
||||
}
|
||||
c.kubeconfig = resp.kubeconfig
|
||||
c.coordinatorVpnKey = resp.coordinatorVpnKey
|
||||
c.clientVpnIp = resp.clientVpnIp
|
||||
c.ownerID = resp.ownerID
|
||||
c.clusterID = resp.clusterID
|
||||
}
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) WriteLogStream(w io.Writer) error {
|
||||
log, err := c.NextLog()
|
||||
for err == nil {
|
||||
fmt.Fprint(w, log)
|
||||
log, err = c.NextLog()
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetKubeconfig() (string, error) {
|
||||
if c.kubeconfig == "" {
|
||||
return "", errors.New("kubeconfig is empty")
|
||||
}
|
||||
return c.kubeconfig, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetCoordinatorVpnKey() (string, error) {
|
||||
if c.coordinatorVpnKey == "" {
|
||||
return "", errors.New("control-plane public VPN key is empty")
|
||||
}
|
||||
return c.coordinatorVpnKey, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetClientVpnIp() (string, error) {
|
||||
if c.clientVpnIp == "" {
|
||||
return "", errors.New("client VPN IP is empty")
|
||||
}
|
||||
return c.clientVpnIp, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetOwnerID() (string, error) {
|
||||
if c.ownerID == "" {
|
||||
return "", errors.New("init secret is empty")
|
||||
}
|
||||
return c.ownerID, nil
|
||||
}
|
||||
|
||||
func (c *fakeActivationRespClient) GetClusterID() (string, error) {
|
||||
if c.clusterID == "" {
|
||||
return "", errors.New("cluster identifier is empty")
|
||||
}
|
||||
return c.clusterID, nil
|
||||
}
|
||||
|
||||
type fakeActivationRespMessage struct {
|
||||
log string
|
||||
kubeconfig string
|
||||
coordinatorVpnKey string
|
||||
clientVpnIp string
|
||||
ownerID string
|
||||
clusterID string
|
||||
}
|
|
@ -1,43 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
)
|
||||
|
||||
func readConfig(out io.Writer, fileHandler file.Handler, name string, provider cloudprovider.Provider) (*config.Config, error) {
|
||||
cnf, err := config.FromFile(fileHandler, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateConfig(out, cnf, provider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cnf, nil
|
||||
}
|
||||
|
||||
func validateConfig(out io.Writer, cnf *config.Config, provider cloudprovider.Provider) error {
|
||||
msgs, err := cnf.Validate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(msgs) > 0 {
|
||||
fmt.Fprintln(out, "Invalid fields in config file:")
|
||||
for _, m := range msgs {
|
||||
fmt.Fprintln(out, "\t"+m)
|
||||
}
|
||||
return errors.New("invalid configuration")
|
||||
}
|
||||
|
||||
if provider != cloudprovider.Unknown && !cnf.HasProvider(provider) {
|
||||
return fmt.Errorf("configuration doesn't contain provider: %v", provider)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,78 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/config"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateConfig(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
cnf *config.Config
|
||||
provider cloudprovider.Provider
|
||||
wantOutput bool
|
||||
wantErr bool
|
||||
}{
|
||||
"default config is valid": {
|
||||
cnf: config.Default(),
|
||||
},
|
||||
"config with an error": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
cnf.Version = "v0"
|
||||
return cnf
|
||||
}(),
|
||||
wantOutput: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"config without provider is ok if no provider required": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
cnf.Provider = config.ProviderConfig{}
|
||||
return cnf
|
||||
}(),
|
||||
},
|
||||
"config with only required provider": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
az := cnf.Provider.Azure
|
||||
cnf.Provider = config.ProviderConfig{}
|
||||
cnf.Provider.Azure = az
|
||||
return cnf
|
||||
}(),
|
||||
provider: cloudprovider.Azure,
|
||||
},
|
||||
"config without required provider": {
|
||||
cnf: func() *config.Config {
|
||||
cnf := config.Default()
|
||||
cnf.Provider.Azure = nil
|
||||
return cnf
|
||||
}(),
|
||||
provider: cloudprovider.Azure,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
out := &bytes.Buffer{}
|
||||
|
||||
err := validateConfig(out, tc.cnf, tc.provider)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
assert.Equal(tc.wantOutput, out.Len() > 0)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,153 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/internal/proto"
|
||||
"github.com/edgelesssys/constellation/coordinator/util"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var diskUUIDRegexp = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
|
||||
func newRecoverCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "recover",
|
||||
Short: "Recover a completely stopped Constellation cluster",
|
||||
Long: "Recover a Constellation cluster by sending a recovery key to an instance in the boot stage." +
|
||||
"\nThis is only required if instances restart without other instances available for bootstrapping.",
|
||||
Args: cobra.ExactArgs(0),
|
||||
RunE: runRecover,
|
||||
}
|
||||
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
|
||||
must(cmd.MarkFlagRequired("endpoint"))
|
||||
cmd.Flags().String("disk-uuid", "", "disk UUID of the encrypted state disk (required)")
|
||||
must(cmd.MarkFlagRequired("disk-uuid"))
|
||||
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to base64-encoded master secret")
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runRecover(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
recoveryClient := &proto.KeyClient{}
|
||||
defer recoveryClient.Close()
|
||||
return recover(cmd.Context(), cmd, fileHandler, recoveryClient)
|
||||
}
|
||||
|
||||
func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error {
|
||||
flags, err := parseRecoverFlags(cmd, fileHandler)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var stat state.ConstellationState
|
||||
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
provider := cloudprovider.FromString(stat.CloudProvider)
|
||||
|
||||
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validators, err := cloudcmd.NewValidators(provider, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Print(validators.WarningsIncludeInit())
|
||||
|
||||
if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
diskKey, err := deriveStateDiskKey(flags.masterSecret, flags.diskUUID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := recoveryClient.PushStateDiskKey(ctx, diskKey); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Pushed recovery key.")
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
|
||||
endpoint, err := cmd.Flags().GetString("endpoint")
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
|
||||
diskUUID, err := cmd.Flags().GetString("disk-uuid")
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
if match := diskUUIDRegexp.MatchString(diskUUID); !match {
|
||||
return recoverFlags{}, errors.New("flag '--disk-uuid' isn't a valid LUKS UUID")
|
||||
}
|
||||
diskUUID = strings.ToLower(diskUUID)
|
||||
|
||||
masterSecretPath, err := cmd.Flags().GetString("master-secret")
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
masterSecret, err := readMasterSecret(fileHandler, masterSecretPath)
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
|
||||
configPath, err := cmd.Flags().GetString("config")
|
||||
if err != nil {
|
||||
return recoverFlags{}, err
|
||||
}
|
||||
|
||||
return recoverFlags{
|
||||
endpoint: endpoint,
|
||||
diskUUID: diskUUID,
|
||||
masterSecret: masterSecret,
|
||||
configPath: configPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type recoverFlags struct {
|
||||
endpoint string
|
||||
diskUUID string
|
||||
masterSecret []byte
|
||||
configPath string
|
||||
}
|
||||
|
||||
// readMasterSecret reads a base64 encoded master secret from file.
|
||||
func readMasterSecret(fileHandler file.Handler, filename string) ([]byte, error) {
|
||||
// Try to read the base64 secret from file
|
||||
encodedSecret, err := fileHandler.Read(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decoded, nil
|
||||
}
|
||||
|
||||
// deriveStateDiskKey derives a state disk key from a master secret and a disk UUID.
|
||||
func deriveStateDiskKey(masterKey []byte, diskUUID string) ([]byte, error) {
|
||||
return util.DeriveKey(masterKey, []byte("Constellation"), []byte("key"+diskUUID), constants.StateDiskKeyLength)
|
||||
}
|
|
@ -1,356 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecoverCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantErr bool
|
||||
}{
|
||||
"no args": {[]string{}, false},
|
||||
"too many arguments": {[]string{"abc"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newRecoverCmd()
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecover(t *testing.T) {
|
||||
validState := state.ConstellationState{CloudProvider: "GCP"}
|
||||
invalidCSPState := state.ConstellationState{CloudProvider: "invalid"}
|
||||
|
||||
testCases := map[string]struct {
|
||||
setupFs func(*require.Assertions) afero.Fs
|
||||
existingState state.ConstellationState
|
||||
client *stubRecoveryClient
|
||||
endpointFlag string
|
||||
diskUUIDFlag string
|
||||
masterSecretFlag string
|
||||
configFlag string
|
||||
stateless bool
|
||||
wantErr bool
|
||||
wantKey []byte
|
||||
}{
|
||||
"works": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
|
||||
},
|
||||
"uppercase disk uuid works": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
|
||||
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
||||
},
|
||||
"lowercase disk uuid results in same key": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
|
||||
},
|
||||
"missing flags": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
wantErr: true,
|
||||
},
|
||||
"missing config": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
configFlag: "nonexistent-config",
|
||||
wantErr: true,
|
||||
},
|
||||
"missing state": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
stateless: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid cloud provider": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: invalidCSPState,
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantErr: true,
|
||||
},
|
||||
"connect fails": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantErr: true,
|
||||
},
|
||||
"pushing state key fails": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
return fs
|
||||
},
|
||||
existingState: validState,
|
||||
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
|
||||
endpointFlag: "192.0.2.1",
|
||||
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newRecoverCmd()
|
||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
||||
out := &bytes.Buffer{}
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
if tc.endpointFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
|
||||
}
|
||||
if tc.diskUUIDFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
|
||||
}
|
||||
if tc.masterSecretFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag))
|
||||
}
|
||||
if tc.configFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("config", tc.configFlag))
|
||||
}
|
||||
fileHandler := file.NewHandler(tc.setupFs(require))
|
||||
if !tc.stateless {
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
err := recover(ctx, cmd, fileHandler, tc.client)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Contains(out.String(), "Pushed recovery key.")
|
||||
assert.Equal(tc.wantKey, tc.client.pushStateDiskKeyKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRecoverFlags(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantFlags recoverFlags
|
||||
wantErr bool
|
||||
}{
|
||||
"no flags": {
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid ip": {
|
||||
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid disk uuid": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid master secret path": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
"minimal args set": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
|
||||
wantFlags: recoverFlags{
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||
masterSecret: []byte("constellation-master-secret-leng"),
|
||||
},
|
||||
},
|
||||
"all args set": {
|
||||
args: []string{
|
||||
"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
|
||||
"--master-secret", "constellation-mastersecret.base64", "--config", "config-path",
|
||||
},
|
||||
wantFlags: recoverFlags{
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "12345678-1234-1234-1234-123456789012",
|
||||
masterSecret: []byte("constellation-master-secret-leng"),
|
||||
configPath: "config-path",
|
||||
},
|
||||
},
|
||||
"uppercase disk-uuid is converted to lowercase": {
|
||||
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
|
||||
wantFlags: recoverFlags{
|
||||
endpoint: "192.0.2.1:2",
|
||||
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
|
||||
masterSecret: []byte("constellation-master-secret-leng"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
|
||||
cmd := newRecoverCmd()
|
||||
cmd.Flags().String("config", "", "") // register persistent flag manually
|
||||
require.NoError(cmd.ParseFlags(tc.args))
|
||||
flags, err := parseRecoverFlags(cmd, file.NewHandler(fs))
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantFlags, flags)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadMasterSecret(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
fileContents []byte
|
||||
filename string
|
||||
wantMasterSecret []byte
|
||||
wantErr bool
|
||||
}{
|
||||
"invalid base64": {
|
||||
fileContents: []byte("invalid"),
|
||||
filename: "constellation-mastersecret.base64",
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid filename": {
|
||||
fileContents: []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="),
|
||||
filename: "invalid",
|
||||
wantErr: true,
|
||||
},
|
||||
"correct master secret": {
|
||||
fileContents: []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="),
|
||||
filename: "constellation-mastersecret.base64",
|
||||
wantMasterSecret: []byte("constellation-master-secret-leng"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", tc.fileContents, 0o777))
|
||||
masterSecret, err := readMasterSecret(file.NewHandler(fs), tc.filename)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantMasterSecret, masterSecret)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveStateDiskKey(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
masterKey []byte
|
||||
diskUUID string
|
||||
wantStateDiskKey []byte
|
||||
}{
|
||||
"all zero": {
|
||||
masterKey: []byte{
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
|
||||
},
|
||||
diskUUID: "00000000-0000-0000-0000-000000000000",
|
||||
wantStateDiskKey: []byte{
|
||||
0xa8, 0xb0, 0x86, 0x83, 0x6f, 0x0b, 0x26, 0x04, 0x86, 0x22, 0x27, 0xcc, 0xa1, 0x1c, 0xaf, 0x6c,
|
||||
0x30, 0x4d, 0x90, 0x89, 0x82, 0x68, 0x53, 0x7f, 0x4f, 0x46, 0x7a, 0x65, 0xa2, 0x5d, 0x5e, 0x43,
|
||||
},
|
||||
},
|
||||
"all 0xff": {
|
||||
masterKey: []byte{
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
|
||||
},
|
||||
diskUUID: "ffffffff-ffff-ffff-ffff-ffffffffffff",
|
||||
wantStateDiskKey: []byte{
|
||||
0x24, 0x18, 0x84, 0x7f, 0xca, 0x86, 0x55, 0xb5, 0x45, 0xa6, 0xb3, 0xc4, 0x45, 0xbb, 0x08, 0x10,
|
||||
0x16, 0xb3, 0xde, 0x30, 0x30, 0x74, 0x0b, 0xd4, 0x1e, 0x22, 0x55, 0x45, 0x51, 0x91, 0xfb, 0xa9,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
stateDiskKey, err := deriveStateDiskKey(tc.masterKey, tc.diskUUID)
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantStateDiskKey, stateDiskKey)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
)
|
||||
|
||||
type recoveryClient interface {
|
||||
Connect(endpoint string, validators []atls.Validator) error
|
||||
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
|
||||
io.Closer
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
)
|
||||
|
||||
type stubRecoveryClient struct {
|
||||
conn bool
|
||||
connectErr error
|
||||
closeErr error
|
||||
pushStateDiskKeyErr error
|
||||
|
||||
pushStateDiskKeyKey []byte
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
|
||||
c.conn = true
|
||||
return c.connectErr
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) Close() error {
|
||||
c.conn = false
|
||||
return c.closeErr
|
||||
}
|
||||
|
||||
func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey []byte) error {
|
||||
c.pushStateDiskKeyKey = stateDiskKey
|
||||
return c.pushStateDiskKeyErr
|
||||
}
|
|
@ -6,6 +6,7 @@ import (
|
|||
"os"
|
||||
"os/signal"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/cmd"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
@ -34,13 +35,13 @@ func NewRootCmd() *cobra.Command {
|
|||
rootCmd.PersistentFlags().String("config", constants.ConfigFilename, "path to the configuration file")
|
||||
must(rootCmd.MarkPersistentFlagFilename("config", "json"))
|
||||
|
||||
rootCmd.AddCommand(newConfigCmd())
|
||||
rootCmd.AddCommand(newCreateCmd())
|
||||
rootCmd.AddCommand(newInitCmd())
|
||||
rootCmd.AddCommand(newVerifyCmd())
|
||||
rootCmd.AddCommand(newRecoverCmd())
|
||||
rootCmd.AddCommand(newTerminateCmd())
|
||||
rootCmd.AddCommand(newVersionCmd())
|
||||
rootCmd.AddCommand(cmd.NewConfigCmd())
|
||||
rootCmd.AddCommand(cmd.NewCreateCmd())
|
||||
rootCmd.AddCommand(cmd.NewInitCmd())
|
||||
rootCmd.AddCommand(cmd.NewVerifyCmd())
|
||||
rootCmd.AddCommand(cmd.NewRecoverCmd())
|
||||
rootCmd.AddCommand(cmd.NewTerminateCmd())
|
||||
rootCmd.AddCommand(cmd.NewVersionCmd())
|
||||
|
||||
return rootCmd
|
||||
}
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type statusWaiter interface {
|
||||
InitializeValidators([]atls.Validator) error
|
||||
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
|
||||
}
|
|
@ -1,27 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/edgelesssys/constellation/coordinator/atls"
|
||||
"github.com/edgelesssys/constellation/coordinator/state"
|
||||
)
|
||||
|
||||
type stubStatusWaiter struct {
|
||||
initialized bool
|
||||
initializeErr error
|
||||
waitForAllErr error
|
||||
}
|
||||
|
||||
func (s *stubStatusWaiter) InitializeValidators([]atls.Validator) error {
|
||||
s.initialized = true
|
||||
return s.initializeErr
|
||||
}
|
||||
|
||||
func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
|
||||
if !s.initialized {
|
||||
return errors.New("waiter not initialized")
|
||||
}
|
||||
return s.waitForAllErr
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
)
|
||||
|
||||
func newTerminateCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "terminate",
|
||||
Short: "Terminate a Constellation cluster",
|
||||
Long: "Terminate a Constellation cluster. The cluster can't be started again, and all persistent storage will be lost.",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: runTerminate,
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// runTerminate runs the terminate command.
|
||||
func runTerminate(cmd *cobra.Command, args []string) error {
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
terminator := cloudcmd.NewTerminator()
|
||||
|
||||
return terminate(cmd, terminator, fileHandler)
|
||||
}
|
||||
|
||||
func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler) error {
|
||||
var stat state.ConstellationState
|
||||
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Terminating ...")
|
||||
|
||||
if err := terminator.Terminate(cmd.Context(), stat); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("Your Constellation cluster was terminated successfully.")
|
||||
|
||||
var retErr error
|
||||
if err := fileHandler.Remove(constants.StateFilename); err != nil {
|
||||
retErr = multierr.Append(err, fmt.Errorf("failed to remove file '%s', please remove manually", constants.StateFilename))
|
||||
}
|
||||
|
||||
if err := fileHandler.Remove(constants.AdminConfFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
retErr = multierr.Append(err, fmt.Errorf("failed to remove file '%s', please remove manually", constants.AdminConfFilename))
|
||||
}
|
||||
|
||||
if err := fileHandler.Remove(constants.WGQuickConfigFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
retErr = multierr.Append(err, fmt.Errorf("failed to remove file '%s', please remove manually", constants.WGQuickConfigFilename))
|
||||
}
|
||||
|
||||
return retErr
|
||||
}
|
|
@ -1,136 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/edgelesssys/constellation/internal/state"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTerminateCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantErr bool
|
||||
}{
|
||||
"no args": {[]string{}, false},
|
||||
"some args": {[]string{"hello", "test"}, true},
|
||||
"some other args": {[]string{"12", "2"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newTerminateCmd()
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTerminate(t *testing.T) {
|
||||
setupFs := func(require *require.Assertions, state state.ConstellationState) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.Write(constants.AdminConfFilename, []byte{1, 2}, file.OptNone))
|
||||
require.NoError(fileHandler.Write(constants.WGQuickConfigFilename, []byte{1, 2}, file.OptNone))
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, state, file.OptNone))
|
||||
return fs
|
||||
}
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
state state.ConstellationState
|
||||
setupFs func(*require.Assertions, state.ConstellationState) afero.Fs
|
||||
terminator spyCloudTerminator
|
||||
wantErr bool
|
||||
}{
|
||||
"success": {
|
||||
state: state.ConstellationState{CloudProvider: "gcp"},
|
||||
setupFs: setupFs,
|
||||
terminator: &stubCloudTerminator{},
|
||||
},
|
||||
"files to remove do not exist": {
|
||||
state: state.ConstellationState{CloudProvider: "gcp"},
|
||||
setupFs: func(require *require.Assertions, state state.ConstellationState) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.WriteJSON(constants.StateFilename, state, file.OptNone))
|
||||
return fs
|
||||
},
|
||||
terminator: &stubCloudTerminator{},
|
||||
},
|
||||
"terminate error": {
|
||||
state: state.ConstellationState{CloudProvider: "gcp"},
|
||||
setupFs: setupFs,
|
||||
terminator: &stubCloudTerminator{terminateErr: someErr},
|
||||
wantErr: true,
|
||||
},
|
||||
"missing state file": {
|
||||
state: state.ConstellationState{CloudProvider: "gcp"},
|
||||
setupFs: func(require *require.Assertions, state state.ConstellationState) afero.Fs {
|
||||
fs := afero.NewMemMapFs()
|
||||
fileHandler := file.NewHandler(fs)
|
||||
require.NoError(fileHandler.Write(constants.AdminConfFilename, []byte{1, 2}, file.OptNone))
|
||||
require.NoError(fileHandler.Write(constants.WGQuickConfigFilename, []byte{1, 2}, file.OptNone))
|
||||
return fs
|
||||
},
|
||||
terminator: &stubCloudTerminator{},
|
||||
wantErr: true,
|
||||
},
|
||||
"remove file fails": {
|
||||
state: state.ConstellationState{CloudProvider: "gcp"},
|
||||
setupFs: func(require *require.Assertions, state state.ConstellationState) afero.Fs {
|
||||
fs := setupFs(require, state)
|
||||
return afero.NewReadOnlyFs(fs)
|
||||
},
|
||||
terminator: &stubCloudTerminator{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newTerminateCmd()
|
||||
cmd.SetOut(&bytes.Buffer{})
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
|
||||
require.NotNil(tc.setupFs)
|
||||
fileHandler := file.NewHandler(tc.setupFs(require, tc.state))
|
||||
|
||||
err := terminate(cmd, tc.terminator, fileHandler)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.True(tc.terminator.Called())
|
||||
_, err := fileHandler.Stat(constants.StateFilename)
|
||||
assert.Error(err)
|
||||
_, err = fileHandler.Stat(constants.AdminConfFilename)
|
||||
assert.Error(err)
|
||||
_, err = fileHandler.Stat(constants.WGQuickConfigFilename)
|
||||
assert.Error(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type spyCloudTerminator interface {
|
||||
cloudTerminator
|
||||
Called() bool
|
||||
}
|
|
@ -1,35 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// ErrInvalidInput is an error where user entered invalid input.
|
||||
var ErrInvalidInput = errors.New("user made invalid input")
|
||||
|
||||
// askToConfirm asks user to confirm an action.
|
||||
// The user will be asked the handed question and can answer with
|
||||
// yes or no.
|
||||
func askToConfirm(cmd *cobra.Command, question string) (bool, error) {
|
||||
reader := bufio.NewReader(cmd.InOrStdin())
|
||||
cmd.Printf("%s [y/n]: ", question)
|
||||
for i := 0; i < 3; i++ {
|
||||
resp, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
resp = strings.ToLower(strings.TrimSpace(resp))
|
||||
if resp == "n" || resp == "no" {
|
||||
return false, nil
|
||||
}
|
||||
if resp == "y" || resp == "yes" {
|
||||
return true, nil
|
||||
}
|
||||
cmd.Printf("Type 'y' or 'yes' to confirm, or abort action with 'n' or 'no': ")
|
||||
}
|
||||
return false, ErrInvalidInput
|
||||
}
|
|
@ -1,63 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAskToConfirm(t *testing.T) {
|
||||
// errAborted is an error where the user aborted the action.
|
||||
errAborted := errors.New("user aborted")
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Args: cobra.NoArgs,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
ok, err := askToConfirm(cmd, "777")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return errAborted
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
testCases := map[string]struct {
|
||||
input string
|
||||
wantErr error
|
||||
}{
|
||||
"user confirms": {"y\n", nil},
|
||||
"user confirms long": {"yes\n", nil},
|
||||
"user disagrees": {"n\n", errAborted},
|
||||
"user disagrees long": {"no\n", errAborted},
|
||||
"user is first unsure, but agrees": {"what?\ny\n", nil},
|
||||
"user is first unsure, but disagrees": {"wait.\nn\n", errAborted},
|
||||
"repeated invalid input": {"h\nb\nq\n", ErrInvalidInput},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
out := &bytes.Buffer{}
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
in := bytes.NewBufferString(tc.input)
|
||||
cmd.SetIn(in)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.ErrorIs(err, tc.wantErr)
|
||||
|
||||
output, err := io.ReadAll(out)
|
||||
assert.NoError(err)
|
||||
assert.Contains(string(output), "777")
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,71 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/internal/azure"
|
||||
"github.com/edgelesssys/constellation/cli/internal/gcp"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// warnAWS warns that AWS isn't supported.
|
||||
func warnAWS(providerPos int) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
if cloudprovider.FromString(args[providerPos]) == cloudprovider.AWS {
|
||||
return errors.New("AWS isn't supported by this version of Constellation")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func isCloudProvider(arg int) cobra.PositionalArgs {
|
||||
return func(cmd *cobra.Command, args []string) error {
|
||||
if provider := cloudprovider.FromString(args[arg]); provider == cloudprovider.Unknown {
|
||||
return fmt.Errorf("argument %s isn't a valid cloud provider", args[arg])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func validInstanceTypeForProvider(cmd *cobra.Command, insType string, provider cloudprovider.Provider) error {
|
||||
switch provider {
|
||||
case cloudprovider.GCP:
|
||||
for _, instanceType := range gcp.InstanceTypes {
|
||||
if insType == instanceType {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
cmd.SetUsageTemplate("GCP instance types:\n" + formatInstanceTypes(gcp.InstanceTypes))
|
||||
cmd.SilenceUsage = false
|
||||
return fmt.Errorf("%s isn't a valid GCP instance type", insType)
|
||||
case cloudprovider.Azure:
|
||||
for _, instanceType := range azure.InstanceTypes {
|
||||
if insType == instanceType {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
cmd.SetUsageTemplate("Azure instance types:\n" + formatInstanceTypes(azure.InstanceTypes))
|
||||
cmd.SilenceUsage = false
|
||||
return fmt.Errorf("%s isn't a valid Azure instance type", insType)
|
||||
default:
|
||||
return fmt.Errorf("%s isn't a valid cloud platform", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
|
||||
_, _, err := net.SplitHostPort(endpoint)
|
||||
if err == nil {
|
||||
return endpoint, nil
|
||||
}
|
||||
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
|
@ -1,90 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsCloudProvider(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
pos int
|
||||
args []string
|
||||
wantErr bool
|
||||
}{
|
||||
"gcp": {0, []string{"gcp"}, false},
|
||||
"azure": {1, []string{"foo", "azure"}, false},
|
||||
"foo": {0, []string{"foo"}, true},
|
||||
"empty": {0, []string{""}, true},
|
||||
"unknown": {0, []string{"unknown"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
testCmd := &cobra.Command{Args: isCloudProvider(tc.pos)}
|
||||
|
||||
err := testCmd.ValidateArgs(tc.args)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateEndpoint(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
endpoint string
|
||||
defaultPort int
|
||||
wantResult string
|
||||
wantErr bool
|
||||
}{
|
||||
"ip and port": {
|
||||
endpoint: "192.0.2.1:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:2",
|
||||
},
|
||||
"hostname and port": {
|
||||
endpoint: "foo:2",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:2",
|
||||
},
|
||||
"ip": {
|
||||
endpoint: "192.0.2.1",
|
||||
defaultPort: 3,
|
||||
wantResult: "192.0.2.1:3",
|
||||
},
|
||||
"hostname": {
|
||||
endpoint: "foo",
|
||||
defaultPort: 3,
|
||||
wantResult: "foo:3",
|
||||
},
|
||||
"invalid endpoint": {
|
||||
endpoint: "foo:2:2",
|
||||
defaultPort: 3,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResult, res)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,133 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
|
||||
"github.com/edgelesssys/constellation/cli/internal/proto"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
rpcStatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func newVerifyCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "verify {aws|azure|gcp}",
|
||||
Short: "Verify the confidential properties of a Constellation cluster",
|
||||
Long: "Verify the confidential properties of a Constellation cluster.",
|
||||
Args: cobra.MatchAll(
|
||||
cobra.ExactArgs(1),
|
||||
isCloudProvider(0),
|
||||
warnAWS(0),
|
||||
),
|
||||
RunE: runVerify,
|
||||
}
|
||||
cmd.Flags().String("owner-id", "", "verify using the owner identity derived from the master secret")
|
||||
cmd.Flags().String("unique-id", "", "verify using the unique cluster identity")
|
||||
cmd.Flags().StringP("node-endpoint", "e", "", "endpoint of the node to verify, passed as HOST[:PORT] (required)")
|
||||
must(cmd.MarkFlagRequired("node-endpoint"))
|
||||
return cmd
|
||||
}
|
||||
|
||||
func runVerify(cmd *cobra.Command, args []string) error {
|
||||
provider := cloudprovider.FromString(args[0])
|
||||
fileHandler := file.NewHandler(afero.NewOsFs())
|
||||
protoClient := &proto.Client{}
|
||||
defer protoClient.Close()
|
||||
return verify(cmd.Context(), cmd, provider, fileHandler, protoClient)
|
||||
}
|
||||
|
||||
func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error {
|
||||
flags, err := parseVerifyFlags(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
validators, err := cloudcmd.NewValidators(provider, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validators.UpdateInitPCRs(flags.ownerID, flags.clusterID); err != nil {
|
||||
return err
|
||||
}
|
||||
if validators.Warnings() != "" {
|
||||
cmd.Print(validators.Warnings())
|
||||
}
|
||||
|
||||
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := protoClient.GetState(ctx); err != nil {
|
||||
if err, ok := rpcStatus.FromError(err); ok {
|
||||
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Println("OK")
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
|
||||
ownerID, err := cmd.Flags().GetString("owner-id")
|
||||
if err != nil {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
clusterID, err := cmd.Flags().GetString("unique-id")
|
||||
if err != nil {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
if ownerID == "" && clusterID == "" {
|
||||
return verifyFlags{}, errors.New("neither owner ID nor unique ID provided to verify the cluster")
|
||||
}
|
||||
|
||||
endpoint, err := cmd.Flags().GetString("node-endpoint")
|
||||
if err != nil {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
|
||||
if err != nil {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
|
||||
configPath, err := cmd.Flags().GetString("config")
|
||||
if err != nil {
|
||||
return verifyFlags{}, err
|
||||
}
|
||||
|
||||
return verifyFlags{
|
||||
endpoint: endpoint,
|
||||
configPath: configPath,
|
||||
ownerID: ownerID,
|
||||
clusterID: clusterID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type verifyFlags struct {
|
||||
endpoint string
|
||||
ownerID string
|
||||
clusterID string
|
||||
configPath string
|
||||
}
|
||||
|
||||
// verifyCompletion handels the completion of CLI arguments. It is frequently called
|
||||
// while the user types arguments of the command to suggest completion.
|
||||
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
switch len(args) {
|
||||
case 0:
|
||||
return []string{"gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
|
||||
default:
|
||||
return []string{}, cobra.ShellCompDirectiveError
|
||||
}
|
||||
}
|
|
@ -1,197 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
|
||||
"github.com/edgelesssys/constellation/internal/file"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/codes"
|
||||
rpcStatus "google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func TestVerifyCmdArgumentValidation(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
wantErr bool
|
||||
}{
|
||||
"no args": {[]string{}, true},
|
||||
"valid azure": {[]string{"azure"}, false},
|
||||
"valid gcp": {[]string{"gcp"}, false},
|
||||
"invalid provider": {[]string{"invalid", "192.0.2.1", "1234"}, true},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newVerifyCmd()
|
||||
err := cmd.ValidateArgs(tc.args)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerify(t *testing.T) {
|
||||
zeroBase64 := base64.StdEncoding.EncodeToString([]byte("00000000000000000000000000000000"))
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
setupFs func(*require.Assertions) afero.Fs
|
||||
provider cloudprovider.Provider
|
||||
protoClient protoClient
|
||||
nodeEndpointFlag string
|
||||
configFlag string
|
||||
ownerIDFlag string
|
||||
clusterIDFlag string
|
||||
wantErr bool
|
||||
}{
|
||||
"gcp": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{},
|
||||
},
|
||||
"azure": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.Azure,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{},
|
||||
},
|
||||
"default port": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: "192.0.2.1",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{},
|
||||
},
|
||||
"invalid endpoint": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: ":::::",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{},
|
||||
wantErr: true,
|
||||
},
|
||||
"neither owner id nor cluster id set": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.GCP,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
wantErr: true,
|
||||
},
|
||||
"config file not existing": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.GCP,
|
||||
ownerIDFlag: zeroBase64,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
configFlag: "./file",
|
||||
wantErr: true,
|
||||
},
|
||||
"error protoClient Connect": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.Azure,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{connectErr: someErr},
|
||||
wantErr: true,
|
||||
},
|
||||
"error protoClient GetState": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.Azure,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{getStateErr: rpcStatus.Error(codes.Internal, "failed")},
|
||||
wantErr: true,
|
||||
},
|
||||
"error protoClient GetState not rpc": {
|
||||
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
|
||||
provider: cloudprovider.Azure,
|
||||
nodeEndpointFlag: "192.0.2.1:1234",
|
||||
ownerIDFlag: zeroBase64,
|
||||
protoClient: &stubProtoClient{getStateErr: someErr},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
cmd := newVerifyCmd()
|
||||
cmd.Flags().String("config", "", "") // register persisten flag manually
|
||||
out := &bytes.Buffer{}
|
||||
cmd.SetOut(out)
|
||||
cmd.SetErr(&bytes.Buffer{})
|
||||
if tc.configFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("config", tc.configFlag))
|
||||
}
|
||||
if tc.ownerIDFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("owner-id", tc.ownerIDFlag))
|
||||
}
|
||||
if tc.clusterIDFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag))
|
||||
}
|
||||
if tc.nodeEndpointFlag != "" {
|
||||
require.NoError(cmd.Flags().Set("node-endpoint", tc.nodeEndpointFlag))
|
||||
}
|
||||
fileHandler := file.NewHandler(tc.setupFs(require))
|
||||
|
||||
ctx := context.Background()
|
||||
err := verify(ctx, cmd, tc.provider, fileHandler, tc.protoClient)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Contains(out.String(), "OK")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestVerifyCompletion(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
args []string
|
||||
toComplete string
|
||||
wantResult []string
|
||||
wantShellCD cobra.ShellCompDirective
|
||||
}{
|
||||
"first arg": {
|
||||
args: []string{},
|
||||
toComplete: "az",
|
||||
wantResult: []string{"gcp", "azure"},
|
||||
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
|
||||
},
|
||||
"additional arg": {
|
||||
args: []string{"gcp", "foo"},
|
||||
wantResult: []string{},
|
||||
wantShellCD: cobra.ShellCompDirectiveError,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
result, shellCD := verifyCompletion(cmd, tc.args, tc.toComplete)
|
||||
assert.Equal(tc.wantResult, result)
|
||||
assert.Equal(tc.wantShellCD, shellCD)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newVersionCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "version",
|
||||
Short: "Display version of this CLI",
|
||||
Long: "Display version of this CLI.",
|
||||
Args: cobra.NoArgs,
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
cmd.Printf("CLI Version: v%s \n", constants.CliVersion)
|
||||
},
|
||||
}
|
||||
return cmd
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestVersionCmd(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
cmd := newVersionCmd()
|
||||
b := &bytes.Buffer{}
|
||||
cmd.SetOut(b)
|
||||
|
||||
err := cmd.Execute()
|
||||
assert.NoError(err)
|
||||
|
||||
s, err := io.ReadAll(b)
|
||||
assert.NoError(err)
|
||||
assert.Contains(string(s), constants.CliVersion)
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import wgquick "github.com/nmiculinic/wg-quick-go"
|
||||
|
||||
type vpnHandler interface {
|
||||
Create(coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string, clientVPNIP string, mtu int) (*wgquick.Config, error)
|
||||
Apply(*wgquick.Config) error
|
||||
Marshal(*wgquick.Config) ([]byte, error)
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package cmd
|
||||
|
||||
import wgquick "github.com/nmiculinic/wg-quick-go"
|
||||
|
||||
type stubVPNHandler struct {
|
||||
configured bool
|
||||
marshalRes string
|
||||
|
||||
createErr error
|
||||
applyErr error
|
||||
marshalErr error
|
||||
}
|
||||
|
||||
func (c *stubVPNHandler) Create(coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string, clientVPNIP string, mtu int) (*wgquick.Config, error) {
|
||||
return &wgquick.Config{}, c.createErr
|
||||
}
|
||||
|
||||
func (c *stubVPNHandler) Apply(*wgquick.Config) error {
|
||||
c.configured = true
|
||||
return c.applyErr
|
||||
}
|
||||
|
||||
func (c *stubVPNHandler) Marshal(*wgquick.Config) ([]byte, error) {
|
||||
return []byte(c.marshalRes), c.marshalErr
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue