mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-01-07 05:38:03 -05:00
227 lines
6.3 KiB
Go
227 lines
6.3 KiB
Go
/*
|
|
Copyright (c) Edgeless Systems GmbH
|
|
|
|
SPDX-License-Identifier: AGPL-3.0-only
|
|
*/
|
|
|
|
// Package filetransfer implements the exchange of files between cdgb <-> debugd
|
|
// and between debugd <-> debugd pairs.
|
|
package filetransfer
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"log/slog"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
|
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
|
|
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
|
)
|
|
|
|
// RecvFilesStream is a stream that receives FileTransferMessages.
|
|
type RecvFilesStream interface {
|
|
Recv() (*pb.FileTransferMessage, error)
|
|
}
|
|
|
|
// SendFilesStream is a stream that sends FileTransferMessages.
|
|
type SendFilesStream interface {
|
|
Send(*pb.FileTransferMessage) error
|
|
}
|
|
|
|
// FileTransferer manages sending and receiving of files.
|
|
type FileTransferer struct {
|
|
fileMux sync.RWMutex
|
|
log *slog.Logger
|
|
receiveStarted bool
|
|
receiveFinished atomic.Bool
|
|
files []FileStat
|
|
streamer streamReadWriter
|
|
showProgress bool
|
|
}
|
|
|
|
// New creates a new FileTransferer.
|
|
func New(log *slog.Logger, streamer streamReadWriter, showProgress bool) *FileTransferer {
|
|
return &FileTransferer{
|
|
log: log,
|
|
streamer: streamer,
|
|
showProgress: showProgress,
|
|
}
|
|
}
|
|
|
|
// SendFiles sends files to the given stream.
|
|
// If the FileTransferer has not received any files to send, an error is returned.
|
|
func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
|
|
if !s.receiveFinished.Load() {
|
|
return errors.New("cannot send files before receiving them")
|
|
}
|
|
|
|
s.fileMux.RLock()
|
|
defer s.fileMux.RUnlock()
|
|
|
|
for _, file := range s.files {
|
|
if err := s.handleFileSend(stream, file); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RecvFiles receives files from the given stream.
|
|
func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
|
|
s.fileMux.Lock()
|
|
defer s.fileMux.Unlock()
|
|
if err := s.startRecv(); err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err != nil {
|
|
s.abortRecv()
|
|
} else {
|
|
s.finishRecv()
|
|
}
|
|
}()
|
|
var done bool
|
|
for !done && err == nil {
|
|
done, err = s.handleFileRecv(stream)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// GetFiles returns the a copy of the list of files that have been received.
|
|
func (s *FileTransferer) GetFiles() []FileStat {
|
|
s.fileMux.RLock()
|
|
defer s.fileMux.RUnlock()
|
|
res := make([]FileStat, len(s.files))
|
|
copy(res, s.files)
|
|
return res
|
|
}
|
|
|
|
// SetFiles sets the list of files that can be sent.
|
|
// This function is used for a sender which has not received any files through
|
|
// this FileTransferer i.e. the CLI.
|
|
func (s *FileTransferer) SetFiles(files []FileStat) {
|
|
s.fileMux.Lock()
|
|
defer s.fileMux.Unlock()
|
|
res := make([]FileStat, len(files))
|
|
copy(res, files)
|
|
s.files = res
|
|
s.receiveFinished.Store(true)
|
|
}
|
|
|
|
func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error {
|
|
header := &pb.FileTransferMessage_Header{
|
|
Header: &pb.FileTransferHeader{
|
|
TargetPath: file.TargetPath,
|
|
Mode: uint32(file.Mode),
|
|
},
|
|
}
|
|
if file.OverrideServiceUnit != "" {
|
|
header.Header.OverrideServiceUnit = &file.OverrideServiceUnit
|
|
}
|
|
if err := stream.Send(&pb.FileTransferMessage{Kind: header}); err != nil {
|
|
return err
|
|
}
|
|
sendChunkStream := &sendChunkStream{stream: stream}
|
|
return s.streamer.ReadStream(file.SourcePath, sendChunkStream, debugd.Chunksize, s.showProgress)
|
|
}
|
|
|
|
// handleFileRecv handles the file receive of a single file.
|
|
// It returns true if the stream is finished (all of the file consumed) and false otherwise.
|
|
func (s *FileTransferer) handleFileRecv(stream RecvFilesStream) (bool, error) {
|
|
// first message must be a header message
|
|
msg, err := stream.Recv()
|
|
switch {
|
|
case err == nil:
|
|
// nop
|
|
case errors.Is(err, io.EOF):
|
|
return true, nil // stream is finished
|
|
default:
|
|
return false, err
|
|
}
|
|
header := msg.GetHeader()
|
|
if header == nil {
|
|
return false, errors.New("first message must be a header message")
|
|
}
|
|
s.log.Info(fmt.Sprintf("Starting file receive of %q", header.TargetPath))
|
|
s.addFile(FileStat{
|
|
SourcePath: header.TargetPath,
|
|
TargetPath: header.TargetPath,
|
|
Mode: fs.FileMode(header.Mode),
|
|
OverrideServiceUnit: func() string {
|
|
if header.OverrideServiceUnit != nil {
|
|
return *header.OverrideServiceUnit
|
|
}
|
|
return ""
|
|
}(),
|
|
})
|
|
recvChunkStream := &recvChunkStream{stream: stream}
|
|
if err := s.streamer.WriteStream(header.TargetPath, recvChunkStream, s.showProgress); err != nil {
|
|
s.log.With(slog.Any("error", err)).Error(fmt.Sprintf("Receive of file %q failed", header.TargetPath))
|
|
return false, err
|
|
}
|
|
s.log.Info(fmt.Sprintf("Finished file receive of %q", header.TargetPath))
|
|
return false, nil
|
|
}
|
|
|
|
// startRecv marks the file receive as started. It returns an error if receiving has already started.
|
|
func (s *FileTransferer) startRecv() error {
|
|
switch {
|
|
case s.receiveFinished.Load():
|
|
return ErrReceiveFinished
|
|
case s.receiveStarted:
|
|
return ErrReceiveRunning
|
|
}
|
|
s.receiveStarted = true
|
|
return nil
|
|
}
|
|
|
|
// abortRecv marks the file receive as failed.
|
|
// This allows for a retry of the file receive.
|
|
func (s *FileTransferer) abortRecv() {
|
|
s.receiveStarted = false
|
|
s.files = nil
|
|
}
|
|
|
|
// finishRecv marks the file receive as completed.
|
|
// This allows other debugd instances to request files from this server.
|
|
func (s *FileTransferer) finishRecv() {
|
|
s.receiveStarted = false
|
|
s.receiveFinished.Store(true)
|
|
}
|
|
|
|
// addFile adds a file to the list of received files.
|
|
func (s *FileTransferer) addFile(file FileStat) {
|
|
s.files = append(s.files, file)
|
|
}
|
|
|
|
// FileStat contains the metadata of a file that can be up/downloaded.
|
|
type FileStat struct {
|
|
SourcePath string
|
|
TargetPath string
|
|
Mode fs.FileMode
|
|
OverrideServiceUnit string // optional name of the service unit to override
|
|
}
|
|
|
|
var (
|
|
// ErrReceiveRunning is returned if a file receive is already running.
|
|
ErrReceiveRunning = errors.New("receive already running")
|
|
// ErrReceiveFinished is returned if a file receive has already finished.
|
|
ErrReceiveFinished = errors.New("receive finished")
|
|
)
|
|
|
|
const (
|
|
// ShowProgress indicates that progress should be shown.
|
|
ShowProgress = true
|
|
// DontShowProgress indicates that progress should not be shown.
|
|
DontShowProgress = false
|
|
)
|
|
|
|
type streamReadWriter interface {
|
|
WriteStream(filename string, stream streamer.ReadChunkStream, showProgress bool) error
|
|
ReadStream(filename string, stream streamer.WriteChunkStream, chunksize uint, showProgress bool) error
|
|
}
|