mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-08-02 03:56:07 -04:00
AB#2327 move debugd code into internal folder (#403)
* move debugd code into internal folder * Fix paths in CMakeLists.txt Signed-off-by: Fabian Kammel <fk@edgeless.systems>
This commit is contained in:
parent
708c6e057e
commit
5b40e0cc77
25 changed files with 31 additions and 31 deletions
31
debugd/internal/debugd/constants.go
Normal file
31
debugd/internal/debugd/constants.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package debugd
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
DebugdMetadataFlag = "constellation-debugd"
|
||||
GRPCTimeout = 5 * time.Minute
|
||||
SSHCheckInterval = 30 * time.Second
|
||||
DiscoverDebugdInterval = 30 * time.Second
|
||||
BootstrapperDownloadRetryBackoff = 1 * time.Minute
|
||||
BootstrapperDeployFilename = "/opt/bootstrapper"
|
||||
Chunksize = 1024
|
||||
BootstrapperSystemdUnitName = "bootstrapper.service"
|
||||
BootstrapperSystemdUnitContents = `[Unit]
|
||||
Description=Constellation Bootstrapper
|
||||
Wants=network-online.target
|
||||
After=network-online.target
|
||||
[Service]
|
||||
Type=simple
|
||||
RemainAfterExit=yes
|
||||
Restart=on-failure
|
||||
EnvironmentFile=/etc/constellation.env
|
||||
ExecStartPre=-setenforce Permissive
|
||||
ExecStartPre=/usr/bin/mkdir -p /opt/cni/bin/
|
||||
# merge all CNI binaries in writable folder until containerd can use multiple CNI bins: https://github.com/containerd/containerd/issues/6600
|
||||
ExecStartPre=/bin/sh -c "/usr/bin/cp /usr/libexec/cni/* /opt/cni/bin/"
|
||||
ExecStart=/opt/bootstrapper
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
`
|
||||
)
|
117
debugd/internal/debugd/deploy/download.go
Normal file
117
debugd/internal/debugd/deploy/download.go
Normal file
|
@ -0,0 +1,117 @@
|
|||
package deploy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd"
|
||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
// Download downloads a bootstrapper from a given debugd instance.
|
||||
type Download struct {
|
||||
log *logger.Logger
|
||||
dialer NetDialer
|
||||
writer streamToFileWriter
|
||||
serviceManager serviceManager
|
||||
attemptedDownloads map[string]time.Time
|
||||
}
|
||||
|
||||
// New creates a new Download.
|
||||
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager, writer streamToFileWriter) *Download {
|
||||
return &Download{
|
||||
log: log,
|
||||
dialer: dialer,
|
||||
writer: writer,
|
||||
serviceManager: serviceManager,
|
||||
attemptedDownloads: map[string]time.Time{},
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadDeployment will open a new grpc connection to another instance, attempting to download a bootstrapper from that instance.
|
||||
func (d *Download) DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error) {
|
||||
log := d.log.With(zap.String("ip", ip))
|
||||
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
|
||||
|
||||
// only retry download from same endpoint after backoff
|
||||
if lastAttempt, ok := d.attemptedDownloads[serverAddr]; ok && time.Since(lastAttempt) < debugd.BootstrapperDownloadRetryBackoff {
|
||||
return nil, fmt.Errorf("download failed too recently: %v / %v", time.Since(lastAttempt), debugd.BootstrapperDownloadRetryBackoff)
|
||||
}
|
||||
|
||||
log.Infof("Connecting to server")
|
||||
d.attemptedDownloads[serverAddr] = time.Now()
|
||||
conn, err := d.dial(ctx, serverAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connecting to other instance via gRPC: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
|
||||
log.Infof("Trying to download bootstrapper")
|
||||
stream, err := client.DownloadBootstrapper(ctx, &pb.DownloadBootstrapperRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("starting bootstrapper download from other instance: %w", err)
|
||||
}
|
||||
if err := d.writer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
|
||||
return nil, fmt.Errorf("streaming bootstrapper from other instance: %w", err)
|
||||
}
|
||||
log.Infof("Successfully downloaded bootstrapper")
|
||||
|
||||
log.Infof("Trying to download ssh keys")
|
||||
resp, err := client.DownloadAuthorizedKeys(ctx, &pb.DownloadAuthorizedKeysRequest{})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("downloading authorized keys: %w", err)
|
||||
}
|
||||
|
||||
var keys []ssh.UserKey
|
||||
for _, key := range resp.Keys {
|
||||
keys = append(keys, ssh.UserKey{Username: key.Username, PublicKey: key.KeyValue})
|
||||
}
|
||||
|
||||
// after the upload succeeds, try to restart the bootstrapper
|
||||
restartAction := ServiceManagerRequest{
|
||||
Unit: debugd.BootstrapperSystemdUnitName,
|
||||
Action: Restart,
|
||||
}
|
||||
if err := d.serviceManager.SystemdAction(ctx, restartAction); err != nil {
|
||||
return nil, fmt.Errorf("restarting bootstrapper: %w", err)
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (d *Download) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
||||
return grpc.DialContext(ctx, target,
|
||||
d.grpcWithDialer(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
}
|
||||
|
||||
func (d *Download) grpcWithDialer() grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return d.dialer.DialContext(ctx, "tcp", addr)
|
||||
})
|
||||
}
|
||||
|
||||
type serviceManager interface {
|
||||
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
|
||||
}
|
||||
|
||||
type streamToFileWriter interface {
|
||||
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
|
||||
}
|
||||
|
||||
// NetDialer can open a net.Conn.
|
||||
type NetDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
187
debugd/internal/debugd/deploy/download_test.go
Normal file
187
debugd/internal/debugd/deploy/download_test.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
package deploy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd"
|
||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/goleak"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m,
|
||||
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestDownloadBootstrapper(t *testing.T) {
|
||||
filename := "/opt/bootstrapper"
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
server fakeDownloadServer
|
||||
serviceManager stubServiceManager
|
||||
attemptedDownloads map[string]time.Time
|
||||
wantChunks [][]byte
|
||||
wantDownloadErr bool
|
||||
wantFile bool
|
||||
wantSystemdAction bool
|
||||
wantDeployed bool
|
||||
wantKeys []ssh.UserKey
|
||||
}{
|
||||
"download works": {
|
||||
server: fakeDownloadServer{
|
||||
chunks: [][]byte{[]byte("test")},
|
||||
keys: []*pb.AuthorizedKey{{Username: "name", KeyValue: "key"}},
|
||||
},
|
||||
attemptedDownloads: map[string]time.Time{},
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
wantDownloadErr: false,
|
||||
wantFile: true,
|
||||
wantSystemdAction: true,
|
||||
wantDeployed: true,
|
||||
wantKeys: []ssh.UserKey{{Username: "name", PublicKey: "key"}},
|
||||
},
|
||||
"second download is not attempted twice": {
|
||||
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
|
||||
attemptedDownloads: map[string]time.Time{"192.0.2.0:" + strconv.Itoa(constants.DebugdPort): time.Now()},
|
||||
wantDownloadErr: true,
|
||||
},
|
||||
"download rpc call error is detected": {
|
||||
server: fakeDownloadServer{downladErr: someErr},
|
||||
attemptedDownloads: map[string]time.Time{},
|
||||
wantDownloadErr: true,
|
||||
},
|
||||
"download key error": {
|
||||
server: fakeDownloadServer{
|
||||
chunks: [][]byte{[]byte("test")},
|
||||
downloadAuthorizedKeysErr: someErr,
|
||||
},
|
||||
attemptedDownloads: map[string]time.Time{},
|
||||
wantDownloadErr: true,
|
||||
},
|
||||
"service restart error is detected": {
|
||||
server: fakeDownloadServer{chunks: [][]byte{[]byte("test")}},
|
||||
serviceManager: stubServiceManager{systemdActionErr: someErr},
|
||||
attemptedDownloads: map[string]time.Time{},
|
||||
wantChunks: [][]byte{[]byte("test")},
|
||||
wantDownloadErr: true,
|
||||
wantFile: true,
|
||||
wantDeployed: true,
|
||||
wantSystemdAction: false,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
ip := "192.0.2.0"
|
||||
writer := &fakeStreamToFileWriter{}
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
|
||||
grpcServ := grpc.NewServer()
|
||||
pb.RegisterDebugdServer(grpcServ, &tc.server)
|
||||
lis := dialer.GetListener(net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort)))
|
||||
go grpcServ.Serve(lis)
|
||||
defer grpcServ.GracefulStop()
|
||||
|
||||
download := &Download{
|
||||
log: logger.NewTest(t),
|
||||
dialer: dialer,
|
||||
writer: writer,
|
||||
serviceManager: &tc.serviceManager,
|
||||
attemptedDownloads: tc.attemptedDownloads,
|
||||
}
|
||||
|
||||
keys, err := download.DownloadDeployment(context.Background(), ip)
|
||||
|
||||
if tc.wantDownloadErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
}
|
||||
|
||||
if tc.wantFile {
|
||||
assert.Equal(tc.wantChunks, writer.chunks)
|
||||
assert.Equal(filename, writer.filename)
|
||||
}
|
||||
if tc.wantSystemdAction {
|
||||
assert.ElementsMatch(
|
||||
[]ServiceManagerRequest{
|
||||
{Unit: debugd.BootstrapperSystemdUnitName, Action: Restart},
|
||||
},
|
||||
tc.serviceManager.requests,
|
||||
)
|
||||
}
|
||||
assert.Equal(tc.wantKeys, keys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubServiceManager struct {
|
||||
requests []ServiceManagerRequest
|
||||
systemdActionErr error
|
||||
}
|
||||
|
||||
func (s *stubServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
|
||||
s.requests = append(s.requests, request)
|
||||
return s.systemdActionErr
|
||||
}
|
||||
|
||||
type fakeStreamToFileWriter struct {
|
||||
chunks [][]byte
|
||||
filename string
|
||||
}
|
||||
|
||||
func (f *fakeStreamToFileWriter) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
|
||||
f.filename = filename
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("reading stream: %w", err)
|
||||
}
|
||||
f.chunks = append(f.chunks, chunk.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// fakeDownloadServer implements DebugdServer; only fakes DownloadBootstrapper, panics on every other rpc.
|
||||
type fakeDownloadServer struct {
|
||||
chunks [][]byte
|
||||
downladErr error
|
||||
keys []*pb.AuthorizedKey
|
||||
downloadAuthorizedKeysErr error
|
||||
|
||||
pb.UnimplementedDebugdServer
|
||||
}
|
||||
|
||||
func (f *fakeDownloadServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
|
||||
for _, chunk := range f.chunks {
|
||||
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
|
||||
return fmt.Errorf("sending chunk: %w", err)
|
||||
}
|
||||
}
|
||||
return f.downladErr
|
||||
}
|
||||
|
||||
func (s *fakeDownloadServer) DownloadAuthorizedKeys(context.Context, *pb.DownloadAuthorizedKeysRequest) (*pb.DownloadAuthorizedKeysResponse, error) {
|
||||
return &pb.DownloadAuthorizedKeysResponse{Keys: s.keys}, s.downloadAuthorizedKeysErr
|
||||
}
|
163
debugd/internal/debugd/deploy/service.go
Normal file
163
debugd/internal/debugd/deploy/service.go
Normal file
|
@ -0,0 +1,163 @@
|
|||
package deploy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/spf13/afero"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
systemdUnitFolder = "/etc/systemd/system"
|
||||
)
|
||||
|
||||
//go:generate stringer -type=SystemdAction
|
||||
type SystemdAction uint32
|
||||
|
||||
const (
|
||||
Unknown SystemdAction = iota
|
||||
Start
|
||||
Stop
|
||||
Restart
|
||||
Reload
|
||||
)
|
||||
|
||||
// ServiceManagerRequest describes a requested ServiceManagerAction to be performed on a specified service unit.
|
||||
type ServiceManagerRequest struct {
|
||||
Unit string
|
||||
Action SystemdAction
|
||||
}
|
||||
|
||||
// SystemdUnit describes a systemd service file including the unit name and contents.
|
||||
type SystemdUnit struct {
|
||||
Name string `yaml:"name"`
|
||||
Contents string `yaml:"contents"`
|
||||
}
|
||||
|
||||
// ServiceManager receives ServiceManagerRequests and units via channels and performs the requests / creates the unit files.
|
||||
type ServiceManager struct {
|
||||
log *logger.Logger
|
||||
dbus dbusClient
|
||||
fs afero.Fs
|
||||
systemdUnitFilewriteLock sync.Mutex
|
||||
}
|
||||
|
||||
// NewServiceManager creates a new ServiceManager.
|
||||
func NewServiceManager(log *logger.Logger) *ServiceManager {
|
||||
fs := afero.NewOsFs()
|
||||
return &ServiceManager{
|
||||
log: log,
|
||||
dbus: &dbusWrapper{},
|
||||
fs: fs,
|
||||
systemdUnitFilewriteLock: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
type dbusClient interface {
|
||||
// NewSystemConnectionContext establishes a connection to the system bus and authenticates.
|
||||
// Callers should call Close() when done with the connection.
|
||||
NewSystemdConnectionContext(ctx context.Context) (dbusConn, error)
|
||||
}
|
||||
|
||||
type dbusConn interface {
|
||||
// StartUnitContext enqueues a start job and depending jobs, if any (unless otherwise
|
||||
// specified by the mode string).
|
||||
StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
|
||||
// StopUnitContext is similar to StartUnitContext, but stops the specified unit
|
||||
// rather than starting it.
|
||||
StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
|
||||
// RestartUnitContext restarts a service. If a service is restarted that isn't
|
||||
// running it will be started.
|
||||
RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error)
|
||||
// ReloadContext instructs systemd to scan for and reload unit files. This is
|
||||
// an equivalent to systemctl daemon-reload.
|
||||
ReloadContext(ctx context.Context) error
|
||||
}
|
||||
|
||||
// SystemdAction will perform a systemd action on a service unit (start, stop, restart, reload).
|
||||
func (s *ServiceManager) SystemdAction(ctx context.Context, request ServiceManagerRequest) error {
|
||||
log := s.log.With(zap.String("unit", request.Unit), zap.String("action", request.Action.String()))
|
||||
conn, err := s.dbus.NewSystemdConnectionContext(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("establishing systemd connection: %w", err)
|
||||
}
|
||||
|
||||
resultChan := make(chan string, 1)
|
||||
switch request.Action {
|
||||
case Start:
|
||||
_, err = conn.StartUnitContext(ctx, request.Unit, "replace", resultChan)
|
||||
case Stop:
|
||||
_, err = conn.StopUnitContext(ctx, request.Unit, "replace", resultChan)
|
||||
case Restart:
|
||||
_, err = conn.RestartUnitContext(ctx, request.Unit, "replace", resultChan)
|
||||
case Reload:
|
||||
err = conn.ReloadContext(ctx)
|
||||
default:
|
||||
return fmt.Errorf("unknown systemd action: %s", request.Action.String())
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("performing systemd action %v on unit %v: %w", request.Action, request.Unit, err)
|
||||
}
|
||||
|
||||
if request.Action == Reload {
|
||||
log.Infof("daemon-reload succeeded")
|
||||
return nil
|
||||
}
|
||||
// Wait for the action to finish and then check if it was
|
||||
// successful or not.
|
||||
result := <-resultChan
|
||||
|
||||
switch result {
|
||||
case "done":
|
||||
log.Infof("%s on systemd unit %s succeeded", request.Action, request.Unit)
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("performing action %q on systemd unit %q failed: expected %q but received %q", request.Action.String(), request.Unit, "done", result)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteSystemdUnitFile will write a systemd unit to disk.
|
||||
func (s *ServiceManager) WriteSystemdUnitFile(ctx context.Context, unit SystemdUnit) error {
|
||||
log := s.log.With(zap.String("unitFile", fmt.Sprintf("%s/%s", systemdUnitFolder, unit.Name)))
|
||||
log.Infof("Writing systemd unit file")
|
||||
s.systemdUnitFilewriteLock.Lock()
|
||||
defer s.systemdUnitFilewriteLock.Unlock()
|
||||
if err := afero.WriteFile(s.fs, fmt.Sprintf("%s/%s", systemdUnitFolder, unit.Name), []byte(unit.Contents), 0o644); err != nil {
|
||||
return fmt.Errorf("writing systemd unit file \"%v\": %w", unit.Name, err)
|
||||
}
|
||||
|
||||
if err := s.SystemdAction(ctx, ServiceManagerRequest{Unit: unit.Name, Action: Reload}); err != nil {
|
||||
return fmt.Errorf("performing systemd daemon-reload: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Wrote systemd unit file and performed daemon-reload")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeployDefaultServiceUnit will write the default "bootstrapper.service" unit file.
|
||||
func DeployDefaultServiceUnit(ctx context.Context, serviceManager *ServiceManager) error {
|
||||
if err := serviceManager.WriteSystemdUnitFile(ctx, SystemdUnit{
|
||||
Name: debugd.BootstrapperSystemdUnitName,
|
||||
Contents: debugd.BootstrapperSystemdUnitContents,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("writing systemd unit file %q: %w", debugd.BootstrapperSystemdUnitName, err)
|
||||
}
|
||||
|
||||
// try to start the default service if the binary exists but ignore failure.
|
||||
// this is meant to start the bootstrapper after a reboot
|
||||
// if a bootstrapper binary was uploaded before.
|
||||
if ok, err := afero.Exists(serviceManager.fs, debugd.BootstrapperDeployFilename); ok && err == nil {
|
||||
_ = serviceManager.SystemdAction(ctx, ServiceManagerRequest{
|
||||
Unit: debugd.BootstrapperSystemdUnitName,
|
||||
Action: Start,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
242
debugd/internal/debugd/deploy/service_test.go
Normal file
242
debugd/internal/debugd/deploy/service_test.go
Normal file
|
@ -0,0 +1,242 @@
|
|||
package deploy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSystemdAction(t *testing.T) {
|
||||
unitName := "example.service"
|
||||
|
||||
testCases := map[string]struct {
|
||||
dbus stubDbus
|
||||
action SystemdAction
|
||||
wantErr bool
|
||||
}{
|
||||
"start works": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
action: Start,
|
||||
wantErr: false,
|
||||
},
|
||||
"stop works": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
action: Stop,
|
||||
wantErr: false,
|
||||
},
|
||||
"restart works": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
action: Restart,
|
||||
wantErr: false,
|
||||
},
|
||||
"reload works": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{},
|
||||
},
|
||||
action: Reload,
|
||||
wantErr: false,
|
||||
},
|
||||
"unknown action": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{},
|
||||
},
|
||||
action: Unknown,
|
||||
wantErr: true,
|
||||
},
|
||||
"action fails": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
actionErr: errors.New("action fails"),
|
||||
},
|
||||
},
|
||||
action: Start,
|
||||
wantErr: true,
|
||||
},
|
||||
"action result is failure": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "failure",
|
||||
},
|
||||
},
|
||||
action: Start,
|
||||
wantErr: true,
|
||||
},
|
||||
"newConn fails": {
|
||||
dbus: stubDbus{
|
||||
connErr: errors.New("newConn fails"),
|
||||
},
|
||||
action: Start,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
manager := ServiceManager{
|
||||
log: logger.NewTest(t),
|
||||
dbus: &tc.dbus,
|
||||
fs: fs,
|
||||
systemdUnitFilewriteLock: sync.Mutex{},
|
||||
}
|
||||
err := manager.SystemdAction(context.Background(), ServiceManagerRequest{
|
||||
Unit: unitName,
|
||||
Action: tc.action,
|
||||
})
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteSystemdUnitFile(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
dbus stubDbus
|
||||
unit SystemdUnit
|
||||
readonly bool
|
||||
wantErr bool
|
||||
wantFileContents string
|
||||
}{
|
||||
"start works": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
unit: SystemdUnit{
|
||||
Name: "test.service",
|
||||
Contents: "testservicefilecontents",
|
||||
},
|
||||
wantErr: false,
|
||||
wantFileContents: "testservicefilecontents",
|
||||
},
|
||||
"write fails": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
result: "done",
|
||||
},
|
||||
},
|
||||
unit: SystemdUnit{
|
||||
Name: "test.service",
|
||||
Contents: "testservicefilecontents",
|
||||
},
|
||||
readonly: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"systemd reload fails": {
|
||||
dbus: stubDbus{
|
||||
conn: &fakeDbusConn{
|
||||
actionErr: errors.New("reload error"),
|
||||
},
|
||||
},
|
||||
unit: SystemdUnit{
|
||||
Name: "test.service",
|
||||
Contents: "testservicefilecontents",
|
||||
},
|
||||
readonly: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
assert.NoError(fs.MkdirAll(systemdUnitFolder, 0o755))
|
||||
if tc.readonly {
|
||||
fs = afero.NewReadOnlyFs(fs)
|
||||
}
|
||||
manager := ServiceManager{
|
||||
log: logger.NewTest(t),
|
||||
dbus: &tc.dbus,
|
||||
fs: fs,
|
||||
systemdUnitFilewriteLock: sync.Mutex{},
|
||||
}
|
||||
err := manager.WriteSystemdUnitFile(context.Background(), tc.unit)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
fileContents, err := afero.ReadFile(fs, fmt.Sprintf("%s/%s", systemdUnitFolder, tc.unit.Name))
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantFileContents, string(fileContents))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubDbus struct {
|
||||
conn dbusConn
|
||||
connErr error
|
||||
}
|
||||
|
||||
func (s *stubDbus) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
|
||||
return s.conn, s.connErr
|
||||
}
|
||||
|
||||
type dbusConnActionInput struct {
|
||||
name string
|
||||
mode string
|
||||
}
|
||||
|
||||
type fakeDbusConn struct {
|
||||
inputs []dbusConnActionInput
|
||||
result string
|
||||
|
||||
jobID int
|
||||
actionErr error
|
||||
}
|
||||
|
||||
func (c *fakeDbusConn) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
|
||||
ch <- c.result
|
||||
|
||||
return c.jobID, c.actionErr
|
||||
}
|
||||
|
||||
func (c *fakeDbusConn) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
|
||||
ch <- c.result
|
||||
|
||||
return c.jobID, c.actionErr
|
||||
}
|
||||
|
||||
func (c *fakeDbusConn) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
c.inputs = append(c.inputs, dbusConnActionInput{name: name, mode: mode})
|
||||
ch <- c.result
|
||||
|
||||
return c.jobID, c.actionErr
|
||||
}
|
||||
|
||||
func (c *fakeDbusConn) ReloadContext(ctx context.Context) error {
|
||||
return c.actionErr
|
||||
}
|
27
debugd/internal/debugd/deploy/systemdaction_string.go
Normal file
27
debugd/internal/debugd/deploy/systemdaction_string.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
// Code generated by "stringer -type=SystemdAction"; DO NOT EDIT.
|
||||
|
||||
package deploy
|
||||
|
||||
import "strconv"
|
||||
|
||||
func _() {
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
var x [1]struct{}
|
||||
_ = x[Unknown-0]
|
||||
_ = x[Start-1]
|
||||
_ = x[Stop-2]
|
||||
_ = x[Restart-3]
|
||||
_ = x[Reload-4]
|
||||
}
|
||||
|
||||
const _SystemdAction_name = "UnknownStartStopRestartReload"
|
||||
|
||||
var _SystemdAction_index = [...]uint8{0, 7, 12, 16, 23, 29}
|
||||
|
||||
func (i SystemdAction) String() string {
|
||||
if i >= SystemdAction(len(_SystemdAction_index)-1) {
|
||||
return "SystemdAction(" + strconv.FormatInt(int64(i), 10) + ")"
|
||||
}
|
||||
return _SystemdAction_name[_SystemdAction_index[i]:_SystemdAction_index[i+1]]
|
||||
}
|
40
debugd/internal/debugd/deploy/wrappers.go
Normal file
40
debugd/internal/debugd/deploy/wrappers.go
Normal file
|
@ -0,0 +1,40 @@
|
|||
package deploy
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/dbus"
|
||||
)
|
||||
|
||||
// wraps go-systemd dbus.
|
||||
type dbusWrapper struct{}
|
||||
|
||||
func (d *dbusWrapper) NewSystemdConnectionContext(ctx context.Context) (dbusConn, error) {
|
||||
conn, err := dbus.NewSystemdConnectionContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dbusConnWrapper{
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dbusConnWrapper struct {
|
||||
conn *dbus.Conn
|
||||
}
|
||||
|
||||
func (c *dbusConnWrapper) StartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
return c.conn.StartUnitContext(ctx, name, mode, ch)
|
||||
}
|
||||
|
||||
func (c *dbusConnWrapper) StopUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
return c.conn.StopUnitContext(ctx, name, mode, ch)
|
||||
}
|
||||
|
||||
func (c *dbusConnWrapper) RestartUnitContext(ctx context.Context, name string, mode string, ch chan<- string) (int, error) {
|
||||
return c.conn.RestartUnitContext(ctx, name, mode, ch)
|
||||
}
|
||||
|
||||
func (c *dbusConnWrapper) ReloadContext(ctx context.Context) error {
|
||||
return c.conn.ReloadContext(ctx)
|
||||
}
|
130
debugd/internal/debugd/metadata/cloudprovider/cloudprovider.go
Normal file
130
debugd/internal/debugd/metadata/cloudprovider/cloudprovider.go
Normal file
|
@ -0,0 +1,130 @@
|
|||
package cloudprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
azurecloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/azure"
|
||||
gcpcloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/gcp"
|
||||
qemucloud "github.com/edgelesssys/constellation/bootstrapper/cloudprovider/qemu"
|
||||
"github.com/edgelesssys/constellation/bootstrapper/role"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
)
|
||||
|
||||
type providerMetadata interface {
|
||||
// List retrieves all instances belonging to the current constellation.
|
||||
List(ctx context.Context) ([]metadata.InstanceMetadata, error)
|
||||
// Self retrieves the current instance.
|
||||
Self(ctx context.Context) (metadata.InstanceMetadata, error)
|
||||
// GetLoadBalancerEndpoint returns the endpoint of the load balancer.
|
||||
GetLoadBalancerEndpoint(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// Fetcher checks the metadata service to search for instances that were set up for debugging and cloud provider specific SSH keys.
|
||||
type Fetcher struct {
|
||||
metaAPI providerMetadata
|
||||
}
|
||||
|
||||
// NewGCP creates a new GCP fetcher.
|
||||
func NewGCP(ctx context.Context) (*Fetcher, error) {
|
||||
gcpClient, err := gcpcloud.NewClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metaAPI := gcpcloud.New(gcpClient)
|
||||
|
||||
return &Fetcher{
|
||||
metaAPI: metaAPI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewAzure creates a new Azure fetcher.
|
||||
func NewAzure(ctx context.Context) (*Fetcher, error) {
|
||||
metaAPI, err := azurecloud.NewMetadata(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Fetcher{
|
||||
metaAPI: metaAPI,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewQEMU() *Fetcher {
|
||||
return &Fetcher{
|
||||
metaAPI: &qemucloud.Metadata{},
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Fetcher) Role(ctx context.Context) (role.Role, error) {
|
||||
self, err := f.metaAPI.Self(ctx)
|
||||
if err != nil {
|
||||
return role.Unknown, fmt.Errorf("retrieving role from cloud provider metadata: %w", err)
|
||||
}
|
||||
|
||||
return self.Role, nil
|
||||
}
|
||||
|
||||
// DiscoverDebugdIPs will query the metadata of all instances and return any ips of instances already set up for debugging.
|
||||
func (f *Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
|
||||
self, err := f.metaAPI.Self(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrieving own instance: %w", err)
|
||||
}
|
||||
instances, err := f.metaAPI.List(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrieving instances: %w", err)
|
||||
}
|
||||
// filter own instance from instance list
|
||||
for i, instance := range instances {
|
||||
if instance.ProviderID == self.ProviderID {
|
||||
instances = append(instances[:i], instances[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
var ips []string
|
||||
for _, instance := range instances {
|
||||
if instance.VPCIP != "" {
|
||||
ips = append(ips, instance.VPCIP)
|
||||
}
|
||||
}
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (f *Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
|
||||
lbEndpoint, err := f.metaAPI.GetLoadBalancerEndpoint(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("retrieving load balancer endpoint: %w", err)
|
||||
}
|
||||
|
||||
// The port of the endpoint is not the port we need. We need to strip it off.
|
||||
//
|
||||
// TODO: Tag the specific load balancer we are looking for with a distinct tag.
|
||||
// Change the GetLoadBalancerEndpoint method to return the endpoint of a load
|
||||
// balancer with a given tag.
|
||||
lbIP, _, err := net.SplitHostPort(lbEndpoint)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parsing load balancer endpoint: %w", err)
|
||||
}
|
||||
|
||||
return lbIP, nil
|
||||
}
|
||||
|
||||
// FetchSSHKeys will query the metadata of the current instance and deploys any SSH keys found.
|
||||
func (f *Fetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
|
||||
self, err := f.metaAPI.Self(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrieving ssh keys from cloud provider metadata: %w", err)
|
||||
}
|
||||
|
||||
keys := []ssh.UserKey{}
|
||||
for username, userKeys := range self.SSHKeys {
|
||||
for _, keyValue := range userKeys {
|
||||
keys = append(keys, ssh.UserKey{Username: username, PublicKey: keyValue})
|
||||
}
|
||||
}
|
||||
|
||||
return keys, nil
|
||||
}
|
|
@ -0,0 +1,244 @@
|
|||
package cloudprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/bootstrapper/role"
|
||||
"github.com/edgelesssys/constellation/internal/cloud/metadata"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m,
|
||||
// https://github.com/census-instrumentation/opencensus-go/issues/1262
|
||||
goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
|
||||
)
|
||||
}
|
||||
|
||||
func TestRole(t *testing.T) {
|
||||
instance1 := metadata.InstanceMetadata{Role: role.ControlPlane}
|
||||
instance2 := metadata.InstanceMetadata{Role: role.Worker}
|
||||
|
||||
testCases := map[string]struct {
|
||||
meta *stubMetadata
|
||||
wantErr bool
|
||||
wantRole role.Role
|
||||
}{
|
||||
"control plane": {
|
||||
meta: &stubMetadata{selfRes: instance1},
|
||||
wantRole: role.ControlPlane,
|
||||
},
|
||||
"worker": {
|
||||
meta: &stubMetadata{selfRes: instance2},
|
||||
wantRole: role.Worker,
|
||||
},
|
||||
"self fails": {
|
||||
meta: &stubMetadata{selfErr: errors.New("some err")},
|
||||
wantErr: true,
|
||||
wantRole: role.Unknown,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
fetcher := Fetcher{tc.meta}
|
||||
|
||||
role, err := fetcher.Role(context.Background())
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantRole, role)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverDebugIPs(t *testing.T) {
|
||||
err := errors.New("some err")
|
||||
|
||||
testCases := map[string]struct {
|
||||
meta stubMetadata
|
||||
wantIPs []string
|
||||
wantErr bool
|
||||
}{
|
||||
"disovery works": {
|
||||
meta: stubMetadata{
|
||||
listRes: []metadata.InstanceMetadata{
|
||||
{
|
||||
VPCIP: "192.0.2.0",
|
||||
},
|
||||
{
|
||||
VPCIP: "192.0.2.1",
|
||||
},
|
||||
{
|
||||
VPCIP: "192.0.2.2",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantIPs: []string{
|
||||
"192.0.2.1", "192.0.2.2",
|
||||
},
|
||||
},
|
||||
"retrieve fails": {
|
||||
meta: stubMetadata{
|
||||
listErr: err,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fetcher := Fetcher{
|
||||
metaAPI: &tc.meta,
|
||||
}
|
||||
ips, err := fetcher.DiscoverDebugdIPs(context.Background())
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(tc.wantIPs, ips)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiscoverLoadbalancerIP(t *testing.T) {
|
||||
ip := "192.0.2.1"
|
||||
endpoint := ip + ":1234"
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
metaAPI providerMetadata
|
||||
wantIP string
|
||||
wantErr bool
|
||||
}{
|
||||
"discovery works": {
|
||||
metaAPI: &stubMetadata{getLBEndpointRes: endpoint},
|
||||
wantIP: ip,
|
||||
},
|
||||
"get endpoint fails": {
|
||||
metaAPI: &stubMetadata{getLBEndpointErr: someErr},
|
||||
wantErr: true,
|
||||
},
|
||||
"invalid endpoint": {
|
||||
metaAPI: &stubMetadata{getLBEndpointRes: "invalid"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
fetcher := &Fetcher{
|
||||
metaAPI: tc.metaAPI,
|
||||
}
|
||||
|
||||
ip, err := fetcher.DiscoverLoadbalancerIP(context.Background())
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
} else {
|
||||
assert.NoError(err)
|
||||
assert.Equal(tc.wantIP, ip)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFetchSSHKeys(t *testing.T) {
|
||||
err := errors.New("some err")
|
||||
|
||||
testCases := map[string]struct {
|
||||
meta stubMetadata
|
||||
wantKeys []ssh.UserKey
|
||||
wantErr bool
|
||||
}{
|
||||
"fetch works": {
|
||||
meta: stubMetadata{
|
||||
selfRes: metadata.InstanceMetadata{
|
||||
Name: "name",
|
||||
ProviderID: "provider-id",
|
||||
SSHKeys: map[string][]string{"bob": {"ssh-rsa bobskey"}},
|
||||
},
|
||||
},
|
||||
wantKeys: []ssh.UserKey{
|
||||
{
|
||||
Username: "bob",
|
||||
PublicKey: "ssh-rsa bobskey",
|
||||
},
|
||||
},
|
||||
},
|
||||
"retrieve fails": {
|
||||
meta: stubMetadata{
|
||||
selfErr: err,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fetcher := Fetcher{
|
||||
metaAPI: &tc.meta,
|
||||
}
|
||||
keys, err := fetcher.FetchSSHKeys(context.Background())
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(tc.wantKeys, keys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubMetadata struct {
|
||||
listRes []metadata.InstanceMetadata
|
||||
listErr error
|
||||
selfRes metadata.InstanceMetadata
|
||||
selfErr error
|
||||
getInstanceRes metadata.InstanceMetadata
|
||||
getInstanceErr error
|
||||
getLBEndpointRes string
|
||||
getLBEndpointErr error
|
||||
supportedRes bool
|
||||
}
|
||||
|
||||
func (m *stubMetadata) List(ctx context.Context) ([]metadata.InstanceMetadata, error) {
|
||||
return m.listRes, m.listErr
|
||||
}
|
||||
|
||||
func (m *stubMetadata) Self(ctx context.Context) (metadata.InstanceMetadata, error) {
|
||||
return m.selfRes, m.selfErr
|
||||
}
|
||||
|
||||
func (m *stubMetadata) GetInstance(ctx context.Context, providerID string) (metadata.InstanceMetadata, error) {
|
||||
return m.getInstanceRes, m.getInstanceErr
|
||||
}
|
||||
|
||||
func (m *stubMetadata) GetLoadBalancerEndpoint(ctx context.Context) (string, error) {
|
||||
return m.getLBEndpointRes, m.getLBEndpointErr
|
||||
}
|
||||
|
||||
func (m *stubMetadata) Supported() bool {
|
||||
return m.supportedRes
|
||||
}
|
31
debugd/internal/debugd/metadata/fallback/fallback.go
Normal file
31
debugd/internal/debugd/metadata/fallback/fallback.go
Normal file
|
@ -0,0 +1,31 @@
|
|||
package fallback
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/edgelesssys/constellation/bootstrapper/role"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
)
|
||||
|
||||
// Fetcher implements metadata.Fetcher interface but does not actually fetch cloud provider metadata.
|
||||
type Fetcher struct{}
|
||||
|
||||
func (f Fetcher) Role(_ context.Context) (role.Role, error) {
|
||||
// Fallback fetcher does not try to fetch role
|
||||
return role.Unknown, nil
|
||||
}
|
||||
|
||||
func (f Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
|
||||
// Fallback fetcher does not try to discover debugd IPs
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
|
||||
// Fallback fetcher does not try to discover loadbalancer IP
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (f Fetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
|
||||
// Fallback fetcher does not try to fetch ssh keys
|
||||
return nil, nil
|
||||
}
|
33
debugd/internal/debugd/metadata/fallback/fallback_test.go
Normal file
33
debugd/internal/debugd/metadata/fallback/fallback_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package fallback
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestDiscoverDebugdIPs(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
fetcher := Fetcher{}
|
||||
ips, err := fetcher.DiscoverDebugdIPs(context.Background())
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Empty(ips)
|
||||
}
|
||||
|
||||
func TestFetchSSHKeys(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
fetcher := Fetcher{}
|
||||
keys, err := fetcher.FetchSSHKeys(context.Background())
|
||||
|
||||
assert.NoError(err)
|
||||
assert.Empty(keys)
|
||||
}
|
143
debugd/internal/debugd/metadata/scheduler.go
Normal file
143
debugd/internal/debugd/metadata/scheduler.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/bootstrapper/role"
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Fetcher retrieves other debugd IPs and SSH keys from cloud provider metadata.
|
||||
type Fetcher interface {
|
||||
Role(ctx context.Context) (role.Role, error)
|
||||
DiscoverDebugdIPs(ctx context.Context) ([]string, error)
|
||||
FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error)
|
||||
DiscoverLoadbalancerIP(ctx context.Context) (string, error)
|
||||
}
|
||||
|
||||
// Scheduler schedules fetching of metadata using timers.
|
||||
type Scheduler struct {
|
||||
log *logger.Logger
|
||||
fetcher Fetcher
|
||||
ssh sshDeployer
|
||||
downloader downloader
|
||||
}
|
||||
|
||||
// NewScheduler returns a new scheduler.
|
||||
func NewScheduler(log *logger.Logger, fetcher Fetcher, ssh sshDeployer, downloader downloader) *Scheduler {
|
||||
return &Scheduler{
|
||||
log: log,
|
||||
fetcher: fetcher,
|
||||
ssh: ssh,
|
||||
downloader: downloader,
|
||||
}
|
||||
}
|
||||
|
||||
// Start will start the loops for discovering debugd endpoints and ssh keys.
|
||||
func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
wg.Add(2)
|
||||
go s.discoveryLoop(ctx, wg)
|
||||
go s.sshLoop(ctx, wg)
|
||||
}
|
||||
|
||||
// discoveryLoop discovers new debugd endpoints from cloud-provider metadata periodically.
|
||||
func (s *Scheduler) discoveryLoop(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
// execute debugd discovery once at the start to skip wait for first tick
|
||||
ips, err := s.fetcher.DiscoverDebugdIPs(ctx)
|
||||
if err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Discovering debugd IPs failed")
|
||||
} else {
|
||||
if s.downloadDeployment(ctx, ips) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(debugd.DiscoverDebugdInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
var err error
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ips, err = s.fetcher.DiscoverDebugdIPs(ctx)
|
||||
if err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Discovering debugd IPs failed")
|
||||
continue
|
||||
}
|
||||
s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances")
|
||||
if s.downloadDeployment(ctx, ips) {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sshLoop discovers new ssh keys from cloud provider metadata periodically.
|
||||
func (s *Scheduler) sshLoop(ctx context.Context, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
ticker := time.NewTicker(debugd.SSHCheckInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
keys, err := s.fetcher.FetchSSHKeys(ctx)
|
||||
if err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Fetching SSH keys failed")
|
||||
} else {
|
||||
s.deploySSHKeys(ctx, keys)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ticker.C:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// downloadDeployment tries to download deployment from a list of ips and logs errors encountered.
|
||||
func (s *Scheduler) downloadDeployment(ctx context.Context, ips []string) (success bool) {
|
||||
for _, ip := range ips {
|
||||
keys, err := s.downloader.DownloadDeployment(ctx, ip)
|
||||
if err == nil {
|
||||
s.deploySSHKeys(ctx, keys)
|
||||
return true
|
||||
}
|
||||
if errors.Is(err, fs.ErrExist) {
|
||||
// bootstrapper was already uploaded
|
||||
s.log.Infof("Bootstrapper was already uploaded.")
|
||||
return true
|
||||
}
|
||||
s.log.With(zap.Error(err), zap.String("peer", ip)).Errorf("Downloading deployment from peer failed")
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// deploySSHKeys tries to deploy a list of SSH keys and logs errors encountered.
|
||||
func (s *Scheduler) deploySSHKeys(ctx context.Context, keys []ssh.UserKey) {
|
||||
for _, key := range keys {
|
||||
err := s.ssh.DeployAuthorizedKey(ctx, key)
|
||||
if err != nil {
|
||||
s.log.With(zap.Error(err), zap.Any("key", key)).Errorf("Deploying SSH key failed")
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type downloader interface {
|
||||
DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error)
|
||||
}
|
||||
|
||||
type sshDeployer interface {
|
||||
DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error
|
||||
}
|
134
debugd/internal/debugd/metadata/scheduler_test.go
Normal file
134
debugd/internal/debugd/metadata/scheduler_test.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package metadata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/bootstrapper/role"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestSchedulerStart(t *testing.T) {
|
||||
someErr := errors.New("failed")
|
||||
|
||||
testCases := map[string]struct {
|
||||
fetcher stubFetcher
|
||||
ssh stubSSHDeployer
|
||||
downloader stubDownloader
|
||||
timeout time.Duration
|
||||
wantSSHKeys []ssh.UserKey
|
||||
wantDebugdDownloads []string
|
||||
}{
|
||||
"scheduler works and calls fetcher functions at least once": {},
|
||||
"ssh keys are fetched": {
|
||||
fetcher: stubFetcher{
|
||||
keys: []ssh.UserKey{{Username: "test", PublicKey: "testkey"}},
|
||||
},
|
||||
wantSSHKeys: []ssh.UserKey{{Username: "test", PublicKey: "testkey"}},
|
||||
},
|
||||
"download for discovered debugd ips is started": {
|
||||
fetcher: stubFetcher{
|
||||
ips: []string{"192.0.2.1", "192.0.2.2"},
|
||||
},
|
||||
downloader: stubDownloader{downloadErr: someErr},
|
||||
wantDebugdDownloads: []string{"192.0.2.1", "192.0.2.2"},
|
||||
},
|
||||
"if download is successful, second download is not attempted": {
|
||||
fetcher: stubFetcher{
|
||||
ips: []string{"192.0.2.1", "192.0.2.2"},
|
||||
},
|
||||
wantDebugdDownloads: []string{"192.0.2.1"},
|
||||
},
|
||||
"endpoint discovery can fail": {
|
||||
fetcher: stubFetcher{discoverErr: someErr},
|
||||
},
|
||||
"ssh key fetch can fail": {
|
||||
fetcher: stubFetcher{fetchSSHKeysErr: someErr},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), tc.timeout)
|
||||
defer cancel()
|
||||
scheduler := Scheduler{
|
||||
log: logger.NewTest(t),
|
||||
fetcher: &tc.fetcher,
|
||||
ssh: &tc.ssh,
|
||||
downloader: &tc.downloader,
|
||||
}
|
||||
wg.Add(1)
|
||||
go scheduler.Start(ctx, wg)
|
||||
|
||||
wg.Wait()
|
||||
assert.Equal(tc.wantSSHKeys, tc.ssh.sshKeys)
|
||||
assert.Equal(tc.wantDebugdDownloads, tc.downloader.ips)
|
||||
assert.Greater(tc.fetcher.discoverCalls, 0)
|
||||
assert.Greater(tc.fetcher.fetchSSHKeysCalls, 0)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubFetcher struct {
|
||||
discoverCalls int
|
||||
fetchSSHKeysCalls int
|
||||
|
||||
ips []string
|
||||
keys []ssh.UserKey
|
||||
discoverErr error
|
||||
fetchSSHKeysErr error
|
||||
}
|
||||
|
||||
func (s *stubFetcher) Role(_ context.Context) (role.Role, error) {
|
||||
return role.Unknown, nil
|
||||
}
|
||||
|
||||
func (s *stubFetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
|
||||
s.discoverCalls++
|
||||
return s.ips, s.discoverErr
|
||||
}
|
||||
|
||||
func (s *stubFetcher) FetchSSHKeys(ctx context.Context) ([]ssh.UserKey, error) {
|
||||
s.fetchSSHKeysCalls++
|
||||
return s.keys, s.fetchSSHKeysErr
|
||||
}
|
||||
|
||||
func (s *stubFetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
type stubSSHDeployer struct {
|
||||
sshKeys []ssh.UserKey
|
||||
|
||||
deployErr error
|
||||
}
|
||||
|
||||
func (s *stubSSHDeployer) DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error {
|
||||
s.sshKeys = append(s.sshKeys, sshKey)
|
||||
|
||||
return s.deployErr
|
||||
}
|
||||
|
||||
type stubDownloader struct {
|
||||
ips []string
|
||||
downloadErr error
|
||||
keys []ssh.UserKey
|
||||
}
|
||||
|
||||
func (s *stubDownloader) DownloadDeployment(ctx context.Context, ip string) ([]ssh.UserKey, error) {
|
||||
s.ips = append(s.ips, ip)
|
||||
return s.keys, s.downloadErr
|
||||
}
|
162
debugd/internal/debugd/server/server.go
Normal file
162
debugd/internal/debugd/server/server.go
Normal file
|
@ -0,0 +1,162 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
|
||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
type debugdServer struct {
|
||||
log *logger.Logger
|
||||
ssh sshDeployer
|
||||
serviceManager serviceManager
|
||||
streamer streamer
|
||||
pb.UnimplementedDebugdServer
|
||||
}
|
||||
|
||||
// New creates a new debugdServer according to the gRPC spec.
|
||||
func New(log *logger.Logger, ssh sshDeployer, serviceManager serviceManager, streamer streamer) pb.DebugdServer {
|
||||
return &debugdServer{
|
||||
log: log,
|
||||
ssh: ssh,
|
||||
serviceManager: serviceManager,
|
||||
streamer: streamer,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadAuthorizedKeys receives a list of authorized keys and forwards them to a channel.
|
||||
func (s *debugdServer) UploadAuthorizedKeys(ctx context.Context, in *pb.UploadAuthorizedKeysRequest) (*pb.UploadAuthorizedKeysResponse, error) {
|
||||
s.log.Infof("Uploading authorized keys")
|
||||
for _, key := range in.Keys {
|
||||
if err := s.ssh.DeployAuthorizedKey(ctx, ssh.UserKey{Username: key.Username, PublicKey: key.KeyValue}); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Uploading authorized keys failed")
|
||||
return &pb.UploadAuthorizedKeysResponse{
|
||||
Status: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_FAILURE,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &pb.UploadAuthorizedKeysResponse{
|
||||
Status: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UploadBootstrapper receives a bootstrapper binary in a stream of chunks and writes to a file.
|
||||
func (s *debugdServer) UploadBootstrapper(stream pb.Debugd_UploadBootstrapperServer) error {
|
||||
startAction := deploy.ServiceManagerRequest{
|
||||
Unit: debugd.BootstrapperSystemdUnitName,
|
||||
Action: deploy.Start,
|
||||
}
|
||||
var responseStatus pb.UploadBootstrapperStatus
|
||||
defer func() {
|
||||
if err := s.serviceManager.SystemdAction(stream.Context(), startAction); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Starting uploaded bootstrapper failed")
|
||||
if responseStatus == pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS {
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED
|
||||
}
|
||||
}
|
||||
stream.SendAndClose(&pb.UploadBootstrapperResponse{
|
||||
Status: responseStatus,
|
||||
})
|
||||
}()
|
||||
s.log.Infof("Starting bootstrapper upload")
|
||||
if err := s.streamer.WriteStream(debugd.BootstrapperDeployFilename, stream, true); err != nil {
|
||||
if errors.Is(err, fs.ErrExist) {
|
||||
// bootstrapper was already uploaded
|
||||
s.log.Warnf("Bootstrapper already uploaded")
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_FILE_EXISTS
|
||||
return nil
|
||||
}
|
||||
s.log.With(zap.Error(err)).Errorf("Uploading bootstrapper failed")
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED
|
||||
return fmt.Errorf("uploading bootstrapper: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infof("Successfully uploaded bootstrapper")
|
||||
responseStatus = pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS
|
||||
return nil
|
||||
}
|
||||
|
||||
// DownloadBootstrapper streams the local bootstrapper binary to other instances.
|
||||
func (s *debugdServer) DownloadBootstrapper(request *pb.DownloadBootstrapperRequest, stream pb.Debugd_DownloadBootstrapperServer) error {
|
||||
s.log.Infof("Sending bootstrapper to other instance")
|
||||
return s.streamer.ReadStream(debugd.BootstrapperDeployFilename, stream, debugd.Chunksize, true)
|
||||
}
|
||||
|
||||
// DownloadAuthorizedKeys streams the local authorized keys to other instances.
|
||||
func (s *debugdServer) DownloadAuthorizedKeys(_ context.Context, req *pb.DownloadAuthorizedKeysRequest) (*pb.DownloadAuthorizedKeysResponse, error) {
|
||||
s.log.Infof("Sending authorized keys to other instance")
|
||||
|
||||
var authKeys []*pb.AuthorizedKey
|
||||
for _, key := range s.ssh.GetAuthorizedKeys() {
|
||||
authKeys = append(authKeys, &pb.AuthorizedKey{
|
||||
Username: key.Username,
|
||||
KeyValue: key.PublicKey,
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.DownloadAuthorizedKeysResponse{Keys: authKeys}, nil
|
||||
}
|
||||
|
||||
// UploadSystemServiceUnits receives systemd service units, writes them to a service file and schedules a daemon-reload.
|
||||
func (s *debugdServer) UploadSystemServiceUnits(ctx context.Context, in *pb.UploadSystemdServiceUnitsRequest) (*pb.UploadSystemdServiceUnitsResponse, error) {
|
||||
s.log.Infof("Uploading systemd service units")
|
||||
for _, unit := range in.Units {
|
||||
if err := s.serviceManager.WriteSystemdUnitFile(ctx, deploy.SystemdUnit{Name: unit.Name, Contents: unit.Contents}); err != nil {
|
||||
return &pb.UploadSystemdServiceUnitsResponse{Status: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_FAILURE}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &pb.UploadSystemdServiceUnitsResponse{Status: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS}, nil
|
||||
}
|
||||
|
||||
// Start will start the gRPC server and block.
|
||||
func Start(log *logger.Logger, wg *sync.WaitGroup, serv pb.DebugdServer) {
|
||||
defer wg.Done()
|
||||
|
||||
grpcLog := log.Named("gRPC")
|
||||
grpcLog.WithIncreasedLevel(zap.WarnLevel).ReplaceGRPCLogger()
|
||||
|
||||
grpcServer := grpc.NewServer(
|
||||
grpcLog.GetServerStreamInterceptor(),
|
||||
grpcLog.GetServerUnaryInterceptor(),
|
||||
grpc.KeepaliveParams(keepalive.ServerParameters{Time: 15 * time.Second}),
|
||||
)
|
||||
pb.RegisterDebugdServer(grpcServer, serv)
|
||||
lis, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", strconv.Itoa(constants.DebugdPort)))
|
||||
if err != nil {
|
||||
log.With(zap.Error(err)).Fatalf("Listening failed")
|
||||
}
|
||||
log.Infof("gRPC server is waiting for connections")
|
||||
grpcServer.Serve(lis)
|
||||
}
|
||||
|
||||
type sshDeployer interface {
|
||||
DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error
|
||||
GetAuthorizedKeys() []ssh.UserKey
|
||||
}
|
||||
|
||||
type serviceManager interface {
|
||||
SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error
|
||||
WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error
|
||||
}
|
||||
|
||||
type streamer interface {
|
||||
WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error
|
||||
ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error
|
||||
}
|
490
debugd/internal/debugd/server/server_test.go
Normal file
490
debugd/internal/debugd/server/server_test.go
Normal file
|
@ -0,0 +1,490 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/debugd/internal/bootstrapper"
|
||||
"github.com/edgelesssys/constellation/debugd/internal/debugd/deploy"
|
||||
pb "github.com/edgelesssys/constellation/debugd/service"
|
||||
"github.com/edgelesssys/constellation/internal/constants"
|
||||
"github.com/edgelesssys/constellation/internal/deploy/ssh"
|
||||
"github.com/edgelesssys/constellation/internal/grpc/testdialer"
|
||||
"github.com/edgelesssys/constellation/internal/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestUploadAuthorizedKeys(t *testing.T) {
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
|
||||
testCases := map[string]struct {
|
||||
ssh stubSSHDeployer
|
||||
serviceManager stubServiceManager
|
||||
request *pb.UploadAuthorizedKeysRequest
|
||||
wantErr bool
|
||||
wantResponseStatus pb.UploadAuthorizedKeysStatus
|
||||
wantKeys []ssh.UserKey
|
||||
}{
|
||||
"upload authorized keys works": {
|
||||
request: &pb.UploadAuthorizedKeysRequest{
|
||||
Keys: []*pb.AuthorizedKey{
|
||||
{
|
||||
Username: "testuser",
|
||||
KeyValue: "teskey",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantResponseStatus: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_SUCCESS,
|
||||
wantKeys: []ssh.UserKey{
|
||||
{
|
||||
Username: "testuser",
|
||||
PublicKey: "teskey",
|
||||
},
|
||||
},
|
||||
},
|
||||
"deploy fails": {
|
||||
request: &pb.UploadAuthorizedKeysRequest{
|
||||
Keys: []*pb.AuthorizedKey{
|
||||
{
|
||||
Username: "testuser",
|
||||
KeyValue: "teskey",
|
||||
},
|
||||
},
|
||||
},
|
||||
ssh: stubSSHDeployer{deployErr: errors.New("ssh key deployment error")},
|
||||
wantResponseStatus: pb.UploadAuthorizedKeysStatus_UPLOAD_AUTHORIZED_KEYS_FAILURE,
|
||||
wantKeys: []ssh.UserKey{
|
||||
{
|
||||
Username: "testuser",
|
||||
PublicKey: "teskey",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
ssh: &tc.ssh,
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &fakeStreamer{},
|
||||
}
|
||||
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
resp, err := client.UploadAuthorizedKeys(context.Background(), tc.request)
|
||||
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResponseStatus, resp.Status)
|
||||
assert.ElementsMatch(tc.ssh.sshKeys, tc.wantKeys)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUploadBootstrapper(t *testing.T) {
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
|
||||
testCases := map[string]struct {
|
||||
ssh stubSSHDeployer
|
||||
serviceManager stubServiceManager
|
||||
streamer fakeStreamer
|
||||
uploadChunks [][]byte
|
||||
wantErr bool
|
||||
wantResponseStatus pb.UploadBootstrapperStatus
|
||||
wantFile bool
|
||||
wantChunks [][]byte
|
||||
}{
|
||||
"upload works": {
|
||||
uploadChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
wantFile: true,
|
||||
wantChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_SUCCESS,
|
||||
},
|
||||
"recv fails": {
|
||||
streamer: fakeStreamer{
|
||||
writeStreamErr: errors.New("recv error"),
|
||||
},
|
||||
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_UPLOAD_FAILED,
|
||||
wantErr: true,
|
||||
},
|
||||
"starting bootstrapper fails": {
|
||||
uploadChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
serviceManager: stubServiceManager{
|
||||
systemdActionErr: errors.New("starting bootstrapper error"),
|
||||
},
|
||||
wantFile: true,
|
||||
wantChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
wantResponseStatus: pb.UploadBootstrapperStatus_UPLOAD_BOOTSTRAPPER_START_FAILED,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
ssh: &tc.ssh,
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &tc.streamer,
|
||||
}
|
||||
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
stream, err := client.UploadBootstrapper(context.Background())
|
||||
require.NoError(err)
|
||||
require.NoError(fakeWrite(stream, tc.uploadChunks))
|
||||
resp, err := stream.CloseAndRecv()
|
||||
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantResponseStatus, resp.Status)
|
||||
if tc.wantFile {
|
||||
assert.Equal(tc.wantChunks, tc.streamer.writeStreamChunks)
|
||||
assert.Equal("/opt/bootstrapper", tc.streamer.writeStreamFilename)
|
||||
} else {
|
||||
assert.Empty(tc.streamer.writeStreamChunks)
|
||||
assert.Empty(tc.streamer.writeStreamFilename)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadBootstrapper(t *testing.T) {
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
|
||||
testCases := map[string]struct {
|
||||
ssh stubSSHDeployer
|
||||
serviceManager stubServiceManager
|
||||
request *pb.DownloadBootstrapperRequest
|
||||
streamer fakeStreamer
|
||||
wantErr bool
|
||||
wantChunks [][]byte
|
||||
}{
|
||||
"download works": {
|
||||
request: &pb.DownloadBootstrapperRequest{},
|
||||
streamer: fakeStreamer{
|
||||
readStreamChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
wantChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
},
|
||||
"download fails": {
|
||||
request: &pb.DownloadBootstrapperRequest{},
|
||||
streamer: fakeStreamer{
|
||||
readStreamErr: errors.New("read bootstrapper fails"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
ssh: &tc.ssh,
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &tc.streamer,
|
||||
}
|
||||
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
stream, err := client.DownloadBootstrapper(context.Background(), tc.request)
|
||||
require.NoError(err)
|
||||
chunks, err := fakeRead(stream)
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantChunks, chunks)
|
||||
assert.Equal("/opt/bootstrapper", tc.streamer.readStreamFilename)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDownloadAuthorizedKeys(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
deployer := &stubSSHDeployer{
|
||||
sshKeys: []ssh.UserKey{
|
||||
{Username: "test1", PublicKey: "foo"},
|
||||
{Username: "test2", PublicKey: "bar"},
|
||||
},
|
||||
}
|
||||
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
ssh: deployer,
|
||||
}
|
||||
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
defer grpcServ.GracefulStop()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
|
||||
resp, err := client.DownloadAuthorizedKeys(context.Background(), &pb.DownloadAuthorizedKeysRequest{})
|
||||
|
||||
assert.NoError(err)
|
||||
wantKeys := []*pb.AuthorizedKey{
|
||||
{Username: "test1", KeyValue: "foo"},
|
||||
{Username: "test2", KeyValue: "bar"},
|
||||
}
|
||||
assert.ElementsMatch(wantKeys, resp.Keys)
|
||||
}
|
||||
|
||||
func TestUploadSystemServiceUnits(t *testing.T) {
|
||||
endpoint := "192.0.2.1:" + strconv.Itoa(constants.DebugdPort)
|
||||
|
||||
testCases := map[string]struct {
|
||||
ssh stubSSHDeployer
|
||||
serviceManager stubServiceManager
|
||||
request *pb.UploadSystemdServiceUnitsRequest
|
||||
wantErr bool
|
||||
wantResponseStatus pb.UploadSystemdServiceUnitsStatus
|
||||
wantUnitFiles []deploy.SystemdUnit
|
||||
}{
|
||||
"upload systemd service units": {
|
||||
request: &pb.UploadSystemdServiceUnitsRequest{
|
||||
Units: []*pb.ServiceUnit{
|
||||
{
|
||||
Name: "test.service",
|
||||
Contents: "testcontents",
|
||||
},
|
||||
},
|
||||
},
|
||||
wantResponseStatus: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_SUCCESS,
|
||||
wantUnitFiles: []deploy.SystemdUnit{
|
||||
{
|
||||
Name: "test.service",
|
||||
Contents: "testcontents",
|
||||
},
|
||||
},
|
||||
},
|
||||
"writing fails": {
|
||||
request: &pb.UploadSystemdServiceUnitsRequest{
|
||||
Units: []*pb.ServiceUnit{
|
||||
{
|
||||
Name: "test.service",
|
||||
Contents: "testcontents",
|
||||
},
|
||||
},
|
||||
},
|
||||
serviceManager: stubServiceManager{
|
||||
writeSystemdUnitFileErr: errors.New("write error"),
|
||||
},
|
||||
wantResponseStatus: pb.UploadSystemdServiceUnitsStatus_UPLOAD_SYSTEMD_SERVICE_UNITS_FAILURE,
|
||||
wantUnitFiles: []deploy.SystemdUnit{
|
||||
{
|
||||
Name: "test.service",
|
||||
Contents: "testcontents",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
serv := debugdServer{
|
||||
log: logger.NewTest(t),
|
||||
ssh: &tc.ssh,
|
||||
serviceManager: &tc.serviceManager,
|
||||
streamer: &fakeStreamer{},
|
||||
}
|
||||
grpcServ, conn, err := setupServerWithConn(endpoint, &serv)
|
||||
require.NoError(err)
|
||||
defer conn.Close()
|
||||
client := pb.NewDebugdClient(conn)
|
||||
resp, err := client.UploadSystemServiceUnits(context.Background(), tc.request)
|
||||
|
||||
grpcServ.GracefulStop()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
require.NotNil(resp.Status)
|
||||
assert.Equal(tc.wantResponseStatus, resp.Status)
|
||||
assert.ElementsMatch(tc.wantUnitFiles, tc.serviceManager.unitFiles)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubSSHDeployer struct {
|
||||
sshKeys []ssh.UserKey
|
||||
|
||||
deployErr error
|
||||
}
|
||||
|
||||
func (s *stubSSHDeployer) DeployAuthorizedKey(ctx context.Context, sshKey ssh.UserKey) error {
|
||||
s.sshKeys = append(s.sshKeys, sshKey)
|
||||
|
||||
return s.deployErr
|
||||
}
|
||||
|
||||
func (s *stubSSHDeployer) GetAuthorizedKeys() []ssh.UserKey {
|
||||
return s.sshKeys
|
||||
}
|
||||
|
||||
type stubServiceManager struct {
|
||||
requests []deploy.ServiceManagerRequest
|
||||
unitFiles []deploy.SystemdUnit
|
||||
systemdActionErr error
|
||||
writeSystemdUnitFileErr error
|
||||
}
|
||||
|
||||
func (s *stubServiceManager) SystemdAction(ctx context.Context, request deploy.ServiceManagerRequest) error {
|
||||
s.requests = append(s.requests, request)
|
||||
return s.systemdActionErr
|
||||
}
|
||||
|
||||
func (s *stubServiceManager) WriteSystemdUnitFile(ctx context.Context, unit deploy.SystemdUnit) error {
|
||||
s.unitFiles = append(s.unitFiles, unit)
|
||||
return s.writeSystemdUnitFileErr
|
||||
}
|
||||
|
||||
type netDialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
func dial(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) {
|
||||
return grpc.DialContext(ctx, target,
|
||||
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return dialer.DialContext(ctx, "tcp", addr)
|
||||
}),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
}
|
||||
|
||||
type fakeStreamer struct {
|
||||
writeStreamChunks [][]byte
|
||||
writeStreamFilename string
|
||||
writeStreamErr error
|
||||
readStreamChunks [][]byte
|
||||
readStreamFilename string
|
||||
readStreamErr error
|
||||
}
|
||||
|
||||
func (f *fakeStreamer) WriteStream(filename string, stream bootstrapper.ReadChunkStream, showProgress bool) error {
|
||||
f.writeStreamFilename = filename
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return f.writeStreamErr
|
||||
}
|
||||
return fmt.Errorf("reading stream: %w", err)
|
||||
}
|
||||
f.writeStreamChunks = append(f.writeStreamChunks, chunk.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeStreamer) ReadStream(filename string, stream bootstrapper.WriteChunkStream, chunksize uint, showProgress bool) error {
|
||||
f.readStreamFilename = filename
|
||||
for _, chunk := range f.readStreamChunks {
|
||||
if err := stream.Send(&pb.Chunk{Content: chunk}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return f.readStreamErr
|
||||
}
|
||||
|
||||
func setupServerWithConn(endpoint string, serv *debugdServer) (*grpc.Server, *grpc.ClientConn, error) {
|
||||
dialer := testdialer.NewBufconnDialer()
|
||||
grpcServ := grpc.NewServer()
|
||||
pb.RegisterDebugdServer(grpcServ, serv)
|
||||
lis := dialer.GetListener(endpoint)
|
||||
go grpcServ.Serve(lis)
|
||||
|
||||
conn, err := dial(context.Background(), dialer, endpoint)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return grpcServ, conn, nil
|
||||
}
|
||||
|
||||
func fakeWrite(stream bootstrapper.WriteChunkStream, chunks [][]byte) error {
|
||||
for _, chunk := range chunks {
|
||||
err := stream.Send(&pb.Chunk{
|
||||
Content: chunk,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func fakeRead(stream bootstrapper.ReadChunkStream) ([][]byte, error) {
|
||||
var chunks [][]byte
|
||||
for {
|
||||
chunk, err := stream.Recv()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return chunks, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
chunks = append(chunks, chunk.Content)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue