mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-08-02 03:56:07 -04:00
debugd: implement upload of multiple binaries
This commit is contained in:
parent
e6ac8e2a91
commit
6f56ed69f8
21 changed files with 2040 additions and 661 deletions
|
@ -14,7 +14,9 @@ const (
|
|||
GRPCTimeout = 5 * time.Minute
|
||||
DiscoverDebugdInterval = 30 * time.Second
|
||||
DownloadRetryBackoff = 1 * time.Minute
|
||||
BinaryAccessMode = 0o755 // -rwxr-xr-x
|
||||
BootstrapperDeployFilename = "/run/state/bin/bootstrapper"
|
||||
UpgradeAgentDeployFilename = "/run/state/bin/upgrade-agent"
|
||||
Chunksize = 1024
|
||||
BootstrapperSystemdUnitName = "bootstrapper.service"
|
||||
BootstrapperSystemdUnitContents = `[Unit]
|
||||
|
|
|
@ -8,13 +8,13 @@ package deploy
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
|
@ -27,19 +27,19 @@ import (
|
|||
type Download struct {
|
||||
log *logger.Logger
|
||||
dialer NetDialer
|
||||
writer streamToFileWriter
|
||||
transfer fileTransferer
|
||||
serviceManager serviceManager
|
||||
info infoSetter
|
||||
}
|
||||
|
||||
// New creates a new Download.
|
||||
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager,
|
||||
writer streamToFileWriter, info infoSetter,
|
||||
transfer fileTransferer, info infoSetter,
|
||||
) *Download {
|
||||
return &Download{
|
||||
log: log,
|
||||
dialer: dialer,
|
||||
writer: writer,
|
||||
transfer: transfer,
|
||||
info: info,
|
||||
serviceManager: serviceManager,
|
||||
}
|
||||
|
@ -47,6 +47,10 @@ func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager,
|
|||
|
||||
// DownloadInfo will try to download the info from another instance.
|
||||
func (d *Download) DownloadInfo(ctx context.Context, ip string) error {
|
||||
if d.info.Received() {
|
||||
return nil
|
||||
}
|
||||
|
||||
log := d.log.With(zap.String("ip", ip))
|
||||
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
|
||||
|
||||
|
@ -66,7 +70,7 @@ func (d *Download) DownloadInfo(ctx context.Context, ip string) error {
|
|||
return d.info.SetProto(resp.Info)
|
||||
}
|
||||
|
||||
// DownloadDeployment will open a new grpc connection to another instance, attempting to download a bootstrapper from that instance.
|
||||
// DownloadDeployment will open a new grpc connection to another instance, attempting to download files from that instance.
|
||||
func (d *Download) DownloadDeployment(ctx context.Context, ip string) error {
|
||||
log := d.log.With(zap.String("ip", ip))
|
||||
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
|
||||
|
@ -77,23 +81,38 @@ func (d *Download) DownloadDeployment(ctx context.Context, ip string) error {
|
|||
}
|
||||
defer closer.Close()
|
||||
|
||||
log.Infof("Trying to download bootstrapper")
|
||||
stream, err := client.DownloadBootstrapper(ctx, &pb.DownloadBootstrapperRequest{})
|
||||
log.Infof("Trying to download files")
|
||||
stream, err := client.DownloadFiles(ctx, &pb.DownloadFilesRequest{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting bootstrapper download from other instance: %w", err)
|
||||
return fmt.Errorf("starting file download from other instance: %w", err)
|
||||
}
|
||||
if err := d.writer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
|
||||
return fmt.Errorf("streaming bootstrapper from other instance: %w", err)
|
||||
}
|
||||
log.Infof("Successfully downloaded bootstrapper")
|
||||
|
||||
// after the upload succeeds, try to restart the bootstrapper
|
||||
restartAction := ServiceManagerRequest{
|
||||
Unit: debugd.BootstrapperSystemdUnitName,
|
||||
Action: Restart,
|
||||
err = d.transfer.RecvFiles(stream)
|
||||
switch {
|
||||
case err == nil:
|
||||
d.log.Infof("Downloading files succeeded")
|
||||
case errors.Is(err, filetransfer.ErrReceiveRunning):
|
||||
d.log.Warnf("Download already in progress")
|
||||
return err
|
||||
case errors.Is(err, filetransfer.ErrReceiveFinished):
|
||||
d.log.Warnf("Download already finished")
|
||||
return nil
|
||||
default:
|
||||
d.log.With(zap.Error(err)).Errorf("Downloading files failed")
|
||||
return err
|
||||
}
|
||||
if err := d.serviceManager.SystemdAction(ctx, restartAction); err != nil {
|
||||
return fmt.Errorf("restarting bootstrapper: %w", err)
|
||||
|
||||
files := d.transfer.GetFiles()
|
||||
for _, file := range files {
|
||||
if file.OverrideServiceUnit == "" {
|
||||
continue
|
||||
}
|
||||
if err := d.serviceManager.OverrideServiceUnitExecStart(
|
||||
ctx, file.OverrideServiceUnit, file.TargetPath,
|
||||
); err != nil {
|
||||
// continue on error to allow other units to be overridden
|
||||
d.log.With(zap.Error(err)).Errorf("Failed to override service unit %s", file.OverrideServiceUnit)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -123,14 +142,17 @@ func (d *Download) grpcWithDialer() grpc.DialOption {
|
|||
|
||||
type infoSetter interface {
|
||||
SetProto(infos []*pb.Info) error
|
||||
Received() bool
|
||||
}
|
||||
|
||||
type serviceManager interface {
|
||||
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
|
||||
OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error
|
||||
}
|
||||
|
||||
type streamToFileWriter interface {
|
||||
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
|
||||
type fileTransferer interface {
|
||||
RecvFiles(stream filetransfer.RecvFilesStream) error
|
||||
GetFiles() []filetransfer.FileStat
|
||||
}
|
||||
|
||||
// NetDialer can open a net.Conn.
|
||||
|
|
|
@ -9,14 +9,11 @@ package deploy
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
|
||||
|
@ -33,41 +30,72 @@ func TestMain(m *testing.M) {
|
|||
)
|
||||
}
|
||||
|
||||
func TestDownloadBootstrapper(t *testing.T) {
|
||||
filename := "/run/state/bin/bootstrapper"
|
||||
someErr := errors.New("failed")
|
||||
|
||||
func TestDownloadDeployment(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
server fakeDownloadServer
|
||||
serviceManager stubServiceManager
|
||||
wantChunks [][]byte
|
||||
wantDownloadErr bool
|
||||
wantFile bool
|
||||
wantSystemdAction bool
|
||||
wantDeployed bool
|
||||
files []filetransfer.FileStat
|
||||
recvFilesErr error
|
||||
overrideServiceUnitErr error
|
||||
wantErr bool
|
||||
wantOverrideCalls []struct{ UnitName, ExecStart string }
|
||||
}{
|
||||
"download works": {
|
||||
server: fakeDownloadServer{
|
||||
chunks: [][]byte{[]byte("test")},
|
||||
files: []filetransfer.FileStat{
|
||||
{
|
||||
SourcePath: "source/testfileA",
|
||||
TargetPath: "target/testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "unitA",
|
||||
},
|
||||
{
|
||||
SourcePath: "source/testfileB",
|
||||
TargetPath: "target/testfileB",
|
||||
Mode: 0o644,
|
||||
},
|
||||
},
|
||||
wantOverrideCalls: []struct{ UnitName, ExecStart string }{
|
||||
{"unitA", "target/testfileA"},
|
||||
},
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
wantDownloadErr: false,
|
||||
wantFile: true,
|
||||
wantSystemdAction: true,
|
||||
wantDeployed: true,
|
||||
},
|
||||
"download rpc call error is detected": {
|
||||
server: fakeDownloadServer{downladErr: someErr},
|
||||
wantDownloadErr: true,
|
||||
"recv files error is detected": {
|
||||
recvFilesErr: errors.New("some error"),
|
||||
wantErr: true,
|
||||
},
|
||||
"service restart error is detected": {
|
||||
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
|
||||
serviceManager: stubServiceManager{systemdActionErr: someErr},
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
wantDownloadErr: true,
|
||||
wantFile: true,
|
||||
wantDeployed: true,
|
||||
wantSystemdAction: false,
|
||||
"recv already running": {
|
||||
recvFilesErr: filetransfer.ErrReceiveRunning,
|
||||
wantErr: true,
|
||||
},
|
||||
"recv already finished": {
|
||||
files: []filetransfer.FileStat{
|
||||
{
|
||||
SourcePath: "source/testfileA",
|
||||
TargetPath: "target/testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "unitA",
|
||||
},
|
||||
},
|
||||
recvFilesErr: filetransfer.ErrReceiveFinished,
|
||||
wantErr: false,
|
||||
},
|
||||
"service unit fail does not stop further tries": {
|
||||
files: []filetransfer.FileStat{
|
||||
{
|
||||
SourcePath: "source/testfileA",
|
||||
TargetPath: "target/testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "unitA",
|
||||
},
|
||||
{
|
||||
SourcePath: "source/testfileB",
|
||||
TargetPath: "target/testfileB",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "unitB",
|
||||
},
|
||||
},
|
||||
overrideServiceUnitErr: errors.New("some error"),
|
||||
wantOverrideCalls: []struct{ UnitName, ExecStart string }{
|
||||
{"unitA", "target/testfileA"},
|
||||
{"unitB", "target/testfileB"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -76,11 +104,13 @@ func TestDownloadBootstrapper(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
|
||||
ip := "192.0.2.0"
|
||||
writer := &fakeStreamToFileWriter{}
|
||||
transfer := &stubTransfer{recvFilesErr: tc.recvFilesErr, files: tc.files}
|
||||
serviceMgr := &stubServiceManager{overrideServiceUnitExecStartErr: tc.overrideServiceUnitErr}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
|
||||
server := &stubDownloadServer{}
|
||||
grpcServ := grpc.NewServer()
|
||||
pb.RegisterDebugdServer(grpcServ, &tc.server)
|
||||
pb.RegisterDebugdServer(grpcServ, server)
|
||||
lis := dialer.GetListener(net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort)))
|
||||
go grpcServ.Serve(lis)
|
||||
defer grpcServ.GracefulStop()
|
||||
|
@ -88,30 +118,19 @@ func TestDownloadBootstrapper(t *testing.T) {
|
|||
download := &Download{
|
||||
log: logger.NewTest(t),
|
||||
dialer: dialer,
|
||||
writer: writer,
|
||||
serviceManager: &tc.serviceManager,
|
||||
transfer: transfer,
|
||||
serviceManager: serviceMgr,
|
||||
}
|
||||
|
||||
err := download.DownloadDeployment(context.Background(), ip)
|
||||
|
||||
if tc.wantDownloadErr {
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
if tc.wantFile {
|
||||
assert.Equal(tc.wantChunks, writer.chunks)
|
||||
assert.Equal(filename, writer.filename)
|
||||
}
|
||||
if tc.wantSystemdAction {
|
||||
assert.ElementsMatch(
|
||||
[]ServiceManagerRequest{
|
||||
{Unit: debugd.BootstrapperSystemdUnitName, Action: Restart},
|
||||
},
|
||||
tc.serviceManager.requests,
|
||||
)
|
||||
}
|
||||
assert.Equal(tc.wantOverrideCalls, serviceMgr.overrideCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -189,6 +208,9 @@ func TestDownloadInfo(t *testing.T) {
|
|||
type stubServiceManager struct {
|
||||
requests []ServiceManagerRequest
|
||||
systemdActionErr error
|
||||
|
||||
overrideCalls []struct{ UnitName, ExecStart string }
|
||||
overrideServiceUnitExecStartErr error
|
||||
}
|
||||
|
||||
func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
|
||||
|
@ -196,39 +218,34 @@ func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceM
|
|||
return s.systemdActionErr
|
||||
}
|
||||
|
||||
type fakeStreamToFileWriter struct {
|
||||
chunks [][]byte
|
||||
filename string
|
||||
func (s *stubServiceManager) OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error {
|
||||
s.overrideCalls = append(s.overrideCalls, struct {
|
||||
UnitName, ExecStart string
|
||||
}{UnitName: unitName, ExecStart: execStart})
|
||||
return s.overrideServiceUnitExecStartErr
|
||||
}
|
||||
|
||||
func (f *fakeStreamToFileWriter) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
|
||||
f.filename = filename
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading stream: %w", err)
|
||||
}
|
||||
f.chunks = append(f.chunks, chunk.Content)
|
||||
}
|
||||
type stubTransfer struct {
|
||||
recvFilesErr error
|
||||
files []filetransfer.FileStat
|
||||
}
|
||||
|
||||
// fakeDownloadServer implements DebugdServer; only fakes DownloadBootstrapper, panics on every other rpc.
|
||||
type fakeDownloadServer struct {
|
||||
chunks [][]byte
|
||||
func (t *stubTransfer) RecvFiles(_ filetransfer.RecvFilesStream) error {
|
||||
return t.recvFilesErr
|
||||
}
|
||||
|
||||
func (t *stubTransfer) GetFiles() []filetransfer.FileStat {
|
||||
return t.files
|
||||
}
|
||||
|
||||
// stubDownloadServer implements DebugdServer; only stubs DownloadFiles, panics on every other rpc.
|
||||
type stubDownloadServer struct {
|
||||
downladErr error
|
||||
|
||||
pb.UnimplementedDebugdServer
|
||||
}
|
||||
|
||||
func (s *fakeDownloadServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
|
||||
for _, chunk := range s.chunks {
|
||||
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
|
||||
return fmt.Errorf("sending chunk: %w", err)
|
||||
}
|
||||
}
|
||||
func (s *stubDownloadServer) DownloadFiles(request *pb.DownloadFilesRequest, stream pb.Debugd_DownloadFilesServer) error {
|
||||
return s.downladErr
|
||||
}
|
||||
|
||||
|
@ -244,6 +261,7 @@ func (s *stubDebugdServer) GetInfo(ctx context.Context, request *pb.GetInfoReque
|
|||
|
||||
type stubInfoSetter struct {
|
||||
info []*pb.Info
|
||||
received bool
|
||||
setProtoErr error
|
||||
}
|
||||
|
||||
|
@ -251,3 +269,7 @@ func (s *stubInfoSetter) SetProto(infos []*pb.Info) error {
|
|||
s.info = infos
|
||||
return s.setProtoErr
|
||||
}
|
||||
|
||||
func (s *stubInfoSetter) Received() bool {
|
||||
return s.received
|
||||
}
|
||||
|
|
|
@ -9,9 +9,12 @@ package deploy
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/spf13/afero"
|
||||
"go.uber.org/zap"
|
||||
|
@ -21,6 +24,10 @@ const (
|
|||
systemdUnitFolder = "/run/systemd/system"
|
||||
)
|
||||
|
||||
// systemdUnitNameRegexp is a regular expression that matches valid systemd unit names.
|
||||
// This is only the unit name, without the .service suffix.
|
||||
var systemdUnitNameRegexp = regexp.MustCompile(`^[a-zA-Z0-9@._\-\\]+$`)
|
||||
|
||||
// SystemdAction encodes the available actions.
|
||||
//
|
||||
//go:generate stringer -type=SystemdAction
|
||||
|
@ -73,7 +80,7 @@ func NewServiceManager(log *logger.Logger) *ServiceManager {
|
|||
type dbusClient interface {
|
||||
// NewSystemConnectionContext establishes a connection to the system bus and authenticates.
|
||||
// Callers should call Close() when done with the connection.
|
||||
NewSystemdConnectionContext(ctx context.Context) (dbusConn, error)
|
||||
NewSystemConnectionContext(ctx context.Context) (dbusConn, error)
|
||||
}
|
||||
|
||||
type dbusConn interface {
|
||||
|
@ -89,15 +96,18 @@ type dbusConn interface {
|
|||
// ReloadContext instructs systemd to scan for and reload unit files. This is
|
||||
// an equivalent to systemctl daemon-reload.
|
||||
ReloadContext(ctx context.Context) error
|
||||
// Close closes the connection.
|
||||
Close()
|
||||
}
|
||||
|
||||
// SystemdAction will perform a systemd action on a service unit (start, stop, restart, reload).
|
||||
func (s *ServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
|
||||
log := s.log.With(zap.String("unit", request.Unit), zap.String("action", request.Action.String()))
|
||||
conn, err := s.dbus.NewSystemdConnectionContext(ctx)
|
||||
conn, err := s.dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("establishing systemd connection: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
resultChan := make(chan string, 1)
|
||||
switch request.Action {
|
||||
|
@ -149,28 +159,41 @@ func (s *ServiceManager) WriteSystemdUnitFile(ctx context.Context, unit SystemdU
|
|||
}
|
||||
|
||||
log.Infof("Wrote systemd unit file and performed daemon-reload")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DefaultServiceUnit will write the default "bootstrapper.service" unit file.
|
||||
func DefaultServiceUnit(ctx context.Context, serviceManager *ServiceManager) error {
|
||||
if err := serviceManager.WriteSystemdUnitFile(ctx, SystemdUnit{
|
||||
Name: debugd.BootstrapperSystemdUnitName,
|
||||
Contents: debugd.BootstrapperSystemdUnitContents,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("writing systemd unit file %q: %w", debugd.BootstrapperSystemdUnitName, err)
|
||||
}
|
||||
|
||||
// try to start the default service if the binary exists but ignore failure.
|
||||
// this is meant to start the bootstrapper after a reboot
|
||||
// if a bootstrapper binary was uploaded before.
|
||||
if ok, err := afero.Exists(serviceManager.fs, debugd.BootstrapperDeployFilename); ok && err == nil {
|
||||
_ = serviceManager.SystemdAction(ctx, ServiceManagerRequest{
|
||||
Unit: debugd.BootstrapperSystemdUnitName,
|
||||
Action: Start,
|
||||
})
|
||||
// OverrideServiceUnitExecStart will override the ExecStart of a systemd unit.
|
||||
func (s *ServiceManager) OverrideServiceUnitExecStart(ctx context.Context, unitName, execStart string) error {
|
||||
log := s.log.With(zap.String("unitFile", fmt.Sprintf("%s/%s", systemdUnitFolder, unitName)))
|
||||
log.Infof("Overriding systemd unit file execStart")
|
||||
if !systemdUnitNameRegexp.MatchString(unitName) {
|
||||
return fmt.Errorf("unit name %q is invalid", unitName)
|
||||
}
|
||||
// validate execStart (no newlines)
|
||||
if strings.Contains(execStart, "\n") || strings.Contains(execStart, "\r") {
|
||||
return fmt.Errorf("execStart must not contain newlines")
|
||||
}
|
||||
overrideUnitContents := fmt.Sprintf("[Service]\nExecStart=\nExecStart=%s\n", execStart)
|
||||
s.systemdUnitFilewriteLock.Lock()
|
||||
defer s.systemdUnitFilewriteLock.Unlock()
|
||||
path := filepath.Join(systemdUnitFolder, unitName+".service.d", "override.conf")
|
||||
if err := s.fs.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
|
||||
return fmt.Errorf("creating systemd unit file override directory %q: %w", filepath.Dir(path), err)
|
||||
}
|
||||
if err := afero.WriteFile(s.fs, path, []byte(overrideUnitContents), 0o644); err != nil {
|
||||
return fmt.Errorf("writing systemd unit override file %q: %w", unitName, err)
|
||||
}
|
||||
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unitName, Action: Reload}); err != nil {
|
||||
// do not return early here
|
||||
// the "daemon-reload" command may return an unrelated error
|
||||
// and there is no way to know if the override was successful
|
||||
log.Warnf("Failed to perform systemd daemon-reload: %v", err)
|
||||
}
|
||||
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unitName + ".service", Action: Restart}); err != nil {
|
||||
log.Warnf("Failed to perform unit restart: %v", err)
|
||||
return fmt.Errorf("performing systemd unit restart: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Overrode systemd unit file execStart, performed daemon-reload and restarted unit %v", unitName)
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -158,7 +158,7 @@ func TestWriteSystemdUnitFile(t *testing.T) {
|
|||
"systemd reload fails": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
actionErr: errors.New("reload error"),
|
||||
reloadErr: errors.New("reload error"),
|
||||
},
|
||||
},
|
||||
unit: SystemdUnit{
|
||||
|
@ -200,12 +200,127 @@ func TestWriteSystemdUnitFile(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOverrideServiceUnitExecStart(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
dbus stubDbus
|
||||
unitName, execStart string
|
||||
readonly bool
|
||||
wantErr bool
|
||||
wantFileContents string
|
||||
wantActionCalls []dbusConnActionInput
|
||||
wantReloads int
|
||||
}{
|
||||
"override works": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
unitName: "test",
|
||||
execStart: "/run/state/bin/test",
|
||||
wantFileContents: "[Service]\nExecStart=\nExecStart=/run/state/bin/test\n",
|
||||
wantActionCalls: []dbusConnActionInput{
|
||||
{name: "test.service", mode: "replace"},
|
||||
},
|
||||
wantReloads: 1,
|
||||
},
|
||||
"unit name invalid": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
unitName: "invalid name",
|
||||
execStart: "/run/state/bin/test",
|
||||
wantErr: true,
|
||||
},
|
||||
"exec start invalid": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
unitName: "test",
|
||||
execStart: "/run/state/bin/\r\ntest",
|
||||
wantErr: true,
|
||||
},
|
||||
"write fails": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
unitName: "test",
|
||||
execStart: "/run/state/bin/test",
|
||||
readonly: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"reload fails but restart is still attempted": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
reloadErr: errors.New("reload error"),
|
||||
},
|
||||
},
|
||||
unitName: "test",
|
||||
execStart: "/run/state/bin/test",
|
||||
wantFileContents: "[Service]\nExecStart=\nExecStart=/run/state/bin/test\n",
|
||||
wantActionCalls: []dbusConnActionInput{
|
||||
{name: "test.service", mode: "replace"},
|
||||
},
|
||||
wantReloads: 1,
|
||||
},
|
||||
"restart fails": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
actionErr: errors.New("action error"),
|
||||
},
|
||||
},
|
||||
unitName: "test",
|
||||
execStart: "/run/state/bin/test",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
assert.NoError(fs.MkdirAll(systemdUnitFolder, 0o755))
|
||||
if tc.readonly {
|
||||
fs = afero.NewReadOnlyFs(fs)
|
||||
}
|
||||
manager := ServiceManager{
|
||||
log: logger.NewTest(t),
|
||||
dbus: &tc.dbus,
|
||||
fs: fs,
|
||||
systemdUnitFilewriteLock: sync.Mutex{},
|
||||
}
|
||||
err := manager.OverrideServiceUnitExecStart(context.Background(), tc.unitName, tc.execStart)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
fileContents, err := afero.ReadFile(fs, "/run/systemd/system/test.service.d/override.conf")
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantFileContents, string(fileContents))
|
||||
assert.Equal(tc.wantActionCalls, tc.dbus.conn.(*fakeDbusConn).inputs)
|
||||
assert.Equal(tc.wantReloads, tc.dbus.conn.(*fakeDbusConn).reloadCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubDbus struct {
|
||||
conn dbusConn
|
||||
connErr error
|
||||
}
|
||||
|
||||
func (s *stubDbus) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
|
||||
func (s *stubDbus) NewSystemConnectionContext(ctx context.Context) (dbusConn, error) {
|
||||
return s.conn, s.connErr
|
||||
}
|
||||
|
||||
|
@ -215,11 +330,13 @@ type dbusConnActionInput struct {
|
|||
}
|
||||
|
||||
type fakeDbusConn struct {
|
||||
inputs []dbusConnActionInput
|
||||
result string
|
||||
inputs []dbusConnActionInput
|
||||
result string
|
||||
reloadCalls int
|
||||
|
||||
jobID int
|
||||
actionErr error
|
||||
reloadErr error
|
||||
}
|
||||
|
||||
func (c *fakeDbusConn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
|
@ -244,5 +361,9 @@ func (c *fakeDbusConn) RestartUnitContext(ctx context.Context, name string, mode
|
|||
}
|
||||
|
||||
func (c *fakeDbusConn) ReloadContext(ctx context.Context) error {
|
||||
return c.actionErr
|
||||
c.reloadCalls++
|
||||
|
||||
return c.reloadErr
|
||||
}
|
||||
|
||||
func (c *fakeDbusConn) Close() {}
|
||||
|
|
|
@ -15,8 +15,8 @@ import (
|
|||
// wraps go-systemd dbus.
|
||||
type dbusWrapper struct{}
|
||||
|
||||
func (d *dbusWrapper) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
|
||||
conn, err := dbus.NewSystemdConnectionContext(ctx)
|
||||
func (d *dbusWrapper) NewSystemConnectionContext(ctx context.Context) (dbusConn, error) {
|
||||
conn, err := dbus.NewSystemConnectionContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -44,3 +44,7 @@ func (c *dbusConnWrapper) RestartUnitContext(ctx context.Context, name string, m
|
|||
func (c *dbusConnWrapper) ReloadContext(ctx context.Context) error {
|
||||
return c.conn.ReloadContext(ctx)
|
||||
}
|
||||
|
||||
func (c *dbusConnWrapper) Close() {
|
||||
c.conn.Close()
|
||||
}
|
||||
|
|
|
@ -29,6 +29,13 @@ func NewMap() *Map {
|
|||
}
|
||||
}
|
||||
|
||||
// Received returns true if the info map has been set.
|
||||
func (i *Map) Received() bool {
|
||||
i.mux.RLock()
|
||||
defer i.mux.RUnlock()
|
||||
return i.received
|
||||
}
|
||||
|
||||
// Get returns the value of the info with the given key.
|
||||
func (i *Map) Get(key string) (string, bool, error) {
|
||||
i.mux.RLock()
|
||||
|
@ -67,7 +74,7 @@ func (i *Map) SetProto(infos []*servicepb.Info) error {
|
|||
defer i.mux.Unlock()
|
||||
|
||||
if i.received {
|
||||
return errors.New("info already set")
|
||||
return ErrInfoAlreadySet
|
||||
}
|
||||
|
||||
infoMap := make(map[string]string)
|
||||
|
@ -114,3 +121,6 @@ func (i *Map) GetProto() ([]*servicepb.Info, error) {
|
|||
}
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
// ErrInfoAlreadySet is returned if the info map has already been set.
|
||||
var ErrInfoAlreadySet = errors.New("info already set")
|
||||
|
|
|
@ -284,6 +284,10 @@ func TestConcurrency(t *testing.T) {
|
|||
_, _ = i.GetProto()
|
||||
}
|
||||
|
||||
received := func() {
|
||||
_ = i.Received()
|
||||
}
|
||||
|
||||
go get()
|
||||
go get()
|
||||
go get()
|
||||
|
@ -300,4 +304,8 @@ func TestConcurrency(t *testing.T) {
|
|||
go getProto()
|
||||
go getProto()
|
||||
go getProto()
|
||||
go received()
|
||||
go received()
|
||||
go received()
|
||||
go received()
|
||||
}
|
||||
|
|
|
@ -9,20 +9,18 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd/deploy"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd/info"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
@ -31,18 +29,18 @@ import (
|
|||
type debugdServer struct {
|
||||
log *logger.Logger
|
||||
serviceManager serviceManager
|
||||
streamer streamer
|
||||
transfer fileTransferer
|
||||
info *info.Map
|
||||
|
||||
pb.UnimplementedDebugdServer
|
||||
}
|
||||
|
||||
// New creates a new debugdServer according to the gRPC spec.
|
||||
func New(log *logger.Logger, serviceManager serviceManager, streamer streamer, infos *info.Map) pb.DebugdServer {
|
||||
func New(log *logger.Logger, serviceManager serviceManager, transfer fileTransferer, infos *info.Map) pb.DebugdServer {
|
||||
return &debugdServer{
|
||||
log: log,
|
||||
serviceManager: serviceManager,
|
||||
streamer: streamer,
|
||||
transfer: transfer,
|
||||
info: infos,
|
||||
}
|
||||
}
|
||||
|
@ -55,13 +53,23 @@ func (s *debugdServer) SetInfo(ctx context.Context, req *pb.SetInfoRequest) (*pb
|
|||
s.log.Infof("Info is empty")
|
||||
}
|
||||
|
||||
if err := s.info.SetProto(req.Info); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Setting info failed")
|
||||
return &pb.SetInfoResponse{}, err
|
||||
setProtoErr := s.info.SetProto(req.Info)
|
||||
if errors.Is(setProtoErr, info.ErrInfoAlreadySet) {
|
||||
s.log.Warnf("Setting info failed (already set)")
|
||||
return &pb.SetInfoResponse{
|
||||
Status: pb.SetInfoStatus_SET_INFO_ALREADY_SET,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if setProtoErr != nil {
|
||||
s.log.With(zap.Error(setProtoErr)).Errorf("Setting info failed")
|
||||
return nil, setProtoErr
|
||||
}
|
||||
s.log.Infof("Info set")
|
||||
|
||||
return &pb.SetInfoResponse{}, nil
|
||||
return &pb.SetInfoResponse{
|
||||
Status: pb.SetInfoStatus_SET_INFO_SUCCESS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetInfo returns the info of the debugd instance.
|
||||
|
@ -76,46 +84,66 @@ func (s *debugdServer) GetInfo(ctx context.Context, req *pb.GetInfoRequest) (*pb
|
|||
return &pb.GetInfoResponse{Info: info}, nil
|
||||
}
|
||||
|
||||
// UploadBootstrapper receives a bootstrapper binary in a stream of chunks and writes to a file.
|
||||
func (s *debugdServer) UploadBootstrapper(stream pb.Debugd_UploadBootstrapperServer) error {
|
||||
startAction := deploy.ServiceManagerRequest{
|
||||
Unit: debugd.BootstrapperSystemdUnitName,
|
||||
Action: deploy.Start,
|
||||
}
|
||||
var responseStatus pb.UploadBootstrapperStatus
|
||||
defer func() {
|
||||
if err := s.serviceManager.SystemdAction(stream.Context(), startAction); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Starting uploaded bootstrapper failed")
|
||||
if responseStatus == pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS {
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED
|
||||
}
|
||||
}
|
||||
stream.SendAndClose(&pb.UploadBootstrapperResponse{
|
||||
Status: responseStatus,
|
||||
// UploadFiles receives a stream of files (each consisting of a header and a stream of chunks) and writes them to the filesystem.
|
||||
func (s *debugdServer) UploadFiles(stream pb.Debugd_UploadFilesServer) error {
|
||||
s.log.Infof("Received UploadFiles request")
|
||||
err := s.transfer.RecvFiles(stream)
|
||||
switch {
|
||||
case err == nil:
|
||||
s.log.Infof("Uploading files succeeded")
|
||||
case errors.Is(err, filetransfer.ErrReceiveRunning):
|
||||
s.log.Warnf("Upload already in progress")
|
||||
return stream.SendAndClose(&pb.UploadFilesResponse{
|
||||
Status: pb.UploadFilesStatus_UPLOAD_FILES_ALREADY_STARTED,
|
||||
})
|
||||
case errors.Is(err, filetransfer.ErrReceiveFinished):
|
||||
s.log.Warnf("Upload already finished")
|
||||
return stream.SendAndClose(&pb.UploadFilesResponse{
|
||||
Status: pb.UploadFilesStatus_UPLOAD_FILES_ALREADY_FINISHED,
|
||||
})
|
||||
default:
|
||||
s.log.With(zap.Error(err)).Errorf("Uploading files failed")
|
||||
return stream.SendAndClose(&pb.UploadFilesResponse{
|
||||
Status: pb.UploadFilesStatus_UPLOAD_FILES_UPLOAD_FAILED,
|
||||
})
|
||||
}()
|
||||
s.log.Infof("Starting bootstrapper upload")
|
||||
if err := s.streamer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
|
||||
if errors.Is(err, fs.ErrExist) {
|
||||
// bootstrapper was already uploaded
|
||||
s.log.Warnf("Bootstrapper already uploaded")
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_FILE_EXISTS
|
||||
return nil
|
||||
}
|
||||
s.log.With(zap.Error(err)).Errorf("Uploading bootstrapper failed")
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED
|
||||
return fmt.Errorf("uploading bootstrapper: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infof("Successfully uploaded bootstrapper")
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS
|
||||
return nil
|
||||
files := s.transfer.GetFiles()
|
||||
var overrideUnitErr error
|
||||
for _, file := range files {
|
||||
if file.OverrideServiceUnit == "" {
|
||||
continue
|
||||
}
|
||||
// continue on error to allow other units to be overridden
|
||||
// TODO: switch to native go multierror once 1.20 is released
|
||||
// err = s.serviceManager.OverrideServiceUnitExecStart(stream.Context(), file.OverrideServiceUnit, file.TargetPath)
|
||||
// if err != nil {
|
||||
// overrideUnitErr = errors.Join(overrideUnitErr, err)
|
||||
// }
|
||||
err = s.serviceManager.OverrideServiceUnitExecStart(stream.Context(), file.OverrideServiceUnit, file.TargetPath)
|
||||
if err != nil {
|
||||
overrideUnitErr = multierr.Append(overrideUnitErr, err)
|
||||
}
|
||||
}
|
||||
|
||||
if overrideUnitErr != nil {
|
||||
s.log.With(zap.Error(overrideUnitErr)).Errorf("Overriding service units failed")
|
||||
return stream.SendAndClose(&pb.UploadFilesResponse{
|
||||
Status: pb.UploadFilesStatus_UPLOAD_FILES_START_FAILED,
|
||||
})
|
||||
}
|
||||
return stream.SendAndClose(&pb.UploadFilesResponse{
|
||||
Status: pb.UploadFilesStatus_UPLOAD_FILES_SUCCESS,
|
||||
})
|
||||
}
|
||||
|
||||
// DownloadBootstrapper streams the local bootstrapper binary to other instances.
|
||||
func (s *debugdServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
|
||||
s.log.Infof("Sending bootstrapper to other instance")
|
||||
return s.streamer.ReadStream(debugd.BootstrapperDeployFilename, stream, debugd.Chunksize, true)
|
||||
// DownloadFiles streams the previously received files to other instances.
|
||||
func (s *debugdServer) DownloadFiles(request *pb.DownloadFilesRequest, stream pb.Debugd_DownloadFilesServer) error {
|
||||
s.log.Infof("Sending files to other instance")
|
||||
if !s.transfer.CanSend() {
|
||||
return errors.New("cannot send files at this time")
|
||||
}
|
||||
return s.transfer.SendFiles(stream)
|
||||
}
|
||||
|
||||
// UploadSystemServiceUnits receives systemd service units, writes them to a service file and schedules a daemon-reload.
|
||||
|
@ -157,9 +185,12 @@ func Start(log *logger.Logger, wg *sync.WaitGroup, serv pb.DebugdServer) {
|
|||
type serviceManager interface {
|
||||
SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error
|
||||
WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error
|
||||
OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error
|
||||
}
|
||||
|
||||
type streamer interface {
|
||||
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
|
||||
ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error
|
||||
type fileTransferer interface {
|
||||
RecvFiles(stream filetransfer.RecvFilesStream) error
|
||||
SendFiles(stream filetransfer.SendFilesStream) error
|
||||
GetFiles() []filetransfer.FileStat
|
||||
CanSend() bool
|
||||
}
|
||||
|
|
|
@ -9,15 +9,14 @@ package server
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd/deploy"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd/info"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/edgelesssys/constellation/v2/internal/constants"
|
||||
"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
|
||||
|
@ -40,21 +39,23 @@ func TestSetInfo(t *testing.T) {
|
|||
info *info.Map
|
||||
infoReceived bool
|
||||
setInfo []*pb.Info
|
||||
wantErr bool
|
||||
wantStatus pb.SetInfoStatus
|
||||
}{
|
||||
"set info works": {
|
||||
setInfo: []*pb.Info{{Key: "foo", Value: "bar"}},
|
||||
info: info.NewMap(),
|
||||
setInfo: []*pb.Info{{Key: "foo", Value: "bar"}},
|
||||
info: info.NewMap(),
|
||||
wantStatus: pb.SetInfoStatus_SET_INFO_SUCCESS,
|
||||
},
|
||||
"set empty info works": {
|
||||
setInfo: []*pb.Info{},
|
||||
info: info.NewMap(),
|
||||
setInfo: []*pb.Info{},
|
||||
info: info.NewMap(),
|
||||
wantStatus: pb.SetInfoStatus_SET_INFO_SUCCESS,
|
||||
},
|
||||
"set fails when info already set": {
|
||||
info: info.NewMap(),
|
||||
infoReceived: true,
|
||||
setInfo: []*pb.Info{{Key: "foo", Value: "bar"}},
|
||||
wantErr: true,
|
||||
wantStatus: pb.SetInfoStatus_SET_INFO_ALREADY_SET,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -78,19 +79,16 @@ func TestSetInfo(t *testing.T) {
|
|||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
|
||||
_, err = client.SetInfo(context.Background(), &pb.SetInfoRequest{Info: tc.setInfo})
|
||||
setInfoStatus, err := client.SetInfo(context.Background(), &pb.SetInfoRequest{Info: tc.setInfo})
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantStatus, setInfoStatus.Status)
|
||||
for i := range tc.setInfo {
|
||||
value, ok, err := tc.info.Get(tc.setInfo[i].Key)
|
||||
assert.NoError(err)
|
||||
for i := range tc.setInfo {
|
||||
value, ok, err := tc.info.Get(tc.setInfo[i].Key)
|
||||
assert.NoError(err)
|
||||
assert.True(ok)
|
||||
assert.Equal(tc.setInfo[i].Value, value)
|
||||
}
|
||||
assert.True(ok)
|
||||
assert.Equal(tc.setInfo[i].Value, value)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -152,35 +150,36 @@ func TestGetInfo(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestUploadBootstrapper(t *testing.T) {
|
||||
func TestUploadFiles(t *testing.T) {
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
|
||||
testCases := map[string]struct {
|
||||
serviceManager stubServiceManager
|
||||
streamer fakeStreamer
|
||||
uploadChunks [][]byte
|
||||
wantErr bool
|
||||
wantResponseStatus pb.UploadBootstrapperStatus
|
||||
wantFile bool
|
||||
wantChunks [][]byte
|
||||
files []filetransfer.FileStat
|
||||
recvFilesErr error
|
||||
wantResponseStatus pb.UploadFilesStatus
|
||||
wantOverrideCalls []struct{ UnitName, ExecStart string }
|
||||
}{
|
||||
"upload works": {
|
||||
uploadChunks: [][]byte{[]byte("test")},
|
||||
wantFile: true,
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS,
|
||||
files: []filetransfer.FileStat{
|
||||
{SourcePath: "source/testA", TargetPath: "target/testA", Mode: 0o644, OverrideServiceUnit: "testA"},
|
||||
{SourcePath: "source/testB", TargetPath: "target/testB", Mode: 0o644},
|
||||
},
|
||||
wantOverrideCalls: []struct{ UnitName, ExecStart string }{
|
||||
{"testA", "target/testA"},
|
||||
},
|
||||
wantResponseStatus: pb.UploadFilesStatus_UPLOAD_FILES_SUCCESS,
|
||||
},
|
||||
"recv fails": {
|
||||
streamer: fakeStreamer{writeStreamErr: errors.New("recv error")},
|
||||
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED,
|
||||
wantErr: true,
|
||||
recvFilesErr: errors.New("recv error"),
|
||||
wantResponseStatus: pb.UploadFilesStatus_UPLOAD_FILES_UPLOAD_FAILED,
|
||||
},
|
||||
"starting bootstrapper fails": {
|
||||
uploadChunks: [][]byte{[]byte("test")},
|
||||
serviceManager: stubServiceManager{systemdActionErr: errors.New("starting bootstrapper error")},
|
||||
wantFile: true,
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED,
|
||||
"upload in progress": {
|
||||
recvFilesErr: filetransfer.ErrReceiveRunning,
|
||||
wantResponseStatus: pb.UploadFilesStatus_UPLOAD_FILES_ALREADY_STARTED,
|
||||
},
|
||||
"upload already finished": {
|
||||
recvFilesErr: filetransfer.ErrReceiveFinished,
|
||||
wantResponseStatus: pb.UploadFilesStatus_UPLOAD_FILES_ALREADY_FINISHED,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -189,60 +188,49 @@ func TestUploadBootstrapper(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
serviceMgr := &stubServiceManager{}
|
||||
transfer := &stubTransfer{files: tc.files, recvFilesErr: tc.recvFilesErr}
|
||||
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &tc.streamer,
|
||||
serviceManager: serviceMgr,
|
||||
transfer: transfer,
|
||||
}
|
||||
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
stream, err := client.UploadBootstrapper(context.Background())
|
||||
stream, err := client.UploadFiles(context.Background())
|
||||
require.NoError(err)
|
||||
require.NoError(fakeWrite(stream, tc.uploadChunks))
|
||||
resp, err := stream.CloseAndRecv()
|
||||
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResponseStatus, resp.Status)
|
||||
if tc.wantFile {
|
||||
assert.Equal(tc.wantChunks, tc.streamer.writeStreamChunks)
|
||||
assert.Equal("/run/state/bin/bootstrapper", tc.streamer.writeStreamFilename)
|
||||
} else {
|
||||
assert.Empty(tc.streamer.writeStreamChunks)
|
||||
assert.Empty(tc.streamer.writeStreamFilename)
|
||||
}
|
||||
assert.Equal(tc.wantOverrideCalls, serviceMgr.overrideCalls)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadBootstrapper(t *testing.T) {
|
||||
func TestDownloadFiles(t *testing.T) {
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
|
||||
testCases := map[string]struct {
|
||||
serviceManager stubServiceManager
|
||||
request *pb.DownloadBootstrapperRequest
|
||||
streamer fakeStreamer
|
||||
wantErr bool
|
||||
wantChunks [][]byte
|
||||
request *pb.DownloadFilesRequest
|
||||
canSend bool
|
||||
wantRecvErr bool
|
||||
wantSendFileCalls int
|
||||
}{
|
||||
"download works": {
|
||||
request: &pb.DownloadBootstrapperRequest{},
|
||||
streamer: fakeStreamer{readStreamChunks: [][]byte{[]byte("test")}},
|
||||
wantErr: false,
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
request: &pb.DownloadFilesRequest{},
|
||||
canSend: true,
|
||||
wantSendFileCalls: 1,
|
||||
},
|
||||
"download fails": {
|
||||
request: &pb.DownloadBootstrapperRequest{},
|
||||
streamer: fakeStreamer{readStreamErr: errors.New("read bootstrapper fails")},
|
||||
wantErr: true,
|
||||
"transfer is not ready for sending": {
|
||||
request: &pb.DownloadFilesRequest{},
|
||||
wantRecvErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -251,28 +239,29 @@ func TestDownloadBootstrapper(t *testing.T) {
|
|||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
transfer := &stubTransfer{canSend: tc.canSend}
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &tc.streamer,
|
||||
log: logger.NewTest(t),
|
||||
transfer: transfer,
|
||||
}
|
||||
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
stream, err := client.DownloadBootstrapper(context.Background(), tc.request)
|
||||
stream, err := client.DownloadFiles(context.Background(), tc.request)
|
||||
require.NoError(err)
|
||||
chunks, err := fakeRead(stream)
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
_, recvErr := stream.Recv()
|
||||
if tc.wantRecvErr {
|
||||
require.Error(recvErr)
|
||||
} else {
|
||||
require.ErrorIs(recvErr, io.EOF)
|
||||
}
|
||||
require.NoError(stream.CloseSend())
|
||||
grpcServ.GracefulStop()
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantChunks, chunks)
|
||||
assert.Equal("/run/state/bin/bootstrapper", tc.streamer.readStreamFilename)
|
||||
|
||||
assert.Equal(tc.wantSendFileCalls, transfer.sendFilesCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -334,7 +323,6 @@ func TestUploadSystemServiceUnits(t *testing.T) {
|
|||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &fakeStreamer{},
|
||||
}
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
|
@ -357,10 +345,13 @@ func TestUploadSystemServiceUnits(t *testing.T) {
|
|||
}
|
||||
|
||||
type stubServiceManager struct {
|
||||
requests []deploy.ServiceManagerRequest
|
||||
unitFiles []deploy.SystemdUnit
|
||||
systemdActionErr error
|
||||
writeSystemdUnitFileErr error
|
||||
requests []deploy.ServiceManagerRequest
|
||||
unitFiles []deploy.SystemdUnit
|
||||
overrideCalls []struct{ UnitName, ExecStart string }
|
||||
|
||||
systemdActionErr error
|
||||
writeSystemdUnitFileErr error
|
||||
overrideServiceUnitExecStartErr error
|
||||
}
|
||||
|
||||
func (s *stubServiceManager) SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error {
|
||||
|
@ -373,6 +364,13 @@ func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit depl
|
|||
return s.writeSystemdUnitFileErr
|
||||
}
|
||||
|
||||
func (s *stubServiceManager) OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error {
|
||||
s.overrideCalls = append(s.overrideCalls, struct {
|
||||
UnitName, ExecStart string
|
||||
}{UnitName: unitName, ExecStart: execStart})
|
||||
return s.overrideServiceUnitExecStartErr
|
||||
}
|
||||
|
||||
type netDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
@ -386,37 +384,31 @@ func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientCon
|
|||
)
|
||||
}
|
||||
|
||||
type fakeStreamer struct {
|
||||
writeStreamChunks [][]byte
|
||||
writeStreamFilename string
|
||||
writeStreamErr error
|
||||
readStreamChunks [][]byte
|
||||
readStreamFilename string
|
||||
readStreamErr error
|
||||
type stubTransfer struct {
|
||||
recvFilesCount int
|
||||
sendFilesCount int
|
||||
files []filetransfer.FileStat
|
||||
canSend bool
|
||||
recvFilesErr error
|
||||
sendFilesErr error
|
||||
}
|
||||
|
||||
func (f *fakeStreamer) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
|
||||
f.writeStreamFilename = filename
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return f.writeStreamErr
|
||||
}
|
||||
return fmt.Errorf("reading stream: %w", err)
|
||||
}
|
||||
f.writeStreamChunks = append(f.writeStreamChunks, chunk.Content)
|
||||
}
|
||||
func (t *stubTransfer) RecvFiles(_ filetransfer.RecvFilesStream) error {
|
||||
t.recvFilesCount++
|
||||
return t.recvFilesErr
|
||||
}
|
||||
|
||||
func (f *fakeStreamer) ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error {
|
||||
f.readStreamFilename = filename
|
||||
for _, chunk := range f.readStreamChunks {
|
||||
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return f.readStreamErr
|
||||
func (t *stubTransfer) SendFiles(_ filetransfer.SendFilesStream) error {
|
||||
t.sendFilesCount++
|
||||
return t.sendFilesErr
|
||||
}
|
||||
|
||||
func (t *stubTransfer) GetFiles() []filetransfer.FileStat {
|
||||
return t.files
|
||||
}
|
||||
|
||||
func (t *stubTransfer) CanSend() bool {
|
||||
return t.canSend
|
||||
}
|
||||
|
||||
func setupServerWithConn(endpoint string, serv *debugdServer) (*grpc.Server, *grpc.ClientConn, error) {
|
||||
|
@ -433,29 +425,3 @@ func setupServerWithConn(endpoint string, serv *debugdServer) (*grpc.Server, *gr
|
|||
|
||||
return grpcServ, conn, nil
|
||||
}
|
||||
|
||||
func fakeWrite(stream bootstrapper.WriteChunkStream, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
err := stream.Send(&pb.Chunk{
|
||||
Content: chunk,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fakeRead(stream bootstrapper.ReadChunkStream) ([][]byte, error) {
|
||||
var chunks [][]byte
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return chunks, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
chunks = append(chunks, chunk.Content)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue