AB#2327 move debugd code into internal folder (#403)

* move debugd code into internal folder
* Fix paths in CMakeLists.txt
Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
Fabian Kammel 2022-08-26 11:58:18 +02:00 committed by GitHub
parent 708c6e057e
commit 5b40e0cc77
25 changed files with 31 additions and 31 deletions

View file

@ -0,0 +1,31 @@
package debugd
import "time"
const (
DebugdMetadataFlag = "constellation-debugd"
GRPCTimeout = 5 * time.Minute
SSHCheckInterval = 30 * time.Second
DiscoverDebugdInterval = 30 * time.Second
BootstrapperDownloadRetryBackoff = 1 * time.Minute
BootstrapperDeployFilename = "/opt/bootstrapper"
Chunksize = 1024
BootstrapperSystemdUnitName = "bootstrapper.service"
BootstrapperSystemdUnitContents = `[Unit]
Description=Constellation Bootstrapper
Wants=network-online.target
After=network-online.target
[Service]
Type=simple
RemainAfterExit=yes
Restart=on-failure
EnvironmentFile=/etc/constellation.env
ExecStartPre=-setenforce Permissive
ExecStartPre=/usr/bin/mkdir -p /opt/cni/bin/
# merge all CNI binaries in writable folder until containerd can use multiple CNI bins: https://github.com/containerd/containerd/issues/6600
ExecStartPre=/bin/sh -c "/usr/bin/cp /usr/libexec/cni/* /opt/cni/bin/"
ExecStart=/opt/bootstrapper
[Install]
WantedBy=multi-user.target
`
)

View file

@ -0,0 +1,117 @@
package deploy
import (
"context"
"fmt"
"net"
"strconv"
"time"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Download downloads a bootstrapper from a given debugd instance.
type Download struct {
log *logger.Logger
dialer NetDialer
writer streamToFileWriter
serviceManager serviceManager
attemptedDownloads map[string]time.Time
}
// New creates a new Download.
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
return &Download{
log: log,
dialer: dialer,
writer: writer,
serviceManager: serviceManager,
attemptedDownloads: map[string]time.Time{},
}
}
// DownloadDeployment will open a new grpc connection to another instance, attempting to download a bootstrapper from that instance.
func (d *Download) DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error) {
log := d.log.With(zap.String("ip", ip))
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
// only retry download from same endpoint after backoff
if lastAttempt, ok := d.attemptedDownloads[serverAddr]; ok && time.Since(lastAttempt) < debugd.BootstrapperDownloadRetryBackoff {
return nil, fmt.Errorf("download failed too recently: %v / %v", time.Since(lastAttempt), debugd.BootstrapperDownloadRetryBackoff)
}
log.Infof("Connecting to server")
d.attemptedDownloads[serverAddr] = time.Now()
conn, err := d.dial(ctx, serverAddr)
if err != nil {
return nil, fmt.Errorf("connecting to other instance via gRPC: %w", err)
}
defer conn.Close()
client := pb.NewDebugdClient(conn)
log.Infof("Trying to download bootstrapper")
stream, err := client.DownloadBootstrapper(ctx, &pb.DownloadBootstrapperRequest{})
if err != nil {
return nil, fmt.Errorf("starting bootstrapper download from other instance: %w", err)
}
if err := d.writer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
return nil, fmt.Errorf("streaming bootstrapper from other instance: %w", err)
}
log.Infof("Successfully downloaded bootstrapper")
log.Infof("Trying to download ssh keys")
resp, err := client.DownloadAuthorizedKeys(ctx, &pb.DownloadAuthorizedKeysRequest{})
if err != nil {
return nil, fmt.Errorf("downloading authorized keys: %w", err)
}
var keys []ssh.UserKey
for _, key := range resp.Keys {
keys = append(keys, ssh.UserKey{Username: key.Username, PublicKey: key.KeyValue})
}
// after the upload succeeds, try to restart the bootstrapper
restartAction := ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: Restart,
}
if err := d.serviceManager.SystemdAction(ctx, restartAction); err != nil {
return nil, fmt.Errorf("restarting bootstrapper: %w", err)
}
return keys, nil
}
func (d *Download) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
func (d *Download) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return d.dialer.DialContext(ctx, "tcp", addr)
})
}
type serviceManager interface {
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
}
type streamToFileWriter interface {
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
}
// NetDialer can open a net.Conn.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}

View file

@ -0,0 +1,187 @@
package deploy
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"testing"
"time"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
"google.golang.org/grpc"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestDownloadBootstrapper(t *testing.T) {
filename := "/opt/bootstrapper"
someErr := errors.New("failed")
testCases := map[string]struct {
server fakeDownloadServer
serviceManager stubServiceManager
attemptedDownloads map[string]time.Time
wantChunks [][]byte
wantDownloadErr bool
wantFile bool
wantSystemdAction bool
wantDeployed bool
wantKeys []ssh.UserKey
}{
"download works": {
server: fakeDownloadServer{
chunks: [][]byte{[]byte("test")},
keys: []*pb.AuthorizedKey{{Username: "name", KeyValue: "key"}},
},
attemptedDownloads: map[string]time.Time{},
wantChunks: [][]byte{[]byte("test")},
wantDownloadErr: false,
wantFile: true,
wantSystemdAction: true,
wantDeployed: true,
wantKeys: []ssh.UserKey{{Username: "name", PublicKey: "key"}},
},
"second download is not attempted twice": {
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
attemptedDownloads: map[string]time.Time{"192.0.2.0:" + strconv.Itoa(constants.DebugdPort): time.Now()},
wantDownloadErr: true,
},
"download rpc call error is detected": {
server: fakeDownloadServer{downladErr: someErr},
attemptedDownloads: map[string]time.Time{},
wantDownloadErr: true,
},
"download key error": {
server: fakeDownloadServer{
chunks: [][]byte{[]byte("test")},
downloadAuthorizedKeysErr: someErr,
},
attemptedDownloads: map[string]time.Time{},
wantDownloadErr: true,
},
"service restart error is detected": {
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
serviceManager: stubServiceManager{systemdActionErr: someErr},
attemptedDownloads: map[string]time.Time{},
wantChunks: [][]byte{[]byte("test")},
wantDownloadErr: true,
wantFile: true,
wantDeployed: true,
wantSystemdAction: false,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
ip := "192.0.2.0"
writer := &fakeStreamToFileWriter{}
dialer := testdialer.NewBufconnDialer()
grpcServ := grpc.NewServer()
pb.RegisterDebugdServer(grpcServ, &tc.server)
lis := dialer.GetListener(net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort)))
go grpcServ.Serve(lis)
defer grpcServ.GracefulStop()
download := &Download{
log: logger.NewTest(t),
dialer: dialer,
writer: writer,
serviceManager: &tc.serviceManager,
attemptedDownloads: tc.attemptedDownloads,
}
keys, err := download.DownloadDeployment(context.Background(), ip)
if tc.wantDownloadErr {
assert.Error(err)
} else {
assert.NoError(err)
}
if tc.wantFile {
assert.Equal(tc.wantChunks, writer.chunks)
assert.Equal(filename, writer.filename)
}
if tc.wantSystemdAction {
assert.ElementsMatch(
[]ServiceManagerRequest{
{Unit: debugd.BootstrapperSystemdUnitName, Action: Restart},
},
tc.serviceManager.requests,
)
}
assert.Equal(tc.wantKeys, keys)
})
}
}
type stubServiceManager struct {
requests []ServiceManagerRequest
systemdActionErr error
}
func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
s.requests = append(s.requests, request)
return s.systemdActionErr
}
type fakeStreamToFileWriter struct {
chunks [][]byte
filename string
}
func (f *fakeStreamToFileWriter) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
f.filename = filename
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return fmt.Errorf("reading stream: %w", err)
}
f.chunks = append(f.chunks, chunk.Content)
}
}
// fakeDownloadServer implements DebugdServer; only fakes DownloadBootstrapper, panics on every other rpc.
type fakeDownloadServer struct {
chunks [][]byte
downladErr error
keys []*pb.AuthorizedKey
downloadAuthorizedKeysErr error
pb.UnimplementedDebugdServer
}
func (f *fakeDownloadServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
for _, chunk := range f.chunks {
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
return fmt.Errorf("sending chunk: %w", err)
}
}
return f.downladErr
}
func (s *fakeDownloadServer) DownloadAuthorizedKeys(context.Context, *pb.DownloadAuthorizedKeysRequest) (*pb.DownloadAuthorizedKeysResponse, error) {
return &pb.DownloadAuthorizedKeysResponse{Keys: s.keys}, s.downloadAuthorizedKeysErr
}

View file

@ -0,0 +1,163 @@
package deploy
import (
"context"
"fmt"
"sync"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/spf13/afero"
"go.uber.org/zap"
)
const (
systemdUnitFolder = "/etc/systemd/system"
)
//go:generate stringer -type=SystemdAction
type SystemdAction uint32
const (
Unknown SystemdAction = iota
Start
Stop
Restart
Reload
)
// ServiceManagerRequest describes a requested ServiceManagerAction to be performed on a specified service unit.
type ServiceManagerRequest struct {
Unit string
Action SystemdAction
}
// SystemdUnit describes a systemd service file including the unit name and contents.
type SystemdUnit struct {
Name string `yaml:"name"`
Contents string `yaml:"contents"`
}
// ServiceManager receives ServiceManagerRequests and units via channels and performs the requests / creates the unit files.
type ServiceManager struct {
log *logger.Logger
dbus dbusClient
fs afero.Fs
systemdUnitFilewriteLock sync.Mutex
}
// NewServiceManager creates a new ServiceManager.
func NewServiceManager(log *logger.Logger) *ServiceManager {
fs := afero.NewOsFs()
return &ServiceManager{
log: log,
dbus: &dbusWrapper{},
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
}
type dbusClient interface {
// NewSystemConnectionContext establishes a connection to the system bus and authenticates.
// Callers should call Close() when done with the connection.
NewSystemdConnectionContext(ctx context.Context) (dbusConn, error)
}
type dbusConn interface {
// StartUnitContext enqueues a start job and depending jobs, if any (unless otherwise
// specified by the mode string).
StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
// StopUnitContext is similar to StartUnitContext, but stops the specified unit
// rather than starting it.
StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
// RestartUnitContext restarts a service. If a service is restarted that isn't
// running it will be started.
RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
// ReloadContext instructs systemd to scan for and reload unit files. This is
// an equivalent to systemctl daemon-reload.
ReloadContext(ctx context.Context) error
}
// SystemdAction will perform a systemd action on a service unit (start, stop, restart, reload).
func (s *ServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
log := s.log.With(zap.String("unit", request.Unit), zap.String("action", request.Action.String()))
conn, err := s.dbus.NewSystemdConnectionContext(ctx)
if err != nil {
return fmt.Errorf("establishing systemd connection: %w", err)
}
resultChan := make(chan string, 1)
switch request.Action {
case Start:
_, err = conn.StartUnitContext(ctx, request.Unit, "replace", resultChan)
case Stop:
_, err = conn.StopUnitContext(ctx, request.Unit, "replace", resultChan)
case Restart:
_, err = conn.RestartUnitContext(ctx, request.Unit, "replace", resultChan)
case Reload:
err = conn.ReloadContext(ctx)
default:
return fmt.Errorf("unknown systemd action: %s", request.Action.String())
}
if err != nil {
return fmt.Errorf("performing systemd action %v on unit %v: %w", request.Action, request.Unit, err)
}
if request.Action == Reload {
log.Infof("daemon-reload succeeded")
return nil
}
// Wait for the action to finish and then check if it was
// successful or not.
result := <-resultChan
switch result {
case "done":
log.Infof("%s on systemd unit %s succeeded", request.Action, request.Unit)
return nil
default:
return fmt.Errorf("performing action %q on systemd unit %q failed: expected %q but received %q", request.Action.String(), request.Unit, "done", result)
}
}
// WriteSystemdUnitFile will write a systemd unit to disk.
func (s *ServiceManager) WriteSystemdUnitFile(ctx context.Context, unit SystemdUnit) error {
log := s.log.With(zap.String("unitFile", fmt.Sprintf("%s/%s", systemdUnitFolder, unit.Name)))
log.Infof("Writing systemd unit file")
s.systemdUnitFilewriteLock.Lock()
defer s.systemdUnitFilewriteLock.Unlock()
if err := afero.WriteFile(s.fs, fmt.Sprintf("%s/%s", systemdUnitFolder, unit.Name), []byte(unit.Contents), 0o644); err != nil {
return fmt.Errorf("writing systemd unit file \"%v\": %w", unit.Name, err)
}
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unit.Name, Action: Reload}); err != nil {
return fmt.Errorf("performing systemd daemon-reload: %w", err)
}
log.Infof("Wrote systemd unit file and performed daemon-reload")
return nil
}
// DeployDefaultServiceUnit will write the default "bootstrapper.service" unit file.
func DeployDefaultServiceUnit(ctx context.Context, serviceManager *ServiceManager) error {
if err := serviceManager.WriteSystemdUnitFile(ctx, SystemdUnit{
Name: debugd.BootstrapperSystemdUnitName,
Contents: debugd.BootstrapperSystemdUnitContents,
}); err != nil {
return fmt.Errorf("writing systemd unit file %q: %w", debugd.BootstrapperSystemdUnitName, err)
}
// try to start the default service if the binary exists but ignore failure.
// this is meant to start the bootstrapper after a reboot
// if a bootstrapper binary was uploaded before.
if ok, err := afero.Exists(serviceManager.fs, debugd.BootstrapperDeployFilename); ok && err == nil {
_ = serviceManager.SystemdAction(ctx, ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: Start,
})
}
return nil
}

View file

@ -0,0 +1,242 @@
package deploy
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSystemdAction(t *testing.T) {
unitName := "example.service"
testCases := map[string]struct {
dbus stubDbus
action SystemdAction
wantErr bool
}{
"start works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
action: Start,
wantErr: false,
},
"stop works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
action: Stop,
wantErr: false,
},
"restart works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
action: Restart,
wantErr: false,
},
"reload works": {
dbus: stubDbus{
conn: &fakeDbusConn{},
},
action: Reload,
wantErr: false,
},
"unknown action": {
dbus: stubDbus{
conn: &fakeDbusConn{},
},
action: Unknown,
wantErr: true,
},
"action fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
actionErr: errors.New("action fails"),
},
},
action: Start,
wantErr: true,
},
"action result is failure": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "failure",
},
},
action: Start,
wantErr: true,
},
"newConn fails": {
dbus: stubDbus{
connErr: errors.New("newConn fails"),
},
action: Start,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
manager := ServiceManager{
log: logger.NewTest(t),
dbus: &tc.dbus,
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
err := manager.SystemdAction(context.Background(), ServiceManagerRequest{
Unit: unitName,
Action: tc.action,
})
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
})
}
}
func TestWriteSystemdUnitFile(t *testing.T) {
testCases := map[string]struct {
dbus stubDbus
unit SystemdUnit
readonly bool
wantErr bool
wantFileContents string
}{
"start works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unit: SystemdUnit{
Name: "test.service",
Contents: "testservicefilecontents",
},
wantErr: false,
wantFileContents: "testservicefilecontents",
},
"write fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unit: SystemdUnit{
Name: "test.service",
Contents: "testservicefilecontents",
},
readonly: true,
wantErr: true,
},
"systemd reload fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
actionErr: errors.New("reload error"),
},
},
unit: SystemdUnit{
Name: "test.service",
Contents: "testservicefilecontents",
},
readonly: false,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
assert.NoError(fs.MkdirAll(systemdUnitFolder, 0o755))
if tc.readonly {
fs = afero.NewReadOnlyFs(fs)
}
manager := ServiceManager{
log: logger.NewTest(t),
dbus: &tc.dbus,
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
err := manager.WriteSystemdUnitFile(context.Background(), tc.unit)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
fileContents, err := afero.ReadFile(fs, fmt.Sprintf("%s/%s", systemdUnitFolder, tc.unit.Name))
assert.NoError(err)
assert.Equal(tc.wantFileContents, string(fileContents))
})
}
}
type stubDbus struct {
conn dbusConn
connErr error
}
func (s *stubDbus) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
return s.conn, s.connErr
}
type dbusConnActionInput struct {
name string
mode string
}
type fakeDbusConn struct {
inputs []dbusConnActionInput
result string
jobID int
actionErr error
}
func (c *fakeDbusConn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
ch <- c.result
return c.jobID, c.actionErr
}
func (c *fakeDbusConn) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
ch <- c.result
return c.jobID, c.actionErr
}
func (c *fakeDbusConn) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
ch <- c.result
return c.jobID, c.actionErr
}
func (c *fakeDbusConn) ReloadContext(ctx context.Context) error {
return c.actionErr
}

View file

@ -0,0 +1,27 @@
// Code generated by "stringer -type=SystemdAction"; DO NOT EDIT.
package deploy
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[Unknown-0]
_ = x[Start-1]
_ = x[Stop-2]
_ = x[Restart-3]
_ = x[Reload-4]
}
const _SystemdAction_name = "UnknownStartStopRestartReload"
var _SystemdAction_index = [...]uint8{0, 7, 12, 16, 23, 29}
func (i SystemdAction) String() string {
if i >= SystemdAction(len(_SystemdAction_index)-1) {
return "SystemdAction(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _SystemdAction_name[_SystemdAction_index[i]:_SystemdAction_index[i+1]]
}

View file

@ -0,0 +1,40 @@
package deploy
import (
"context"
"github.com/coreos/go-systemd/v22/dbus"
)
// wraps go-systemd dbus.
type dbusWrapper struct{}
func (d *dbusWrapper) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
conn, err := dbus.NewSystemdConnectionContext(ctx)
if err != nil {
return nil, err
}
return &dbusConnWrapper{
conn: conn,
}, nil
}
type dbusConnWrapper struct {
conn *dbus.Conn
}
func (c *dbusConnWrapper) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.conn.StartUnitContext(ctx, name, mode, ch)
}
func (c *dbusConnWrapper) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.conn.StopUnitContext(ctx, name, mode, ch)
}
func (c *dbusConnWrapper) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
return c.conn.RestartUnitContext(ctx, name, mode, ch)
}
func (c *dbusConnWrapper) ReloadContext(ctx context.Context) error {
return c.conn.ReloadContext(ctx)
}

View file

@ -0,0 +1,130 @@
package cloudprovider
import (
"context"
"fmt"
"net"
azurecloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/azure"
gcpcloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/gcp"
qemucloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/qemu"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
)
type providerMetadata interface {
// List retrieves all instances belonging to the current constellation.
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
// Self retrieves the current instance.
Self(ctx context.Context) (metadata.InstanceMetadata, error)
// GetLoadBalancerEndpoint returns the endpoint of the load balancer.
GetLoadBalancerEndpoint(ctx context.Context) (string, error)
}
// Fetcher checks the metadata service to search for instances that were set up for debugging and cloud provider specific SSH keys.
type Fetcher struct {
metaAPI providerMetadata
}
// NewGCP creates a new GCP fetcher.
func NewGCP(ctx context.Context) (*Fetcher, error) {
gcpClient, err := gcpcloud.NewClient(ctx)
if err != nil {
return nil, err
}
metaAPI := gcpcloud.New(gcpClient)
return &Fetcher{
metaAPI: metaAPI,
}, nil
}
// NewAzure creates a new Azure fetcher.
func NewAzure(ctx context.Context) (*Fetcher, error) {
metaAPI, err := azurecloud.NewMetadata(ctx)
if err != nil {
return nil, err
}
return &Fetcher{
metaAPI: metaAPI,
}, nil
}
func NewQEMU() *Fetcher {
return &Fetcher{
metaAPI: &qemucloud.Metadata{},
}
}
func (f *Fetcher) Role(ctx context.Context) (role.Role, error) {
self, err := f.metaAPI.Self(ctx)
if err != nil {
return role.Unknown, fmt.Errorf("retrieving role from cloud provider metadata: %w", err)
}
return self.Role, nil
}
// DiscoverDebugdIPs will query the metadata of all instances and return any ips of instances already set up for debugging.
func (f *Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
self, err := f.metaAPI.Self(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving own instance: %w", err)
}
instances, err := f.metaAPI.List(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving instances: %w", err)
}
// filter own instance from instance list
for i, instance := range instances {
if instance.ProviderID == self.ProviderID {
instances = append(instances[:i], instances[i+1:]...)
break
}
}
var ips []string
for _, instance := range instances {
if instance.VPCIP != "" {
ips = append(ips, instance.VPCIP)
}
}
return ips, nil
}
func (f *Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
lbEndpoint, err := f.metaAPI.GetLoadBalancerEndpoint(ctx)
if err != nil {
return "", fmt.Errorf("retrieving load balancer endpoint: %w", err)
}
// The port of the endpoint is not the port we need. We need to strip it off.
//
// TODO: Tag the specific load balancer we are looking for with a distinct tag.
// Change the GetLoadBalancerEndpoint method to return the endpoint of a load
// balancer with a given tag.
lbIP, _, err := net.SplitHostPort(lbEndpoint)
if err != nil {
return "", fmt.Errorf("parsing load balancer endpoint: %w", err)
}
return lbIP, nil
}
// FetchSSHKeys will query the metadata of the current instance and deploys any SSH keys found.
func (f *Fetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
self, err := f.metaAPI.Self(ctx)
if err != nil {
return nil, fmt.Errorf("retrieving ssh keys from cloud provider metadata: %w", err)
}
keys := []ssh.UserKey{}
for username, userKeys := range self.SSHKeys {
for _, keyValue := range userKeys {
keys = append(keys, ssh.UserKey{Username: username, PublicKey: keyValue})
}
}
return keys, nil
}

View file

@ -0,0 +1,244 @@
package cloudprovider
import (
"context"
"errors"
"testing"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/cloud/metadata"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m,
// https://github.com/census-instrumentation/opencensus-go/issues/1262
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
)
}
func TestRole(t *testing.T) {
instance1 := metadata.InstanceMetadata{Role: role.ControlPlane}
instance2 := metadata.InstanceMetadata{Role: role.Worker}
testCases := map[string]struct {
meta *stubMetadata
wantErr bool
wantRole role.Role
}{
"control plane": {
meta: &stubMetadata{selfRes: instance1},
wantRole: role.ControlPlane,
},
"worker": {
meta: &stubMetadata{selfRes: instance2},
wantRole: role.Worker,
},
"self fails": {
meta: &stubMetadata{selfErr: errors.New("some err")},
wantErr: true,
wantRole: role.Unknown,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fetcher := Fetcher{tc.meta}
role, err := fetcher.Role(context.Background())
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantRole, role)
}
})
}
}
func TestDiscoverDebugIPs(t *testing.T) {
err := errors.New("some err")
testCases := map[string]struct {
meta stubMetadata
wantIPs []string
wantErr bool
}{
"disovery works": {
meta: stubMetadata{
listRes: []metadata.InstanceMetadata{
{
VPCIP: "192.0.2.0",
},
{
VPCIP: "192.0.2.1",
},
{
VPCIP: "192.0.2.2",
},
},
},
wantIPs: []string{
"192.0.2.1", "192.0.2.2",
},
},
"retrieve fails": {
meta: stubMetadata{
listErr: err,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fetcher := Fetcher{
metaAPI: &tc.meta,
}
ips, err := fetcher.DiscoverDebugdIPs(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.ElementsMatch(tc.wantIPs, ips)
})
}
}
func TestDiscoverLoadbalancerIP(t *testing.T) {
ip := "192.0.2.1"
endpoint := ip + ":1234"
someErr := errors.New("failed")
testCases := map[string]struct {
metaAPI providerMetadata
wantIP string
wantErr bool
}{
"discovery works": {
metaAPI: &stubMetadata{getLBEndpointRes: endpoint},
wantIP: ip,
},
"get endpoint fails": {
metaAPI: &stubMetadata{getLBEndpointErr: someErr},
wantErr: true,
},
"invalid endpoint": {
metaAPI: &stubMetadata{getLBEndpointRes: "invalid"},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
fetcher := &Fetcher{
metaAPI: tc.metaAPI,
}
ip, err := fetcher.DiscoverLoadbalancerIP(context.Background())
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
assert.Equal(tc.wantIP, ip)
}
})
}
}
func TestFetchSSHKeys(t *testing.T) {
err := errors.New("some err")
testCases := map[string]struct {
meta stubMetadata
wantKeys []ssh.UserKey
wantErr bool
}{
"fetch works": {
meta: stubMetadata{
selfRes: metadata.InstanceMetadata{
Name: "name",
ProviderID: "provider-id",
SSHKeys: map[string][]string{"bob": {"ssh-rsa bobskey"}},
},
},
wantKeys: []ssh.UserKey{
{
Username: "bob",
PublicKey: "ssh-rsa bobskey",
},
},
},
"retrieve fails": {
meta: stubMetadata{
selfErr: err,
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fetcher := Fetcher{
metaAPI: &tc.meta,
}
keys, err := fetcher.FetchSSHKeys(context.Background())
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.ElementsMatch(tc.wantKeys, keys)
})
}
}
type stubMetadata struct {
listRes []metadata.InstanceMetadata
listErr error
selfRes metadata.InstanceMetadata
selfErr error
getInstanceRes metadata.InstanceMetadata
getInstanceErr error
getLBEndpointRes string
getLBEndpointErr error
supportedRes bool
}
func (m *stubMetadata) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
return m.listRes, m.listErr
}
func (m *stubMetadata) Self(ctx context.Context) (metadata.InstanceMetadata, error) {
return m.selfRes, m.selfErr
}
func (m *stubMetadata) GetInstance(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
return m.getInstanceRes, m.getInstanceErr
}
func (m *stubMetadata) GetLoadBalancerEndpoint(ctx context.Context) (string, error) {
return m.getLBEndpointRes, m.getLBEndpointErr
}
func (m *stubMetadata) Supported() bool {
return m.supportedRes
}

View file

@ -0,0 +1,31 @@
package fallback
import (
"context"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
)
// Fetcher implements metadata.Fetcher interface but does not actually fetch cloud provider metadata.
type Fetcher struct{}
func (f Fetcher) Role(_ context.Context) (role.Role, error) {
// Fallback fetcher does not try to fetch role
return role.Unknown, nil
}
func (f Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
// Fallback fetcher does not try to discover debugd IPs
return nil, nil
}
func (f Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
// Fallback fetcher does not try to discover loadbalancer IP
return "", nil
}
func (f Fetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
// Fallback fetcher does not try to fetch ssh keys
return nil, nil
}

View file

@ -0,0 +1,33 @@
package fallback
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestDiscoverDebugdIPs(t *testing.T) {
assert := assert.New(t)
fetcher := Fetcher{}
ips, err := fetcher.DiscoverDebugdIPs(context.Background())
assert.NoError(err)
assert.Empty(ips)
}
func TestFetchSSHKeys(t *testing.T) {
assert := assert.New(t)
fetcher := Fetcher{}
keys, err := fetcher.FetchSSHKeys(context.Background())
assert.NoError(err)
assert.Empty(keys)
}

View file

@ -0,0 +1,143 @@
package metadata
import (
"context"
"errors"
"io/fs"
"sync"
"time"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
)
// Fetcher retrieves other debugd IPs and SSH keys from cloud provider metadata.
type Fetcher interface {
Role(ctx context.Context) (role.Role, error)
DiscoverDebugdIPs(ctx context.Context) ([]string, error)
FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error)
DiscoverLoadbalancerIP(ctx context.Context) (string, error)
}
// Scheduler schedules fetching of metadata using timers.
type Scheduler struct {
log *logger.Logger
fetcher Fetcher
ssh sshDeployer
downloader downloader
}
// NewScheduler returns a new scheduler.
func NewScheduler(log *logger.Logger, fetcher Fetcher, ssh sshDeployer, downloader downloader) *Scheduler {
return &Scheduler{
log: log,
fetcher: fetcher,
ssh: ssh,
downloader: downloader,
}
}
// Start will start the loops for discovering debugd endpoints and ssh keys.
func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
wg.Add(2)
go s.discoveryLoop(ctx, wg)
go s.sshLoop(ctx, wg)
}
// discoveryLoop discovers new debugd endpoints from cloud-provider metadata periodically.
func (s *Scheduler) discoveryLoop(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
// execute debugd discovery once at the start to skip wait for first tick
ips, err := s.fetcher.DiscoverDebugdIPs(ctx)
if err != nil {
s.log.With(zap.Error(err)).Errorf("Discovering debugd IPs failed")
} else {
if s.downloadDeployment(ctx, ips) {
return
}
}
ticker := time.NewTicker(debugd.DiscoverDebugdInterval)
defer ticker.Stop()
for {
var err error
select {
case <-ticker.C:
ips, err = s.fetcher.DiscoverDebugdIPs(ctx)
if err != nil {
s.log.With(zap.Error(err)).Errorf("Discovering debugd IPs failed")
continue
}
s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances")
if s.downloadDeployment(ctx, ips) {
return
}
case <-ctx.Done():
return
}
}
}
// sshLoop discovers new ssh keys from cloud provider metadata periodically.
func (s *Scheduler) sshLoop(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
ticker := time.NewTicker(debugd.SSHCheckInterval)
defer ticker.Stop()
for {
keys, err := s.fetcher.FetchSSHKeys(ctx)
if err != nil {
s.log.With(zap.Error(err)).Errorf("Fetching SSH keys failed")
} else {
s.deploySSHKeys(ctx, keys)
}
select {
case <-ticker.C:
case <-ctx.Done():
return
}
}
}
// downloadDeployment tries to download deployment from a list of ips and logs errors encountered.
func (s *Scheduler) downloadDeployment(ctx context.Context, ips []string) (success bool) {
for _, ip := range ips {
keys, err := s.downloader.DownloadDeployment(ctx, ip)
if err == nil {
s.deploySSHKeys(ctx, keys)
return true
}
if errors.Is(err, fs.ErrExist) {
// bootstrapper was already uploaded
s.log.Infof("Bootstrapper was already uploaded.")
return true
}
s.log.With(zap.Error(err), zap.String("peer", ip)).Errorf("Downloading deployment from peer failed")
}
return false
}
// deploySSHKeys tries to deploy a list of SSH keys and logs errors encountered.
func (s *Scheduler) deploySSHKeys(ctx context.Context, keys []ssh.UserKey) {
for _, key := range keys {
err := s.ssh.DeployAuthorizedKey(ctx, key)
if err != nil {
s.log.With(zap.Error(err), zap.Any("key", key)).Errorf("Deploying SSH key failed")
continue
}
}
}
type downloader interface {
DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error)
}
type sshDeployer interface {
DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error
}

View file

@ -0,0 +1,134 @@
package metadata
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/edgelesssys/constellation/bootstrapper/role"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestSchedulerStart(t *testing.T) {
someErr := errors.New("failed")
testCases := map[string]struct {
fetcher stubFetcher
ssh stubSSHDeployer
downloader stubDownloader
timeout time.Duration
wantSSHKeys []ssh.UserKey
wantDebugdDownloads []string
}{
"scheduler works and calls fetcher functions at least once": {},
"ssh keys are fetched": {
fetcher: stubFetcher{
keys: []ssh.UserKey{{Username: "test", PublicKey: "testkey"}},
},
wantSSHKeys: []ssh.UserKey{{Username: "test", PublicKey: "testkey"}},
},
"download for discovered debugd ips is started": {
fetcher: stubFetcher{
ips: []string{"192.0.2.1", "192.0.2.2"},
},
downloader: stubDownloader{downloadErr: someErr},
wantDebugdDownloads: []string{"192.0.2.1", "192.0.2.2"},
},
"if download is successful, second download is not attempted": {
fetcher: stubFetcher{
ips: []string{"192.0.2.1", "192.0.2.2"},
},
wantDebugdDownloads: []string{"192.0.2.1"},
},
"endpoint discovery can fail": {
fetcher: stubFetcher{discoverErr: someErr},
},
"ssh key fetch can fail": {
fetcher: stubFetcher{fetchSSHKeysErr: someErr},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
wg := &sync.WaitGroup{}
ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
defer cancel()
scheduler := Scheduler{
log: logger.NewTest(t),
fetcher: &tc.fetcher,
ssh: &tc.ssh,
downloader: &tc.downloader,
}
wg.Add(1)
go scheduler.Start(ctx, wg)
wg.Wait()
assert.Equal(tc.wantSSHKeys, tc.ssh.sshKeys)
assert.Equal(tc.wantDebugdDownloads, tc.downloader.ips)
assert.Greater(tc.fetcher.discoverCalls, 0)
assert.Greater(tc.fetcher.fetchSSHKeysCalls, 0)
})
}
}
type stubFetcher struct {
discoverCalls int
fetchSSHKeysCalls int
ips []string
keys []ssh.UserKey
discoverErr error
fetchSSHKeysErr error
}
func (s *stubFetcher) Role(_ context.Context) (role.Role, error) {
return role.Unknown, nil
}
func (s *stubFetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
s.discoverCalls++
return s.ips, s.discoverErr
}
func (s *stubFetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
s.fetchSSHKeysCalls++
return s.keys, s.fetchSSHKeysErr
}
func (s *stubFetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
return "", nil
}
type stubSSHDeployer struct {
sshKeys []ssh.UserKey
deployErr error
}
func (s *stubSSHDeployer) DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error {
s.sshKeys = append(s.sshKeys, sshKey)
return s.deployErr
}
type stubDownloader struct {
ips []string
downloadErr error
keys []ssh.UserKey
}
func (s *stubDownloader) DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error) {
s.ips = append(s.ips, ip)
return s.keys, s.downloadErr
}

View file

@ -0,0 +1,162 @@
package server
import (
"context"
"errors"
"fmt"
"io/fs"
"net"
"strconv"
"sync"
"time"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd"
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
)
type debugdServer struct {
log *logger.Logger
ssh sshDeployer
serviceManager serviceManager
streamer streamer
pb.UnimplementedDebugdServer
}
// New creates a new debugdServer according to the gRPC spec.
func New(log *logger.Logger, ssh sshDeployer, serviceManager serviceManager, streamer streamer) pb.DebugdServer {
return &debugdServer{
log: log,
ssh: ssh,
serviceManager: serviceManager,
streamer: streamer,
}
}
// UploadAuthorizedKeys receives a list of authorized keys and forwards them to a channel.
func (s *debugdServer) UploadAuthorizedKeys(ctx context.Context, in *pb.UploadAuthorizedKeysRequest) (*pb.UploadAuthorizedKeysResponse, error) {
s.log.Infof("Uploading authorized keys")
for _, key := range in.Keys {
if err := s.ssh.DeployAuthorizedKey(ctx, ssh.UserKey{Username: key.Username, PublicKey: key.KeyValue}); err != nil {
s.log.With(zap.Error(err)).Errorf("Uploading authorized keys failed")
return &pb.UploadAuthorizedKeysResponse{
Status: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_FAILURE,
}, nil
}
}
return &pb.UploadAuthorizedKeysResponse{
Status: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS,
}, nil
}
// UploadBootstrapper receives a bootstrapper binary in a stream of chunks and writes to a file.
func (s *debugdServer) UploadBootstrapper(stream pb.Debugd_UploadBootstrapperServer) error {
startAction := deploy.ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: deploy.Start,
}
var responseStatus pb.UploadBootstrapperStatus
defer func() {
if err := s.serviceManager.SystemdAction(stream.Context(), startAction); err != nil {
s.log.With(zap.Error(err)).Errorf("Starting uploaded bootstrapper failed")
if responseStatus == pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS {
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED
}
}
stream.SendAndClose(&pb.UploadBootstrapperResponse{
Status: responseStatus,
})
}()
s.log.Infof("Starting bootstrapper upload")
if err := s.streamer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
if errors.Is(err, fs.ErrExist) {
// bootstrapper was already uploaded
s.log.Warnf("Bootstrapper already uploaded")
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_FILE_EXISTS
return nil
}
s.log.With(zap.Error(err)).Errorf("Uploading bootstrapper failed")
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED
return fmt.Errorf("uploading bootstrapper: %w", err)
}
s.log.Infof("Successfully uploaded bootstrapper")
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS
return nil
}
// DownloadBootstrapper streams the local bootstrapper binary to other instances.
func (s *debugdServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
s.log.Infof("Sending bootstrapper to other instance")
return s.streamer.ReadStream(debugd.BootstrapperDeployFilename, stream, debugd.Chunksize, true)
}
// DownloadAuthorizedKeys streams the local authorized keys to other instances.
func (s *debugdServer) DownloadAuthorizedKeys(_ context.Context, req *pb.DownloadAuthorizedKeysRequest) (*pb.DownloadAuthorizedKeysResponse, error) {
s.log.Infof("Sending authorized keys to other instance")
var authKeys []*pb.AuthorizedKey
for _, key := range s.ssh.GetAuthorizedKeys() {
authKeys = append(authKeys, &pb.AuthorizedKey{
Username: key.Username,
KeyValue: key.PublicKey,
})
}
return &pb.DownloadAuthorizedKeysResponse{Keys: authKeys}, nil
}
// UploadSystemServiceUnits receives systemd service units, writes them to a service file and schedules a daemon-reload.
func (s *debugdServer) UploadSystemServiceUnits(ctx context.Context, in *pb.UploadSystemdServiceUnitsRequest) (*pb.UploadSystemdServiceUnitsResponse, error) {
s.log.Infof("Uploading systemd service units")
for _, unit := range in.Units {
if err := s.serviceManager.WriteSystemdUnitFile(ctx, deploy.SystemdUnit{Name: unit.Name, Contents: unit.Contents}); err != nil {
return &pb.UploadSystemdServiceUnitsResponse{Status: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_FAILURE}, nil
}
}
return &pb.UploadSystemdServiceUnitsResponse{Status: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS}, nil
}
// Start will start the gRPC server and block.
func Start(log *logger.Logger, wg *sync.WaitGroup, serv pb.DebugdServer) {
defer wg.Done()
grpcLog := log.Named("gRPC")
grpcLog.WithIncreasedLevel(zap.WarnLevel).ReplaceGRPCLogger()
grpcServer := grpc.NewServer(
grpcLog.GetServerStreamInterceptor(),
grpcLog.GetServerUnaryInterceptor(),
grpc.KeepaliveParams(keepalive.ServerParameters{Time: 15 * time.Second}),
)
pb.RegisterDebugdServer(grpcServer, serv)
lis, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.DebugdPort)))
if err != nil {
log.With(zap.Error(err)).Fatalf("Listening failed")
}
log.Infof("gRPC server is waiting for connections")
grpcServer.Serve(lis)
}
type sshDeployer interface {
DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error
GetAuthorizedKeys() []ssh.UserKey
}
type serviceManager interface {
SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error
WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error
}
type streamer interface {
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error
}

View file

@ -0,0 +1,490 @@
package server
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"testing"
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
pb "github.com/edgelesssys/constellation/debugd/service"
"github.com/edgelesssys/constellation/internal/constants"
"github.com/edgelesssys/constellation/internal/deploy/ssh"
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
"github.com/edgelesssys/constellation/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestUploadAuthorizedKeys(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
request *pb.UploadAuthorizedKeysRequest
wantErr bool
wantResponseStatus pb.UploadAuthorizedKeysStatus
wantKeys []ssh.UserKey
}{
"upload authorized keys works": {
request: &pb.UploadAuthorizedKeysRequest{
Keys: []*pb.AuthorizedKey{
{
Username: "testuser",
KeyValue: "teskey",
},
},
},
wantResponseStatus: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS,
wantKeys: []ssh.UserKey{
{
Username: "testuser",
PublicKey: "teskey",
},
},
},
"deploy fails": {
request: &pb.UploadAuthorizedKeysRequest{
Keys: []*pb.AuthorizedKey{
{
Username: "testuser",
KeyValue: "teskey",
},
},
},
ssh: stubSSHDeployer{deployErr: errors.New("ssh key deployment error")},
wantResponseStatus: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_FAILURE,
wantKeys: []ssh.UserKey{
{
Username: "testuser",
PublicKey: "teskey",
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &fakeStreamer{},
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
resp, err := client.UploadAuthorizedKeys(context.Background(), tc.request)
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResponseStatus, resp.Status)
assert.ElementsMatch(tc.ssh.sshKeys, tc.wantKeys)
})
}
}
func TestUploadBootstrapper(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
streamer fakeStreamer
uploadChunks [][]byte
wantErr bool
wantResponseStatus pb.UploadBootstrapperStatus
wantFile bool
wantChunks [][]byte
}{
"upload works": {
uploadChunks: [][]byte{
[]byte("test"),
},
wantFile: true,
wantChunks: [][]byte{
[]byte("test"),
},
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS,
},
"recv fails": {
streamer: fakeStreamer{
writeStreamErr: errors.New("recv error"),
},
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED,
wantErr: true,
},
"starting bootstrapper fails": {
uploadChunks: [][]byte{
[]byte("test"),
},
serviceManager: stubServiceManager{
systemdActionErr: errors.New("starting bootstrapper error"),
},
wantFile: true,
wantChunks: [][]byte{
[]byte("test"),
},
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &tc.streamer,
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
stream, err := client.UploadBootstrapper(context.Background())
require.NoError(err)
require.NoError(fakeWrite(stream, tc.uploadChunks))
resp, err := stream.CloseAndRecv()
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantResponseStatus, resp.Status)
if tc.wantFile {
assert.Equal(tc.wantChunks, tc.streamer.writeStreamChunks)
assert.Equal("/opt/bootstrapper", tc.streamer.writeStreamFilename)
} else {
assert.Empty(tc.streamer.writeStreamChunks)
assert.Empty(tc.streamer.writeStreamFilename)
}
})
}
}
func TestDownloadBootstrapper(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
request *pb.DownloadBootstrapperRequest
streamer fakeStreamer
wantErr bool
wantChunks [][]byte
}{
"download works": {
request: &pb.DownloadBootstrapperRequest{},
streamer: fakeStreamer{
readStreamChunks: [][]byte{
[]byte("test"),
},
},
wantErr: false,
wantChunks: [][]byte{
[]byte("test"),
},
},
"download fails": {
request: &pb.DownloadBootstrapperRequest{},
streamer: fakeStreamer{
readStreamErr: errors.New("read bootstrapper fails"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &tc.streamer,
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
stream, err := client.DownloadBootstrapper(context.Background(), tc.request)
require.NoError(err)
chunks, err := fakeRead(stream)
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantChunks, chunks)
assert.Equal("/opt/bootstrapper", tc.streamer.readStreamFilename)
})
}
}
func TestDownloadAuthorizedKeys(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
deployer := &stubSSHDeployer{
sshKeys: []ssh.UserKey{
{Username: "test1", PublicKey: "foo"},
{Username: "test2", PublicKey: "bar"},
},
}
serv := debugdServer{
log: logger.NewTest(t),
ssh: deployer,
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
defer grpcServ.GracefulStop()
client := pb.NewDebugdClient(conn)
resp, err := client.DownloadAuthorizedKeys(context.Background(), &pb.DownloadAuthorizedKeysRequest{})
assert.NoError(err)
wantKeys := []*pb.AuthorizedKey{
{Username: "test1", KeyValue: "foo"},
{Username: "test2", KeyValue: "bar"},
}
assert.ElementsMatch(wantKeys, resp.Keys)
}
func TestUploadSystemServiceUnits(t *testing.T) {
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
testCases := map[string]struct {
ssh stubSSHDeployer
serviceManager stubServiceManager
request *pb.UploadSystemdServiceUnitsRequest
wantErr bool
wantResponseStatus pb.UploadSystemdServiceUnitsStatus
wantUnitFiles []deploy.SystemdUnit
}{
"upload systemd service units": {
request: &pb.UploadSystemdServiceUnitsRequest{
Units: []*pb.ServiceUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
wantResponseStatus: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS,
wantUnitFiles: []deploy.SystemdUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
"writing fails": {
request: &pb.UploadSystemdServiceUnitsRequest{
Units: []*pb.ServiceUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
serviceManager: stubServiceManager{
writeSystemdUnitFileErr: errors.New("write error"),
},
wantResponseStatus: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_FAILURE,
wantUnitFiles: []deploy.SystemdUnit{
{
Name: "test.service",
Contents: "testcontents",
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
serv := debugdServer{
log: logger.NewTest(t),
ssh: &tc.ssh,
serviceManager: &tc.serviceManager,
streamer: &fakeStreamer{},
}
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
require.NoError(err)
defer conn.Close()
client := pb.NewDebugdClient(conn)
resp, err := client.UploadSystemServiceUnits(context.Background(), tc.request)
grpcServ.GracefulStop()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
require.NotNil(resp.Status)
assert.Equal(tc.wantResponseStatus, resp.Status)
assert.ElementsMatch(tc.wantUnitFiles, tc.serviceManager.unitFiles)
})
}
}
type stubSSHDeployer struct {
sshKeys []ssh.UserKey
deployErr error
}
func (s *stubSSHDeployer) DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error {
s.sshKeys = append(s.sshKeys, sshKey)
return s.deployErr
}
func (s *stubSSHDeployer) GetAuthorizedKeys() []ssh.UserKey {
return s.sshKeys
}
type stubServiceManager struct {
requests []deploy.ServiceManagerRequest
unitFiles []deploy.SystemdUnit
systemdActionErr error
writeSystemdUnitFileErr error
}
func (s *stubServiceManager) SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error {
s.requests = append(s.requests, request)
return s.systemdActionErr
}
func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error {
s.unitFiles = append(s.unitFiles, unit)
return s.writeSystemdUnitFileErr
}
type netDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", addr)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
type fakeStreamer struct {
writeStreamChunks [][]byte
writeStreamFilename string
writeStreamErr error
readStreamChunks [][]byte
readStreamFilename string
readStreamErr error
}
func (f *fakeStreamer) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
f.writeStreamFilename = filename
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return f.writeStreamErr
}
return fmt.Errorf("reading stream: %w", err)
}
f.writeStreamChunks = append(f.writeStreamChunks, chunk.Content)
}
}
func (f *fakeStreamer) ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error {
f.readStreamFilename = filename
for _, chunk := range f.readStreamChunks {
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
panic(err)
}
}
return f.readStreamErr
}
func setupServerWithConn(endpoint string, serv *debugdServer) (*grpc.Server, *grpc.ClientConn, error) {
dialer := testdialer.NewBufconnDialer()
grpcServ := grpc.NewServer()
pb.RegisterDebugdServer(grpcServ, serv)
lis := dialer.GetListener(endpoint)
go grpcServ.Serve(lis)
conn, err := dial(context.Background(), dialer, endpoint)
if err != nil {
return nil, nil, err
}
return grpcServ, conn, nil
}
func fakeWrite(stream bootstrapper.WriteChunkStream, chunks [][]byte) error {
for _, chunk := range chunks {
err := stream.Send(&pb.Chunk{
Content: chunk,
})
if err != nil {
return err
}
}
return nil
}
func fakeRead(stream bootstrapper.ReadChunkStream) ([][]byte, error) {
var chunks [][]byte
for {
chunk, err := stream.Recv()
if err != nil {
if err == io.EOF {
return chunks, nil
}
return nil, err
}
chunks = append(chunks, chunk.Content)
}
}