constellation/debugd/internal/debugd/deploy/download.go
2023-01-25 09:58:56 +01:00

161 lines
4.4 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package deploy
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/edgelesssys/constellation/v2/internal/constants"
"github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
// Download downloads a bootstrapper from a given debugd instance.
type Download struct {
log *logger.Logger
dialer NetDialer
transfer fileTransferer
serviceManager serviceManager
info infoSetter
}
// New creates a new Download.
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager,
transfer fileTransferer, info infoSetter,
) *Download {
return &Download{
log: log,
dialer: dialer,
transfer: transfer,
info: info,
serviceManager: serviceManager,
}
}
// DownloadInfo will try to download the info from another instance.
func (d *Download) DownloadInfo(ctx context.Context, ip string) error {
if d.info.Received() {
return nil
}
log := d.log.With(zap.String("ip", ip))
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
client, closer, err := d.newClient(ctx, serverAddr, log)
if err != nil {
return err
}
defer closer.Close()
log.Infof("Trying to download info")
resp, err := client.GetInfo(ctx, &pb.GetInfoRequest{})
if err != nil {
return fmt.Errorf("getting info from other instance: %w", err)
}
log.Infof("Successfully downloaded info")
return d.info.SetProto(resp.Info)
}
// DownloadDeployment will open a new grpc connection to another instance, attempting to download files from that instance.
func (d *Download) DownloadDeployment(ctx context.Context, ip string) error {
log := d.log.With(zap.String("ip", ip))
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
client, closer, err := d.newClient(ctx, serverAddr, log)
if err != nil {
return err
}
defer closer.Close()
log.Infof("Trying to download files")
stream, err := client.DownloadFiles(ctx, &pb.DownloadFilesRequest{})
if err != nil {
return fmt.Errorf("starting file download from other instance: %w", err)
}
err = d.transfer.RecvFiles(stream)
switch {
case err == nil:
d.log.Infof("Downloading files succeeded")
case errors.Is(err, filetransfer.ErrReceiveRunning):
d.log.Warnf("Download already in progress")
return err
case errors.Is(err, filetransfer.ErrReceiveFinished):
d.log.Warnf("Download already finished")
return nil
default:
d.log.With(zap.Error(err)).Errorf("Downloading files failed")
return err
}
files := d.transfer.GetFiles()
for _, file := range files {
if file.OverrideServiceUnit == "" {
continue
}
if err := d.serviceManager.OverrideServiceUnitExecStart(
ctx, file.OverrideServiceUnit, file.TargetPath,
); err != nil {
// continue on error to allow other units to be overridden
d.log.With(zap.Error(err)).Errorf("Failed to override service unit %s", file.OverrideServiceUnit)
}
}
return nil
}
func (d *Download) newClient(ctx context.Context, serverAddr string, log *logger.Logger) (pb.DebugdClient, io.Closer, error) {
log.Infof("Connecting to server")
conn, err := d.dial(ctx, serverAddr)
if err != nil {
return nil, nil, fmt.Errorf("connecting to other instance via gRPC: %w", err)
}
return pb.NewDebugdClient(conn), conn, nil
}
func (d *Download) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
return grpc.DialContext(ctx, target,
d.grpcWithDialer(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
}
func (d *Download) grpcWithDialer() grpc.DialOption {
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return d.dialer.DialContext(ctx, "tcp", addr)
})
}
type infoSetter interface {
SetProto(infos []*pb.Info) error
Received() bool
}
type serviceManager interface {
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error
}
type fileTransferer interface {
RecvFiles(stream filetransfer.RecvFilesStream) error
GetFiles() []filetransfer.FileStat
}
// NetDialer can open a net.Conn.
type NetDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}