constellation/debugd/internal/filetransfer/filetransfer.go
2024-02-08 14:20:01 +00:00

227 lines
6.3 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
// Package filetransfer implements the exchange of files between cdgb <-> debugd
// and between debugd <-> debugd pairs.
package filetransfer
import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"sync"
"sync/atomic"
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
)
// RecvFilesStream is a stream that receives FileTransferMessages.
type RecvFilesStream interface {
Recv() (*pb.FileTransferMessage, error)
}
// SendFilesStream is a stream that sends FileTransferMessages.
type SendFilesStream interface {
Send(*pb.FileTransferMessage) error
}
// FileTransferer manages sending and receiving of files.
type FileTransferer struct {
fileMux sync.RWMutex
log *slog.Logger
receiveStarted bool
receiveFinished atomic.Bool
files []FileStat
streamer streamReadWriter
showProgress bool
}
// New creates a new FileTransferer.
func New(log *slog.Logger, streamer streamReadWriter, showProgress bool) *FileTransferer {
return &FileTransferer{
log: log,
streamer: streamer,
showProgress: showProgress,
}
}
// SendFiles sends files to the given stream.
// If the FileTransferer has not received any files to send, an error is returned.
func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
if !s.receiveFinished.Load() {
return errors.New("cannot send files before receiving them")
}
s.fileMux.RLock()
defer s.fileMux.RUnlock()
for _, file := range s.files {
if err := s.handleFileSend(stream, file); err != nil {
return err
}
}
return nil
}
// RecvFiles receives files from the given stream.
func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
s.fileMux.Lock()
defer s.fileMux.Unlock()
if err := s.startRecv(); err != nil {
return err
}
defer func() {
if err != nil {
s.abortRecv()
} else {
s.finishRecv()
}
}()
var done bool
for !done && err == nil {
done, err = s.handleFileRecv(stream)
}
return err
}
// GetFiles returns the a copy of the list of files that have been received.
func (s *FileTransferer) GetFiles() []FileStat {
s.fileMux.RLock()
defer s.fileMux.RUnlock()
res := make([]FileStat, len(s.files))
copy(res, s.files)
return res
}
// SetFiles sets the list of files that can be sent.
// This function is used for a sender which has not received any files through
// this FileTransferer i.e. the CLI.
func (s *FileTransferer) SetFiles(files []FileStat) {
s.fileMux.Lock()
defer s.fileMux.Unlock()
res := make([]FileStat, len(files))
copy(res, files)
s.files = res
s.receiveFinished.Store(true)
}
func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error {
header := &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: file.TargetPath,
Mode: uint32(file.Mode),
},
}
if file.OverrideServiceUnit != "" {
header.Header.OverrideServiceUnit = &file.OverrideServiceUnit
}
if err := stream.Send(&pb.FileTransferMessage{Kind: header}); err != nil {
return err
}
sendChunkStream := &sendChunkStream{stream: stream}
return s.streamer.ReadStream(file.SourcePath, sendChunkStream, debugd.Chunksize, s.showProgress)
}
// handleFileRecv handles the file receive of a single file.
// It returns true if the stream is finished (all of the file consumed) and false otherwise.
func (s *FileTransferer) handleFileRecv(stream RecvFilesStream) (bool, error) {
// first message must be a header message
msg, err := stream.Recv()
switch {
case err == nil:
// nop
case errors.Is(err, io.EOF):
return true, nil // stream is finished
default:
return false, err
}
header := msg.GetHeader()
if header == nil {
return false, errors.New("first message must be a header message")
}
s.log.Info(fmt.Sprintf("Starting file receive of %q", header.TargetPath))
s.addFile(FileStat{
SourcePath: header.TargetPath,
TargetPath: header.TargetPath,
Mode: fs.FileMode(header.Mode),
OverrideServiceUnit: func() string {
if header.OverrideServiceUnit != nil {
return *header.OverrideServiceUnit
}
return ""
}(),
})
recvChunkStream := &recvChunkStream{stream: stream}
if err := s.streamer.WriteStream(header.TargetPath, recvChunkStream, s.showProgress); err != nil {
s.log.With(slog.Any("error", err)).Error(fmt.Sprintf("Receive of file %q failed", header.TargetPath))
return false, err
}
s.log.Info(fmt.Sprintf("Finished file receive of %q", header.TargetPath))
return false, nil
}
// startRecv marks the file receive as started. It returns an error if receiving has already started.
func (s *FileTransferer) startRecv() error {
switch {
case s.receiveFinished.Load():
return ErrReceiveFinished
case s.receiveStarted:
return ErrReceiveRunning
}
s.receiveStarted = true
return nil
}
// abortRecv marks the file receive as failed.
// This allows for a retry of the file receive.
func (s *FileTransferer) abortRecv() {
s.receiveStarted = false
s.files = nil
}
// finishRecv marks the file receive as completed.
// This allows other debugd instances to request files from this server.
func (s *FileTransferer) finishRecv() {
s.receiveStarted = false
s.receiveFinished.Store(true)
}
// addFile adds a file to the list of received files.
func (s *FileTransferer) addFile(file FileStat) {
s.files = append(s.files, file)
}
// FileStat contains the metadata of a file that can be up/downloaded.
type FileStat struct {
SourcePath string
TargetPath string
Mode fs.FileMode
OverrideServiceUnit string // optional name of the service unit to override
}
var (
// ErrReceiveRunning is returned if a file receive is already running.
ErrReceiveRunning = errors.New("receive already running")
// ErrReceiveFinished is returned if a file receive has already finished.
ErrReceiveFinished = errors.New("receive finished")
)
const (
// ShowProgress indicates that progress should be shown.
ShowProgress = true
// DontShowProgress indicates that progress should not be shown.
DontShowProgress = false
)
type streamReadWriter interface {
WriteStream(filename string, stream streamer.ReadChunkStream, showProgress bool) error
ReadStream(filename string, stream streamer.WriteChunkStream, chunksize uint, showProgress bool) error
}