AB#2327 move debugd code into internal folder (#403)

* move debugd code into internal folder
* Fix paths in CMakeLists.txt
Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
Fabian Kammel 2022-08-26 11:58:18 +02:00 committed by GitHub
parent 708c6e057e
commit 5b40e0cc77
25 changed files with 31 additions and 31 deletions

View file

@ -0,0 +1,138 @@
package bootstrapper
import (
"errors"
"fmt"
"io"
"os"
"sync"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
)
// FileStreamer handles reading and writing of a file using a stream of chunks.
type FileStreamer struct {
fs afero.Fs
mux sync.RWMutex
}
// ReadChunkStream is abstraction over a gRPC stream that allows us to receive chunks via gRPC.
type ReadChunkStream interface {
Recv() (*pb.Chunk, error)
}
// WriteChunkStream is abstraction over a gRPC stream that allows us to send chunks via gRPC.
type WriteChunkStream interface {
Send(chunk *pb.Chunk) error
}
// NewFileStreamer creates a new FileStreamer.
func NewFileStreamer(fs afero.Fs) *FileStreamer {
return &FileStreamer{
fs: fs,
mux: sync.RWMutex{},
}
}
// WriteStream opens a file to write to and streams chunks from a gRPC stream into the file.
func (f *FileStreamer) WriteStream(filename string, stream ReadChunkStream, showProgress bool) error {
// try to read from stream once before acquiring write lock
chunk, err := stream.Recv()
if err != nil {
return fmt.Errorf("reading stream: %w", err)
}
f.mux.Lock()
defer f.mux.Unlock()
file, err := f.fs.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0o755)
if err != nil {
return fmt.Errorf("open %v for writing: %w", filename, err)
}
defer file.Close()
var bar *progressbar.ProgressBar
if showProgress {
bar = progressbar.NewOptions64(
-1,
progressbar.OptionSetDescription("receiving bootstrapper"),
progressbar.OptionShowBytes(true),
progressbar.OptionClearOnFinish(),
)
defer bar.Close()
}
for {
if err != nil {
if errors.Is(err, io.EOF) {
break
}
_ = file.Close()
_ = f.fs.Remove(filename)
return fmt.Errorf("reading stream: %w", err)
}
if _, err := file.Write(chunk.Content); err != nil {
_ = file.Close()
_ = f.fs.Remove(filename)
return fmt.Errorf("writing chunk to disk: %w", err)
}
if showProgress {
_ = bar.Add(len(chunk.Content))
}
chunk, err = stream.Recv()
}
return nil
}
// ReadStream opens a file to read from and streams its contents chunkwise over gRPC.
func (f *FileStreamer) ReadStream(filename string, stream WriteChunkStream, chunksize uint, showProgress bool) error {
if chunksize == 0 {
return errors.New("invalid chunksize")
}
// fail if file is currently RW locked
if f.mux.TryRLock() {
defer f.mux.RUnlock()
} else {
return errors.New("file is opened for writing cannot be read at this time")
}
file, err := f.fs.OpenFile(filename, os.O_RDONLY, 0o755)
if err != nil {
return fmt.Errorf("open %v for reading: %w", filename, err)
}
defer file.Close()
var bar *progressbar.ProgressBar
if showProgress {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("performing stat on %v to get the file size: %w", filename, err)
}
bar = progressbar.NewOptions64(
stat.Size(),
progressbar.OptionSetDescription("uploading bootstrapper"),
progressbar.OptionShowBytes(true),
progressbar.OptionClearOnFinish(),
)
defer bar.Close()
}
buf := make([]byte, chunksize)
for {
n, err := file.Read(buf)
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return fmt.Errorf("reading file chunk: %w", err)
}
if err = stream.Send(&pb.Chunk{Content: buf[:n]}); err != nil {
return fmt.Errorf("sending chunk: %w", err)
}
if showProgress {
_ = bar.Add(n)
}
}
}

View file

@ -0,0 +1,205 @@
package bootstrapper
import (
"errors"
"io"
"testing"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestWriteStream(t *testing.T) {
filename := "testfile"
testCases := map[string]struct {
readChunkStream fakeReadChunkStream
fs afero.Fs
showProgress bool
wantFile []byte
wantErr bool
}{
"stream works": {
readChunkStream: fakeReadChunkStream{
chunks: [][]byte{
[]byte("test"),
},
},
fs: afero.NewMemMapFs(),
wantFile: []byte("test"),
wantErr: false,
},
"chunking works": {
readChunkStream: fakeReadChunkStream{
chunks: [][]byte{
[]byte("te"),
[]byte("st"),
},
},
fs: afero.NewMemMapFs(),
wantFile: []byte("test"),
wantErr: false,
},
"showProgress works": {
readChunkStream: fakeReadChunkStream{
chunks: [][]byte{
[]byte("test"),
},
},
fs: afero.NewMemMapFs(),
showProgress: true,
wantFile: []byte("test"),
wantErr: false,
},
"Open fails": {
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
wantErr: true,
},
"recv fails": {
readChunkStream: fakeReadChunkStream{
recvErr: errors.New("someErr"),
},
fs: afero.NewMemMapFs(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
writer := NewFileStreamer(tc.fs)
err := writer.WriteStream(filename, &tc.readChunkStream, tc.showProgress)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
fileContents, err := afero.ReadFile(tc.fs, filename)
require.NoError(err)
assert.Equal(tc.wantFile, fileContents)
})
}
}
func TestReadStream(t *testing.T) {
correctFilename := "testfile"
testCases := map[string]struct {
writeChunkStream stubWriteChunkStream
filename string
chunksize uint
showProgress bool
wantChunks [][]byte
wantErr bool
}{
"stream works": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 4,
wantChunks: [][]byte{
[]byte("test"),
},
wantErr: false,
},
"chunking works": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 2,
wantChunks: [][]byte{
[]byte("te"),
[]byte("st"),
},
wantErr: false,
},
"chunksize of 0 detected": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 0,
wantErr: true,
},
"showProgress works": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 4,
showProgress: true,
wantChunks: [][]byte{
[]byte("test"),
},
wantErr: false,
},
"Open fails": {
filename: "incorrect-filename",
chunksize: 4,
wantErr: true,
},
"send fails": {
writeChunkStream: stubWriteChunkStream{
sendErr: errors.New("someErr"),
},
filename: correctFilename,
chunksize: 4,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
assert.NoError(afero.WriteFile(fs, correctFilename, []byte("test"), 0o755))
reader := NewFileStreamer(fs)
err := reader.ReadStream(tc.filename, &tc.writeChunkStream, tc.chunksize, tc.showProgress)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantChunks, tc.writeChunkStream.chunks)
})
}
}
type fakeReadChunkStream struct {
chunks [][]byte
pos int
recvErr error
}
func (s *fakeReadChunkStream) Recv() (*pb.Chunk, error) {
if s.recvErr != nil {
return nil, s.recvErr
}
if s.pos < len(s.chunks) {
result := &pb.Chunk{Content: s.chunks[s.pos]}
s.pos++
return result, nil
}
return nil, io.EOF
}
type stubWriteChunkStream struct {
chunks [][]byte
sendErr error
}
func (s *stubWriteChunkStream) Send(chunk *pb.Chunk) error {
cpy := make([]byte, len(chunk.Content))
copy(cpy, chunk.Content)
s.chunks = append(s.chunks, cpy)
return s.sendErr
}

View file

@ -0,0 +1,181 @@
package cmd
import (
"context"
"fmt"
"log"
"net"
"strconv"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/cdbg/config"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
depl "github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service"
configc "github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/file"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func newDeployCmd() *cobra.Command {
deployCmd := &cobra.Command{
Use: "deploy",
Short: "Deploys a self-compiled bootstrapper binary and SSH keys on the current constellation",
Long: `Deploys a self-compiled bootstrapper binary and SSH keys on the current constellation.
Uses config provided by --config and reads constellation config from its default location.
If required, you can override the IP addresses that are used for a deployment by specifying "--ips" and a list of IP addresses.
Specifying --bootstrapper will upload the bootstrapper from the specified path.`,
RunE: runDeploy,
Example: "cdbg deploy\ncdbg deploy --config /path/to/config\ncdbg deploy --bootstrapper /path/to/bootstrapper --ips 192.0.2.1,192.0.2.2,192.0.2.3 --config /path/to/config",
}
deployCmd.Flags().StringSlice("ips", nil, "override the ips that the bootstrapper will be uploaded to (defaults to ips from constellation config)")
deployCmd.Flags().String("bootstrapper", "", "override the path to the bootstrapper binary uploaded to instances (defaults to path set in config)")
return deployCmd
}
func runDeploy(cmd *cobra.Command, args []string) error {
debugConfigName, err := cmd.Flags().GetString("cdbg-config")
if err != nil {
return err
}
configName, err := cmd.Flags().GetString("config")
if err != nil {
return fmt.Errorf("parsing config path argument: %w", err)
}
fileHandler := file.NewHandler(afero.NewOsFs())
debugConfig, err := config.FromFile(fileHandler, debugConfigName)
if err != nil {
return err
}
constellationConfig, err := configc.FromFile(fileHandler, configName)
if err != nil {
return err
}
return deploy(cmd, fileHandler, constellationConfig, debugConfig, bootstrapper.NewFileStreamer(afero.NewOsFs()))
}
func deploy(cmd *cobra.Command, fileHandler file.Handler, constellationConfig *configc.Config, debugConfig *config.CDBGConfig, reader fileToStreamReader) error {
overrideBootstrapperPath, err := cmd.Flags().GetString("bootstrapper")
if err != nil {
return err
}
if len(overrideBootstrapperPath) > 0 {
debugConfig.ConstellationDebugConfig.BootstrapperPath = overrideBootstrapperPath
}
if !constellationConfig.IsImageDebug() {
log.Println("WARN: constellation image does not look like a debug image. Are you using a debug image?")
}
ips, err := cmd.Flags().GetStringSlice("ips")
if err != nil {
return err
}
if len(ips) == 0 {
var idFile clusterIDsFile
if err := fileHandler.ReadJSON(constants.ClusterIDsFileName, &idFile); err != nil {
return fmt.Errorf("reading cluster IDs file: %w", err)
}
ips = []string{idFile.IP}
}
for _, ip := range ips {
input := deployOnEndpointInput{
debugdEndpoint: net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort)),
bootstrapperPath: debugConfig.ConstellationDebugConfig.BootstrapperPath,
reader: reader,
authorizedKeys: debugConfig.ConstellationDebugConfig.AuthorizedKeys,
systemdUnits: debugConfig.ConstellationDebugConfig.SystemdUnits,
}
if err := deployOnEndpoint(cmd.Context(), input); err != nil {
return err
}
}
return nil
}
type deployOnEndpointInput struct {
debugdEndpoint string
bootstrapperPath string
reader fileToStreamReader
authorizedKeys []configc.UserKey
systemdUnits []depl.SystemdUnit
}
// deployOnEndpoint deploys SSH public keys, systemd units and a locally built bootstrapper binary to a debugd endpoint.
func deployOnEndpoint(ctx context.Context, in deployOnEndpointInput) error {
log.Printf("Deploying on %v\n", in.debugdEndpoint)
dialCTX, cancel := context.WithTimeout(ctx, debugd.GRPCTimeout)
defer cancel()
conn, err := grpc.DialContext(dialCTX, in.debugdEndpoint, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return fmt.Errorf("connecting to other instance via gRPC: %w", err)
}
defer conn.Close()
client := pb.NewDebugdClient(conn)
log.Println("Uploading authorized keys")
pbKeys := []*pb.AuthorizedKey{}
for _, key := range in.authorizedKeys {
pbKeys = append(pbKeys, &pb.AuthorizedKey{
Username: key.Username,
KeyValue: key.PublicKey,
})
}
authorizedKeysResponse, err := client.UploadAuthorizedKeys(ctx, &pb.UploadAuthorizedKeysRequest{Keys: pbKeys}, grpc.WaitForReady(true))
if err != nil || authorizedKeysResponse.Status != pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS {
return fmt.Errorf("uploading bootstrapper to instance %v failed: %v / %w", in.debugdEndpoint, authorizedKeysResponse, err)
}
if len(in.systemdUnits) > 0 {
log.Println("Uploading systemd unit files")
pbUnits := []*pb.ServiceUnit{}
for _, unit := range in.systemdUnits {
pbUnits = append(pbUnits, &pb.ServiceUnit{
Name: unit.Name,
Contents: unit.Contents,
})
}
uploadSystemdServiceUnitsResponse, err := client.UploadSystemServiceUnits(ctx, &pb.UploadSystemdServiceUnitsRequest{Units: pbUnits})
if err != nil || uploadSystemdServiceUnitsResponse.Status != pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS {
return fmt.Errorf("uploading systemd service unit to instance %v failed: %v / %w", in.debugdEndpoint, uploadSystemdServiceUnitsResponse, err)
}
}
stream, err := client.UploadBootstrapper(ctx)
if err != nil {
return fmt.Errorf("starting bootstrapper upload to instance %v: %w", in.debugdEndpoint, err)
}
streamErr := in.reader.ReadStream(in.bootstrapperPath, stream, debugd.Chunksize, true)
uploadResponse, closeErr := stream.CloseAndRecv()
if closeErr != nil {
return fmt.Errorf("closing upload stream after uploading bootstrapper to %v: %w", in.debugdEndpoint, closeErr)
}
if uploadResponse.Status == pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_FILE_EXISTS {
log.Println("Bootstrapper was already uploaded")
return nil
}
if uploadResponse.Status != pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS || streamErr != nil {
return fmt.Errorf("uploading bootstrapper to instance %v failed: %v / %w", in.debugdEndpoint, uploadResponse, streamErr)
}
log.Println("Uploaded bootstrapper")
return nil
}
type fileToStreamReader interface {
ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error
}
type clusterIDsFile struct {
ClusterID string
OwnerID string
IP string
}

View file

@ -0,0 +1,29 @@
package cmd
import (
"os"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/spf13/cobra"
)
func newRootCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "cdbg",
Short: "Constellation debugging client",
Long: `cdbg is the constellation debugging client.
It connects to CoreOS instances running debugd and deploys a self-compiled version of the bootstrapper.`,
}
cmd.PersistentFlags().String("config", constants.ConfigFilename, "Constellation config file")
cmd.PersistentFlags().String("cdbg-config", constants.DebugdConfigFilename, "debugd config file")
cmd.AddCommand(newDeployCmd())
return cmd
}
// Execute starts the CLI.
func Execute() {
cmd := newRootCmd()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}
}

View file

@ -0,0 +1,35 @@
package config
import (
"errors"
"fmt"
"io/fs"
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
configc "github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/file"
)
// CDBGConfig describes the constellation-cli config file.
type CDBGConfig struct {
ConstellationDebugConfig ConstellationDebugdConfig `yaml:"cdbg"`
}
// ConstellationDebugdConfig is the cdbg specific configuration.
type ConstellationDebugdConfig struct {
AuthorizedKeys []configc.UserKey `yaml:"authorizedKeys"`
BootstrapperPath string `yaml:"bootstrapperPath"`
SystemdUnits []deploy.SystemdUnit `yaml:"systemdUnits,omitempty"`
}
// FromFile reads a debug configuration.
func FromFile(fileHandler file.Handler, name string) (*CDBGConfig, error) {
conf := &CDBGConfig{}
if err := fileHandler.ReadYAML(name, conf); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("%s not found - consult the README on how to setup cdbg", name)
}
return nil, fmt.Errorf("loading config from file %s: %w", name, err)
}
return conf, nil
}

View file

@ -0,0 +1,78 @@
package state
import (
"errors"
"github.com/edgelesssys/constellation/internal/cloud/cloudtypes"
"github.com/edgelesssys/constellation/internal/config"
"github.com/edgelesssys/constellation/internal/state"
)
// Code in this file is mostly copied from constellation-controlPlane
// TODO: import as package from controlPlane once it is properly refactored
func GetScalingGroupsFromConfig(stat state.ConstellationState, config *config.Config) (controlPlanes, workers cloudtypes.ScalingGroup, err error) {
switch {
case len(stat.GCPControlPlaneInstances) != 0:
return getGCPInstances(stat, config)
case len(stat.AzureControlPlaneInstances) != 0:
return getAzureInstances(stat, config)
case len(stat.QEMUControlPlaneInstances) != 0:
return getQEMUInstances(stat, config)
default:
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no instances to init")
}
}
func getGCPInstances(stat state.ConstellationState, _ *config.Config) (controlPlanes, workers cloudtypes.ScalingGroup, err error) {
if len(stat.GCPControlPlaneInstances) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no control-plane workers available, can't create Constellation without any instance")
}
// GroupID of controlPlanes is empty, since they currently do not scale.
controlPlanes = cloudtypes.ScalingGroup{Instances: stat.GCPControlPlaneInstances}
if len(stat.GCPWorkerInstances) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no worker workers available, can't create Constellation with one instance")
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
workers = cloudtypes.ScalingGroup{Instances: stat.GCPWorkerInstances}
return
}
func getAzureInstances(stat state.ConstellationState, _ *config.Config) (controlPlanes, workers cloudtypes.ScalingGroup, err error) {
if len(stat.AzureControlPlaneInstances) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no control-plane workers available, can't create Constellation cluster without any instance")
}
// GroupID of controlPlanes is empty, since they currently do not scale.
controlPlanes = cloudtypes.ScalingGroup{Instances: stat.AzureControlPlaneInstances}
if len(stat.AzureWorkerInstances) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no worker workers available, can't create Constellation cluster with one instance")
}
// TODO: make min / max configurable and abstract autoscaling for different cloud providers
workers = cloudtypes.ScalingGroup{Instances: stat.AzureWorkerInstances}
return
}
func getQEMUInstances(stat state.ConstellationState, _ *config.Config) (controlPlanes, workers cloudtypes.ScalingGroup, err error) {
controlPlaneMap := stat.QEMUControlPlaneInstances
if len(controlPlaneMap) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no controlPlanes available, can't create Constellation without any instance")
}
// QEMU does not support autoscaling
controlPlanes = cloudtypes.ScalingGroup{Instances: stat.QEMUControlPlaneInstances}
if len(stat.QEMUWorkerInstances) == 0 {
return cloudtypes.ScalingGroup{}, cloudtypes.ScalingGroup{}, errors.New("no workers available, can't create Constellation with one instance")
}
// QEMU does not support autoscaling
workers = cloudtypes.ScalingGroup{Instances: stat.QEMUWorkerInstances}
return
}

View file

@ -0,0 +1,31 @@
package debugd
import "time"
const (
DebugdMetadataFlag = "constellation-debugd"
GRPCTimeout = 5 * time.Minute
SSHCheckInterval = 30 * time.Second
DiscoverDebugdInterval = 30 * time.Second
BootstrapperDownloadRetryBackoff = 1 * time.Minute
BootstrapperDeployFilename = "/opt/bootstrapper"
Chunksize = 1024
BootstrapperSystemdUnitName = "bootstrapper.service"
BootstrapperSystemdUnitContents = `[Unit]
Description=Constellation Bootstrapper
Wants=network-online.target
After=network-online.target
[Service]
Type=simple
RemainAfterExit=yes
Restart=on-failure
EnvironmentFile=/etc/constellation.env
ExecStartPre=-setenforce Permissive
ExecStartPre=/usr/bin/mkdir -p /opt/cni/bin/
# merge all CNI binaries in writable folder until containerd can use multiple CNI bins: https://github.com/containerd/containerd/issues/6600
ExecStartPre=/bin/sh -c "/usr/bin/cp /usr/libexec/cni/* /opt/cni/bin/"
ExecStart=/opt/bootstrapper
[Install]
WantedBy=multi-user.target
`
)

View file

@ -0,0 +1,117 @@
package deploy
import (
"context"
"fmt"
"net"
"strconv"
"time"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Download downloads a bootstrapper from a given debugd instance.
type Download struct {
log *logger.Logger
dialer NetDialer
writer streamToFileWriter
serviceManager serviceManager
attemptedDownloads map[string]time.Time
}
// New creates a new Download.
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
return &Download{
log: log,
dialer: dialer,
writer: writer,
serviceManager: serviceManager,
attemptedDownloads: map[string]time.Time{},
}
}
// DownloadDeployment will open a new grpc connection to another instance, attempting to download a bootstrapper from that instance.
func (d *Download) DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error) {
log := d.log.With(zap.String("ip", ip))
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
// only retry download from same endpoint after backoff
if lastAttempt, ok := d.attemptedDownloads[serverAddr]; ok && time.Since(lastAttempt) < debugd.BootstrapperDownloadRetryBackoff {
return nil, fmt.Errorf("download failed too recently: %v / %v", time.Since(lastAttempt), debugd.BootstrapperDownloadRetryBackoff)
}
log.Infof("Connecting to server")
d.attemptedDownloads[serverAddr] = time.Now()
conn, err := d.dial(ctx, serverAddr)
if err != nil {
return nil, fmt.Errorf("connecting to other instance via gRPC: %w", err)
}
defer conn.Close()
client := pb.NewDebugdClient(conn)
log.Infof("Trying to download bootstrapper")
stream, err := client.DownloadBootstrapper(ctx, &pb.DownloadBootstrapperRequest{})
if err != nil {
return nil, fmt.Errorf("starting bootstrapper download from other instance: %w", err)
}
if err := d.writer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
return nil, fmt.Errorf("streaming bootstrapper from other instance: %w", err)
}
log.Infof("Successfully downloaded bootstrapper")
log.Infof("Trying to download ssh keys")
resp, err := client.DownloadAuthorizedKeys(ctx, &pb.DownloadAuthorizedKeysRequest{})
if err != nil {
return nil, fmt.Errorf("downloading authorized keys: %w", err)
}
var keys []ssh.UserKey
for _, key := range resp.Keys {
keys = append(keys, ssh.UserKey{Username: key.Username, PublicKey: key.KeyValue})
}
// after the upload succeeds, try to restart the bootstrapper
restartAction := ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: Restart,
}
if err := d.serviceManager.SystemdAction(ctx, restartAction); err != nil {
return nil, fmt.Errorf("restarting bootstrapper: %w", err)
}
return keys, nil
}
func (d *Download) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
func (d *Download) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return d.dialer.DialContext(ctx, "tcp", addr)
})
}
type serviceManager interface {
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
}
type streamToFileWriter interface {
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
}
// NetDialer can open a net.Conn.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View file

@ -0,0 +1,187 @@
package deploy
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"testing"
"time"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestDownloadBootstrapper(t *testing.T) {
filename := "/opt/bootstrapper"
someErr := errors.New("failed")
testCases := map[string]struct {
server fakeDownloadServer
serviceManager stubServiceManager
attemptedDownloads map[string]time.Time
wantChunks [][]byte
wantDownloadErr bool
wantFile bool
wantSystemdAction bool
wantDeployed bool
wantKeys []ssh.UserKey
}{
"download works": {
server: fakeDownloadServer{
chunks: [][]byte{[]byte("test")},
keys: []*pb.AuthorizedKey{{Username: "name", KeyValue: "key"}},
},
attemptedDownloads: map[string]time.Time{},
wantChunks: [][]byte{[]byte("test")},
wantDownloadErr: false,
wantFile: true,
wantSystemdAction: true,
wantDeployed: true,
wantKeys: []ssh.UserKey{{Username: "name", PublicKey: "key"}},
},
"second download is not attempted twice": {
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
attemptedDownloads: map[string]time.Time{"192.0.2.0:" + strconv.Itoa(constants.DebugdPort): time.Now()},
wantDownloadErr: true,
},
"download rpc call error is detected": {
server: fakeDownloadServer{downladErr: someErr},
attemptedDownloads: map[string]time.Time{},
wantDownloadErr: true,
},
"download key error": {
server: fakeDownloadServer{
chunks: [][]byte{[]byte("test")},
downloadAuthorizedKeysErr: someErr,
},
attemptedDownloads: map[string]time.Time{},
wantDownloadErr: true,
},
"service restart error is detected": {
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
serviceManager: stubServiceManager{systemdActionErr: someErr},
attemptedDownloads: map[string]time.Time{},
wantChunks: [][]byte{[]byte("test")},
wantDownloadErr: true,
wantFile: true,
wantDeployed: true,
wantSystemdAction: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ip := "192.0.2.0"
writer := &fakeStreamToFileWriter{}
dialer := testdialer.NewBufconnDialer()
grpcServ := grpc.NewServer()
pb.RegisterDebugdServer(grpcServ, &tc.server)
lis := dialer.GetListener(net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort)))
go grpcServ.Serve(lis)
defer grpcServ.GracefulStop()
download := &Download{
log: logger.NewTest(t),
dialer: dialer,
writer: writer,
serviceManager: &tc.serviceManager,
attemptedDownloads: tc.attemptedDownloads,
}
keys, err := download.DownloadDeployment(context.Background(), ip)
if tc.wantDownloadErr {
assert.Error(err)
} else {
assert.NoError(err)
}
if tc.wantFile {
assert.Equal(tc.wantChunks, writer.chunks)
assert.Equal(filename, writer.filename)
}
if tc.wantSystemdAction {
assert.ElementsMatch(
[]ServiceManagerRequest{
{Unit: debugd.BootstrapperSystemdUnitName, Action: Restart},
},
tc.serviceManager.requests,
)
}
assert.Equal(tc.wantKeys, keys)
})
}
}
type stubServiceManager struct {
requests []ServiceManagerRequest
systemdActionErr error
}
func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
s.requests = append(s.requests, request)
return s.systemdActionErr
}
type fakeStreamToFileWriter struct {
chunks [][]byte
filename string
}
func (f *fakeStreamToFileWriter) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
f.filename = filename
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return fmt.Errorf("reading stream: %w", err)
}
f.chunks = append(f.chunks, chunk.Content)
}
}
// fakeDownloadServer implements DebugdServer; only fakes DownloadBootstrapper, panics on every other rpc.
type fakeDownloadServer struct {
chunks [][]byte
downladErr error
keys []*pb.AuthorizedKey
downloadAuthorizedKeysErr error
pb.UnimplementedDebugdServer
}
func (f *fakeDownloadServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
for _, chunk := range f.chunks {
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
return fmt.Errorf("sending chunk: %w", err)
}
}
return f.downladErr
}
func (s *fakeDownloadServer) DownloadAuthorizedKeys(context.Context, *pb.DownloadAuthorizedKeysRequest) (*pb.DownloadAuthorizedKeysResponse, error) {
return &pb.DownloadAuthorizedKeysResponse{Keys: s.keys}, s.downloadAuthorizedKeysErr
}

View file

@ -0,0 +1,163 @@
package deploy
import (
"context"
"fmt"
"sync"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/spf13/afero"
"go.uber.org/zap"
)
const (
systemdUnitFolder = "/etc/systemd/system"
)
//go:generate stringer -type=SystemdAction
type SystemdAction uint32
const (
Unknown SystemdAction = iota
Start
Stop
Restart
Reload
)
// ServiceManagerRequest describes a requested ServiceManagerAction to be performed on a specified service unit.
type ServiceManagerRequest struct {
Unit string
Action SystemdAction
}
// SystemdUnit describes a systemd service file including the unit name and contents.
type SystemdUnit struct {
Name string `yaml:"name"`
Contents string `yaml:"contents"`
}
// ServiceManager receives ServiceManagerRequests and units via channels and performs the requests / creates the unit files.
type ServiceManager struct {
log *logger.Logger
dbus dbusClient
fs afero.Fs
systemdUnitFilewriteLock sync.Mutex
}
// NewServiceManager creates a new ServiceManager.
func NewServiceManager(log *logger.Logger) *ServiceManager {
fs := afero.NewOsFs()
return &ServiceManager{
log: log,
dbus: &dbusWrapper{},
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
}
type dbusClient interface {
// NewSystemConnectionContext establishes a connection to the system bus and authenticates.
// Callers should call Close() when done with the connection.
NewSystemdConnectionContext(ctx context.Context) (dbusConn, error)
}
type dbusConn interface {
// StartUnitContext enqueues a start job and depending jobs, if any (unless otherwise
// specified by the mode string).
StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
// StopUnitContext is similar to StartUnitContext, but stops the specified unit
// rather than starting it.
StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
// RestartUnitContext restarts a service. If a service is restarted that isn't
// running it will be started.
RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
// ReloadContext instructs systemd to scan for and reload unit files. This is
// an equivalent to systemctl daemon-reload.
ReloadContext(ctx context.Context) error
}
// SystemdAction will perform a systemd action on a service unit (start, stop, restart, reload).
func (s *ServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
log := s.log.With(zap.String("unit", request.Unit), zap.String("action", request.Action.String()))
conn, err := s.dbus.NewSystemdConnectionContext(ctx)
if err != nil {
return fmt.Errorf("establishing systemd connection: %w", err)
}
resultChan := make(chan string, 1)
switch request.Action {
case Start:
_, err = conn.StartUnitContext(ctx, request.Unit, "replace", resultChan)
case Stop:
_, err = conn.StopUnitContext(ctx, request.Unit, "replace", resultChan)
case Restart:
_, err = conn.RestartUnitContext(ctx, request.Unit, "replace", resultChan)
case Reload:
err = conn.ReloadContext(ctx)
default:
return fmt.Errorf("unknown systemd action: %s", request.Action.String())
}
if err != nil {
return fmt.Errorf("performing systemd action %v on unit %v: %w", request.Action, request.Unit, err)
}
if request.Action == Reload {
log.Infof("daemon-reload succeeded")
return nil
}
// Wait for the action to finish and then check if it was
// successful or not.
result := <-resultChan
switch result {
case "done":
log.Infof("%s on systemd unit %s succeeded", request.Action, request.Unit)
return nil
default:
return fmt.Errorf("performing action %q on systemd unit %q failed: expected %q but received %q", request.Action.String(), request.Unit, "done", result)
}
}
// WriteSystemdUnitFile will write a systemd unit to disk.
func (s *ServiceManager) WriteSystemdUnitFile(ctx context.Context, unit SystemdUnit) error {
log := s.log.With(zap.String("unitFile", fmt.Sprintf("%s/%s", systemdUnitFolder, unit.Name)))
log.Infof("Writing systemd unit file")
s.systemdUnitFilewriteLock.Lock()
defer s.systemdUnitFilewriteLock.Unlock()
if err := afero.WriteFile(s.fs, fmt.Sprintf("%s/%s", systemdUnitFolder, unit.Name), []byte(unit.Contents), 0o644); err != nil {
return fmt.Errorf("writing systemd unit file \"%v\": %w", unit.Name, err)
}
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unit.Name, Action: Reload}); err != nil {
return fmt.Errorf("performing systemd daemon-reload: %w", err)
}
log.Infof("Wrote systemd unit file and performed daemon-reload")
return nil
}
// DeployDefaultServiceUnit will write the default "bootstrapper.service" unit file.
func DeployDefaultServiceUnit(ctx context.Context, serviceManager *ServiceManager) error {
if err := serviceManager.WriteSystemdUnitFile(ctx, SystemdUnit{
Name: debugd.BootstrapperSystemdUnitName,
Contents: debugd.BootstrapperSystemdUnitContents,
}); err != nil {
return fmt.Errorf("writing systemd unit file %q: %w", debugd.BootstrapperSystemdUnitName, err)
}
// try to start the default service if the binary exists but ignore failure.
// this is meant to start the bootstrapper after a reboot
// if a bootstrapper binary was uploaded before.
if ok, err := afero.Exists(serviceManager.fs, debugd.BootstrapperDeployFilename); ok && err == nil {
_ = serviceManager.SystemdAction(ctx, ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: Start,
})
}
return nil
}

View file

@ -0,0 +1,242 @@
package deploy
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSystemdAction(t *testing.T) {
unitName := "example.service"
testCases := map[string]struct {
dbus stubDbus
action SystemdAction
wantErr bool
}{
"start works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
action: Start,
wantErr: false,
},
"stop works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
action: Stop,
wantErr: false,
},
"restart works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
action: Restart,
wantErr: false,
},
"reload works": {
dbus: stubDbus{
conn: &fakeDbusConn{},
},
action: Reload,
wantErr: false,
},
"unknown action": {
dbus: stubDbus{
conn: &fakeDbusConn{},
},
action: Unknown,
wantErr: true,
},
"action fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
actionErr: errors.New("action fails"),
},
},
action: Start,
wantErr: true,
},
"action result is failure": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "failure",
},
},
action: Start,
wantErr: true,
},
"newConn fails": {
dbus: stubDbus{
connErr: errors.New("newConn fails"),
},
action: Start,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
manager := ServiceManager{
log: logger.NewTest(t),
dbus: &tc.dbus,
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
err := manager.SystemdAction(context.Background(), ServiceManagerRequest{
Unit: unitName,
Action: tc.action,
})
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
func TestWriteSystemdUnitFile(t *testing.T) {
testCases := map[string]struct {
dbus stubDbus
unit SystemdUnit
readonly bool
wantErr bool
wantFileContents string
}{
"start works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unit: SystemdUnit{
Name: "test.service",
Contents: "testservicefilecontents",
},
wantErr: false,
wantFileContents: "testservicefilecontents",
},
"write fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unit: SystemdUnit{
Name: "test.service",
Contents: "testservicefilecontents",
},
readonly: true,
wantErr: true,
},
"systemd reload fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
actionErr: errors.New("reload error"),
},
},
unit: SystemdUnit{
Name: "test.service",
Contents: "testservicefilecontents",
},
readonly: false,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
assert.NoError(fs.MkdirAll(systemdUnitFolder, 0o755))
if tc.readonly {
fs = afero.NewReadOnlyFs(fs)
}
manager := ServiceManager{
log: logger.NewTest(t),
dbus: &tc.dbus,
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
err := manager.WriteSystemdUnitFile(context.Background(), tc.unit)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
fileContents, err := afero.ReadFile(fs, fmt.Sprintf("%s/%s", systemdUnitFolder, tc.unit.Name))
assert.NoError(err)
assert.Equal(tc.wantFileContents, string(fileContents))
})
}
}
type stubDbus struct {
conn dbusConn
connErr error
}
func (s *stubDbus) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
return s.conn, s.connErr
}
type dbusConnActionInput struct {
name string
mode string
}
type fakeDbusConn struct {
inputs []dbusConnActionInput
result string
jobID int
actionErr error
}
func (c *fakeDbusConn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
ch <- c.result
return c.jobID, c.actionErr
}
func (c *fakeDbusConn) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
ch <- c.result
return c.jobID, c.actionErr
}
func (c *fakeDbusConn) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
ch <- c.result
return c.jobID, c.actionErr
}
func (c *fakeDbusConn) ReloadContext(ctx context.Context) error {
return c.actionErr
}

View file

@ -0,0 +1,27 @@
// Code generated by "stringer -type=SystemdAction"; DO NOT EDIT.
package deploy
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[Start-1]
_ = x[Stop-2]
_ = x[Restart-3]
_ = x[Reload-4]
}
const _SystemdAction_name = "UnknownStartStopRestartReload"
var _SystemdAction_index = [...]uint8{0, 7, 12, 16, 23, 29}
func (i SystemdAction) String() string {
if i >= SystemdAction(len(_SystemdAction_index)-1) {
return "SystemdAction(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _SystemdAction_name[_SystemdAction_index[i]:_SystemdAction_index[i+1]]
}

View file

@ -0,0 +1,40 @@
package deploy
import (
"context"
"github.com/coreos/go-systemd/v22/dbus"
)
// wraps go-systemd dbus.
type dbusWrapper struct{}
func (d *dbusWrapper) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
conn, err := dbus.NewSystemdConnectionContext(ctx)
if err != nil {
return nil, err
}
return &dbusConnWrapper{
conn: conn,
}, nil
}
type dbusConnWrapper struct {
conn *dbus.Conn
}
func (c *dbusConnWrapper) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.conn.StartUnitContext(ctx, name, mode, ch)
}
func (c *dbusConnWrapper) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.conn.StopUnitContext(ctx, name, mode, ch)
}
func (c *dbusConnWrapper) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.conn.RestartUnitContext(ctx, name, mode, ch)
}
func (c *dbusConnWrapper) ReloadContext(ctx context.Context) error {
return c.conn.ReloadContext(ctx)
}

View file

@ -0,0 +1,130 @@
package cloudprovider
import (
"context"
"fmt"
"net"
azurecloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/azure"
gcpcloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/gcp"
qemucloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/qemu"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
)
type providerMetadata interface {
// List retrieves all instances belonging to the current constellation.
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
// Self retrieves the current instance.
Self(ctx context.Context) (metadata.InstanceMetadata, error)
// GetLoadBalancerEndpoint returns the endpoint of the load balancer.
GetLoadBalancerEndpoint(ctx context.Context) (string, error)
}
// Fetcher checks the metadata service to search for instances that were set up for debugging and cloud provider specific SSH keys.
type Fetcher struct {
metaAPI providerMetadata
}
// NewGCP creates a new GCP fetcher.
func NewGCP(ctx context.Context) (*Fetcher, error) {
gcpClient, err := gcpcloud.NewClient(ctx)
if err != nil {
return nil, err
}
metaAPI := gcpcloud.New(gcpClient)
return &Fetcher{
metaAPI: metaAPI,
}, nil
}
// NewAzure creates a new Azure fetcher.
func NewAzure(ctx context.Context) (*Fetcher, error) {
metaAPI, err := azurecloud.NewMetadata(ctx)
if err != nil {
return nil, err
}
return &Fetcher{
metaAPI: metaAPI,
}, nil
}
func NewQEMU() *Fetcher {
return &Fetcher{
metaAPI: &qemucloud.Metadata{},
}
}
func (f *Fetcher) Role(ctx context.Context) (role.Role, error) {
self, err := f.metaAPI.Self(ctx)
if err != nil {
return role.Unknown, fmt.Errorf("retrieving role from cloud provider metadata: %w", err)
}
return self.Role, nil
}
// DiscoverDebugdIPs will query the metadata of all instances and return any ips of instances already set up for debugging.
func (f *Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
self, err := f.metaAPI.Self(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving own instance: %w", err)
}
instances, err := f.metaAPI.List(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving instances: %w", err)
}
// filter own instance from instance list
for i, instance := range instances {
if instance.ProviderID == self.ProviderID {
instances = append(instances[:i], instances[i+1:]...)
break
}
}
var ips []string
for _, instance := range instances {
if instance.VPCIP != "" {
ips = append(ips, instance.VPCIP)
}
}
return ips, nil
}
func (f *Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
lbEndpoint, err := f.metaAPI.GetLoadBalancerEndpoint(ctx)
if err != nil {
return "", fmt.Errorf("retrieving load balancer endpoint: %w", err)
}
// The port of the endpoint is not the port we need. We need to strip it off.
//
// TODO: Tag the specific load balancer we are looking for with a distinct tag.
// Change the GetLoadBalancerEndpoint method to return the endpoint of a load
// balancer with a given tag.
lbIP, _, err := net.SplitHostPort(lbEndpoint)
if err != nil {
return "", fmt.Errorf("parsing load balancer endpoint: %w", err)
}
return lbIP, nil
}
// FetchSSHKeys will query the metadata of the current instance and deploys any SSH keys found.
func (f *Fetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
self, err := f.metaAPI.Self(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving ssh keys from cloud provider metadata: %w", err)
}
keys := []ssh.UserKey{}
for username, userKeys := range self.SSHKeys {
for _, keyValue := range userKeys {
keys = append(keys, ssh.UserKey{Username: username, PublicKey: keyValue})
}
}
return keys, nil
}

View file

@ -0,0 +1,244 @@
package cloudprovider
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestRole(t *testing.T) {
instance1 := metadata.InstanceMetadata{Role: role.ControlPlane}
instance2 := metadata.InstanceMetadata{Role: role.Worker}
testCases := map[string]struct {
meta *stubMetadata
wantErr bool
wantRole role.Role
}{
"control plane": {
meta: &stubMetadata{selfRes: instance1},
wantRole: role.ControlPlane,
},
"worker": {
meta: &stubMetadata{selfRes: instance2},
wantRole: role.Worker,
},
"self fails": {
meta: &stubMetadata{selfErr: errors.New("some err")},
wantErr: true,
wantRole: role.Unknown,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fetcher := Fetcher{tc.meta}
role, err := fetcher.Role(context.Background())
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantRole, role)
}
})
}
}
func TestDiscoverDebugIPs(t *testing.T) {
err := errors.New("some err")
testCases := map[string]struct {
meta stubMetadata
wantIPs []string
wantErr bool
}{
"disovery works": {
meta: stubMetadata{
listRes: []metadata.InstanceMetadata{
{
VPCIP: "192.0.2.0",
},
{
VPCIP: "192.0.2.1",
},
{
VPCIP: "192.0.2.2",
},
},
},
wantIPs: []string{
"192.0.2.1", "192.0.2.2",
},
},
"retrieve fails": {
meta: stubMetadata{
listErr: err,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fetcher := Fetcher{
metaAPI: &tc.meta,
}
ips, err := fetcher.DiscoverDebugdIPs(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.ElementsMatch(tc.wantIPs, ips)
})
}
}
func TestDiscoverLoadbalancerIP(t *testing.T) {
ip := "192.0.2.1"
endpoint := ip + ":1234"
someErr := errors.New("failed")
testCases := map[string]struct {
metaAPI providerMetadata
wantIP string
wantErr bool
}{
"discovery works": {
metaAPI: &stubMetadata{getLBEndpointRes: endpoint},
wantIP: ip,
},
"get endpoint fails": {
metaAPI: &stubMetadata{getLBEndpointErr: someErr},
wantErr: true,
},
"invalid endpoint": {
metaAPI: &stubMetadata{getLBEndpointRes: "invalid"},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fetcher := &Fetcher{
metaAPI: tc.metaAPI,
}
ip, err := fetcher.DiscoverLoadbalancerIP(context.Background())
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantIP, ip)
}
})
}
}
func TestFetchSSHKeys(t *testing.T) {
err := errors.New("some err")
testCases := map[string]struct {
meta stubMetadata
wantKeys []ssh.UserKey
wantErr bool
}{
"fetch works": {
meta: stubMetadata{
selfRes: metadata.InstanceMetadata{
Name: "name",
ProviderID: "provider-id",
SSHKeys: map[string][]string{"bob": {"ssh-rsa bobskey"}},
},
},
wantKeys: []ssh.UserKey{
{
Username: "bob",
PublicKey: "ssh-rsa bobskey",
},
},
},
"retrieve fails": {
meta: stubMetadata{
selfErr: err,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fetcher := Fetcher{
metaAPI: &tc.meta,
}
keys, err := fetcher.FetchSSHKeys(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.ElementsMatch(tc.wantKeys, keys)
})
}
}
type stubMetadata struct {
listRes []metadata.InstanceMetadata
listErr error
selfRes metadata.InstanceMetadata
selfErr error
getInstanceRes metadata.InstanceMetadata
getInstanceErr error
getLBEndpointRes string
getLBEndpointErr error
supportedRes bool
}
func (m *stubMetadata) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
return m.listRes, m.listErr
}
func (m *stubMetadata) Self(ctx context.Context) (metadata.InstanceMetadata, error) {
return m.selfRes, m.selfErr
}
func (m *stubMetadata) GetInstance(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
return m.getInstanceRes, m.getInstanceErr
}
func (m *stubMetadata) GetLoadBalancerEndpoint(ctx context.Context) (string, error) {
return m.getLBEndpointRes, m.getLBEndpointErr
}
func (m *stubMetadata) Supported() bool {
return m.supportedRes
}

View file

@ -0,0 +1,31 @@
package fallback
import (
"context"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
)
// Fetcher implements metadata.Fetcher interface but does not actually fetch cloud provider metadata.
type Fetcher struct{}
func (f Fetcher) Role(_ context.Context) (role.Role, error) {
// Fallback fetcher does not try to fetch role
return role.Unknown, nil
}
func (f Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
// Fallback fetcher does not try to discover debugd IPs
return nil, nil
}
func (f Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
// Fallback fetcher does not try to discover loadbalancer IP
return "", nil
}
func (f Fetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
// Fallback fetcher does not try to fetch ssh keys
return nil, nil
}

View file

@ -0,0 +1,33 @@
package fallback
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDiscoverDebugdIPs(t *testing.T) {
assert := assert.New(t)
fetcher := Fetcher{}
ips, err := fetcher.DiscoverDebugdIPs(context.Background())
assert.NoError(err)
assert.Empty(ips)
}
func TestFetchSSHKeys(t *testing.T) {
assert := assert.New(t)
fetcher := Fetcher{}
keys, err := fetcher.FetchSSHKeys(context.Background())
assert.NoError(err)
assert.Empty(keys)
}

View file

@ -0,0 +1,143 @@
package metadata
import (
"context"
"errors"
"io/fs"
"sync"
"time"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
)
// Fetcher retrieves other debugd IPs and SSH keys from cloud provider metadata.
type Fetcher interface {
Role(ctx context.Context) (role.Role, error)
DiscoverDebugdIPs(ctx context.Context) ([]string, error)
FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error)
DiscoverLoadbalancerIP(ctx context.Context) (string, error)
}
// Scheduler schedules fetching of metadata using timers.
type Scheduler struct {
log *logger.Logger
fetcher Fetcher
ssh sshDeployer
downloader downloader
}
// NewScheduler returns a new scheduler.
func NewScheduler(log *logger.Logger, fetcher Fetcher, ssh sshDeployer, downloader downloader) *Scheduler {
return &Scheduler{
log: log,
fetcher: fetcher,
ssh: ssh,
downloader: downloader,
}
}
// Start will start the loops for discovering debugd endpoints and ssh keys.
func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
wg.Add(2)
go s.discoveryLoop(ctx, wg)
go s.sshLoop(ctx, wg)
}
// discoveryLoop discovers new debugd endpoints from cloud-provider metadata periodically.
func (s *Scheduler) discoveryLoop(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
// execute debugd discovery once at the start to skip wait for first tick
ips, err := s.fetcher.DiscoverDebugdIPs(ctx)
if err != nil {
s.log.With(zap.Error(err)).Errorf("Discovering debugd IPs failed")
} else {
if s.downloadDeployment(ctx, ips) {
return
}
}
ticker := time.NewTicker(debugd.DiscoverDebugdInterval)
defer ticker.Stop()
for {
var err error
select {
case <-ticker.C:
ips, err = s.fetcher.DiscoverDebugdIPs(ctx)
if err != nil {
s.log.With(zap.Error(err)).Errorf("Discovering debugd IPs failed")
continue
}
s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances")
if s.downloadDeployment(ctx, ips) {
return
}
case <-ctx.Done():
return
}
}
}
// sshLoop discovers new ssh keys from cloud provider metadata periodically.
func (s *Scheduler) sshLoop(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(debugd.SSHCheckInterval)
defer ticker.Stop()
for {
keys, err := s.fetcher.FetchSSHKeys(ctx)
if err != nil {
s.log.With(zap.Error(err)).Errorf("Fetching SSH keys failed")
} else {
s.deploySSHKeys(ctx, keys)
}
select {
case <-ticker.C:
case <-ctx.Done():
return
}
}
}
// downloadDeployment tries to download deployment from a list of ips and logs errors encountered.
func (s *Scheduler) downloadDeployment(ctx context.Context, ips []string) (success bool) {
for _, ip := range ips {
keys, err := s.downloader.DownloadDeployment(ctx, ip)
if err == nil {
s.deploySSHKeys(ctx, keys)
return true
}
if errors.Is(err, fs.ErrExist) {
// bootstrapper was already uploaded
s.log.Infof("Bootstrapper was already uploaded.")
return true
}
s.log.With(zap.Error(err), zap.String("peer", ip)).Errorf("Downloading deployment from peer failed")
}
return false
}
// deploySSHKeys tries to deploy a list of SSH keys and logs errors encountered.
func (s *Scheduler) deploySSHKeys(ctx context.Context, keys []ssh.UserKey) {
for _, key := range keys {
err := s.ssh.DeployAuthorizedKey(ctx, key)
if err != nil {
s.log.With(zap.Error(err), zap.Any("key", key)).Errorf("Deploying SSH key failed")
continue
}
}
}
type downloader interface {
DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error)
}
type sshDeployer interface {
DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error
}

View file

@ -0,0 +1,134 @@
package metadata
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestSchedulerStart(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
fetcher stubFetcher
ssh stubSSHDeployer
downloader stubDownloader
timeout time.Duration
wantSSHKeys []ssh.UserKey
wantDebugdDownloads []string
}{
"scheduler works and calls fetcher functions at least once": {},
"ssh keys are fetched": {
fetcher: stubFetcher{
keys: []ssh.UserKey{{Username: "test", PublicKey: "testkey"}},
},
wantSSHKeys: []ssh.UserKey{{Username: "test", PublicKey: "testkey"}},
},
"download for discovered debugd ips is started": {
fetcher: stubFetcher{
ips: []string{"192.0.2.1", "192.0.2.2"},
},
downloader: stubDownloader{downloadErr: someErr},
wantDebugdDownloads: []string{"192.0.2.1", "192.0.2.2"},
},
"if download is successful, second download is not attempted": {
fetcher: stubFetcher{
ips: []string{"192.0.2.1", "192.0.2.2"},
},
wantDebugdDownloads: []string{"192.0.2.1"},
},
"endpoint discovery can fail": {
fetcher: stubFetcher{discoverErr: someErr},
},
"ssh key fetch can fail": {
fetcher: stubFetcher{fetchSSHKeysErr: someErr},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
wg := &sync.WaitGroup{}
ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
defer cancel()
scheduler := Scheduler{
log: logger.NewTest(t),
fetcher: &tc.fetcher,
ssh: &tc.ssh,
downloader: &tc.downloader,
}
wg.Add(1)
go scheduler.Start(ctx, wg)
wg.Wait()
assert.Equal(tc.wantSSHKeys, tc.ssh.sshKeys)
assert.Equal(tc.wantDebugdDownloads, tc.downloader.ips)
assert.Greater(tc.fetcher.discoverCalls, 0)
assert.Greater(tc.fetcher.fetchSSHKeysCalls, 0)
})
}
}
type stubFetcher struct {
discoverCalls int
fetchSSHKeysCalls int
ips []string
keys []ssh.UserKey
discoverErr error
fetchSSHKeysErr error
}
func (s *stubFetcher) Role(_ context.Context) (role.Role, error) {
return role.Unknown, nil
}
func (s *stubFetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
s.discoverCalls++
return s.ips, s.discoverErr
}
func (s *stubFetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
s.fetchSSHKeysCalls++
return s.keys, s.fetchSSHKeysErr
}
func (s *stubFetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
return "", nil
}
type stubSSHDeployer struct {
sshKeys []ssh.UserKey
deployErr error
}
func (s *stubSSHDeployer) DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error {
s.sshKeys = append(s.sshKeys, sshKey)
return s.deployErr
}
type stubDownloader struct {
ips []string
downloadErr error
keys []ssh.UserKey
}
func (s *stubDownloader) DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error) {
s.ips = append(s.ips, ip)
return s.keys, s.downloadErr
}

View file

@ -0,0 +1,162 @@
package server
import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"strconv"
"sync"
"time"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
)
type debugdServer struct {
log *logger.Logger
ssh sshDeployer
serviceManager serviceManager
streamer streamer
pb.UnimplementedDebugdServer
}
// New creates a new debugdServer according to the gRPC spec.
func New(log *logger.Logger, ssh sshDeployer, serviceManager serviceManager, streamer streamer) pb.DebugdServer {
return &debugdServer{
log: log,
ssh: ssh,
serviceManager: serviceManager,
streamer: streamer,
}
}
// UploadAuthorizedKeys receives a list of authorized keys and forwards them to a channel.
func (s *debugdServer) UploadAuthorizedKeys(ctx context.Context, in *pb.UploadAuthorizedKeysRequest) (*pb.UploadAuthorizedKeysResponse, error) {
s.log.Infof("Uploading authorized keys")
for _, key := range in.Keys {
if err := s.ssh.DeployAuthorizedKey(ctx, ssh.UserKey{Username: key.Username, PublicKey: key.KeyValue}); err != nil {
s.log.With(zap.Error(err)).Errorf("Uploading authorized keys failed")
return &pb.UploadAuthorizedKeysResponse{
Status: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_FAILURE,
}, nil
}
}
return &pb.UploadAuthorizedKeysResponse{
Status: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS,
}, nil
}
// UploadBootstrapper receives a bootstrapper binary in a stream of chunks and writes to a file.
func (s *debugdServer) UploadBootstrapper(stream pb.Debugd_UploadBootstrapperServer) error {
startAction := deploy.ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: deploy.Start,
}
var responseStatus pb.UploadBootstrapperStatus
defer func() {
if err := s.serviceManager.SystemdAction(stream.Context(), startAction); err != nil {
s.log.With(zap.Error(err)).Errorf("Starting uploaded bootstrapper failed")
if responseStatus == pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS {
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED
}
}
stream.SendAndClose(&pb.UploadBootstrapperResponse{
Status: responseStatus,
})
}()
s.log.Infof("Starting bootstrapper upload")
if err := s.streamer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
if errors.Is(err, fs.ErrExist) {
// bootstrapper was already uploaded
s.log.Warnf("Bootstrapper already uploaded")
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_FILE_EXISTS
return nil
}
s.log.With(zap.Error(err)).Errorf("Uploading bootstrapper failed")
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED
return fmt.Errorf("uploading bootstrapper: %w", err)
}
s.log.Infof("Successfully uploaded bootstrapper")
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS
return nil
}
// DownloadBootstrapper streams the local bootstrapper binary to other instances.
func (s *debugdServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
s.log.Infof("Sending bootstrapper to other instance")
return s.streamer.ReadStream(debugd.BootstrapperDeployFilename, stream, debugd.Chunksize, true)
}
// DownloadAuthorizedKeys streams the local authorized keys to other instances.
func (s *debugdServer) DownloadAuthorizedKeys(_ context.Context, req *pb.DownloadAuthorizedKeysRequest) (*pb.DownloadAuthorizedKeysResponse, error) {
s.log.Infof("Sending authorized keys to other instance")
var authKeys []*pb.AuthorizedKey
for _, key := range s.ssh.GetAuthorizedKeys() {
authKeys = append(authKeys, &pb.AuthorizedKey{
Username: key.Username,
KeyValue: key.PublicKey,
})
}
return &pb.DownloadAuthorizedKeysResponse{Keys: authKeys}, nil
}
// UploadSystemServiceUnits receives systemd service units, writes them to a service file and schedules a daemon-reload.
func (s *debugdServer) UploadSystemServiceUnits(ctx context.Context, in *pb.UploadSystemdServiceUnitsRequest) (*pb.UploadSystemdServiceUnitsResponse, error) {
s.log.Infof("Uploading systemd service units")
for _, unit := range in.Units {
if err := s.serviceManager.WriteSystemdUnitFile(ctx, deploy.SystemdUnit{Name: unit.Name, Contents: unit.Contents}); err != nil {
return &pb.UploadSystemdServiceUnitsResponse{Status: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_FAILURE}, nil
}
}
return &pb.UploadSystemdServiceUnitsResponse{Status: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS}, nil
}
// Start will start the gRPC server and block.
func Start(log *logger.Logger, wg *sync.WaitGroup, serv pb.DebugdServer) {
defer wg.Done()
grpcLog := log.Named("gRPC")
grpcLog.WithIncreasedLevel(zap.WarnLevel).ReplaceGRPCLogger()
grpcServer := grpc.NewServer(
grpcLog.GetServerStreamInterceptor(),
grpcLog.GetServerUnaryInterceptor(),
grpc.KeepaliveParams(keepalive.ServerParameters{Time: 15 * time.Second}),
)
pb.RegisterDebugdServer(grpcServer, serv)
lis, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.DebugdPort)))
if err != nil {
log.With(zap.Error(err)).Fatalf("Listening failed")
}
log.Infof("gRPC server is waiting for connections")
grpcServer.Serve(lis)
}
type sshDeployer interface {
DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error
GetAuthorizedKeys() []ssh.UserKey
}
type serviceManager interface {
SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error
WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error
}
type streamer interface {
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error
}

View file

@ -0,0 +1,490 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"testing"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestUploadAuthorizedKeys(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
request *pb.UploadAuthorizedKeysRequest
wantErr bool
wantResponseStatus pb.UploadAuthorizedKeysStatus
wantKeys []ssh.UserKey
}{
"upload authorized keys works": {
request: &pb.UploadAuthorizedKeysRequest{
Keys: []*pb.AuthorizedKey{
{
Username: "testuser",
KeyValue: "teskey",
},
},
},
wantResponseStatus: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS,
wantKeys: []ssh.UserKey{
{
Username: "testuser",
PublicKey: "teskey",
},
},
},
"deploy fails": {
request: &pb.UploadAuthorizedKeysRequest{
Keys: []*pb.AuthorizedKey{
{
Username: "testuser",
KeyValue: "teskey",
},
},
},
ssh: stubSSHDeployer{deployErr: errors.New("ssh key deployment error")},
wantResponseStatus: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_FAILURE,
wantKeys: []ssh.UserKey{
{
Username: "testuser",
PublicKey: "teskey",
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &fakeStreamer{},
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
resp, err := client.UploadAuthorizedKeys(context.Background(), tc.request)
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResponseStatus, resp.Status)
assert.ElementsMatch(tc.ssh.sshKeys, tc.wantKeys)
})
}
}
func TestUploadBootstrapper(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
streamer fakeStreamer
uploadChunks [][]byte
wantErr bool
wantResponseStatus pb.UploadBootstrapperStatus
wantFile bool
wantChunks [][]byte
}{
"upload works": {
uploadChunks: [][]byte{
[]byte("test"),
},
wantFile: true,
wantChunks: [][]byte{
[]byte("test"),
},
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS,
},
"recv fails": {
streamer: fakeStreamer{
writeStreamErr: errors.New("recv error"),
},
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED,
wantErr: true,
},
"starting bootstrapper fails": {
uploadChunks: [][]byte{
[]byte("test"),
},
serviceManager: stubServiceManager{
systemdActionErr: errors.New("starting bootstrapper error"),
},
wantFile: true,
wantChunks: [][]byte{
[]byte("test"),
},
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &tc.streamer,
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
stream, err := client.UploadBootstrapper(context.Background())
require.NoError(err)
require.NoError(fakeWrite(stream, tc.uploadChunks))
resp, err := stream.CloseAndRecv()
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResponseStatus, resp.Status)
if tc.wantFile {
assert.Equal(tc.wantChunks, tc.streamer.writeStreamChunks)
assert.Equal("/opt/bootstrapper", tc.streamer.writeStreamFilename)
} else {
assert.Empty(tc.streamer.writeStreamChunks)
assert.Empty(tc.streamer.writeStreamFilename)
}
})
}
}
func TestDownloadBootstrapper(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
request *pb.DownloadBootstrapperRequest
streamer fakeStreamer
wantErr bool
wantChunks [][]byte
}{
"download works": {
request: &pb.DownloadBootstrapperRequest{},
streamer: fakeStreamer{
readStreamChunks: [][]byte{
[]byte("test"),
},
},
wantErr: false,
wantChunks: [][]byte{
[]byte("test"),
},
},
"download fails": {
request: &pb.DownloadBootstrapperRequest{},
streamer: fakeStreamer{
readStreamErr: errors.New("read bootstrapper fails"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &tc.streamer,
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
stream, err := client.DownloadBootstrapper(context.Background(), tc.request)
require.NoError(err)
chunks, err := fakeRead(stream)
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantChunks, chunks)
assert.Equal("/opt/bootstrapper", tc.streamer.readStreamFilename)
})
}
}
func TestDownloadAuthorizedKeys(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
deployer := &stubSSHDeployer{
sshKeys: []ssh.UserKey{
{Username: "test1", PublicKey: "foo"},
{Username: "test2", PublicKey: "bar"},
},
}
serv := debugdServer{
log: logger.NewTest(t),
ssh: deployer,
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
defer grpcServ.GracefulStop()
client := pb.NewDebugdClient(conn)
resp, err := client.DownloadAuthorizedKeys(context.Background(), &pb.DownloadAuthorizedKeysRequest{})
assert.NoError(err)
wantKeys := []*pb.AuthorizedKey{
{Username: "test1", KeyValue: "foo"},
{Username: "test2", KeyValue: "bar"},
}
assert.ElementsMatch(wantKeys, resp.Keys)
}
func TestUploadSystemServiceUnits(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
request *pb.UploadSystemdServiceUnitsRequest
wantErr bool
wantResponseStatus pb.UploadSystemdServiceUnitsStatus
wantUnitFiles []deploy.SystemdUnit
}{
"upload systemd service units": {
request: &pb.UploadSystemdServiceUnitsRequest{
Units: []*pb.ServiceUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
wantResponseStatus: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS,
wantUnitFiles: []deploy.SystemdUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
"writing fails": {
request: &pb.UploadSystemdServiceUnitsRequest{
Units: []*pb.ServiceUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
serviceManager: stubServiceManager{
writeSystemdUnitFileErr: errors.New("write error"),
},
wantResponseStatus: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_FAILURE,
wantUnitFiles: []deploy.SystemdUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &fakeStreamer{},
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
resp, err := client.UploadSystemServiceUnits(context.Background(), tc.request)
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
require.NotNil(resp.Status)
assert.Equal(tc.wantResponseStatus, resp.Status)
assert.ElementsMatch(tc.wantUnitFiles, tc.serviceManager.unitFiles)
})
}
}
type stubSSHDeployer struct {
sshKeys []ssh.UserKey
deployErr error
}
func (s *stubSSHDeployer) DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error {
s.sshKeys = append(s.sshKeys, sshKey)
return s.deployErr
}
func (s *stubSSHDeployer) GetAuthorizedKeys() []ssh.UserKey {
return s.sshKeys
}
type stubServiceManager struct {
requests []deploy.ServiceManagerRequest
unitFiles []deploy.SystemdUnit
systemdActionErr error
writeSystemdUnitFileErr error
}
func (s *stubServiceManager) SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error {
s.requests = append(s.requests, request)
return s.systemdActionErr
}
func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error {
s.unitFiles = append(s.unitFiles, unit)
return s.writeSystemdUnitFileErr
}
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
type fakeStreamer struct {
writeStreamChunks [][]byte
writeStreamFilename string
writeStreamErr error
readStreamChunks [][]byte
readStreamFilename string
readStreamErr error
}
func (f *fakeStreamer) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
f.writeStreamFilename = filename
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return f.writeStreamErr
}
return fmt.Errorf("reading stream: %w", err)
}
f.writeStreamChunks = append(f.writeStreamChunks, chunk.Content)
}
}
func (f *fakeStreamer) ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error {
f.readStreamFilename = filename
for _, chunk := range f.readStreamChunks {
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
panic(err)
}
}
return f.readStreamErr
}
func setupServerWithConn(endpoint string, serv *debugdServer) (*grpc.Server, *grpc.ClientConn, error) {
dialer := testdialer.NewBufconnDialer()
grpcServ := grpc.NewServer()
pb.RegisterDebugdServer(grpcServ, serv)
lis := dialer.GetListener(endpoint)
go grpcServ.Serve(lis)
conn, err := dial(context.Background(), dialer, endpoint)
if err != nil {
return nil, nil, err
}
return grpcServ, conn, nil
}
func fakeWrite(stream bootstrapper.WriteChunkStream, chunks [][]byte) error {
for _, chunk := range chunks {
err := stream.Send(&pb.Chunk{
Content: chunk,
})
if err != nil {
return err
}
}
return nil
}
func fakeRead(stream bootstrapper.ReadChunkStream) ([][]byte, error) {
var chunks [][]byte
for {
chunk, err := stream.Recv()
if err != nil {
if err == io.EOF {
return chunks, nil
}
return nil, err
}
chunks = append(chunks, chunk.Content)
}
}