mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-11-13 00:50:38 -05:00
debugd: implement upload of multiple binaries
This commit is contained in:
parent
e6ac8e2a91
commit
6f56ed69f8
21 changed files with 2040 additions and 661 deletions
47
debugd/internal/filetransfer/chunkstream.go
Normal file
47
debugd/internal/filetransfer/chunkstream.go
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
)
|
||||
|
||||
// recvChunkStream is a wrapper around a RecvFilesStream that only returns chunks.
|
||||
type recvChunkStream struct {
|
||||
stream RecvFilesStream
|
||||
}
|
||||
|
||||
// Recv receives a FileTransferMessage and returns the chunk.
|
||||
func (s *recvChunkStream) Recv() (*pb.Chunk, error) {
|
||||
msg, err := s.stream.Recv()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chunk := msg.GetChunk()
|
||||
if chunk == nil {
|
||||
return nil, errors.New("expected chunk")
|
||||
}
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
// sendChunkStream is a wrapper around a SendFilesStream that wraps chunks for every message.
|
||||
type sendChunkStream struct {
|
||||
stream SendFilesStream
|
||||
}
|
||||
|
||||
// Send wraps the given chunk in a FileTransferMessage and sends it.
|
||||
func (s *sendChunkStream) Send(chunk *pb.Chunk) error {
|
||||
chunkMessage := &pb.FileTransferMessage_Chunk{
|
||||
Chunk: chunk,
|
||||
}
|
||||
message := &pb.FileTransferMessage{
|
||||
Kind: chunkMessage,
|
||||
}
|
||||
return s.stream.Send(message)
|
||||
}
|
||||
134
debugd/internal/filetransfer/chunkstream_test.go
Normal file
134
debugd/internal/filetransfer/chunkstream_test.go
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRecv(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
stream stubRecvFilesStream
|
||||
wantChunk *pb.Chunk
|
||||
wantErr bool
|
||||
}{
|
||||
"chunk is received": {
|
||||
stream: stubRecvFilesStream{
|
||||
msg: &pb.FileTransferMessage{
|
||||
Kind: &pb.FileTransferMessage_Chunk{
|
||||
Chunk: &pb.Chunk{
|
||||
Content: []byte("test"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantChunk: &pb.Chunk{
|
||||
Content: []byte("test"),
|
||||
},
|
||||
},
|
||||
"wrong type": {
|
||||
stream: stubRecvFilesStream{
|
||||
msg: &pb.FileTransferMessage{
|
||||
Kind: &pb.FileTransferMessage_Header{},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"empty msg": {
|
||||
stream: stubRecvFilesStream{},
|
||||
wantErr: true,
|
||||
},
|
||||
"recv fails": {
|
||||
stream: stubRecvFilesStream{
|
||||
recvErr: errors.New("someErr"),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
stream := recvChunkStream{stream: &tc.stream}
|
||||
chunk, err := stream.Recv()
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantChunk, chunk)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSend(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
stream stubSendFilesStream
|
||||
chunk *pb.Chunk
|
||||
wantMsgs []*pb.FileTransferMessage
|
||||
wantErr bool
|
||||
}{
|
||||
"chunk is wrapped correctly": {
|
||||
chunk: &pb.Chunk{
|
||||
Content: []byte("test"),
|
||||
},
|
||||
wantMsgs: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Chunk{
|
||||
Chunk: &pb.Chunk{
|
||||
Content: []byte("test"),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
stream := sendChunkStream{stream: &tc.stream}
|
||||
err := stream.Send(tc.chunk)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.EqualValues(tc.wantMsgs, tc.stream.msgs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type stubRecvFilesStream struct {
|
||||
msg *pb.FileTransferMessage
|
||||
recvErr error
|
||||
}
|
||||
|
||||
func (s *stubRecvFilesStream) Recv() (*pb.FileTransferMessage, error) {
|
||||
return s.msg, s.recvErr
|
||||
}
|
||||
|
||||
type stubSendFilesStream struct {
|
||||
msgs []*pb.FileTransferMessage
|
||||
sendErr error
|
||||
}
|
||||
|
||||
func (s *stubSendFilesStream) Send(msg *pb.FileTransferMessage) error {
|
||||
s.msgs = append(s.msgs, msg)
|
||||
return s.sendErr
|
||||
}
|
||||
233
debugd/internal/filetransfer/filetransfer.go
Normal file
233
debugd/internal/filetransfer/filetransfer.go
Normal file
|
|
@ -0,0 +1,233 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/fs"
|
||||
"sync"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RecvFilesStream is a stream that receives FileTransferMessages.
|
||||
type RecvFilesStream interface {
|
||||
Recv() (*pb.FileTransferMessage, error)
|
||||
}
|
||||
|
||||
// SendFilesStream is a stream that sends FileTransferMessages.
|
||||
type SendFilesStream interface {
|
||||
Send(*pb.FileTransferMessage) error
|
||||
}
|
||||
|
||||
// FileTransferer manages sending and receiving of files.
|
||||
type FileTransferer struct {
|
||||
mux sync.RWMutex
|
||||
log *logger.Logger
|
||||
receiveStarted bool
|
||||
receiveFinished bool
|
||||
files []FileStat
|
||||
streamer streamReadWriter
|
||||
showProgress bool
|
||||
}
|
||||
|
||||
// New creates a new FileTransferer.
|
||||
func New(log *logger.Logger, streamer streamReadWriter, showProgress bool) *FileTransferer {
|
||||
return &FileTransferer{
|
||||
log: log,
|
||||
streamer: streamer,
|
||||
showProgress: showProgress,
|
||||
}
|
||||
}
|
||||
|
||||
// SendFiles sends files to the given stream.
|
||||
func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
if !s.receiveFinished {
|
||||
return errors.New("cannot send files before receiving them")
|
||||
}
|
||||
for _, file := range s.files {
|
||||
if err := s.handleFileSend(stream, file); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RecvFiles receives files from the given stream.
|
||||
func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
|
||||
if err := s.startRecv(); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
s.abortRecv()
|
||||
} else {
|
||||
s.finishRecv()
|
||||
}
|
||||
}()
|
||||
var done bool
|
||||
for !done && err == nil {
|
||||
done, err = s.handleFileRecv(stream)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetFiles returns the a copy of the list of files that have been received.
|
||||
func (s *FileTransferer) GetFiles() []FileStat {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
res := make([]FileStat, len(s.files))
|
||||
copy(res, s.files)
|
||||
return res
|
||||
}
|
||||
|
||||
// SetFiles sets the list of files that can be sent.
|
||||
func (s *FileTransferer) SetFiles(files []FileStat) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
res := make([]FileStat, len(files))
|
||||
copy(res, files)
|
||||
s.files = res
|
||||
s.receiveFinished = true
|
||||
}
|
||||
|
||||
// CanSend returns true if the file receive has finished.
|
||||
// This is called to determine if a debugd instance can request files from this server.
|
||||
func (s *FileTransferer) CanSend() bool {
|
||||
s.mux.RLock()
|
||||
defer s.mux.RUnlock()
|
||||
ret := s.receiveFinished
|
||||
return ret
|
||||
}
|
||||
|
||||
func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error {
|
||||
header := &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: file.TargetPath,
|
||||
Mode: uint32(file.Mode),
|
||||
},
|
||||
}
|
||||
if file.OverrideServiceUnit != "" {
|
||||
header.Header.OverrideServiceUnit = &file.OverrideServiceUnit
|
||||
}
|
||||
if err := stream.Send(&pb.FileTransferMessage{Kind: header}); err != nil {
|
||||
return err
|
||||
}
|
||||
sendChunkStream := &sendChunkStream{stream: stream}
|
||||
return s.streamer.ReadStream(file.SourcePath, sendChunkStream, debugd.Chunksize, s.showProgress)
|
||||
}
|
||||
|
||||
// handleFileRecv handles the file receive of a single file.
|
||||
// It returns true if the stream is finished (all of the file consumed) and false otherwise.
|
||||
func (s *FileTransferer) handleFileRecv(stream RecvFilesStream) (bool, error) {
|
||||
// first message must be a header message
|
||||
msg, err := stream.Recv()
|
||||
switch {
|
||||
case err == nil:
|
||||
// nop
|
||||
case errors.Is(err, io.EOF):
|
||||
return true, nil // stream is finished
|
||||
default:
|
||||
return false, err
|
||||
}
|
||||
header := msg.GetHeader()
|
||||
if header == nil {
|
||||
return false, errors.New("first message must be a header message")
|
||||
}
|
||||
s.log.Infof("Starting file receive of %q", header.TargetPath)
|
||||
s.addFile(FileStat{
|
||||
SourcePath: header.TargetPath,
|
||||
TargetPath: header.TargetPath,
|
||||
Mode: fs.FileMode(header.Mode),
|
||||
OverrideServiceUnit: func() string {
|
||||
if header.OverrideServiceUnit != nil {
|
||||
return *header.OverrideServiceUnit
|
||||
}
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
recvChunkStream := &recvChunkStream{stream: stream}
|
||||
if err := s.streamer.WriteStream(header.TargetPath, recvChunkStream, s.showProgress); err != nil {
|
||||
s.log.With(zap.Error(err)).Errorf("Receive of file %q failed", header.TargetPath)
|
||||
return false, err
|
||||
}
|
||||
s.log.Infof("Finished file receive of %q", header.TargetPath)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// startRecv marks the file receive as started. It returns an error if receiving has already started.
|
||||
func (s *FileTransferer) startRecv() error {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
switch {
|
||||
case s.receiveFinished:
|
||||
return ErrReceiveFinished
|
||||
case s.receiveStarted:
|
||||
return ErrReceiveRunning
|
||||
}
|
||||
s.receiveStarted = true
|
||||
return nil
|
||||
}
|
||||
|
||||
// abortRecv marks the file receive as failed.
|
||||
// This allows for a retry of the file receive.
|
||||
func (s *FileTransferer) abortRecv() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.receiveStarted = false
|
||||
s.files = nil
|
||||
}
|
||||
|
||||
// finishRecv marks the file receive as completed.
|
||||
// This allows other debugd instances to request files from this server.
|
||||
func (s *FileTransferer) finishRecv() {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.receiveStarted = false
|
||||
s.receiveFinished = true
|
||||
}
|
||||
|
||||
// addFile adds a file to the list of received files.
|
||||
func (s *FileTransferer) addFile(file FileStat) {
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
s.files = append(s.files, file)
|
||||
}
|
||||
|
||||
// FileStat contains the metadata of a file that can be up/downloaded.
|
||||
type FileStat struct {
|
||||
SourcePath string
|
||||
TargetPath string
|
||||
Mode fs.FileMode
|
||||
OverrideServiceUnit string // optional name of the service unit to override
|
||||
}
|
||||
|
||||
var (
|
||||
// ErrReceiveRunning is returned if a file receive is already running.
|
||||
ErrReceiveRunning = errors.New("receive already running")
|
||||
// ErrReceiveFinished is returned if a file receive has already finished.
|
||||
ErrReceiveFinished = errors.New("receive finished")
|
||||
)
|
||||
|
||||
const (
|
||||
// ShowProgress indicates that progress should be shown.
|
||||
ShowProgress = true
|
||||
// DontShowProgress indicates that progress should not be shown.
|
||||
DontShowProgress = false
|
||||
)
|
||||
|
||||
type streamReadWriter interface {
|
||||
WriteStream(filename string, stream streamer.ReadChunkStream, showProgress bool) error
|
||||
ReadStream(filename string, stream streamer.WriteChunkStream, chunksize uint, showProgress bool) error
|
||||
}
|
||||
411
debugd/internal/filetransfer/filetransfer_test.go
Normal file
411
debugd/internal/filetransfer/filetransfer_test.go
Normal file
|
|
@ -0,0 +1,411 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package filetransfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/edgelesssys/constellation/v2/internal/logger"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestSendFiles(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
files *[]FileStat
|
||||
sendErr error
|
||||
readStreamErr error
|
||||
wantHeaders []*pb.FileTransferMessage
|
||||
wantErr bool
|
||||
}{
|
||||
"can send files": {
|
||||
files: &[]FileStat{
|
||||
{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcA",
|
||||
},
|
||||
{
|
||||
TargetPath: "testfileB",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcB",
|
||||
},
|
||||
},
|
||||
wantHeaders: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileB",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: func() *string { s := "somesvcB"; return &s }(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"no files set": {
|
||||
wantErr: true,
|
||||
},
|
||||
"send fails": {
|
||||
files: &[]FileStat{
|
||||
{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcA",
|
||||
},
|
||||
},
|
||||
sendErr: errors.New("send failed"),
|
||||
wantErr: true,
|
||||
},
|
||||
"read stream fails": {
|
||||
files: &[]FileStat{
|
||||
{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcA",
|
||||
},
|
||||
},
|
||||
readStreamErr: errors.New("read stream failed"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr}
|
||||
stream := &stubSendFilesStream{sendErr: tc.sendErr}
|
||||
transfer := New(logger.NewTest(t), streamer, false)
|
||||
if tc.files != nil {
|
||||
transfer.SetFiles(*tc.files)
|
||||
}
|
||||
err := transfer.SendFiles(stream)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantHeaders, stream.msgs)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecvFiles(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
msgs []*pb.FileTransferMessage
|
||||
recvAlreadyStarted bool
|
||||
recvAlreadyFinished bool
|
||||
recvErr error
|
||||
writeStreamErr error
|
||||
wantFiles []FileStat
|
||||
wantErr bool
|
||||
}{
|
||||
"can recv files": {
|
||||
msgs: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
|
||||
},
|
||||
},
|
||||
},
|
||||
// Chunk messages left out since they would be consumed by the streamReadWriter
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileB",
|
||||
Mode: 0o644,
|
||||
},
|
||||
},
|
||||
},
|
||||
// Chunk messages left out since they would be consumed by the streamReadWriter
|
||||
},
|
||||
wantFiles: []FileStat{
|
||||
{
|
||||
SourcePath: "testfileA",
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcA",
|
||||
},
|
||||
{
|
||||
SourcePath: "testfileB",
|
||||
TargetPath: "testfileB",
|
||||
Mode: 0o644,
|
||||
},
|
||||
},
|
||||
},
|
||||
"no messages": {},
|
||||
"recv fails": {
|
||||
recvErr: errors.New("recv failed"),
|
||||
wantErr: true,
|
||||
},
|
||||
"first recv does not yield file header": {
|
||||
msgs: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Chunk{},
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
"write stream fails": {
|
||||
msgs: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
|
||||
},
|
||||
},
|
||||
},
|
||||
// Chunk messages left out since they would be consumed by the streamReadWriter
|
||||
},
|
||||
writeStreamErr: errors.New("write stream failed"),
|
||||
wantErr: true,
|
||||
},
|
||||
"recv has already started": {
|
||||
msgs: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
|
||||
},
|
||||
},
|
||||
},
|
||||
// Chunk messages left out since they would be consumed by the streamReadWriter
|
||||
},
|
||||
recvAlreadyStarted: true,
|
||||
wantErr: true,
|
||||
},
|
||||
"recv has already finished": {
|
||||
msgs: []*pb.FileTransferMessage{
|
||||
{
|
||||
Kind: &pb.FileTransferMessage_Header{
|
||||
Header: &pb.FileTransferHeader{
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
|
||||
},
|
||||
},
|
||||
},
|
||||
// Chunk messages left out since they would be consumed by the streamReadWriter
|
||||
},
|
||||
recvAlreadyFinished: true,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
streamer := &stubStreamReadWriter{writeStreamErr: tc.writeStreamErr}
|
||||
stream := &fakeRecvFilesStream{msgs: tc.msgs, recvErr: tc.recvErr}
|
||||
transfer := New(logger.NewTest(t), streamer, false)
|
||||
if tc.recvAlreadyStarted {
|
||||
transfer.receiveStarted = true
|
||||
}
|
||||
if tc.recvAlreadyFinished {
|
||||
transfer.receiveFinished = true
|
||||
}
|
||||
err := transfer.RecvFiles(stream)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantFiles, transfer.files)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSetFiles(t *testing.T) {
|
||||
testCases := map[string]struct {
|
||||
setFiles *[]FileStat
|
||||
wantFiles []FileStat
|
||||
wantErr bool
|
||||
}{
|
||||
"no files": {
|
||||
wantFiles: []FileStat{},
|
||||
},
|
||||
"files": {
|
||||
setFiles: &[]FileStat{
|
||||
{
|
||||
SourcePath: "testfileA",
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcA",
|
||||
},
|
||||
},
|
||||
wantFiles: []FileStat{
|
||||
{
|
||||
SourcePath: "testfileA",
|
||||
TargetPath: "testfileA",
|
||||
Mode: 0o644,
|
||||
OverrideServiceUnit: "somesvcA",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
streamer := &dummyStreamReadWriter{}
|
||||
transfer := New(logger.NewTest(t), streamer, false)
|
||||
if tc.setFiles != nil {
|
||||
transfer.SetFiles(*tc.setFiles)
|
||||
}
|
||||
gotFiles := transfer.GetFiles()
|
||||
assert.Equal(tc.wantFiles, gotFiles)
|
||||
assert.Equal(tc.setFiles != nil, transfer.receiveFinished)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanSend(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
|
||||
streamer := &stubStreamReadWriter{}
|
||||
stream := &stubRecvFilesStream{recvErr: io.EOF}
|
||||
transfer := New(logger.NewTest(t), streamer, false)
|
||||
assert.False(transfer.CanSend())
|
||||
|
||||
// manual set
|
||||
transfer.SetFiles(nil)
|
||||
assert.True(transfer.CanSend())
|
||||
|
||||
// reset
|
||||
transfer.receiveStarted = false
|
||||
transfer.receiveFinished = false
|
||||
transfer.files = nil
|
||||
assert.False(transfer.CanSend())
|
||||
|
||||
// receive files (empty)
|
||||
assert.NoError(transfer.RecvFiles(stream))
|
||||
assert.True(transfer.CanSend())
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false)
|
||||
|
||||
sendFiles := func() {
|
||||
_ = ft.SendFiles(&stubSendFilesStream{})
|
||||
}
|
||||
|
||||
recvFiles := func() {
|
||||
_ = ft.RecvFiles(&stubRecvFilesStream{})
|
||||
}
|
||||
|
||||
getFiles := func() {
|
||||
_ = ft.GetFiles()
|
||||
}
|
||||
|
||||
setFiles := func() {
|
||||
ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}})
|
||||
}
|
||||
|
||||
canSend := func() {
|
||||
_ = ft.CanSend()
|
||||
}
|
||||
|
||||
go sendFiles()
|
||||
go sendFiles()
|
||||
go sendFiles()
|
||||
go sendFiles()
|
||||
go recvFiles()
|
||||
go recvFiles()
|
||||
go recvFiles()
|
||||
go recvFiles()
|
||||
go getFiles()
|
||||
go getFiles()
|
||||
go getFiles()
|
||||
go getFiles()
|
||||
go setFiles()
|
||||
go setFiles()
|
||||
go setFiles()
|
||||
go setFiles()
|
||||
go canSend()
|
||||
go canSend()
|
||||
go canSend()
|
||||
go canSend()
|
||||
}
|
||||
|
||||
type stubStreamReadWriter struct {
|
||||
readStreamFilename string
|
||||
readStreamErr error
|
||||
|
||||
writeStreamFilename string
|
||||
writeStreamErr error
|
||||
}
|
||||
|
||||
func (s *stubStreamReadWriter) ReadStream(filename string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
|
||||
s.readStreamFilename = filename
|
||||
return s.readStreamErr
|
||||
}
|
||||
|
||||
func (s *stubStreamReadWriter) WriteStream(filename string, _ streamer.ReadChunkStream, _ bool) error {
|
||||
s.writeStreamFilename = filename
|
||||
return s.writeStreamErr
|
||||
}
|
||||
|
||||
type fakeRecvFilesStream struct {
|
||||
msgs []*pb.FileTransferMessage
|
||||
pos int
|
||||
recvErr error
|
||||
}
|
||||
|
||||
func (s *fakeRecvFilesStream) Recv() (*pb.FileTransferMessage, error) {
|
||||
if s.recvErr != nil {
|
||||
return nil, s.recvErr
|
||||
}
|
||||
|
||||
if s.pos < len(s.msgs) {
|
||||
s.pos++
|
||||
return s.msgs[s.pos-1], nil
|
||||
}
|
||||
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
type dummyStreamReadWriter struct{}
|
||||
|
||||
func (s *dummyStreamReadWriter) ReadStream(_ string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
|
||||
panic("dummy")
|
||||
}
|
||||
|
||||
func (s *dummyStreamReadWriter) WriteStream(_ string, _ streamer.ReadChunkStream, _ bool) error {
|
||||
panic("dummy")
|
||||
}
|
||||
146
debugd/internal/filetransfer/streamer/streamer.go
Normal file
146
debugd/internal/filetransfer/streamer/streamer.go
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package streamer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/schollz/progressbar/v3"
|
||||
"github.com/spf13/afero"
|
||||
)
|
||||
|
||||
// FileStreamer handles reading and writing of a file using a stream of chunks.
|
||||
type FileStreamer struct {
|
||||
fs afero.Fs
|
||||
mux sync.RWMutex
|
||||
}
|
||||
|
||||
// ReadChunkStream is abstraction over a gRPC stream that allows us to receive chunks via gRPC.
|
||||
type ReadChunkStream interface {
|
||||
Recv() (*pb.Chunk, error)
|
||||
}
|
||||
|
||||
// WriteChunkStream is abstraction over a gRPC stream that allows us to send chunks via gRPC.
|
||||
type WriteChunkStream interface {
|
||||
Send(chunk *pb.Chunk) error
|
||||
}
|
||||
|
||||
// New creates a new FileStreamer.
|
||||
func New(fs afero.Fs) *FileStreamer {
|
||||
return &FileStreamer{
|
||||
fs: fs,
|
||||
mux: sync.RWMutex{},
|
||||
}
|
||||
}
|
||||
|
||||
// WriteStream opens a file to write to and streams chunks from a gRPC stream into the file.
|
||||
func (f *FileStreamer) WriteStream(filename string, stream ReadChunkStream, showProgress bool) error {
|
||||
f.mux.Lock()
|
||||
defer f.mux.Unlock()
|
||||
file, err := f.fs.OpenFile(filename, os.O_WRONLY|os.O_CREATE, os.ModePerm)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %v for writing: %w", filename, err)
|
||||
}
|
||||
defer file.Close()
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("performing stat on %v to get the file size: %w", filename, err)
|
||||
}
|
||||
|
||||
var bar *progressbar.ProgressBar
|
||||
if showProgress {
|
||||
bar = newProgressBar(stat.Size())
|
||||
defer bar.Close()
|
||||
}
|
||||
|
||||
return writeInner(file, stream, bar)
|
||||
}
|
||||
|
||||
// ReadStream opens a file to read from and streams its contents chunkwise over gRPC.
|
||||
func (f *FileStreamer) ReadStream(filename string, stream WriteChunkStream, chunksize uint, showProgress bool) error {
|
||||
if chunksize == 0 {
|
||||
return errors.New("invalid chunksize")
|
||||
}
|
||||
// fail if file is currently RW locked
|
||||
if f.mux.TryRLock() {
|
||||
defer f.mux.RUnlock()
|
||||
} else {
|
||||
return errors.New("a file is opened for writing so cannot read at this time")
|
||||
}
|
||||
file, err := f.fs.OpenFile(filename, os.O_RDONLY, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %v for reading: %w", filename, err)
|
||||
}
|
||||
defer file.Close()
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("performing stat on %v to get the file size: %w", filename, err)
|
||||
}
|
||||
|
||||
var bar *progressbar.ProgressBar
|
||||
if showProgress {
|
||||
bar = newProgressBar(stat.Size())
|
||||
defer bar.Close()
|
||||
}
|
||||
|
||||
return readInner(file, stream, chunksize, bar)
|
||||
}
|
||||
|
||||
// readInner reads from a an io.Reader and sends chunks over a gRPC stream.
|
||||
func readInner(fp io.Reader, stream WriteChunkStream, chunksize uint, bar *progressbar.ProgressBar) error {
|
||||
buf := make([]byte, chunksize)
|
||||
for {
|
||||
n, readErr := fp.Read(buf)
|
||||
isLast := errors.Is(readErr, io.EOF)
|
||||
if readErr != nil && !isLast {
|
||||
return fmt.Errorf("reading file chunk: %w", readErr)
|
||||
}
|
||||
if err := stream.Send(&pb.Chunk{Content: buf[:n], Last: isLast}); err != nil {
|
||||
return fmt.Errorf("sending chunk: %w", err)
|
||||
}
|
||||
if bar != nil {
|
||||
_ = bar.Add(n)
|
||||
}
|
||||
if isLast {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// writeInner writes chunks from a gRPC stream to an io.Writer.
|
||||
func writeInner(fp io.Writer, stream ReadChunkStream, bar *progressbar.ProgressBar) error {
|
||||
for {
|
||||
chunk, recvErr := stream.Recv()
|
||||
if recvErr != nil {
|
||||
return fmt.Errorf("reading stream: %w", recvErr)
|
||||
}
|
||||
if _, err := fp.Write(chunk.Content); err != nil {
|
||||
return fmt.Errorf("writing chunk to disk: %w", err)
|
||||
}
|
||||
if bar != nil {
|
||||
_ = bar.Add(len(chunk.Content))
|
||||
}
|
||||
if chunk.Last {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newProgressBar creates a new progress bar.
|
||||
func newProgressBar(size int64) *progressbar.ProgressBar {
|
||||
return progressbar.NewOptions64(
|
||||
size,
|
||||
progressbar.OptionSetDescription("transferring file"),
|
||||
progressbar.OptionShowBytes(true),
|
||||
progressbar.OptionClearOnFinish(),
|
||||
)
|
||||
}
|
||||
217
debugd/internal/filetransfer/streamer/streamer_test.go
Normal file
217
debugd/internal/filetransfer/streamer/streamer_test.go
Normal file
|
|
@ -0,0 +1,217 @@
|
|||
/*
|
||||
Copyright (c) Edgeless Systems GmbH
|
||||
|
||||
SPDX-License-Identifier: AGPL-3.0-only
|
||||
*/
|
||||
|
||||
package streamer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
pb "github.com/edgelesssys/constellation/v2/debugd/service"
|
||||
"github.com/spf13/afero"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestWriteStream(t *testing.T) {
|
||||
filename := "testfile"
|
||||
|
||||
testCases := map[string]struct {
|
||||
readChunkStream fakeReadChunkStream
|
||||
fs afero.Fs
|
||||
showProgress bool
|
||||
wantFile []byte
|
||||
wantErr bool
|
||||
}{
|
||||
"stream works": {
|
||||
readChunkStream: fakeReadChunkStream{
|
||||
chunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
},
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantFile: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
"chunking works": {
|
||||
readChunkStream: fakeReadChunkStream{
|
||||
chunks: [][]byte{
|
||||
[]byte("te"),
|
||||
[]byte("st"),
|
||||
},
|
||||
},
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantFile: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
"showProgress works": {
|
||||
readChunkStream: fakeReadChunkStream{
|
||||
chunks: [][]byte{
|
||||
[]byte("test"),
|
||||
},
|
||||
},
|
||||
fs: afero.NewMemMapFs(),
|
||||
showProgress: true,
|
||||
wantFile: []byte("test"),
|
||||
wantErr: false,
|
||||
},
|
||||
"Open fails": {
|
||||
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
|
||||
wantErr: true,
|
||||
},
|
||||
"recv fails": {
|
||||
readChunkStream: fakeReadChunkStream{
|
||||
recvErr: errors.New("someErr"),
|
||||
},
|
||||
fs: afero.NewMemMapFs(),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
writer := New(tc.fs)
|
||||
err := writer.WriteStream(filename, &tc.readChunkStream, tc.showProgress)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
fileContents, err := afero.ReadFile(tc.fs, filename)
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantFile, fileContents)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadStream(t *testing.T) {
|
||||
correctFilename := "testfile"
|
||||
eof := []byte{}
|
||||
|
||||
testCases := map[string]struct {
|
||||
writeChunkStream stubWriteChunkStream
|
||||
filename string
|
||||
chunksize uint
|
||||
showProgress bool
|
||||
wantChunks [][]byte
|
||||
wantErr bool
|
||||
}{
|
||||
"stream works": {
|
||||
writeChunkStream: stubWriteChunkStream{},
|
||||
filename: correctFilename,
|
||||
chunksize: 4,
|
||||
wantChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
eof,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
"chunking works": {
|
||||
writeChunkStream: stubWriteChunkStream{},
|
||||
filename: correctFilename,
|
||||
chunksize: 2,
|
||||
wantChunks: [][]byte{
|
||||
[]byte("te"),
|
||||
[]byte("st"),
|
||||
eof,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
"chunksize of 0 detected": {
|
||||
writeChunkStream: stubWriteChunkStream{},
|
||||
filename: correctFilename,
|
||||
chunksize: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
"showProgress works": {
|
||||
writeChunkStream: stubWriteChunkStream{},
|
||||
filename: correctFilename,
|
||||
chunksize: 4,
|
||||
showProgress: true,
|
||||
wantChunks: [][]byte{
|
||||
[]byte("test"),
|
||||
eof,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
"Open fails": {
|
||||
filename: "incorrect-filename",
|
||||
chunksize: 4,
|
||||
wantErr: true,
|
||||
},
|
||||
"send fails": {
|
||||
writeChunkStream: stubWriteChunkStream{
|
||||
sendErr: errors.New("someErr"),
|
||||
},
|
||||
filename: correctFilename,
|
||||
chunksize: 4,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, tc := range testCases {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
require := require.New(t)
|
||||
|
||||
fs := afero.NewMemMapFs()
|
||||
assert.NoError(afero.WriteFile(fs, correctFilename, []byte("test"), 0o755))
|
||||
reader := New(fs)
|
||||
err := reader.ReadStream(tc.filename, &tc.writeChunkStream, tc.chunksize, tc.showProgress)
|
||||
|
||||
if tc.wantErr {
|
||||
assert.Error(err)
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.Equal(tc.wantChunks, tc.writeChunkStream.chunks)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type fakeReadChunkStream struct {
|
||||
chunks [][]byte
|
||||
pos int
|
||||
recvErr error
|
||||
}
|
||||
|
||||
func (s *fakeReadChunkStream) Recv() (*pb.Chunk, error) {
|
||||
if s.recvErr != nil {
|
||||
return nil, s.recvErr
|
||||
}
|
||||
|
||||
isLastChunk := s.pos == len(s.chunks)-1
|
||||
|
||||
if s.pos < len(s.chunks) {
|
||||
result := &pb.Chunk{Content: s.chunks[s.pos], Last: isLastChunk}
|
||||
s.pos++
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return nil, io.EOF
|
||||
}
|
||||
|
||||
type stubWriteChunkStream struct {
|
||||
chunks [][]byte
|
||||
sendErr error
|
||||
}
|
||||
|
||||
func (s *stubWriteChunkStream) Send(chunk *pb.Chunk) error {
|
||||
cpy := make([]byte, len(chunk.Content))
|
||||
copy(cpy, chunk.Content)
|
||||
s.chunks = append(s.chunks, cpy)
|
||||
return s.sendErr
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue