Move cli/cmd into cli/internal

This commit is contained in:
katexochen 2022-06-08 08:14:28 +02:00
parent d71e97a940
commit c3ebd3d3cd
34 changed files with 45 additions and 32 deletions

View file

@ -1,28 +0,0 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
type cloudCreator interface {
Create(
ctx context.Context,
provider cloudprovider.Provider,
config *config.Config,
name, insType string,
coordCount, nodeCount int,
) (state.ConstellationState, error)
}
type cloudTerminator interface {
Terminate(context.Context, state.ConstellationState) error
}
type serviceAccountCreator interface {
Create(ctx context.Context, stat state.ConstellationState, config *config.Config,
) (string, state.ConstellationState, error)
}

View file

@ -1,49 +0,0 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
type stubCloudCreator struct {
createCalled bool
state state.ConstellationState
createErr error
}
func (c *stubCloudCreator) Create(
ctx context.Context,
provider cloudprovider.Provider,
config *config.Config,
name, insType string,
coordCount, nodeCount int,
) (state.ConstellationState, error) {
c.createCalled = true
return c.state, c.createErr
}
type stubCloudTerminator struct {
called bool
terminateErr error
}
func (c *stubCloudTerminator) Terminate(context.Context, state.ConstellationState) error {
c.called = true
return c.terminateErr
}
func (c *stubCloudTerminator) Called() bool {
return c.called
}
type stubServiceAccountCreator struct {
cloudServiceAccountURI string
createErr error
}
func (c *stubServiceAccountCreator) Create(ctx context.Context, stat state.ConstellationState, config *config.Config) (string, state.ConstellationState, error) {
return c.cloudServiceAccountURI, stat, c.createErr
}

View file

@ -1,18 +0,0 @@
package cmd
import (
"github.com/spf13/cobra"
)
func newConfigCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "config",
Short: "Work with the Constellation configuration file",
Long: "Generate a configuration file for Constellation.",
Args: cobra.ExactArgs(0),
}
cmd.AddCommand(newConfigGenerateCmd())
return cmd
}

View file

@ -1,81 +0,0 @@
package cmd
import (
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/talos-systems/talos/pkg/machinery/config/encoder"
)
func newConfigGenerateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "generate {aws|azure|gcp}",
Short: "Generate a default configuration file",
Long: "Generate a default configuration file for your selected cloud provider.",
Args: cobra.MatchAll(
cobra.ExactArgs(1),
isCloudProvider(0),
warnAWS(0),
),
ValidArgsFunction: generateCompletion,
RunE: runConfigGenerate,
}
cmd.Flags().StringP("file", "f", constants.ConfigFilename, "path to output file, or '-' for stdout")
return cmd
}
type generateFlags struct {
file string
}
func runConfigGenerate(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
provider := cloudprovider.FromString(args[0])
return configGenerate(cmd, fileHandler, provider)
}
func configGenerate(cmd *cobra.Command, fileHandler file.Handler, provider cloudprovider.Provider) error {
flags, err := parseGenerateFlags(cmd)
if err != nil {
return err
}
conf := config.Default()
conf.RemoveProviderExcept(provider)
if flags.file == "-" {
content, err := encoder.NewEncoder(conf).Encode()
if err != nil {
return err
}
_, err = cmd.OutOrStdout().Write(content)
return err
}
return fileHandler.WriteYAML(flags.file, conf, 0o644)
}
func parseGenerateFlags(cmd *cobra.Command) (generateFlags, error) {
file, err := cmd.Flags().GetString("file")
if err != nil {
return generateFlags{}, err
}
return generateFlags{
file: file,
}, nil
}
// createCompletion handles the completion of the create command. It is frequently called
// while the user types arguments of the command to suggest completion.
func generateCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{"aws", "gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

View file

@ -1,87 +0,0 @@
package cmd
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
)
func TestConfigGenerateDefault(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
require.NoError(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
var readConfig config.Config
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
assert.NoError(err)
assert.Equal(*config.Default(), readConfig)
}
func TestConfigGenerateDefaultGCPSpecific(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
wantConf := config.Default()
wantConf.RemoveProviderExcept(cloudprovider.GCP)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
require.NoError(configGenerate(cmd, fileHandler, cloudprovider.GCP))
var readConfig config.Config
err := fileHandler.ReadYAML(constants.ConfigFilename, &readConfig)
assert.NoError(err)
assert.Equal(*wantConf, readConfig)
}
func TestConfigGenerateDefaultExists(t *testing.T) {
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write(constants.ConfigFilename, []byte("foobar"), file.OptNone))
cmd := newConfigGenerateCmd()
require.Error(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
}
func TestConfigGenerateFileFlagRemoved(t *testing.T) {
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
cmd := newConfigGenerateCmd()
cmd.ResetFlags()
require.Error(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
}
func TestConfigGenerateStdOut(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(afero.NewMemMapFs())
var outBuffer bytes.Buffer
cmd := newConfigGenerateCmd()
cmd.SetOut(&outBuffer)
require.NoError(cmd.Flags().Set("file", "-"))
require.NoError(configGenerate(cmd, fileHandler, cloudprovider.Unknown))
var readConfig config.Config
require.NoError(yaml.NewDecoder(&outBuffer).Decode(&readConfig))
assert.Equal(*config.Default(), readConfig)
}

View file

@ -1,223 +0,0 @@
package cmd
import (
"errors"
"fmt"
"io/fs"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/gcp"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newCreateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "create {aws|azure|gcp}",
Short: "Create instances on a cloud platform for your Constellation cluster",
Long: "Create instances on a cloud platform for your Constellation cluster.",
Args: cobra.MatchAll(
cobra.ExactArgs(1),
isCloudProvider(0),
warnAWS(0),
),
ValidArgsFunction: createCompletion,
RunE: runCreate,
}
cmd.Flags().String("name", "constell", "create the cluster with the specified name")
cmd.Flags().BoolP("yes", "y", false, "create the cluster without further confirmation")
cmd.Flags().IntP("control-plane-nodes", "c", 0, "number of control-plane nodes (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "control-plane-nodes"))
cmd.Flags().IntP("worker-nodes", "w", 0, "number of worker nodes (required)")
must(cobra.MarkFlagRequired(cmd.Flags(), "worker-nodes"))
cmd.Flags().StringP("instance-type", "t", "", "instance type of cluster nodes")
must(cmd.RegisterFlagCompletionFunc("instance-type", instanceTypeCompletion))
cmd.SetHelpTemplate(cmd.HelpTemplate() + fmt.Sprintf(`
Azure instance types:
%v
GCP instance types:
%v
`, formatInstanceTypes(azure.InstanceTypes), formatInstanceTypes(gcp.InstanceTypes)))
return cmd
}
func runCreate(cmd *cobra.Command, args []string) error {
provider := cloudprovider.FromString(args[0])
fileHandler := file.NewHandler(afero.NewOsFs())
creator := cloudcmd.NewCreator(cmd.OutOrStdout())
return create(cmd, creator, fileHandler, provider)
}
func create(cmd *cobra.Command, creator cloudCreator, fileHandler file.Handler, provider cloudprovider.Provider,
) (retErr error) {
flags, err := parseCreateFlags(cmd, provider)
if err != nil {
return err
}
if err := checkDirClean(fileHandler); err != nil {
return err
}
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
if err != nil {
return err
}
if !flags.yes {
// Ask user to confirm action.
cmd.Printf("The following Constellation cluster will be created:\n")
cmd.Printf("%d control-planes nodes of type %s will be created.\n", flags.controllerCount, flags.insType)
cmd.Printf("%d worker nodes of type %s will be created.\n", flags.workerCount, flags.insType)
ok, err := askToConfirm(cmd, "Do you want to create this cluster?")
if err != nil {
return err
}
if !ok {
cmd.Println("The creation of the cluster was aborted.")
return nil
}
}
state, err := creator.Create(cmd.Context(), provider, config, flags.name, flags.insType, flags.controllerCount, flags.workerCount)
if err != nil {
return err
}
if err := fileHandler.WriteJSON(constants.StateFilename, state, file.OptNone); err != nil {
return err
}
cmd.Println("Your Constellation cluster was created successfully.")
return nil
}
// parseCreateFlags parses the flags of the create command.
func parseCreateFlags(cmd *cobra.Command, provider cloudprovider.Provider) (createFlags, error) {
controllerCount, err := cmd.Flags().GetInt("control-plane-nodes")
if err != nil {
return createFlags{}, err
}
if controllerCount < constants.MinControllerCount {
return createFlags{}, fmt.Errorf("number of control-plane nodes must be at least %d", constants.MinControllerCount)
}
workerCount, err := cmd.Flags().GetInt("worker-nodes")
if err != nil {
return createFlags{}, err
}
if workerCount < constants.MinWorkerCount {
return createFlags{}, fmt.Errorf("number of worker nodes must be at least %d", constants.MinWorkerCount)
}
insType, err := cmd.Flags().GetString("instance-type")
if err != nil {
return createFlags{}, err
}
if insType == "" {
insType = defaultInstanceType(provider)
}
if err := validInstanceTypeForProvider(cmd, insType, provider); err != nil {
return createFlags{}, err
}
name, err := cmd.Flags().GetString("name")
if err != nil {
return createFlags{}, err
}
if len(name) > constants.ConstellationNameLength {
return createFlags{}, fmt.Errorf(
"name for Constellation cluster too long, maximum length is %d, got %d: %s",
constants.ConstellationNameLength, len(name), name,
)
}
yes, err := cmd.Flags().GetBool("yes")
if err != nil {
return createFlags{}, err
}
configPath, err := cmd.Flags().GetString("config")
if err != nil {
return createFlags{}, err
}
return createFlags{
controllerCount: controllerCount,
workerCount: workerCount,
insType: insType,
name: name,
configPath: configPath,
yes: yes,
}, nil
}
// createFlags contains the parsed flags of the create command.
type createFlags struct {
controllerCount int
workerCount int
insType string
name string
configPath string
yes bool
}
// defaultInstanceType returns the default instance type for the given provider.
func defaultInstanceType(provider cloudprovider.Provider) string {
switch provider {
case cloudprovider.GCP:
return gcp.InstanceTypes[0]
case cloudprovider.Azure:
return azure.InstanceTypes[0]
default:
return ""
}
}
// checkDirClean checks if files of a previous Constellation are left in the current working dir.
func checkDirClean(fileHandler file.Handler) error {
if _, err := fileHandler.Stat(constants.StateFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", constants.StateFilename)
}
if _, err := fileHandler.Stat(constants.AdminConfFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, run 'constellation terminate' before creating a new one", constants.AdminConfFilename)
}
if _, err := fileHandler.Stat(constants.MasterSecretFilename); !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("file '%s' already exists in working directory, clean it up first", constants.MasterSecretFilename)
}
return nil
}
// createCompletion handles the completion of the create command. It is frequently called
// while the user types arguments of the command to suggest completion.
func createCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{"aws", "gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
default:
return []string{}, cobra.ShellCompDirectiveError
}
}
func instanceTypeCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) != 1 {
return []string{}, cobra.ShellCompDirectiveError
}
switch args[0] {
case "gcp":
return gcp.InstanceTypes, cobra.ShellCompDirectiveNoFileComp
case "azure":
return azure.InstanceTypes, cobra.ShellCompDirectiveNoFileComp
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

View file

@ -1,397 +0,0 @@
package cmd
import (
"bytes"
"errors"
"strconv"
"strings"
"testing"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/gcp"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCreateArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
wantErr bool
}{
"gcp": {[]string{"gcp"}, false},
"azure": {[]string{"azure"}, false},
"aws waring": {[]string{"aws"}, true},
"too many args": {[]string{"gcp", "1", "2"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := newCreateCmd().ValidateArgs(tc.args)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreate(t *testing.T) {
testState := state.ConstellationState{Name: "test"}
someErr := errors.New("failed")
testCases := map[string]struct {
setupFs func(*require.Assertions) afero.Fs
creator *stubCloudCreator
provider cloudprovider.Provider
yesFlag bool
controllerCountFlag *int
workerCountFlag *int
insTypeFlag string
configFlag string
nameFlag string
stdin string
wantErr bool
wantAbbort bool
}{
"create": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{state: testState},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(2),
yesFlag: true,
},
"interactive": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{state: testState},
provider: cloudprovider.Azure,
controllerCountFlag: intPtr(2),
workerCountFlag: intPtr(1),
stdin: "yes\n",
},
"interactive abort": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
stdin: "no\n",
wantAbbort: true,
},
"interactive error": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
stdin: "foo\nfoo\nfoo\n",
wantErr: true,
},
"flag name to long": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
nameFlag: strings.Repeat("a", constants.ConstellationNameLength+1),
wantErr: true,
},
"flag control-plane-count invalid": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(0),
workerCountFlag: intPtr(3),
wantErr: true,
},
"flag worker-count invalid": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(3),
workerCountFlag: intPtr(0),
wantErr: true,
},
"flag control-plane-count missing": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
workerCountFlag: intPtr(3),
wantErr: true,
},
"flag worker-count missing": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(3),
wantErr: true,
},
"flag invalid instance-type": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
insTypeFlag: "invalid",
wantErr: true,
},
"old state in directory": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.Write(constants.StateFilename, []byte{1}, file.OptNone))
return fs
},
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
yesFlag: true,
wantErr: true,
},
"old adminConf in directory": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.Write(constants.AdminConfFilename, []byte{1}, file.OptNone))
return fs
},
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
yesFlag: true,
wantErr: true,
},
"old masterSecret in directory": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.Write(constants.MasterSecretFilename, []byte{1}, file.OptNone))
return fs
},
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
yesFlag: true,
wantErr: true,
},
"config does not exist": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
yesFlag: true,
configFlag: constants.ConfigFilename,
wantErr: true,
},
"create error": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
creator: &stubCloudCreator{createErr: someErr},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
yesFlag: true,
wantErr: true,
},
"write state error": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
return afero.NewReadOnlyFs(fs)
},
creator: &stubCloudCreator{},
provider: cloudprovider.GCP,
controllerCountFlag: intPtr(1),
workerCountFlag: intPtr(1),
yesFlag: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newCreateCmd()
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
cmd.SetIn(bytes.NewBufferString(tc.stdin))
cmd.Flags().String("config", "", "") // register persisten flag manually
if tc.yesFlag {
require.NoError(cmd.Flags().Set("yes", "true"))
}
if tc.nameFlag != "" {
require.NoError(cmd.Flags().Set("name", tc.nameFlag))
}
if tc.configFlag != "" {
require.NoError(cmd.Flags().Set("config", tc.configFlag))
}
if tc.controllerCountFlag != nil {
require.NoError(cmd.Flags().Set("control-plane-nodes", strconv.Itoa(*tc.controllerCountFlag)))
}
if tc.workerCountFlag != nil {
require.NoError(cmd.Flags().Set("worker-nodes", strconv.Itoa(*tc.workerCountFlag)))
}
if tc.insTypeFlag != "" {
require.NoError(cmd.Flags().Set("instance-type", tc.insTypeFlag))
}
fileHandler := file.NewHandler(tc.setupFs(require))
err := create(cmd, tc.creator, fileHandler, tc.provider)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
if tc.wantAbbort {
assert.False(tc.creator.createCalled)
} else {
assert.True(tc.creator.createCalled)
var state state.ConstellationState
require.NoError(fileHandler.ReadJSON(constants.StateFilename, &state))
assert.Equal(state, testState)
}
}
})
}
}
func TestCheckDirClean(t *testing.T) {
testCases := map[string]struct {
fileHandler file.Handler
existingFiles []string
wantErr bool
}{
"no file exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
},
"adminconf exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{constants.AdminConfFilename},
wantErr: true,
},
"master secret exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{constants.MasterSecretFilename},
wantErr: true,
},
"state file exists": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{constants.StateFilename},
wantErr: true,
},
"multiple exist": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
existingFiles: []string{constants.AdminConfFilename, constants.MasterSecretFilename, constants.StateFilename},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
for _, f := range tc.existingFiles {
require.NoError(tc.fileHandler.Write(f, []byte{1, 2, 3}, file.OptNone))
}
err := checkDirClean(tc.fileHandler)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestCreateCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
wantResult []string
wantShellCD cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
wantResult: []string{"aws", "gcp", "azure"},
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
},
"second arg": {
args: []string{"gcp", "foo"},
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := createCompletion(cmd, tc.args, "")
assert.Equal(tc.wantResult, result)
assert.Equal(tc.wantShellCD, shellCD)
})
}
}
func TestInstanceTypeCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
wantResult []string
wantShellCD cobra.ShellCompDirective
}{
"azure": {
args: []string{"azure"},
wantResult: azure.InstanceTypes,
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
},
"gcp": {
args: []string{"gcp"},
wantResult: gcp.InstanceTypes,
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
},
"empty args": {
args: []string{},
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveError,
},
"unknown provider": {
args: []string{"foo"},
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := instanceTypeCompletion(cmd, tc.args, "")
assert.Equal(tc.wantResult, result)
assert.Equal(tc.wantShellCD, shellCD)
})
}
}
func intPtr(i int) *int {
return &i
}

View file

@ -1,490 +0,0 @@
package cmd
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"io/fs"
"net"
"strconv"
"text/tabwriter"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/gcp"
"github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/cli/internal/status"
"github.com/edgelesssys/constellation/cli/internal/vpn"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
coordinatorstate "github.com/edgelesssys/constellation/coordinator/state"
"github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
"github.com/kr/text"
wgquick "github.com/nmiculinic/wg-quick-go"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
func newInitCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "init",
Short: "Initialize the Constellation cluster",
Long: "Initialize the Constellation cluster. Start your confidential Kubernetes.",
ValidArgsFunction: initCompletion,
Args: cobra.ExactArgs(0),
RunE: runInitialize,
}
cmd.Flags().String("privatekey", "", "path to your private key")
cmd.Flags().String("master-secret", "", "path to base64-encoded master secret")
cmd.Flags().Bool("wg-autoconfig", false, "enable automatic configuration of WireGuard interface")
must(cmd.Flags().MarkHidden("wg-autoconfig"))
cmd.Flags().Bool("autoscale", false, "enable Kubernetes cluster-autoscaler")
return cmd
}
// runInitialize runs the initialize command.
func runInitialize(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
vpnHandler := vpn.NewConfigHandler()
serviceAccountCreator := cloudcmd.NewServiceAccountCreator()
waiter := status.NewWaiter()
protoClient := &proto.Client{}
defer protoClient.Close()
// We have to parse the context separately, since cmd.Context()
// returns nil during the tests otherwise.
return initialize(cmd.Context(), cmd, protoClient, serviceAccountCreator, fileHandler, waiter, vpnHandler)
}
// initialize initializes a Constellation. Coordinator instances are activated as contole-plane nodes and will
// themself activate the other peers as workers.
func initialize(ctx context.Context, cmd *cobra.Command, protCl protoClient, serviceAccCreator serviceAccountCreator,
fileHandler file.Handler, waiter statusWaiter, vpnHandler vpnHandler,
) error {
flags, err := evalFlagArgs(cmd, fileHandler)
if err != nil {
return err
}
var stat state.ConstellationState
err = fileHandler.ReadJSON(constants.StateFilename, &stat)
if errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("nothing to initialize: %w", err)
} else if err != nil {
return err
}
provider := cloudprovider.FromString(stat.CloudProvider)
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
if err != nil {
return err
}
var sshUsers []*ssh.UserKey
for _, user := range config.SSHUsers {
sshUsers = append(sshUsers, &ssh.UserKey{
Username: user.Username,
PublicKey: user.PublicKey,
})
}
validators, err := cloudcmd.NewValidators(provider, config)
if err != nil {
return err
}
cmd.Print(validators.WarningsIncludeInit())
cmd.Println("Creating service account ...")
serviceAccount, stat, err := serviceAccCreator.Create(ctx, stat, config)
if err != nil {
return err
}
if err := fileHandler.WriteJSON(constants.StateFilename, stat, file.OptOverwrite); err != nil {
return err
}
coordinators, nodes, err := getScalingGroupsFromConfig(stat, config)
if err != nil {
return err
}
endpoints := ipsToEndpoints(append(coordinators.PublicIPs(), nodes.PublicIPs()...), strconv.Itoa(constants.CoordinatorPort))
cmd.Println("Waiting for cloud provider resource creation and boot ...")
if err := waiter.InitializeValidators(validators.V()); err != nil {
return err
}
if err := waiter.WaitForAll(ctx, endpoints, coordinatorstate.AcceptingInit); err != nil {
return fmt.Errorf("failed to wait for peer status: %w", err)
}
var autoscalingNodeGroups []string
if flags.autoscale {
autoscalingNodeGroups = append(autoscalingNodeGroups, nodes.GroupID)
}
input := activationInput{
coordinatorPubIP: coordinators.PublicIPs()[0],
pubKey: flags.userPubKey,
masterSecret: flags.masterSecret,
nodePrivIPs: nodes.PrivateIPs(),
coordinatorPrivIPs: coordinators.PrivateIPs()[1:],
autoscalingNodeGroups: autoscalingNodeGroups,
cloudServiceAccountURI: serviceAccount,
sshUserKeys: ssh.ToProtoSlice(sshUsers),
}
result, err := activate(ctx, cmd, protCl, input, validators.V())
if err != nil {
return err
}
err = result.writeOutput(cmd.OutOrStdout(), fileHandler)
if err != nil {
return err
}
vpnConfig, err := vpnHandler.Create(result.coordinatorPubKey, result.coordinatorPubIP, string(flags.userPrivKey), result.clientVpnIP, constants.WireguardAdminMTU)
if err != nil {
return err
}
if err := writeWGQuickFile(fileHandler, vpnHandler, vpnConfig); err != nil {
return fmt.Errorf("write wg-quick file: %w", err)
}
if flags.autoconfigureWG {
if err := vpnHandler.Apply(vpnConfig); err != nil {
return err
}
}
return nil
}
func activate(ctx context.Context, cmd *cobra.Command, client protoClient, input activationInput,
validators []atls.Validator,
) (activationResult, error) {
err := client.Connect(net.JoinHostPort(input.coordinatorPubIP, strconv.Itoa(constants.CoordinatorPort)), validators)
if err != nil {
return activationResult{}, err
}
respCl, err := client.Activate(ctx, input.pubKey, input.masterSecret, input.nodePrivIPs, input.coordinatorPrivIPs, input.autoscalingNodeGroups, input.cloudServiceAccountURI, input.sshUserKeys)
if err != nil {
return activationResult{}, err
}
indentOut := text.NewIndentWriter(cmd.OutOrStdout(), []byte{'\t'})
cmd.Println("Activating the cluster ...")
if err := respCl.WriteLogStream(indentOut); err != nil {
return activationResult{}, err
}
clientVpnIp, err := respCl.GetClientVpnIp()
if err != nil {
return activationResult{}, err
}
coordinatorPubKey, err := respCl.GetCoordinatorVpnKey()
if err != nil {
return activationResult{}, err
}
kubeconfig, err := respCl.GetKubeconfig()
if err != nil {
return activationResult{}, err
}
ownerID, err := respCl.GetOwnerID()
if err != nil {
return activationResult{}, err
}
clusterID, err := respCl.GetClusterID()
if err != nil {
return activationResult{}, err
}
return activationResult{
clientVpnIP: clientVpnIp,
coordinatorPubKey: coordinatorPubKey,
coordinatorPubIP: input.coordinatorPubIP,
kubeconfig: kubeconfig,
ownerID: ownerID,
clusterID: clusterID,
}, nil
}
type activationInput struct {
coordinatorPubIP string
pubKey []byte
masterSecret []byte
nodePrivIPs []string
coordinatorPrivIPs []string
autoscalingNodeGroups []string
cloudServiceAccountURI string
sshUserKeys []*pubproto.SSHUserKey
}
type activationResult struct {
clientVpnIP string
coordinatorPubKey string
coordinatorPubIP string
kubeconfig string
ownerID string
clusterID string
}
// writeWGQuickFile writes the wg-quick file to the default path.
func writeWGQuickFile(fileHandler file.Handler, vpnHandler vpnHandler, vpnConfig *wgquick.Config) error {
data, err := vpnHandler.Marshal(vpnConfig)
if err != nil {
return err
}
return fileHandler.Write(constants.WGQuickConfigFilename, data, file.OptNone)
}
func (r activationResult) writeOutput(wr io.Writer, fileHandler file.Handler) error {
fmt.Fprint(wr, "Your Constellation cluster was successfully initialized.\n\n")
tw := tabwriter.NewWriter(wr, 0, 0, 2, ' ', 0)
writeRow(tw, "Your WireGuard IP", r.clientVpnIP)
writeRow(tw, "Control plane's public IP", r.coordinatorPubIP)
writeRow(tw, "Control plane's public key", r.coordinatorPubKey)
writeRow(tw, "Constellation cluster's owner identifier", r.ownerID)
writeRow(tw, "Constellation cluster's unique identifier", r.clusterID)
writeRow(tw, "WireGuard configuration file", constants.WGQuickConfigFilename)
writeRow(tw, "Kubernetes configuration", constants.AdminConfFilename)
tw.Flush()
fmt.Fprintln(wr)
if err := fileHandler.Write(constants.AdminConfFilename, []byte(r.kubeconfig), file.OptNone); err != nil {
return fmt.Errorf("write kubeconfig: %w", err)
}
fmt.Fprintln(wr, "You can now connect to your cluster by executing:")
fmt.Fprintf(wr, "\twg-quick up ./%s\n", constants.WGQuickConfigFilename)
fmt.Fprintf(wr, "\texport KUBECONFIG=\"$PWD/%s\"\n", constants.AdminConfFilename)
return nil
}
func writeRow(wr io.Writer, col1 string, col2 string) {
fmt.Fprint(wr, col1, "\t", col2, "\n")
}
// evalFlagArgs gets the flag values and does preprocessing of these values like
// reading the content from file path flags and deriving other values from flag combinations.
func evalFlagArgs(cmd *cobra.Command, fileHandler file.Handler) (initFlags, error) {
userPrivKeyPath, err := cmd.Flags().GetString("privatekey")
if err != nil {
return initFlags{}, err
}
userPrivKey, userPubKey, err := readOrGenerateVPNKey(fileHandler, userPrivKeyPath)
if err != nil {
return initFlags{}, err
}
autoconfigureWG, err := cmd.Flags().GetBool("wg-autoconfig")
if err != nil {
return initFlags{}, err
}
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return initFlags{}, err
}
masterSecret, err := readOrGeneratedMasterSecret(cmd.OutOrStdout(), fileHandler, masterSecretPath)
if err != nil {
return initFlags{}, err
}
autoscale, err := cmd.Flags().GetBool("autoscale")
if err != nil {
return initFlags{}, err
}
configPath, err := cmd.Flags().GetString("config")
if err != nil {
return initFlags{}, err
}
return initFlags{
configPath: configPath,
userPrivKey: userPrivKey,
userPubKey: userPubKey,
autoconfigureWG: autoconfigureWG,
autoscale: autoscale,
masterSecret: masterSecret,
}, nil
}
// initFlags are the resulting values of flag preprocessing.
type initFlags struct {
configPath string
userPrivKey []byte
userPubKey []byte
masterSecret []byte
autoconfigureWG bool
autoscale bool
}
func readOrGenerateVPNKey(fileHandler file.Handler, privKeyPath string) (privKey, pubKey []byte, err error) {
var privKeyParsed wgtypes.Key
if privKeyPath == "" {
privKeyParsed, err = wgtypes.GeneratePrivateKey()
if err != nil {
return nil, nil, err
}
privKey = []byte(privKeyParsed.String())
} else {
privKey, err = fileHandler.Read(privKeyPath)
if err != nil {
return nil, nil, err
}
privKeyParsed, err = wgtypes.ParseKey(string(privKey))
if err != nil {
return nil, nil, err
}
}
pubKey = []byte(privKeyParsed.PublicKey().String())
return privKey, pubKey, nil
}
func ipsToEndpoints(ips []string, port string) []string {
var endpoints []string
for _, ip := range ips {
if ip == "" {
continue
}
endpoints = append(endpoints, net.JoinHostPort(ip, port))
}
return endpoints
}
// readOrGeneratedMasterSecret reads a base64 encoded master secret from file or generates a new 32 byte secret.
func readOrGeneratedMasterSecret(w io.Writer, fileHandler file.Handler, filename string) ([]byte, error) {
if filename != "" {
// Try to read the base64 secret from file
encodedSecret, err := fileHandler.Read(filename)
if err != nil {
return nil, err
}
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
if err != nil {
return nil, err
}
if len(decoded) < constants.MasterSecretLengthMin {
return nil, errors.New("provided master secret is smaller than the required minimum of 16 Bytes")
}
return decoded, nil
}
// No file given, generate a new secret, and save it to disk
masterSecret, err := util.GenerateRandomBytes(constants.MasterSecretLengthDefault)
if err != nil {
return nil, err
}
if err := fileHandler.Write(constants.MasterSecretFilename, []byte(base64.StdEncoding.EncodeToString(masterSecret)), file.OptNone); err != nil {
return nil, err
}
fmt.Fprintf(w, "Your Constellation master secret was successfully written to ./%s\n", constants.MasterSecretFilename)
return masterSecret, nil
}
func getScalingGroupsFromConfig(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
switch {
case len(stat.GCPCoordinators) != 0:
return getGCPInstances(stat, config)
case len(stat.AzureCoordinators) != 0:
return getAzureInstances(stat, config)
case len(stat.QEMUCoordinators) != 0:
return getQEMUInstances(stat, config)
default:
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no instances to initialize")
}
}
func getGCPInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
if len(stat.GCPCoordinators) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no control-plane nodes available, can't create Constellation without any instance")
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = cloudtypes.ScalingGroup{
Instances: stat.GCPCoordinators,
GroupID: "",
}
if len(stat.GCPNodes) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no worker nodes available, can't create Constellation with one instance")
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
nodes = cloudtypes.ScalingGroup{
Instances: stat.GCPNodes,
GroupID: gcp.AutoscalingNodeGroup(stat.GCPProject, stat.GCPZone, stat.GCPNodeInstanceGroup, config.AutoscalingNodeGroupMin, config.AutoscalingNodeGroupMax),
}
return
}
func getAzureInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
if len(stat.AzureCoordinators) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no control-plane nodes available, can't create Constellation cluster without any instance")
}
// GroupID of coordinators is empty, since they currently do not scale.
coordinators = cloudtypes.ScalingGroup{
Instances: stat.AzureCoordinators,
GroupID: "",
}
if len(stat.AzureNodes) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no worker nodes available, can't create Constellation cluster with one instance")
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
nodes = cloudtypes.ScalingGroup{
Instances: stat.AzureNodes,
GroupID: azure.AutoscalingNodeGroup(stat.AzureNodesScaleSet, config.AutoscalingNodeGroupMin, config.AutoscalingNodeGroupMax),
}
return
}
func getQEMUInstances(stat state.ConstellationState, config *config.Config) (coordinators, nodes cloudtypes.ScalingGroup, err error) {
coordinatorMap := stat.QEMUCoordinators
if len(coordinatorMap) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no coordinators available, can't create Constellation without any instance")
}
// QEMU does not support autoscaling
coordinators = cloudtypes.ScalingGroup{
Instances: stat.QEMUCoordinators,
GroupID: "",
}
if len(stat.QEMUNodes) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no nodes available, can't create Constellation with one instance")
}
// QEMU does not support autoscaling
nodes = cloudtypes.ScalingGroup{
Instances: stat.QEMUNodes,
GroupID: "",
}
return
}
// initCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func initCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
if len(args) != 0 {
return []string{}, cobra.ShellCompDirectiveError
}
return []string{}, cobra.ShellCompDirectiveDefault
}

View file

@ -1,662 +0,0 @@
package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"strconv"
"strings"
"testing"
"time"
"github.com/edgelesssys/constellation/cli/cloud/cloudtypes"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
wgquick "github.com/nmiculinic/wg-quick-go"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestInitArgumentValidation(t *testing.T) {
assert := assert.New(t)
cmd := newInitCmd()
assert.NoError(cmd.ValidateArgs(nil))
assert.Error(cmd.ValidateArgs([]string{"something"}))
assert.Error(cmd.ValidateArgs([]string{"sth", "sth"}))
}
func TestInitialize(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
testGcpState := state.ConstellationState{
CloudProvider: "GCP",
GCPNodes: cloudtypes.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
GCPCoordinators: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testAzureState := state.ConstellationState{
CloudProvider: "Azure",
AzureNodes: cloudtypes.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureCoordinators: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "test",
}
testQemuState := state.ConstellationState{
CloudProvider: "QEMU",
QEMUNodes: cloudtypes.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
QEMUCoordinators: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testActivationResps := []fakeActivationRespMessage{
{log: "testlog1"},
{log: "testlog2"},
{
kubeconfig: "kubeconfig",
clientVpnIp: "192.0.2.2",
coordinatorVpnKey: testKey,
ownerID: "ownerID",
clusterID: "clusterID",
},
{log: "testlog3"},
}
someErr := errors.New("failed")
testCases := map[string]struct {
existingState state.ConstellationState
client protoClient
serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter
privKey string
vpnHandler vpnHandler
initVPN bool
wantErr bool
}{
"initialize some gcp instances": {
existingState: testGcpState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
"initialize some azure instances": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
"initialize some qemu instances": {
existingState: testQemuState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
privKey: testKey,
},
"initialize vpn": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{},
initVPN: true,
privKey: testKey,
},
"invalid initialize vpn": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{applyErr: someErr},
initVPN: true,
privKey: testKey,
wantErr: true,
},
"invalid create vpn config": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{createErr: someErr},
initVPN: true,
privKey: testKey,
wantErr: true,
},
"invalid write vpn config": {
existingState: testAzureState,
client: &fakeProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
vpnHandler: &stubVPNHandler{marshalErr: someErr},
initVPN: true,
privKey: testKey,
wantErr: true,
},
"no state exists": {
existingState: state.ConstellationState{},
client: &stubProtoClient{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"no instances to pick one": {
existingState: state.ConstellationState{GCPNodes: cloudtypes.Instances{}},
client: &stubProtoClient{},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"public key to short": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("tooShortKey")),
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"public key to long": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: &stubStatusWaiter{},
privKey: base64.StdEncoding.EncodeToString([]byte("thisWireguardKeyIsToLongAndHasTooManyBytes")),
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"public key not base64": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: &stubStatusWaiter{},
privKey: "this is not base64 encoded",
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail Connect": {
existingState: testGcpState,
client: &stubProtoClient{connectErr: someErr},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail Activate": {
existingState: testGcpState,
client: &stubProtoClient{activateErr: someErr},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail respClient WriteLogStream": {
existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{writeLogStreamErr: someErr}},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail respClient getKubeconfig": {
existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getKubeconfigErr: someErr}},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail respClient getCoordinatorVpnKey": {
existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getCoordinatorVpnKeyErr: someErr}},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail respClient getClientVpnIp": {
existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClientVpnIpErr: someErr}},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail respClient getOwnerID": {
existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getOwnerIDErr: someErr}},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail respClient getClusterID": {
existingState: testGcpState,
client: &stubProtoClient{respClient: &stubActivationRespClient{getClusterIDErr: someErr}},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail to wait for required status": {
existingState: testGcpState,
client: &stubProtoClient{},
waiter: &stubStatusWaiter{waitForAllErr: someErr},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
"fail to create service account": {
existingState: testGcpState,
client: &stubProtoClient{},
serviceAccountCreator: stubServiceAccountCreator{createErr: someErr},
waiter: &stubStatusWaiter{},
privKey: testKey,
vpnHandler: &stubVPNHandler{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.Flags().String("config", "", "") // register persisten flag manually
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
// Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600))
require.NoError(cmd.Flags().Set("privatekey", "privK"))
if tc.initVPN {
require.NoError(cmd.Flags().Set("wg-autoconfig", "true"))
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, 4*time.Second)
defer cancel()
err := initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, tc.vpnHandler)
if tc.wantErr {
assert.Error(err)
} else {
require.NoError(err)
assert.Equal(tc.initVPN, tc.vpnHandler.(*stubVPNHandler).configured)
assert.Contains(out.String(), "192.0.2.2")
assert.Contains(out.String(), "ownerID")
assert.Contains(out.String(), "clusterID")
}
})
}
}
func TestWriteOutput(t *testing.T) {
assert := assert.New(t)
result := activationResult{
clientVpnIP: "foo-qq",
coordinatorPubKey: "bar-qq",
coordinatorPubIP: "baz-qq",
kubeconfig: "foo-bar-baz-qq",
}
var out bytes.Buffer
testFs := afero.NewMemMapFs()
fileHandler := file.NewHandler(testFs)
err := result.writeOutput(&out, fileHandler)
assert.NoError(err)
assert.Contains(out.String(), result.clientVpnIP)
assert.Contains(out.String(), result.coordinatorPubIP)
assert.Contains(out.String(), result.coordinatorPubKey)
afs := afero.Afero{Fs: testFs}
adminConf, err := afs.ReadFile(constants.AdminConfFilename)
assert.NoError(err)
assert.Equal(result.kubeconfig, string(adminConf))
}
func TestIpsToEndpoints(t *testing.T) {
assert := assert.New(t)
ips := []string{"192.0.2.1", "192.0.2.2", "", "192.0.2.3"}
port := "8080"
endpoints := ipsToEndpoints(ips, port)
assert.Equal([]string{"192.0.2.1:8080", "192.0.2.2:8080", "192.0.2.3:8080"}, endpoints)
}
func TestInitCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
wantResult []string
wantShellCD cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "hello",
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveDefault,
},
"secnod arg": {
args: []string{"23"},
toComplete: "/test/h",
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveError,
},
"third arg": {
args: []string{"./file", "sth"},
toComplete: "./file",
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := initCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.wantResult, result)
assert.Equal(tc.wantShellCD, shellCD)
})
}
}
func TestReadOrGenerateVPNKey(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
testKey := []byte(base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting")))
fileHandler := file.NewHandler(afero.NewMemMapFs())
require.NoError(fileHandler.Write("testKey", testKey, file.OptNone))
privK, pubK, err := readOrGenerateVPNKey(fileHandler, "testKey")
assert.NoError(err)
assert.Equal(testKey, privK)
assert.NotEmpty(pubK)
// no path provided
privK, pubK, err = readOrGenerateVPNKey(fileHandler, "")
assert.NoError(err)
assert.NotEmpty(privK)
assert.NotEmpty(pubK)
}
func TestReadOrGeneratedMasterSecret(t *testing.T) {
testCases := map[string]struct {
filename string
filecontent string
createFile bool
fs func() afero.Fs
wantErr bool
}{
"file with secret exists": {
filename: "someSecret",
filecontent: base64.StdEncoding.EncodeToString([]byte("ConstellationSecret")),
createFile: true,
fs: afero.NewMemMapFs,
wantErr: false,
},
"no file given": {
filename: "",
filecontent: "",
fs: afero.NewMemMapFs,
wantErr: false,
},
"file does not exist": {
filename: "nonExistingSecret",
filecontent: "",
createFile: false,
fs: afero.NewMemMapFs,
wantErr: true,
},
"file is empty": {
filename: "emptySecret",
filecontent: "",
createFile: true,
fs: afero.NewMemMapFs,
wantErr: true,
},
"secret too short": {
filename: "shortSecret",
filecontent: base64.StdEncoding.EncodeToString([]byte("short")),
createFile: true,
fs: afero.NewMemMapFs,
wantErr: true,
},
"secret not encoded": {
filename: "unencodedSecret",
filecontent: "Constellation",
createFile: true,
fs: afero.NewMemMapFs,
wantErr: true,
},
"file not writeable": {
filename: "",
filecontent: "",
createFile: false,
fs: func() afero.Fs { return afero.NewReadOnlyFs(afero.NewMemMapFs()) },
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fileHandler := file.NewHandler(tc.fs())
if tc.createFile {
require.NoError(fileHandler.Write(tc.filename, []byte(tc.filecontent), file.OptNone))
}
var out bytes.Buffer
secret, err := readOrGeneratedMasterSecret(&out, fileHandler, tc.filename)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
if tc.filename == "" {
require.Contains(out.String(), constants.MasterSecretFilename)
filename := strings.Split(out.String(), "./")
tc.filename = strings.Trim(filename[1], "\n")
}
content, err := fileHandler.Read(tc.filename)
require.NoError(err)
assert.Equal(content, []byte(base64.StdEncoding.EncodeToString(secret)))
}
})
}
}
func TestAutoscaleFlag(t *testing.T) {
testKey := base64.StdEncoding.EncodeToString([]byte("32bytesWireGuardKeyForTheTesting"))
testGcpState := state.ConstellationState{
CloudProvider: "gcp",
GCPNodes: cloudtypes.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
GCPCoordinators: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
}
testAzureState := state.ConstellationState{
CloudProvider: "azure",
AzureNodes: cloudtypes.Instances{
"id-0": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
"id-1": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureCoordinators: cloudtypes.Instances{
"id-c": {PrivateIP: "192.0.2.1", PublicIP: "192.0.2.1"},
},
AzureResourceGroup: "test",
}
testActivationResps := []fakeActivationRespMessage{
{log: "testlog1"},
{log: "testlog2"},
{
kubeconfig: "kubeconfig",
clientVpnIp: "192.0.2.2",
coordinatorVpnKey: testKey,
ownerID: "ownerID",
clusterID: "clusterID",
},
{log: "testlog3"},
}
testCases := map[string]struct {
autoscaleFlag bool
existingState state.ConstellationState
client *stubProtoClient
serviceAccountCreator stubServiceAccountCreator
waiter statusWaiter
privKey string
}{
"initialize some gcp instances without autoscale flag": {
autoscaleFlag: false,
existingState: testGcpState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some azure instances without autoscale flag": {
autoscaleFlag: false,
existingState: testAzureState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some gcp instances with autoscale flag": {
autoscaleFlag: true,
existingState: testGcpState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
"initialize some azure instances with autoscale flag": {
autoscaleFlag: true,
existingState: testAzureState,
client: &stubProtoClient{
respClient: &fakeActivationRespClient{responses: testActivationResps},
},
waiter: &stubStatusWaiter{},
privKey: testKey,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newInitCmd()
var out bytes.Buffer
cmd.SetOut(&out)
var errOut bytes.Buffer
cmd.SetErr(&errOut)
cmd.Flags().String("config", "", "") // register persisten flag manually
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
vpnHandler := stubVPNHandler{}
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
// Write key file to filesystem and set path in flag.
require.NoError(afero.Afero{Fs: fs}.WriteFile("privK", []byte(tc.privKey), 0o600))
require.NoError(cmd.Flags().Set("privatekey", "privK"))
require.NoError(cmd.Flags().Set("autoscale", strconv.FormatBool(tc.autoscaleFlag)))
ctx := context.Background()
require.NoError(initialize(ctx, cmd, tc.client, &tc.serviceAccountCreator, fileHandler, tc.waiter, &vpnHandler))
if tc.autoscaleFlag {
assert.Len(tc.client.activateAutoscalingNodeGroups, 1)
} else {
assert.Len(tc.client.activateAutoscalingNodeGroups, 0)
}
})
}
}
func TestWriteWGQuickFile(t *testing.T) {
testCases := map[string]struct {
fileHandler file.Handler
vpnHandler *stubVPNHandler
vpnConfig *wgquick.Config
wantErr bool
}{
"write wg quick file": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
vpnHandler: &stubVPNHandler{marshalRes: "config"},
},
"marshal failed": {
fileHandler: file.NewHandler(afero.NewMemMapFs()),
vpnHandler: &stubVPNHandler{marshalErr: errors.New("some err")},
wantErr: true,
},
"write fails": {
fileHandler: file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
vpnHandler: &stubVPNHandler{marshalRes: "config"},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
err := writeWGQuickFile(tc.fileHandler, tc.vpnHandler, tc.vpnConfig)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
file, err := tc.fileHandler.Read(constants.WGQuickConfigFilename)
assert.NoError(err)
assert.Contains(string(file), tc.vpnHandler.marshalRes)
}
})
}
}

View file

@ -1,7 +0,0 @@
package cmd
import "strings"
func formatInstanceTypes(types []string) string {
return " " + strings.Join(types, "\n ")
}

View file

@ -1,17 +0,0 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state"
)
type protoClient interface {
Connect(endpoint string, validators []atls.Validator) error
Close() error
GetState(ctx context.Context) (state.State, error)
Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string, sshUsers []*pubproto.SSHUserKey) (proto.ActivationResponseClient, error)
}

View file

@ -1,225 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/pubapi/pubproto"
"github.com/edgelesssys/constellation/coordinator/state"
)
type stubProtoClient struct {
conn bool
respClient proto.ActivationResponseClient
connectErr error
closeErr error
getStateErr error
activateErr error
getStateState state.State
activateUserPublicKey []byte
activateMasterSecret []byte
activateNodeIPs []string
activateCoordinatorIPs []string
activateAutoscalingNodeGroups []string
cloudServiceAccountURI string
sshUserKeys []*pubproto.SSHUserKey
}
func (c *stubProtoClient) Connect(_ string, _ []atls.Validator) error {
c.conn = true
return c.connectErr
}
func (c *stubProtoClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubProtoClient) GetState(_ context.Context) (state.State, error) {
return c.getStateState, c.getStateErr
}
func (c *stubProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs []string, autoscalingNodeGroups []string, cloudServiceAccountURI string, sshUserKeys []*pubproto.SSHUserKey) (proto.ActivationResponseClient, error) {
c.activateUserPublicKey = userPublicKey
c.activateMasterSecret = masterSecret
c.activateNodeIPs = nodeIPs
c.activateCoordinatorIPs = coordinatorIPs
c.activateAutoscalingNodeGroups = autoscalingNodeGroups
c.cloudServiceAccountURI = cloudServiceAccountURI
c.sshUserKeys = sshUserKeys
return c.respClient, c.activateErr
}
func (c *stubProtoClient) ActivateAdditionalCoordinators(ctx context.Context, ips []string) error {
return c.activateErr
}
type stubActivationRespClient struct {
nextLogErr *error
getKubeconfigErr error
getCoordinatorVpnKeyErr error
getClientVpnIpErr error
getOwnerIDErr error
getClusterIDErr error
writeLogStreamErr error
}
func (s *stubActivationRespClient) NextLog() (string, error) {
if s.nextLogErr == nil {
return "", io.EOF
}
return "", *s.nextLogErr
}
func (s *stubActivationRespClient) WriteLogStream(io.Writer) error {
return s.writeLogStreamErr
}
func (s *stubActivationRespClient) GetKubeconfig() (string, error) {
return "", s.getKubeconfigErr
}
func (s *stubActivationRespClient) GetCoordinatorVpnKey() (string, error) {
return "", s.getCoordinatorVpnKeyErr
}
func (s *stubActivationRespClient) GetClientVpnIp() (string, error) {
return "", s.getClientVpnIpErr
}
func (s *stubActivationRespClient) GetOwnerID() (string, error) {
return "", s.getOwnerIDErr
}
func (s *stubActivationRespClient) GetClusterID() (string, error) {
return "", s.getClusterIDErr
}
type fakeProtoClient struct {
conn bool
respClient proto.ActivationResponseClient
}
func (c *fakeProtoClient) Connect(endpoint string, validators []atls.Validator) error {
if endpoint == "" {
return errors.New("endpoint is empty")
}
if len(validators) == 0 {
return errors.New("validators is empty")
}
c.conn = true
return nil
}
func (c *fakeProtoClient) Close() error {
c.conn = false
return nil
}
func (c *fakeProtoClient) GetState(_ context.Context) (state.State, error) {
if !c.conn {
return state.Uninitialized, errors.New("client is not connected")
}
return state.IsNode, nil
}
func (c *fakeProtoClient) Activate(ctx context.Context, userPublicKey, masterSecret []byte, nodeIPs, coordinatorIPs, autoscalingNodeGroups []string, cloudServiceAccountURI string, sshUserKeys []*pubproto.SSHUserKey) (proto.ActivationResponseClient, error) {
if !c.conn {
return nil, errors.New("client is not connected")
}
return c.respClient, nil
}
func (c *fakeProtoClient) ActivateAdditionalCoordinators(ctx context.Context, ips []string) error {
if !c.conn {
return errors.New("client is not connected")
}
return nil
}
type fakeActivationRespClient struct {
responses []fakeActivationRespMessage
kubeconfig string
coordinatorVpnKey string
clientVpnIp string
ownerID string
clusterID string
}
func (c *fakeActivationRespClient) NextLog() (string, error) {
for len(c.responses) > 0 {
resp := c.responses[0]
c.responses = c.responses[1:]
if len(resp.log) > 0 {
return resp.log, nil
}
c.kubeconfig = resp.kubeconfig
c.coordinatorVpnKey = resp.coordinatorVpnKey
c.clientVpnIp = resp.clientVpnIp
c.ownerID = resp.ownerID
c.clusterID = resp.clusterID
}
return "", io.EOF
}
func (c *fakeActivationRespClient) WriteLogStream(w io.Writer) error {
log, err := c.NextLog()
for err == nil {
fmt.Fprint(w, log)
log, err = c.NextLog()
}
if !errors.Is(err, io.EOF) {
return err
}
return nil
}
func (c *fakeActivationRespClient) GetKubeconfig() (string, error) {
if c.kubeconfig == "" {
return "", errors.New("kubeconfig is empty")
}
return c.kubeconfig, nil
}
func (c *fakeActivationRespClient) GetCoordinatorVpnKey() (string, error) {
if c.coordinatorVpnKey == "" {
return "", errors.New("control-plane public VPN key is empty")
}
return c.coordinatorVpnKey, nil
}
func (c *fakeActivationRespClient) GetClientVpnIp() (string, error) {
if c.clientVpnIp == "" {
return "", errors.New("client VPN IP is empty")
}
return c.clientVpnIp, nil
}
func (c *fakeActivationRespClient) GetOwnerID() (string, error) {
if c.ownerID == "" {
return "", errors.New("init secret is empty")
}
return c.ownerID, nil
}
func (c *fakeActivationRespClient) GetClusterID() (string, error) {
if c.clusterID == "" {
return "", errors.New("cluster identifier is empty")
}
return c.clusterID, nil
}
type fakeActivationRespMessage struct {
log string
kubeconfig string
coordinatorVpnKey string
clientVpnIp string
ownerID string
clusterID string
}

View file

@ -1,43 +0,0 @@
package cmd
import (
"errors"
"fmt"
"io"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/file"
)
func readConfig(out io.Writer, fileHandler file.Handler, name string, provider cloudprovider.Provider) (*config.Config, error) {
cnf, err := config.FromFile(fileHandler, name)
if err != nil {
return nil, err
}
if err := validateConfig(out, cnf, provider); err != nil {
return nil, err
}
return cnf, nil
}
func validateConfig(out io.Writer, cnf *config.Config, provider cloudprovider.Provider) error {
msgs, err := cnf.Validate()
if err != nil {
return err
}
if len(msgs) > 0 {
fmt.Fprintln(out, "Invalid fields in config file:")
for _, m := range msgs {
fmt.Fprintln(out, "\t"+m)
}
return errors.New("invalid configuration")
}
if provider != cloudprovider.Unknown && !cnf.HasProvider(provider) {
return fmt.Errorf("configuration doesn't contain provider: %v", provider)
}
return nil
}

View file

@ -1,78 +0,0 @@
package cmd
import (
"bytes"
"testing"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/config"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestValidateConfig(t *testing.T) {
testCases := map[string]struct {
cnf *config.Config
provider cloudprovider.Provider
wantOutput bool
wantErr bool
}{
"default config is valid": {
cnf: config.Default(),
},
"config with an error": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Version = "v0"
return cnf
}(),
wantOutput: true,
wantErr: true,
},
"config without provider is ok if no provider required": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider = config.ProviderConfig{}
return cnf
}(),
},
"config with only required provider": {
cnf: func() *config.Config {
cnf := config.Default()
az := cnf.Provider.Azure
cnf.Provider = config.ProviderConfig{}
cnf.Provider.Azure = az
return cnf
}(),
provider: cloudprovider.Azure,
},
"config without required provider": {
cnf: func() *config.Config {
cnf := config.Default()
cnf.Provider.Azure = nil
return cnf
}(),
provider: cloudprovider.Azure,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
out := &bytes.Buffer{}
err := validateConfig(out, tc.cnf, tc.provider)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantOutput, out.Len() > 0)
})
}
}

View file

@ -1,153 +0,0 @@
package cmd
import (
"context"
"encoding/base64"
"errors"
"regexp"
"strings"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/coordinator/util"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
var diskUUIDRegexp = regexp.MustCompile("^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
func newRecoverCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "recover",
Short: "Recover a completely stopped Constellation cluster",
Long: "Recover a Constellation cluster by sending a recovery key to an instance in the boot stage." +
"\nThis is only required if instances restart without other instances available for bootstrapping.",
Args: cobra.ExactArgs(0),
RunE: runRecover,
}
cmd.Flags().StringP("endpoint", "e", "", "endpoint of the instance, passed as HOST[:PORT] (required)")
must(cmd.MarkFlagRequired("endpoint"))
cmd.Flags().String("disk-uuid", "", "disk UUID of the encrypted state disk (required)")
must(cmd.MarkFlagRequired("disk-uuid"))
cmd.Flags().String("master-secret", constants.MasterSecretFilename, "path to base64-encoded master secret")
return cmd
}
func runRecover(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
recoveryClient := &proto.KeyClient{}
defer recoveryClient.Close()
return recover(cmd.Context(), cmd, fileHandler, recoveryClient)
}
func recover(ctx context.Context, cmd *cobra.Command, fileHandler file.Handler, recoveryClient recoveryClient) error {
flags, err := parseRecoverFlags(cmd, fileHandler)
if err != nil {
return err
}
var stat state.ConstellationState
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
return err
}
provider := cloudprovider.FromString(stat.CloudProvider)
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
if err != nil {
return err
}
validators, err := cloudcmd.NewValidators(provider, config)
if err != nil {
return err
}
cmd.Print(validators.WarningsIncludeInit())
if err := recoveryClient.Connect(flags.endpoint, validators.V()); err != nil {
return err
}
diskKey, err := deriveStateDiskKey(flags.masterSecret, flags.diskUUID)
if err != nil {
return err
}
if err := recoveryClient.PushStateDiskKey(ctx, diskKey); err != nil {
return err
}
cmd.Println("Pushed recovery key.")
return nil
}
func parseRecoverFlags(cmd *cobra.Command, fileHandler file.Handler) (recoverFlags, error) {
endpoint, err := cmd.Flags().GetString("endpoint")
if err != nil {
return recoverFlags{}, err
}
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
if err != nil {
return recoverFlags{}, err
}
diskUUID, err := cmd.Flags().GetString("disk-uuid")
if err != nil {
return recoverFlags{}, err
}
if match := diskUUIDRegexp.MatchString(diskUUID); !match {
return recoverFlags{}, errors.New("flag '--disk-uuid' isn't a valid LUKS UUID")
}
diskUUID = strings.ToLower(diskUUID)
masterSecretPath, err := cmd.Flags().GetString("master-secret")
if err != nil {
return recoverFlags{}, err
}
masterSecret, err := readMasterSecret(fileHandler, masterSecretPath)
if err != nil {
return recoverFlags{}, err
}
configPath, err := cmd.Flags().GetString("config")
if err != nil {
return recoverFlags{}, err
}
return recoverFlags{
endpoint: endpoint,
diskUUID: diskUUID,
masterSecret: masterSecret,
configPath: configPath,
}, nil
}
type recoverFlags struct {
endpoint string
diskUUID string
masterSecret []byte
configPath string
}
// readMasterSecret reads a base64 encoded master secret from file.
func readMasterSecret(fileHandler file.Handler, filename string) ([]byte, error) {
// Try to read the base64 secret from file
encodedSecret, err := fileHandler.Read(filename)
if err != nil {
return nil, err
}
decoded, err := base64.StdEncoding.DecodeString(string(encodedSecret))
if err != nil {
return nil, err
}
return decoded, nil
}
// deriveStateDiskKey derives a state disk key from a master secret and a disk UUID.
func deriveStateDiskKey(masterKey []byte, diskUUID string) ([]byte, error) {
return util.DeriveKey(masterKey, []byte("Constellation"), []byte("key"+diskUUID), constants.StateDiskKeyLength)
}

View file

@ -1,356 +0,0 @@
package cmd
import (
"bytes"
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecoverCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
wantErr bool
}{
"no args": {[]string{}, false},
"too many arguments": {[]string{"abc"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newRecoverCmd()
err := cmd.ValidateArgs(tc.args)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestRecover(t *testing.T) {
validState := state.ConstellationState{CloudProvider: "GCP"}
invalidCSPState := state.ConstellationState{CloudProvider: "invalid"}
testCases := map[string]struct {
setupFs func(*require.Assertions) afero.Fs
existingState state.ConstellationState
client *stubRecoveryClient
endpointFlag string
diskUUIDFlag string
masterSecretFlag string
configFlag string
stateless bool
wantErr bool
wantKey []byte
}{
"works": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantKey: []byte{0x2e, 0x4d, 0x40, 0x3a, 0x90, 0x96, 0x6e, 0xd, 0x42, 0x3, 0x98, 0xd, 0xce, 0xc5, 0x73, 0x26, 0xf4, 0x87, 0xcf, 0x85, 0x73, 0xe1, 0xb7, 0xd6, 0xb2, 0x82, 0x4c, 0xd9, 0xbc, 0xa5, 0x7c, 0x32},
},
"uppercase disk uuid works": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF",
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
},
"lowercase disk uuid results in same key": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: validState,
client: &stubRecoveryClient{},
endpointFlag: "192.0.2.1",
diskUUIDFlag: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
wantKey: []byte{0xa9, 0x4, 0x3a, 0x74, 0x53, 0xeb, 0x23, 0xb2, 0xbc, 0x88, 0xce, 0xa7, 0x4e, 0xa9, 0xda, 0x9f, 0x11, 0x85, 0xc4, 0x2f, 0x1f, 0x25, 0x10, 0xc9, 0xec, 0xfe, 0xa, 0x6c, 0xa2, 0x6f, 0x53, 0x34},
},
"missing flags": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
wantErr: true,
},
"missing config": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
configFlag: "nonexistent-config",
wantErr: true,
},
"missing state": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: validState,
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
stateless: true,
wantErr: true,
},
"invalid cloud provider": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: invalidCSPState,
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true,
},
"connect fails": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: validState,
client: &stubRecoveryClient{connectErr: errors.New("connect failed")},
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true,
},
"pushing state key fails": {
setupFs: func(require *require.Assertions) afero.Fs {
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
return fs
},
existingState: validState,
client: &stubRecoveryClient{pushStateDiskKeyErr: errors.New("pushing key failed")},
endpointFlag: "192.0.2.1",
diskUUIDFlag: "00000000-0000-0000-0000-000000000000",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newRecoverCmd()
cmd.Flags().String("config", "", "") // register persisten flag manually
out := &bytes.Buffer{}
cmd.SetOut(out)
cmd.SetErr(&bytes.Buffer{})
if tc.endpointFlag != "" {
require.NoError(cmd.Flags().Set("endpoint", tc.endpointFlag))
}
if tc.diskUUIDFlag != "" {
require.NoError(cmd.Flags().Set("disk-uuid", tc.diskUUIDFlag))
}
if tc.masterSecretFlag != "" {
require.NoError(cmd.Flags().Set("master-secret", tc.masterSecretFlag))
}
if tc.configFlag != "" {
require.NoError(cmd.Flags().Set("config", tc.configFlag))
}
fileHandler := file.NewHandler(tc.setupFs(require))
if !tc.stateless {
require.NoError(fileHandler.WriteJSON(constants.StateFilename, tc.existingState, file.OptNone))
}
ctx := context.Background()
err := recover(ctx, cmd, fileHandler, tc.client)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Contains(out.String(), "Pushed recovery key.")
assert.Equal(tc.wantKey, tc.client.pushStateDiskKeyKey)
})
}
}
func TestParseRecoverFlags(t *testing.T) {
testCases := map[string]struct {
args []string
wantFlags recoverFlags
wantErr bool
}{
"no flags": {
wantErr: true,
},
"invalid ip": {
args: []string{"-e", "192.0.2.1:2:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantErr: true,
},
"invalid disk uuid": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "invalid"},
wantErr: true,
},
"invalid master secret path": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012", "--master-secret", "invalid"},
wantErr: true,
},
"minimal args set": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012"},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
masterSecret: []byte("constellation-master-secret-leng"),
},
},
"all args set": {
args: []string{
"-e", "192.0.2.1:2", "--disk-uuid", "12345678-1234-1234-1234-123456789012",
"--master-secret", "constellation-mastersecret.base64", "--config", "config-path",
},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "12345678-1234-1234-1234-123456789012",
masterSecret: []byte("constellation-master-secret-leng"),
configPath: "config-path",
},
},
"uppercase disk-uuid is converted to lowercase": {
args: []string{"-e", "192.0.2.1:2", "--disk-uuid", "ABCDEFAB-CDEF-ABCD-ABCD-ABCDEFABCDEF"},
wantFlags: recoverFlags{
endpoint: "192.0.2.1:2",
diskUUID: "abcdefab-cdef-abcd-abcd-abcdefabcdef",
masterSecret: []byte("constellation-master-secret-leng"),
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="), 0o777))
cmd := newRecoverCmd()
cmd.Flags().String("config", "", "") // register persistent flag manually
require.NoError(cmd.ParseFlags(tc.args))
flags, err := parseRecoverFlags(cmd, file.NewHandler(fs))
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantFlags, flags)
})
}
}
func TestReadMasterSecret(t *testing.T) {
testCases := map[string]struct {
fileContents []byte
filename string
wantMasterSecret []byte
wantErr bool
}{
"invalid base64": {
fileContents: []byte("invalid"),
filename: "constellation-mastersecret.base64",
wantErr: true,
},
"invalid filename": {
fileContents: []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="),
filename: "invalid",
wantErr: true,
},
"correct master secret": {
fileContents: []byte("Y29uc3RlbGxhdGlvbi1tYXN0ZXItc2VjcmV0LWxlbmc="),
filename: "constellation-mastersecret.base64",
wantMasterSecret: []byte("constellation-master-secret-leng"),
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
require.NoError(afero.WriteFile(fs, "constellation-mastersecret.base64", tc.fileContents, 0o777))
masterSecret, err := readMasterSecret(file.NewHandler(fs), tc.filename)
if tc.wantErr {
assert.Error(err)
return
}
assert.NoError(err)
assert.Equal(tc.wantMasterSecret, masterSecret)
})
}
}
func TestDeriveStateDiskKey(t *testing.T) {
testCases := map[string]struct {
masterKey []byte
diskUUID string
wantStateDiskKey []byte
}{
"all zero": {
masterKey: []byte{
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0,
},
diskUUID: "00000000-0000-0000-0000-000000000000",
wantStateDiskKey: []byte{
0xa8, 0xb0, 0x86, 0x83, 0x6f, 0x0b, 0x26, 0x04, 0x86, 0x22, 0x27, 0xcc, 0xa1, 0x1c, 0xaf, 0x6c,
0x30, 0x4d, 0x90, 0x89, 0x82, 0x68, 0x53, 0x7f, 0x4f, 0x46, 0x7a, 0x65, 0xa2, 0x5d, 0x5e, 0x43,
},
},
"all 0xff": {
masterKey: []byte{
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
},
diskUUID: "ffffffff-ffff-ffff-ffff-ffffffffffff",
wantStateDiskKey: []byte{
0x24, 0x18, 0x84, 0x7f, 0xca, 0x86, 0x55, 0xb5, 0x45, 0xa6, 0xb3, 0xc4, 0x45, 0xbb, 0x08, 0x10,
0x16, 0xb3, 0xde, 0x30, 0x30, 0x74, 0x0b, 0xd4, 0x1e, 0x22, 0x55, 0x45, 0x51, 0x91, 0xfb, 0xa9,
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
stateDiskKey, err := deriveStateDiskKey(tc.masterKey, tc.diskUUID)
assert.NoError(err)
assert.Equal(tc.wantStateDiskKey, stateDiskKey)
})
}
}

View file

@ -1,14 +0,0 @@
package cmd
import (
"context"
"io"
"github.com/edgelesssys/constellation/coordinator/atls"
)
type recoveryClient interface {
Connect(endpoint string, validators []atls.Validator) error
PushStateDiskKey(ctx context.Context, stateDiskKey []byte) error
io.Closer
}

View file

@ -1,31 +0,0 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/coordinator/atls"
)
type stubRecoveryClient struct {
conn bool
connectErr error
closeErr error
pushStateDiskKeyErr error
pushStateDiskKeyKey []byte
}
func (c *stubRecoveryClient) Connect(_ string, _ []atls.Validator) error {
c.conn = true
return c.connectErr
}
func (c *stubRecoveryClient) Close() error {
c.conn = false
return c.closeErr
}
func (c *stubRecoveryClient) PushStateDiskKey(_ context.Context, stateDiskKey []byte) error {
c.pushStateDiskKeyKey = stateDiskKey
return c.pushStateDiskKeyErr
}

View file

@ -6,6 +6,7 @@ import (
"os"
"os/signal"
"github.com/edgelesssys/constellation/cli/internal/cmd"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/spf13/cobra"
)
@ -34,13 +35,13 @@ func NewRootCmd() *cobra.Command {
rootCmd.PersistentFlags().String("config", constants.ConfigFilename, "path to the configuration file")
must(rootCmd.MarkPersistentFlagFilename("config", "json"))
rootCmd.AddCommand(newConfigCmd())
rootCmd.AddCommand(newCreateCmd())
rootCmd.AddCommand(newInitCmd())
rootCmd.AddCommand(newVerifyCmd())
rootCmd.AddCommand(newRecoverCmd())
rootCmd.AddCommand(newTerminateCmd())
rootCmd.AddCommand(newVersionCmd())
rootCmd.AddCommand(cmd.NewConfigCmd())
rootCmd.AddCommand(cmd.NewCreateCmd())
rootCmd.AddCommand(cmd.NewInitCmd())
rootCmd.AddCommand(cmd.NewVerifyCmd())
rootCmd.AddCommand(cmd.NewRecoverCmd())
rootCmd.AddCommand(cmd.NewTerminateCmd())
rootCmd.AddCommand(cmd.NewVersionCmd())
return rootCmd
}

View file

@ -1,13 +0,0 @@
package cmd
import (
"context"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/state"
)
type statusWaiter interface {
InitializeValidators([]atls.Validator) error
WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error
}

View file

@ -1,27 +0,0 @@
package cmd
import (
"context"
"errors"
"github.com/edgelesssys/constellation/coordinator/atls"
"github.com/edgelesssys/constellation/coordinator/state"
)
type stubStatusWaiter struct {
initialized bool
initializeErr error
waitForAllErr error
}
func (s *stubStatusWaiter) InitializeValidators([]atls.Validator) error {
s.initialized = true
return s.initializeErr
}
func (s *stubStatusWaiter) WaitForAll(ctx context.Context, endpoints []string, status ...state.State) error {
if !s.initialized {
return errors.New("waiter not initialized")
}
return s.waitForAllErr
}

View file

@ -1,65 +0,0 @@
package cmd
import (
"errors"
"fmt"
"io/fs"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"go.uber.org/multierr"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
)
func newTerminateCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "terminate",
Short: "Terminate a Constellation cluster",
Long: "Terminate a Constellation cluster. The cluster can't be started again, and all persistent storage will be lost.",
Args: cobra.NoArgs,
RunE: runTerminate,
}
return cmd
}
// runTerminate runs the terminate command.
func runTerminate(cmd *cobra.Command, args []string) error {
fileHandler := file.NewHandler(afero.NewOsFs())
terminator := cloudcmd.NewTerminator()
return terminate(cmd, terminator, fileHandler)
}
func terminate(cmd *cobra.Command, terminator cloudTerminator, fileHandler file.Handler) error {
var stat state.ConstellationState
if err := fileHandler.ReadJSON(constants.StateFilename, &stat); err != nil {
return err
}
cmd.Println("Terminating ...")
if err := terminator.Terminate(cmd.Context(), stat); err != nil {
return err
}
cmd.Println("Your Constellation cluster was terminated successfully.")
var retErr error
if err := fileHandler.Remove(constants.StateFilename); err != nil {
retErr = multierr.Append(err, fmt.Errorf("failed to remove file '%s', please remove manually", constants.StateFilename))
}
if err := fileHandler.Remove(constants.AdminConfFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
retErr = multierr.Append(err, fmt.Errorf("failed to remove file '%s', please remove manually", constants.AdminConfFilename))
}
if err := fileHandler.Remove(constants.WGQuickConfigFilename); err != nil && !errors.Is(err, fs.ErrNotExist) {
retErr = multierr.Append(err, fmt.Errorf("failed to remove file '%s', please remove manually", constants.WGQuickConfigFilename))
}
return retErr
}

View file

@ -1,136 +0,0 @@
package cmd
import (
"bytes"
"errors"
"testing"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/edgelesssys/constellation/internal/state"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTerminateCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
wantErr bool
}{
"no args": {[]string{}, false},
"some args": {[]string{"hello", "test"}, true},
"some other args": {[]string{"12", "2"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newTerminateCmd()
err := cmd.ValidateArgs(tc.args)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestTerminate(t *testing.T) {
setupFs := func(require *require.Assertions, state state.ConstellationState) afero.Fs {
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.Write(constants.AdminConfFilename, []byte{1, 2}, file.OptNone))
require.NoError(fileHandler.Write(constants.WGQuickConfigFilename, []byte{1, 2}, file.OptNone))
require.NoError(fileHandler.WriteJSON(constants.StateFilename, state, file.OptNone))
return fs
}
someErr := errors.New("failed")
testCases := map[string]struct {
state state.ConstellationState
setupFs func(*require.Assertions, state.ConstellationState) afero.Fs
terminator spyCloudTerminator
wantErr bool
}{
"success": {
state: state.ConstellationState{CloudProvider: "gcp"},
setupFs: setupFs,
terminator: &stubCloudTerminator{},
},
"files to remove do not exist": {
state: state.ConstellationState{CloudProvider: "gcp"},
setupFs: func(require *require.Assertions, state state.ConstellationState) afero.Fs {
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.WriteJSON(constants.StateFilename, state, file.OptNone))
return fs
},
terminator: &stubCloudTerminator{},
},
"terminate error": {
state: state.ConstellationState{CloudProvider: "gcp"},
setupFs: setupFs,
terminator: &stubCloudTerminator{terminateErr: someErr},
wantErr: true,
},
"missing state file": {
state: state.ConstellationState{CloudProvider: "gcp"},
setupFs: func(require *require.Assertions, state state.ConstellationState) afero.Fs {
fs := afero.NewMemMapFs()
fileHandler := file.NewHandler(fs)
require.NoError(fileHandler.Write(constants.AdminConfFilename, []byte{1, 2}, file.OptNone))
require.NoError(fileHandler.Write(constants.WGQuickConfigFilename, []byte{1, 2}, file.OptNone))
return fs
},
terminator: &stubCloudTerminator{},
wantErr: true,
},
"remove file fails": {
state: state.ConstellationState{CloudProvider: "gcp"},
setupFs: func(require *require.Assertions, state state.ConstellationState) afero.Fs {
fs := setupFs(require, state)
return afero.NewReadOnlyFs(fs)
},
terminator: &stubCloudTerminator{},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newTerminateCmd()
cmd.SetOut(&bytes.Buffer{})
cmd.SetErr(&bytes.Buffer{})
require.NotNil(tc.setupFs)
fileHandler := file.NewHandler(tc.setupFs(require, tc.state))
err := terminate(cmd, tc.terminator, fileHandler)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.True(tc.terminator.Called())
_, err := fileHandler.Stat(constants.StateFilename)
assert.Error(err)
_, err = fileHandler.Stat(constants.AdminConfFilename)
assert.Error(err)
_, err = fileHandler.Stat(constants.WGQuickConfigFilename)
assert.Error(err)
}
})
}
}
type spyCloudTerminator interface {
cloudTerminator
Called() bool
}

View file

@ -1,35 +0,0 @@
package cmd
import (
"bufio"
"errors"
"strings"
"github.com/spf13/cobra"
)
// ErrInvalidInput is an error where user entered invalid input.
var ErrInvalidInput = errors.New("user made invalid input")
// askToConfirm asks user to confirm an action.
// The user will be asked the handed question and can answer with
// yes or no.
func askToConfirm(cmd *cobra.Command, question string) (bool, error) {
reader := bufio.NewReader(cmd.InOrStdin())
cmd.Printf("%s [y/n]: ", question)
for i := 0; i < 3; i++ {
resp, err := reader.ReadString('\n')
if err != nil {
return false, err
}
resp = strings.ToLower(strings.TrimSpace(resp))
if resp == "n" || resp == "no" {
return false, nil
}
if resp == "y" || resp == "yes" {
return true, nil
}
cmd.Printf("Type 'y' or 'yes' to confirm, or abort action with 'n' or 'no': ")
}
return false, ErrInvalidInput
}

View file

@ -1,63 +0,0 @@
package cmd
import (
"bytes"
"errors"
"io"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
)
func TestAskToConfirm(t *testing.T) {
// errAborted is an error where the user aborted the action.
errAborted := errors.New("user aborted")
cmd := &cobra.Command{
Use: "test",
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, args []string) error {
ok, err := askToConfirm(cmd, "777")
if err != nil {
return err
}
if !ok {
return errAborted
}
return nil
},
}
testCases := map[string]struct {
input string
wantErr error
}{
"user confirms": {"y\n", nil},
"user confirms long": {"yes\n", nil},
"user disagrees": {"n\n", errAborted},
"user disagrees long": {"no\n", errAborted},
"user is first unsure, but agrees": {"what?\ny\n", nil},
"user is first unsure, but disagrees": {"wait.\nn\n", errAborted},
"repeated invalid input": {"h\nb\nq\n", ErrInvalidInput},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
out := &bytes.Buffer{}
cmd.SetOut(out)
cmd.SetErr(&bytes.Buffer{})
in := bytes.NewBufferString(tc.input)
cmd.SetIn(in)
err := cmd.Execute()
assert.ErrorIs(err, tc.wantErr)
output, err := io.ReadAll(out)
assert.NoError(err)
assert.Contains(string(output), "777")
})
}
}

View file

@ -1,71 +0,0 @@
package cmd
import (
"errors"
"fmt"
"net"
"strconv"
"strings"
"github.com/edgelesssys/constellation/cli/internal/azure"
"github.com/edgelesssys/constellation/cli/internal/gcp"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/spf13/cobra"
)
// warnAWS warns that AWS isn't supported.
func warnAWS(providerPos int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if cloudprovider.FromString(args[providerPos]) == cloudprovider.AWS {
return errors.New("AWS isn't supported by this version of Constellation")
}
return nil
}
}
func isCloudProvider(arg int) cobra.PositionalArgs {
return func(cmd *cobra.Command, args []string) error {
if provider := cloudprovider.FromString(args[arg]); provider == cloudprovider.Unknown {
return fmt.Errorf("argument %s isn't a valid cloud provider", args[arg])
}
return nil
}
}
func validInstanceTypeForProvider(cmd *cobra.Command, insType string, provider cloudprovider.Provider) error {
switch provider {
case cloudprovider.GCP:
for _, instanceType := range gcp.InstanceTypes {
if insType == instanceType {
return nil
}
}
cmd.SetUsageTemplate("GCP instance types:\n" + formatInstanceTypes(gcp.InstanceTypes))
cmd.SilenceUsage = false
return fmt.Errorf("%s isn't a valid GCP instance type", insType)
case cloudprovider.Azure:
for _, instanceType := range azure.InstanceTypes {
if insType == instanceType {
return nil
}
}
cmd.SetUsageTemplate("Azure instance types:\n" + formatInstanceTypes(azure.InstanceTypes))
cmd.SilenceUsage = false
return fmt.Errorf("%s isn't a valid Azure instance type", insType)
default:
return fmt.Errorf("%s isn't a valid cloud platform", provider)
}
}
func validateEndpoint(endpoint string, defaultPort int) (string, error) {
_, _, err := net.SplitHostPort(endpoint)
if err == nil {
return endpoint, nil
}
if strings.Contains(err.Error(), "missing port in address") {
return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil
}
return "", err
}

View file

@ -1,90 +0,0 @@
package cmd
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsCloudProvider(t *testing.T) {
testCases := map[string]struct {
pos int
args []string
wantErr bool
}{
"gcp": {0, []string{"gcp"}, false},
"azure": {1, []string{"foo", "azure"}, false},
"foo": {0, []string{"foo"}, true},
"empty": {0, []string{""}, true},
"unknown": {0, []string{"unknown"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
testCmd := &cobra.Command{Args: isCloudProvider(tc.pos)}
err := testCmd.ValidateArgs(tc.args)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestValidateEndpoint(t *testing.T) {
testCases := map[string]struct {
endpoint string
defaultPort int
wantResult string
wantErr bool
}{
"ip and port": {
endpoint: "192.0.2.1:2",
defaultPort: 3,
wantResult: "192.0.2.1:2",
},
"hostname and port": {
endpoint: "foo:2",
defaultPort: 3,
wantResult: "foo:2",
},
"ip": {
endpoint: "192.0.2.1",
defaultPort: 3,
wantResult: "192.0.2.1:3",
},
"hostname": {
endpoint: "foo",
defaultPort: 3,
wantResult: "foo:3",
},
"invalid endpoint": {
endpoint: "foo:2:2",
defaultPort: 3,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
res, err := validateEndpoint(tc.endpoint, tc.defaultPort)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResult, res)
})
}
}

View file

@ -1,133 +0,0 @@
package cmd
import (
"context"
"errors"
"fmt"
"github.com/edgelesssys/constellation/cli/cloud/cloudcmd"
"github.com/edgelesssys/constellation/cli/internal/proto"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
rpcStatus "google.golang.org/grpc/status"
)
func newVerifyCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "verify {aws|azure|gcp}",
Short: "Verify the confidential properties of a Constellation cluster",
Long: "Verify the confidential properties of a Constellation cluster.",
Args: cobra.MatchAll(
cobra.ExactArgs(1),
isCloudProvider(0),
warnAWS(0),
),
RunE: runVerify,
}
cmd.Flags().String("owner-id", "", "verify using the owner identity derived from the master secret")
cmd.Flags().String("unique-id", "", "verify using the unique cluster identity")
cmd.Flags().StringP("node-endpoint", "e", "", "endpoint of the node to verify, passed as HOST[:PORT] (required)")
must(cmd.MarkFlagRequired("node-endpoint"))
return cmd
}
func runVerify(cmd *cobra.Command, args []string) error {
provider := cloudprovider.FromString(args[0])
fileHandler := file.NewHandler(afero.NewOsFs())
protoClient := &proto.Client{}
defer protoClient.Close()
return verify(cmd.Context(), cmd, provider, fileHandler, protoClient)
}
func verify(ctx context.Context, cmd *cobra.Command, provider cloudprovider.Provider, fileHandler file.Handler, protoClient protoClient) error {
flags, err := parseVerifyFlags(cmd)
if err != nil {
return err
}
config, err := readConfig(cmd.OutOrStdout(), fileHandler, flags.configPath, provider)
if err != nil {
return err
}
validators, err := cloudcmd.NewValidators(provider, config)
if err != nil {
return err
}
if err := validators.UpdateInitPCRs(flags.ownerID, flags.clusterID); err != nil {
return err
}
if validators.Warnings() != "" {
cmd.Print(validators.Warnings())
}
if err := protoClient.Connect(flags.endpoint, validators.V()); err != nil {
return err
}
if _, err := protoClient.GetState(ctx); err != nil {
if err, ok := rpcStatus.FromError(err); ok {
return fmt.Errorf("unable to verify Constellation cluster: %s", err.Message())
}
return err
}
cmd.Println("OK")
return nil
}
func parseVerifyFlags(cmd *cobra.Command) (verifyFlags, error) {
ownerID, err := cmd.Flags().GetString("owner-id")
if err != nil {
return verifyFlags{}, err
}
clusterID, err := cmd.Flags().GetString("unique-id")
if err != nil {
return verifyFlags{}, err
}
if ownerID == "" && clusterID == "" {
return verifyFlags{}, errors.New("neither owner ID nor unique ID provided to verify the cluster")
}
endpoint, err := cmd.Flags().GetString("node-endpoint")
if err != nil {
return verifyFlags{}, err
}
endpoint, err = validateEndpoint(endpoint, constants.CoordinatorPort)
if err != nil {
return verifyFlags{}, err
}
configPath, err := cmd.Flags().GetString("config")
if err != nil {
return verifyFlags{}, err
}
return verifyFlags{
endpoint: endpoint,
configPath: configPath,
ownerID: ownerID,
clusterID: clusterID,
}, nil
}
type verifyFlags struct {
endpoint string
ownerID string
clusterID string
configPath string
}
// verifyCompletion handels the completion of CLI arguments. It is frequently called
// while the user types arguments of the command to suggest completion.
func verifyCompletion(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
switch len(args) {
case 0:
return []string{"gcp", "azure"}, cobra.ShellCompDirectiveNoFileComp
default:
return []string{}, cobra.ShellCompDirectiveError
}
}

View file

@ -1,197 +0,0 @@
package cmd
import (
"bytes"
"context"
"encoding/base64"
"errors"
"testing"
"github.com/edgelesssys/constellation/internal/cloud/cloudprovider"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
rpcStatus "google.golang.org/grpc/status"
)
func TestVerifyCmdArgumentValidation(t *testing.T) {
testCases := map[string]struct {
args []string
wantErr bool
}{
"no args": {[]string{}, true},
"valid azure": {[]string{"azure"}, false},
"valid gcp": {[]string{"gcp"}, false},
"invalid provider": {[]string{"invalid", "192.0.2.1", "1234"}, true},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := newVerifyCmd()
err := cmd.ValidateArgs(tc.args)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
})
}
}
func TestVerify(t *testing.T) {
zeroBase64 := base64.StdEncoding.EncodeToString([]byte("00000000000000000000000000000000"))
someErr := errors.New("failed")
testCases := map[string]struct {
setupFs func(*require.Assertions) afero.Fs
provider cloudprovider.Provider
protoClient protoClient
nodeEndpointFlag string
configFlag string
ownerIDFlag string
clusterIDFlag string
wantErr bool
}{
"gcp": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{},
},
"azure": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.Azure,
nodeEndpointFlag: "192.0.2.1:1234",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{},
},
"default port": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{},
},
"invalid endpoint": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.GCP,
nodeEndpointFlag: ":::::",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{},
wantErr: true,
},
"neither owner id nor cluster id set": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.GCP,
nodeEndpointFlag: "192.0.2.1:1234",
wantErr: true,
},
"config file not existing": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.GCP,
ownerIDFlag: zeroBase64,
nodeEndpointFlag: "192.0.2.1:1234",
configFlag: "./file",
wantErr: true,
},
"error protoClient Connect": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.Azure,
nodeEndpointFlag: "192.0.2.1:1234",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{connectErr: someErr},
wantErr: true,
},
"error protoClient GetState": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.Azure,
nodeEndpointFlag: "192.0.2.1:1234",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{getStateErr: rpcStatus.Error(codes.Internal, "failed")},
wantErr: true,
},
"error protoClient GetState not rpc": {
setupFs: func(require *require.Assertions) afero.Fs { return afero.NewMemMapFs() },
provider: cloudprovider.Azure,
nodeEndpointFlag: "192.0.2.1:1234",
ownerIDFlag: zeroBase64,
protoClient: &stubProtoClient{getStateErr: someErr},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
cmd := newVerifyCmd()
cmd.Flags().String("config", "", "") // register persisten flag manually
out := &bytes.Buffer{}
cmd.SetOut(out)
cmd.SetErr(&bytes.Buffer{})
if tc.configFlag != "" {
require.NoError(cmd.Flags().Set("config", tc.configFlag))
}
if tc.ownerIDFlag != "" {
require.NoError(cmd.Flags().Set("owner-id", tc.ownerIDFlag))
}
if tc.clusterIDFlag != "" {
require.NoError(cmd.Flags().Set("cluster-id", tc.clusterIDFlag))
}
if tc.nodeEndpointFlag != "" {
require.NoError(cmd.Flags().Set("node-endpoint", tc.nodeEndpointFlag))
}
fileHandler := file.NewHandler(tc.setupFs(require))
ctx := context.Background()
err := verify(ctx, cmd, tc.provider, fileHandler, tc.protoClient)
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Contains(out.String(), "OK")
}
})
}
}
func TestVerifyCompletion(t *testing.T) {
testCases := map[string]struct {
args []string
toComplete string
wantResult []string
wantShellCD cobra.ShellCompDirective
}{
"first arg": {
args: []string{},
toComplete: "az",
wantResult: []string{"gcp", "azure"},
wantShellCD: cobra.ShellCompDirectiveNoFileComp,
},
"additional arg": {
args: []string{"gcp", "foo"},
wantResult: []string{},
wantShellCD: cobra.ShellCompDirectiveError,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
cmd := &cobra.Command{}
result, shellCD := verifyCompletion(cmd, tc.args, tc.toComplete)
assert.Equal(tc.wantResult, result)
assert.Equal(tc.wantShellCD, shellCD)
})
}
}

View file

@ -1,19 +0,0 @@
package cmd
import (
"github.com/edgelesssys/constellation/internal/constants"
"github.com/spf13/cobra"
)
func newVersionCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "version",
Short: "Display version of this CLI",
Long: "Display version of this CLI.",
Args: cobra.NoArgs,
Run: func(cmd *cobra.Command, args []string) {
cmd.Printf("CLI Version: v%s \n", constants.CliVersion)
},
}
return cmd
}

View file

@ -1,25 +0,0 @@
package cmd
import (
"bytes"
"io"
"testing"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/stretchr/testify/assert"
)
func TestVersionCmd(t *testing.T) {
assert := assert.New(t)
cmd := newVersionCmd()
b := &bytes.Buffer{}
cmd.SetOut(b)
err := cmd.Execute()
assert.NoError(err)
s, err := io.ReadAll(b)
assert.NoError(err)
assert.Contains(string(s), constants.CliVersion)
}

View file

@ -1,9 +0,0 @@
package cmd
import wgquick "github.com/nmiculinic/wg-quick-go"
type vpnHandler interface {
Create(coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string, clientVPNIP string, mtu int) (*wgquick.Config, error)
Apply(*wgquick.Config) error
Marshal(*wgquick.Config) ([]byte, error)
}

View file

@ -1,25 +0,0 @@
package cmd
import wgquick "github.com/nmiculinic/wg-quick-go"
type stubVPNHandler struct {
configured bool
marshalRes string
createErr error
applyErr error
marshalErr error
}
func (c *stubVPNHandler) Create(coordinatorPubKey string, coordinatorPubIP string, clientPrivKey string, clientVPNIP string, mtu int) (*wgquick.Config, error) {
return &wgquick.Config{}, c.createErr
}
func (c *stubVPNHandler) Apply(*wgquick.Config) error {
c.configured = true
return c.applyErr
}
func (c *stubVPNHandler) Marshal(*wgquick.Config) ([]byte, error) {
return []byte(c.marshalRes), c.marshalErr
}