debugd: implement upload of multiple binaries

This commit is contained in:
Malte Poll 2023-01-20 10:11:41 +01:00 committed by Malte Poll
parent e6ac8e2a91
commit 6f56ed69f8
21 changed files with 2040 additions and 661 deletions

View file

@ -8,13 +8,13 @@ package deploy
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"github.com/edgelesssys/constellation/v2/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger"
@ -27,19 +27,19 @@ import (
type Download struct {
log *logger.Logger
dialer NetDialer
writer streamToFileWriter
transfer fileTransferer
serviceManager serviceManager
info infoSetter
}
// New creates a new Download.
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager,
writer streamToFileWriter, info infoSetter,
transfer fileTransferer, info infoSetter,
) *Download {
return &Download{
log: log,
dialer: dialer,
writer: writer,
transfer: transfer,
info: info,
serviceManager: serviceManager,
}
@ -47,6 +47,10 @@ func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager,
// DownloadInfo will try to download the info from another instance.
func (d *Download) DownloadInfo(ctx context.Context, ip string) error {
if d.info.Received() {
return nil
}
log := d.log.With(zap.String("ip", ip))
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
@ -66,7 +70,7 @@ func (d *Download) DownloadInfo(ctx context.Context, ip string) error {
return d.info.SetProto(resp.Info)
}
// DownloadDeployment will open a new grpc connection to another instance, attempting to download a bootstrapper from that instance.
// DownloadDeployment will open a new grpc connection to another instance, attempting to download files from that instance.
func (d *Download) DownloadDeployment(ctx context.Context, ip string) error {
log := d.log.With(zap.String("ip", ip))
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
@ -77,23 +81,38 @@ func (d *Download) DownloadDeployment(ctx context.Context, ip string) error {
}
defer closer.Close()
log.Infof("Trying to download bootstrapper")
stream, err := client.DownloadBootstrapper(ctx, &pb.DownloadBootstrapperRequest{})
log.Infof("Trying to download files")
stream, err := client.DownloadFiles(ctx, &pb.DownloadFilesRequest{})
if err != nil {
return fmt.Errorf("starting bootstrapper download from other instance: %w", err)
return fmt.Errorf("starting file download from other instance: %w", err)
}
if err := d.writer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
return fmt.Errorf("streaming bootstrapper from other instance: %w", err)
}
log.Infof("Successfully downloaded bootstrapper")
// after the upload succeeds, try to restart the bootstrapper
restartAction := ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: Restart,
err = d.transfer.RecvFiles(stream)
switch {
case err == nil:
d.log.Infof("Downloading files succeeded")
case errors.Is(err, filetransfer.ErrReceiveRunning):
d.log.Warnf("Download already in progress")
return err
case errors.Is(err, filetransfer.ErrReceiveFinished):
d.log.Warnf("Download already finished")
return nil
default:
d.log.With(zap.Error(err)).Errorf("Downloading files failed")
return err
}
if err := d.serviceManager.SystemdAction(ctx, restartAction); err != nil {
return fmt.Errorf("restarting bootstrapper: %w", err)
files := d.transfer.GetFiles()
for _, file := range files {
if file.OverrideServiceUnit == "" {
continue
}
if err := d.serviceManager.OverrideServiceUnitExecStart(
ctx, file.OverrideServiceUnit, file.TargetPath,
); err != nil {
// continue on error to allow other units to be overridden
d.log.With(zap.Error(err)).Errorf("Failed to override service unit %s", file.OverrideServiceUnit)
}
}
return nil
@ -123,14 +142,17 @@ func (d *Download) grpcWithDialer() grpc.DialOption {
type infoSetter interface {
SetProto(infos []*pb.Info) error
Received() bool
}
type serviceManager interface {
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error
}
type streamToFileWriter interface {
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
type fileTransferer interface {
RecvFiles(stream filetransfer.RecvFilesStream) error
GetFiles() []filetransfer.FileStat
}
// NetDialer can open a net.Conn.

View file

@ -9,14 +9,11 @@ package deploy
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"testing"
"github.com/edgelesssys/constellation/v2/debugd/internal/bootstrapper"
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
@ -33,41 +30,72 @@ func TestMain(m *testing.M) {
)
}
func TestDownloadBootstrapper(t *testing.T) {
filename := "/run/state/bin/bootstrapper"
someErr := errors.New("failed")
func TestDownloadDeployment(t *testing.T) {
testCases := map[string]struct {
server fakeDownloadServer
serviceManager stubServiceManager
wantChunks [][]byte
wantDownloadErr bool
wantFile bool
wantSystemdAction bool
wantDeployed bool
files []filetransfer.FileStat
recvFilesErr error
overrideServiceUnitErr error
wantErr bool
wantOverrideCalls []struct{ UnitName, ExecStart string }
}{
"download works": {
server: fakeDownloadServer{
chunks: [][]byte{[]byte("test")},
files: []filetransfer.FileStat{
{
SourcePath: "source/testfileA",
TargetPath: "target/testfileA",
Mode: 0o644,
OverrideServiceUnit: "unitA",
},
{
SourcePath: "source/testfileB",
TargetPath: "target/testfileB",
Mode: 0o644,
},
},
wantOverrideCalls: []struct{ UnitName, ExecStart string }{
{"unitA", "target/testfileA"},
},
wantChunks: [][]byte{[]byte("test")},
wantDownloadErr: false,
wantFile: true,
wantSystemdAction: true,
wantDeployed: true,
},
"download rpc call error is detected": {
server: fakeDownloadServer{downladErr: someErr},
wantDownloadErr: true,
"recv files error is detected": {
recvFilesErr: errors.New("some error"),
wantErr: true,
},
"service restart error is detected": {
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
serviceManager: stubServiceManager{systemdActionErr: someErr},
wantChunks: [][]byte{[]byte("test")},
wantDownloadErr: true,
wantFile: true,
wantDeployed: true,
wantSystemdAction: false,
"recv already running": {
recvFilesErr: filetransfer.ErrReceiveRunning,
wantErr: true,
},
"recv already finished": {
files: []filetransfer.FileStat{
{
SourcePath: "source/testfileA",
TargetPath: "target/testfileA",
Mode: 0o644,
OverrideServiceUnit: "unitA",
},
},
recvFilesErr: filetransfer.ErrReceiveFinished,
wantErr: false,
},
"service unit fail does not stop further tries": {
files: []filetransfer.FileStat{
{
SourcePath: "source/testfileA",
TargetPath: "target/testfileA",
Mode: 0o644,
OverrideServiceUnit: "unitA",
},
{
SourcePath: "source/testfileB",
TargetPath: "target/testfileB",
Mode: 0o644,
OverrideServiceUnit: "unitB",
},
},
overrideServiceUnitErr: errors.New("some error"),
wantOverrideCalls: []struct{ UnitName, ExecStart string }{
{"unitA", "target/testfileA"},
{"unitB", "target/testfileB"},
},
},
}
@ -76,11 +104,13 @@ func TestDownloadBootstrapper(t *testing.T) {
assert := assert.New(t)
ip := "192.0.2.0"
writer := &fakeStreamToFileWriter{}
transfer := &stubTransfer{recvFilesErr: tc.recvFilesErr, files: tc.files}
serviceMgr := &stubServiceManager{overrideServiceUnitExecStartErr: tc.overrideServiceUnitErr}
dialer := testdialer.NewBufconnDialer()
server := &stubDownloadServer{}
grpcServ := grpc.NewServer()
pb.RegisterDebugdServer(grpcServ, &tc.server)
pb.RegisterDebugdServer(grpcServ, server)
lis := dialer.GetListener(net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort)))
go grpcServ.Serve(lis)
defer grpcServ.GracefulStop()
@ -88,30 +118,19 @@ func TestDownloadBootstrapper(t *testing.T) {
download := &Download{
log: logger.NewTest(t),
dialer: dialer,
writer: writer,
serviceManager: &tc.serviceManager,
transfer: transfer,
serviceManager: serviceMgr,
}
err := download.DownloadDeployment(context.Background(), ip)
if tc.wantDownloadErr {
if tc.wantErr {
assert.Error(err)
} else {
assert.NoError(err)
}
if tc.wantFile {
assert.Equal(tc.wantChunks, writer.chunks)
assert.Equal(filename, writer.filename)
}
if tc.wantSystemdAction {
assert.ElementsMatch(
[]ServiceManagerRequest{
{Unit: debugd.BootstrapperSystemdUnitName, Action: Restart},
},
tc.serviceManager.requests,
)
}
assert.Equal(tc.wantOverrideCalls, serviceMgr.overrideCalls)
})
}
}
@ -189,6 +208,9 @@ func TestDownloadInfo(t *testing.T) {
type stubServiceManager struct {
requests []ServiceManagerRequest
systemdActionErr error
overrideCalls []struct{ UnitName, ExecStart string }
overrideServiceUnitExecStartErr error
}
func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
@ -196,39 +218,34 @@ func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceM
return s.systemdActionErr
}
type fakeStreamToFileWriter struct {
chunks [][]byte
filename string
func (s *stubServiceManager) OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error {
s.overrideCalls = append(s.overrideCalls, struct {
UnitName, ExecStart string
}{UnitName: unitName, ExecStart: execStart})
return s.overrideServiceUnitExecStartErr
}
func (f *fakeStreamToFileWriter) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
f.filename = filename
for {
chunk, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return nil
}
return fmt.Errorf("reading stream: %w", err)
}
f.chunks = append(f.chunks, chunk.Content)
}
type stubTransfer struct {
recvFilesErr error
files []filetransfer.FileStat
}
// fakeDownloadServer implements DebugdServer; only fakes DownloadBootstrapper, panics on every other rpc.
type fakeDownloadServer struct {
chunks [][]byte
func (t *stubTransfer) RecvFiles(_ filetransfer.RecvFilesStream) error {
return t.recvFilesErr
}
func (t *stubTransfer) GetFiles() []filetransfer.FileStat {
return t.files
}
// stubDownloadServer implements DebugdServer; only stubs DownloadFiles, panics on every other rpc.
type stubDownloadServer struct {
downladErr error
pb.UnimplementedDebugdServer
}
func (s *fakeDownloadServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
for _, chunk := range s.chunks {
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
return fmt.Errorf("sending chunk: %w", err)
}
}
func (s *stubDownloadServer) DownloadFiles(request *pb.DownloadFilesRequest, stream pb.Debugd_DownloadFilesServer) error {
return s.downladErr
}
@ -244,6 +261,7 @@ func (s *stubDebugdServer) GetInfo(ctx context.Context, request *pb.GetInfoReque
type stubInfoSetter struct {
info []*pb.Info
received bool
setProtoErr error
}
@ -251,3 +269,7 @@ func (s *stubInfoSetter) SetProto(infos []*pb.Info) error {
s.info = infos
return s.setProtoErr
}
func (s *stubInfoSetter) Received() bool {
return s.received
}

View file

@ -9,9 +9,12 @@ package deploy
import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/spf13/afero"
"go.uber.org/zap"
@ -21,6 +24,10 @@ const (
systemdUnitFolder = "/run/systemd/system"
)
// systemdUnitNameRegexp is a regular expression that matches valid systemd unit names.
// This is only the unit name, without the .service suffix.
var systemdUnitNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9@._\-\\]+$`)
// SystemdAction encodes the available actions.
//
//go:generate stringer -type=SystemdAction
@ -73,7 +80,7 @@ func NewServiceManager(log *logger.Logger) *ServiceManager {
type dbusClient interface {
// NewSystemConnectionContext establishes a connection to the system bus and authenticates.
// Callers should call Close() when done with the connection.
NewSystemdConnectionContext(ctx context.Context) (dbusConn, error)
NewSystemConnectionContext(ctx context.Context) (dbusConn, error)
}
type dbusConn interface {
@ -89,15 +96,18 @@ type dbusConn interface {
// ReloadContext instructs systemd to scan for and reload unit files. This is
// an equivalent to systemctl daemon-reload.
ReloadContext(ctx context.Context) error
// Close closes the connection.
Close()
}
// SystemdAction will perform a systemd action on a service unit (start, stop, restart, reload).
func (s *ServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
log := s.log.With(zap.String("unit", request.Unit), zap.String("action", request.Action.String()))
conn, err := s.dbus.NewSystemdConnectionContext(ctx)
conn, err := s.dbus.NewSystemConnectionContext(ctx)
if err != nil {
return fmt.Errorf("establishing systemd connection: %w", err)
}
defer conn.Close()
resultChan := make(chan string, 1)
switch request.Action {
@ -149,28 +159,41 @@ func (s *ServiceManager) WriteSystemdUnitFile(ctx context.Context, unit SystemdU
}
log.Infof("Wrote systemd unit file and performed daemon-reload")
return nil
}
// DefaultServiceUnit will write the default "bootstrapper.service" unit file.
func DefaultServiceUnit(ctx context.Context, serviceManager *ServiceManager) error {
if err := serviceManager.WriteSystemdUnitFile(ctx, SystemdUnit{
Name: debugd.BootstrapperSystemdUnitName,
Contents: debugd.BootstrapperSystemdUnitContents,
}); err != nil {
return fmt.Errorf("writing systemd unit file %q: %w", debugd.BootstrapperSystemdUnitName, err)
}
// try to start the default service if the binary exists but ignore failure.
// this is meant to start the bootstrapper after a reboot
// if a bootstrapper binary was uploaded before.
if ok, err := afero.Exists(serviceManager.fs, debugd.BootstrapperDeployFilename); ok && err == nil {
_ = serviceManager.SystemdAction(ctx, ServiceManagerRequest{
Unit: debugd.BootstrapperSystemdUnitName,
Action: Start,
})
// OverrideServiceUnitExecStart will override the ExecStart of a systemd unit.
func (s *ServiceManager) OverrideServiceUnitExecStart(ctx context.Context, unitName, execStart string) error {
log := s.log.With(zap.String("unitFile", fmt.Sprintf("%s/%s", systemdUnitFolder, unitName)))
log.Infof("Overriding systemd unit file execStart")
if !systemdUnitNameRegexp.MatchString(unitName) {
return fmt.Errorf("unit name %q is invalid", unitName)
}
// validate execStart (no newlines)
if strings.Contains(execStart, "\n") || strings.Contains(execStart, "\r") {
return fmt.Errorf("execStart must not contain newlines")
}
overrideUnitContents := fmt.Sprintf("[Service]\nExecStart=\nExecStart=%s\n", execStart)
s.systemdUnitFilewriteLock.Lock()
defer s.systemdUnitFilewriteLock.Unlock()
path := filepath.Join(systemdUnitFolder, unitName+".service.d", "override.conf")
if err := s.fs.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return fmt.Errorf("creating systemd unit file override directory %q: %w", filepath.Dir(path), err)
}
if err := afero.WriteFile(s.fs, path, []byte(overrideUnitContents), 0o644); err != nil {
return fmt.Errorf("writing systemd unit override file %q: %w", unitName, err)
}
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unitName, Action: Reload}); err != nil {
// do not return early here
// the "daemon-reload" command may return an unrelated error
// and there is no way to know if the override was successful
log.Warnf("Failed to perform systemd daemon-reload: %v", err)
}
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unitName + ".service", Action: Restart}); err != nil {
log.Warnf("Failed to perform unit restart: %v", err)
return fmt.Errorf("performing systemd unit restart: %w", err)
}
log.Infof("Overrode systemd unit file execStart, performed daemon-reload and restarted unit %v", unitName)
return nil
}

View file

@ -158,7 +158,7 @@ func TestWriteSystemdUnitFile(t *testing.T) {
"systemd reload fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
actionErr: errors.New("reload error"),
reloadErr: errors.New("reload error"),
},
},
unit: SystemdUnit{
@ -200,12 +200,127 @@ func TestWriteSystemdUnitFile(t *testing.T) {
}
}
func TestOverrideServiceUnitExecStart(t *testing.T) {
testCases := map[string]struct {
dbus stubDbus
unitName, execStart string
readonly bool
wantErr bool
wantFileContents string
wantActionCalls []dbusConnActionInput
wantReloads int
}{
"override works": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unitName: "test",
execStart: "/run/state/bin/test",
wantFileContents: "[Service]\nExecStart=\nExecStart=/run/state/bin/test\n",
wantActionCalls: []dbusConnActionInput{
{name: "test.service", mode: "replace"},
},
wantReloads: 1,
},
"unit name invalid": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unitName: "invalid name",
execStart: "/run/state/bin/test",
wantErr: true,
},
"exec start invalid": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unitName: "test",
execStart: "/run/state/bin/\r\ntest",
wantErr: true,
},
"write fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
},
},
unitName: "test",
execStart: "/run/state/bin/test",
readonly: true,
wantErr: true,
},
"reload fails but restart is still attempted": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
reloadErr: errors.New("reload error"),
},
},
unitName: "test",
execStart: "/run/state/bin/test",
wantFileContents: "[Service]\nExecStart=\nExecStart=/run/state/bin/test\n",
wantActionCalls: []dbusConnActionInput{
{name: "test.service", mode: "replace"},
},
wantReloads: 1,
},
"restart fails": {
dbus: stubDbus{
conn: &fakeDbusConn{
result: "done",
actionErr: errors.New("action error"),
},
},
unitName: "test",
execStart: "/run/state/bin/test",
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
assert.NoError(fs.MkdirAll(systemdUnitFolder, 0o755))
if tc.readonly {
fs = afero.NewReadOnlyFs(fs)
}
manager := ServiceManager{
log: logger.NewTest(t),
dbus: &tc.dbus,
fs: fs,
systemdUnitFilewriteLock: sync.Mutex{},
}
err := manager.OverrideServiceUnitExecStart(context.Background(), tc.unitName, tc.execStart)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
fileContents, err := afero.ReadFile(fs, "/run/systemd/system/test.service.d/override.conf")
assert.NoError(err)
assert.Equal(tc.wantFileContents, string(fileContents))
assert.Equal(tc.wantActionCalls, tc.dbus.conn.(*fakeDbusConn).inputs)
assert.Equal(tc.wantReloads, tc.dbus.conn.(*fakeDbusConn).reloadCalls)
})
}
}
type stubDbus struct {
conn dbusConn
connErr error
}
func (s *stubDbus) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
func (s *stubDbus) NewSystemConnectionContext(ctx context.Context) (dbusConn, error) {
return s.conn, s.connErr
}
@ -215,11 +330,13 @@ type dbusConnActionInput struct {
}
type fakeDbusConn struct {
inputs []dbusConnActionInput
result string
inputs []dbusConnActionInput
result string
reloadCalls int
jobID int
actionErr error
reloadErr error
}
func (c *fakeDbusConn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
@ -244,5 +361,9 @@ func (c *fakeDbusConn) RestartUnitContext(ctx context.Context, name string, mode
}
func (c *fakeDbusConn) ReloadContext(ctx context.Context) error {
return c.actionErr
c.reloadCalls++
return c.reloadErr
}
func (c *fakeDbusConn) Close() {}

View file

@ -15,8 +15,8 @@ import (
// wraps go-systemd dbus.
type dbusWrapper struct{}
func (d *dbusWrapper) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
conn, err := dbus.NewSystemdConnectionContext(ctx)
func (d *dbusWrapper) NewSystemConnectionContext(ctx context.Context) (dbusConn, error) {
conn, err := dbus.NewSystemConnectionContext(ctx)
if err != nil {
return nil, err
}
@ -44,3 +44,7 @@ func (c *dbusConnWrapper) RestartUnitContext(ctx context.Context, name string, m
func (c *dbusConnWrapper) ReloadContext(ctx context.Context) error {
return c.conn.ReloadContext(ctx)
}
func (c *dbusConnWrapper) Close() {
c.conn.Close()
}