mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-05-20 23:20:35 -04:00
161 lines
4.4 KiB
Go
161 lines
4.4 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
package deploy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
|
|
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer"
|
|
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
|
"github.com/edgelesssys/constellation/v2/internal/constants"
|
|
"github.com/edgelesssys/constellation/v2/internal/logger"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
)
|
|
|
|
// Download downloads a bootstrapper from a given debugd instance.
|
|
type Download struct {
|
|
log *logger.Logger
|
|
dialer NetDialer
|
|
transfer fileTransferer
|
|
serviceManager serviceManager
|
|
info infoSetter
|
|
}
|
|
|
|
// New creates a new Download.
|
|
func New(log *logger.Logger, dialer NetDialer, serviceManager serviceManager,
|
|
transfer fileTransferer, info infoSetter,
|
|
) *Download {
|
|
return &Download{
|
|
log: log,
|
|
dialer: dialer,
|
|
transfer: transfer,
|
|
info: info,
|
|
serviceManager: serviceManager,
|
|
}
|
|
}
|
|
|
|
// DownloadInfo will try to download the info from another instance.
|
|
func (d *Download) DownloadInfo(ctx context.Context, ip string) error {
|
|
if d.info.Received() {
|
|
return nil
|
|
}
|
|
|
|
log := d.log.With(zap.String("ip", ip))
|
|
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
|
|
|
|
client, closer, err := d.newClient(ctx, serverAddr, log)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closer.Close()
|
|
|
|
log.Infof("Trying to download info")
|
|
resp, err := client.GetInfo(ctx, &pb.GetInfoRequest{})
|
|
if err != nil {
|
|
return fmt.Errorf("getting info from other instance: %w", err)
|
|
}
|
|
log.Infof("Successfully downloaded info")
|
|
|
|
return d.info.SetProto(resp.Info)
|
|
}
|
|
|
|
// DownloadDeployment will open a new grpc connection to another instance, attempting to download files from that instance.
|
|
func (d *Download) DownloadDeployment(ctx context.Context, ip string) error {
|
|
log := d.log.With(zap.String("ip", ip))
|
|
serverAddr := net.JoinHostPort(ip, strconv.Itoa(constants.DebugdPort))
|
|
|
|
client, closer, err := d.newClient(ctx, serverAddr, log)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closer.Close()
|
|
|
|
log.Infof("Trying to download files")
|
|
stream, err := client.DownloadFiles(ctx, &pb.DownloadFilesRequest{})
|
|
if err != nil {
|
|
return fmt.Errorf("starting file download from other instance: %w", err)
|
|
}
|
|
|
|
err = d.transfer.RecvFiles(stream)
|
|
switch {
|
|
case err == nil:
|
|
d.log.Infof("Downloading files succeeded")
|
|
case errors.Is(err, filetransfer.ErrReceiveRunning):
|
|
d.log.Warnf("Download already in progress")
|
|
return err
|
|
case errors.Is(err, filetransfer.ErrReceiveFinished):
|
|
d.log.Warnf("Download already finished")
|
|
return nil
|
|
default:
|
|
d.log.With(zap.Error(err)).Errorf("Downloading files failed")
|
|
return err
|
|
}
|
|
|
|
files := d.transfer.GetFiles()
|
|
for _, file := range files {
|
|
if file.OverrideServiceUnit == "" {
|
|
continue
|
|
}
|
|
if err := d.serviceManager.OverrideServiceUnitExecStart(
|
|
ctx, file.OverrideServiceUnit, file.TargetPath,
|
|
); err != nil {
|
|
// continue on error to allow other units to be overridden
|
|
d.log.With(zap.Error(err)).Errorf("Failed to override service unit %s", file.OverrideServiceUnit)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (d *Download) newClient(ctx context.Context, serverAddr string, log *logger.Logger) (pb.DebugdClient, io.Closer, error) {
|
|
log.Infof("Connecting to server")
|
|
conn, err := d.dial(ctx, serverAddr)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("connecting to other instance via gRPC: %w", err)
|
|
}
|
|
return pb.NewDebugdClient(conn), conn, nil
|
|
}
|
|
|
|
func (d *Download) dial(ctx context.Context, target string) (*grpc.ClientConn, error) {
|
|
return grpc.DialContext(ctx, target,
|
|
d.grpcWithDialer(),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
}
|
|
|
|
func (d *Download) grpcWithDialer() grpc.DialOption {
|
|
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
return d.dialer.DialContext(ctx, "tcp", addr)
|
|
})
|
|
}
|
|
|
|
type infoSetter interface {
|
|
SetProto(infos []*pb.Info) error
|
|
Received() bool
|
|
}
|
|
|
|
type serviceManager interface {
|
|
SystemdAction(ctx context.Context, request ServiceManagerRequest) error
|
|
OverrideServiceUnitExecStart(ctx context.Context, unitName string, execStart string) error
|
|
}
|
|
|
|
type fileTransferer interface {
|
|
RecvFiles(stream filetransfer.RecvFilesStream) error
|
|
GetFiles() []filetransfer.FileStat
|
|
}
|
|
|
|
// NetDialer can open a net.Conn.
|
|
type NetDialer interface {
|
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
}
|