debugd: implement upload of multiple binaries

This commit is contained in:
Malte Poll 2023-01-20 10:11:41 +01:00 committed by Malte Poll
parent e6ac8e2a91
commit 6f56ed69f8
21 changed files with 2040 additions and 661 deletions

View file

@ -0,0 +1,47 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package filetransfer
import (
"errors"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
)
// recvChunkStream is a wrapper around a RecvFilesStream that only returns chunks.
type recvChunkStream struct {
stream RecvFilesStream
}
// Recv receives a FileTransferMessage and returns the chunk.
func (s *recvChunkStream) Recv() (*pb.Chunk, error) {
msg, err := s.stream.Recv()
if err != nil {
return nil, err
}
chunk := msg.GetChunk()
if chunk == nil {
return nil, errors.New("expected chunk")
}
return chunk, nil
}
// sendChunkStream is a wrapper around a SendFilesStream that wraps chunks for every message.
type sendChunkStream struct {
stream SendFilesStream
}
// Send wraps the given chunk in a FileTransferMessage and sends it.
func (s *sendChunkStream) Send(chunk *pb.Chunk) error {
chunkMessage := &pb.FileTransferMessage_Chunk{
Chunk: chunk,
}
message := &pb.FileTransferMessage{
Kind: chunkMessage,
}
return s.stream.Send(message)
}

View file

@ -0,0 +1,134 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package filetransfer
import (
"errors"
"testing"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRecv(t *testing.T) {
testCases := map[string]struct {
stream stubRecvFilesStream
wantChunk *pb.Chunk
wantErr bool
}{
"chunk is received": {
stream: stubRecvFilesStream{
msg: &pb.FileTransferMessage{
Kind: &pb.FileTransferMessage_Chunk{
Chunk: &pb.Chunk{
Content: []byte("test"),
},
},
},
},
wantChunk: &pb.Chunk{
Content: []byte("test"),
},
},
"wrong type": {
stream: stubRecvFilesStream{
msg: &pb.FileTransferMessage{
Kind: &pb.FileTransferMessage_Header{},
},
},
wantErr: true,
},
"empty msg": {
stream: stubRecvFilesStream{},
wantErr: true,
},
"recv fails": {
stream: stubRecvFilesStream{
recvErr: errors.New("someErr"),
},
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
stream := recvChunkStream{stream: &tc.stream}
chunk, err := stream.Recv()
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantChunk, chunk)
})
}
}
func TestSend(t *testing.T) {
testCases := map[string]struct {
stream stubSendFilesStream
chunk *pb.Chunk
wantMsgs []*pb.FileTransferMessage
wantErr bool
}{
"chunk is wrapped correctly": {
chunk: &pb.Chunk{
Content: []byte("test"),
},
wantMsgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Chunk{
Chunk: &pb.Chunk{
Content: []byte("test"),
},
},
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
stream := sendChunkStream{stream: &tc.stream}
err := stream.Send(tc.chunk)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.EqualValues(tc.wantMsgs, tc.stream.msgs)
})
}
}
type stubRecvFilesStream struct {
msg *pb.FileTransferMessage
recvErr error
}
func (s *stubRecvFilesStream) Recv() (*pb.FileTransferMessage, error) {
return s.msg, s.recvErr
}
type stubSendFilesStream struct {
msgs []*pb.FileTransferMessage
sendErr error
}
func (s *stubSendFilesStream) Send(msg *pb.FileTransferMessage) error {
s.msgs = append(s.msgs, msg)
return s.sendErr
}

View file

@ -0,0 +1,233 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package filetransfer
import (
"errors"
"io"
"io/fs"
"sync"
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/edgelesssys/constellation/v2/internal/logger"
"go.uber.org/zap"
)
// RecvFilesStream is a stream that receives FileTransferMessages.
type RecvFilesStream interface {
Recv() (*pb.FileTransferMessage, error)
}
// SendFilesStream is a stream that sends FileTransferMessages.
type SendFilesStream interface {
Send(*pb.FileTransferMessage) error
}
// FileTransferer manages sending and receiving of files.
type FileTransferer struct {
mux sync.RWMutex
log *logger.Logger
receiveStarted bool
receiveFinished bool
files []FileStat
streamer streamReadWriter
showProgress bool
}
// New creates a new FileTransferer.
func New(log *logger.Logger, streamer streamReadWriter, showProgress bool) *FileTransferer {
return &FileTransferer{
log: log,
streamer: streamer,
showProgress: showProgress,
}
}
// SendFiles sends files to the given stream.
func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
s.mux.RLock()
defer s.mux.RUnlock()
if !s.receiveFinished {
return errors.New("cannot send files before receiving them")
}
for _, file := range s.files {
if err := s.handleFileSend(stream, file); err != nil {
return err
}
}
return nil
}
// RecvFiles receives files from the given stream.
func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
if err := s.startRecv(); err != nil {
return err
}
defer func() {
if err != nil {
s.abortRecv()
} else {
s.finishRecv()
}
}()
var done bool
for !done && err == nil {
done, err = s.handleFileRecv(stream)
}
return err
}
// GetFiles returns the a copy of the list of files that have been received.
func (s *FileTransferer) GetFiles() []FileStat {
s.mux.RLock()
defer s.mux.RUnlock()
res := make([]FileStat, len(s.files))
copy(res, s.files)
return res
}
// SetFiles sets the list of files that can be sent.
func (s *FileTransferer) SetFiles(files []FileStat) {
s.mux.Lock()
defer s.mux.Unlock()
res := make([]FileStat, len(files))
copy(res, files)
s.files = res
s.receiveFinished = true
}
// CanSend returns true if the file receive has finished.
// This is called to determine if a debugd instance can request files from this server.
func (s *FileTransferer) CanSend() bool {
s.mux.RLock()
defer s.mux.RUnlock()
ret := s.receiveFinished
return ret
}
func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error {
header := &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: file.TargetPath,
Mode: uint32(file.Mode),
},
}
if file.OverrideServiceUnit != "" {
header.Header.OverrideServiceUnit = &file.OverrideServiceUnit
}
if err := stream.Send(&pb.FileTransferMessage{Kind: header}); err != nil {
return err
}
sendChunkStream := &sendChunkStream{stream: stream}
return s.streamer.ReadStream(file.SourcePath, sendChunkStream, debugd.Chunksize, s.showProgress)
}
// handleFileRecv handles the file receive of a single file.
// It returns true if the stream is finished (all of the file consumed) and false otherwise.
func (s *FileTransferer) handleFileRecv(stream RecvFilesStream) (bool, error) {
// first message must be a header message
msg, err := stream.Recv()
switch {
case err == nil:
// nop
case errors.Is(err, io.EOF):
return true, nil // stream is finished
default:
return false, err
}
header := msg.GetHeader()
if header == nil {
return false, errors.New("first message must be a header message")
}
s.log.Infof("Starting file receive of %q", header.TargetPath)
s.addFile(FileStat{
SourcePath: header.TargetPath,
TargetPath: header.TargetPath,
Mode: fs.FileMode(header.Mode),
OverrideServiceUnit: func() string {
if header.OverrideServiceUnit != nil {
return *header.OverrideServiceUnit
}
return ""
}(),
})
recvChunkStream := &recvChunkStream{stream: stream}
if err := s.streamer.WriteStream(header.TargetPath, recvChunkStream, s.showProgress); err != nil {
s.log.With(zap.Error(err)).Errorf("Receive of file %q failed", header.TargetPath)
return false, err
}
s.log.Infof("Finished file receive of %q", header.TargetPath)
return false, nil
}
// startRecv marks the file receive as started. It returns an error if receiving has already started.
func (s *FileTransferer) startRecv() error {
s.mux.Lock()
defer s.mux.Unlock()
switch {
case s.receiveFinished:
return ErrReceiveFinished
case s.receiveStarted:
return ErrReceiveRunning
}
s.receiveStarted = true
return nil
}
// abortRecv marks the file receive as failed.
// This allows for a retry of the file receive.
func (s *FileTransferer) abortRecv() {
s.mux.Lock()
defer s.mux.Unlock()
s.receiveStarted = false
s.files = nil
}
// finishRecv marks the file receive as completed.
// This allows other debugd instances to request files from this server.
func (s *FileTransferer) finishRecv() {
s.mux.Lock()
defer s.mux.Unlock()
s.receiveStarted = false
s.receiveFinished = true
}
// addFile adds a file to the list of received files.
func (s *FileTransferer) addFile(file FileStat) {
s.mux.Lock()
defer s.mux.Unlock()
s.files = append(s.files, file)
}
// FileStat contains the metadata of a file that can be up/downloaded.
type FileStat struct {
SourcePath string
TargetPath string
Mode fs.FileMode
OverrideServiceUnit string // optional name of the service unit to override
}
var (
// ErrReceiveRunning is returned if a file receive is already running.
ErrReceiveRunning = errors.New("receive already running")
// ErrReceiveFinished is returned if a file receive has already finished.
ErrReceiveFinished = errors.New("receive finished")
)
const (
// ShowProgress indicates that progress should be shown.
ShowProgress = true
// DontShowProgress indicates that progress should not be shown.
DontShowProgress = false
)
type streamReadWriter interface {
WriteStream(filename string, stream streamer.ReadChunkStream, showProgress bool) error
ReadStream(filename string, stream streamer.WriteChunkStream, chunksize uint, showProgress bool) error
}

View file

@ -0,0 +1,411 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package filetransfer
import (
"errors"
"io"
"testing"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestSendFiles(t *testing.T) {
testCases := map[string]struct {
files *[]FileStat
sendErr error
readStreamErr error
wantHeaders []*pb.FileTransferMessage
wantErr bool
}{
"can send files": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
{
TargetPath: "testfileB",
Mode: 0o644,
OverrideServiceUnit: "somesvcB",
},
},
wantHeaders: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileB",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcB"; return &s }(),
},
},
},
},
},
"no files set": {
wantErr: true,
},
"send fails": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
sendErr: errors.New("send failed"),
wantErr: true,
},
"read stream fails": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
readStreamErr: errors.New("read stream failed"),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr}
stream := &stubSendFilesStream{sendErr: tc.sendErr}
transfer := New(logger.NewTest(t), streamer, false)
if tc.files != nil {
transfer.SetFiles(*tc.files)
}
err := transfer.SendFiles(stream)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantHeaders, stream.msgs)
})
}
}
func TestRecvFiles(t *testing.T) {
testCases := map[string]struct {
msgs []*pb.FileTransferMessage
recvAlreadyStarted bool
recvAlreadyFinished bool
recvErr error
writeStreamErr error
wantFiles []FileStat
wantErr bool
}{
"can recv files": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileB",
Mode: 0o644,
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
wantFiles: []FileStat{
{
SourcePath: "testfileA",
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
{
SourcePath: "testfileB",
TargetPath: "testfileB",
Mode: 0o644,
},
},
},
"no messages": {},
"recv fails": {
recvErr: errors.New("recv failed"),
wantErr: true,
},
"first recv does not yield file header": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Chunk{},
},
},
wantErr: true,
},
"write stream fails": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
writeStreamErr: errors.New("write stream failed"),
wantErr: true,
},
"recv has already started": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
recvAlreadyStarted: true,
wantErr: true,
},
"recv has already finished": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
recvAlreadyFinished: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
streamer := &stubStreamReadWriter{writeStreamErr: tc.writeStreamErr}
stream := &fakeRecvFilesStream{msgs: tc.msgs, recvErr: tc.recvErr}
transfer := New(logger.NewTest(t), streamer, false)
if tc.recvAlreadyStarted {
transfer.receiveStarted = true
}
if tc.recvAlreadyFinished {
transfer.receiveFinished = true
}
err := transfer.RecvFiles(stream)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantFiles, transfer.files)
})
}
}
func TestGetSetFiles(t *testing.T) {
testCases := map[string]struct {
setFiles *[]FileStat
wantFiles []FileStat
wantErr bool
}{
"no files": {
wantFiles: []FileStat{},
},
"files": {
setFiles: &[]FileStat{
{
SourcePath: "testfileA",
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
wantFiles: []FileStat{
{
SourcePath: "testfileA",
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
streamer := &dummyStreamReadWriter{}
transfer := New(logger.NewTest(t), streamer, false)
if tc.setFiles != nil {
transfer.SetFiles(*tc.setFiles)
}
gotFiles := transfer.GetFiles()
assert.Equal(tc.wantFiles, gotFiles)
assert.Equal(tc.setFiles != nil, transfer.receiveFinished)
})
}
}
func TestCanSend(t *testing.T) {
assert := assert.New(t)
streamer := &stubStreamReadWriter{}
stream := &stubRecvFilesStream{recvErr: io.EOF}
transfer := New(logger.NewTest(t), streamer, false)
assert.False(transfer.CanSend())
// manual set
transfer.SetFiles(nil)
assert.True(transfer.CanSend())
// reset
transfer.receiveStarted = false
transfer.receiveFinished = false
transfer.files = nil
assert.False(transfer.CanSend())
// receive files (empty)
assert.NoError(transfer.RecvFiles(stream))
assert.True(transfer.CanSend())
}
func TestConcurrency(t *testing.T) {
ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false)
sendFiles := func() {
_ = ft.SendFiles(&stubSendFilesStream{})
}
recvFiles := func() {
_ = ft.RecvFiles(&stubRecvFilesStream{})
}
getFiles := func() {
_ = ft.GetFiles()
}
setFiles := func() {
ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}})
}
canSend := func() {
_ = ft.CanSend()
}
go sendFiles()
go sendFiles()
go sendFiles()
go sendFiles()
go recvFiles()
go recvFiles()
go recvFiles()
go recvFiles()
go getFiles()
go getFiles()
go getFiles()
go getFiles()
go setFiles()
go setFiles()
go setFiles()
go setFiles()
go canSend()
go canSend()
go canSend()
go canSend()
}
type stubStreamReadWriter struct {
readStreamFilename string
readStreamErr error
writeStreamFilename string
writeStreamErr error
}
func (s *stubStreamReadWriter) ReadStream(filename string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
s.readStreamFilename = filename
return s.readStreamErr
}
func (s *stubStreamReadWriter) WriteStream(filename string, _ streamer.ReadChunkStream, _ bool) error {
s.writeStreamFilename = filename
return s.writeStreamErr
}
type fakeRecvFilesStream struct {
msgs []*pb.FileTransferMessage
pos int
recvErr error
}
func (s *fakeRecvFilesStream) Recv() (*pb.FileTransferMessage, error) {
if s.recvErr != nil {
return nil, s.recvErr
}
if s.pos < len(s.msgs) {
s.pos++
return s.msgs[s.pos-1], nil
}
return nil, io.EOF
}
type dummyStreamReadWriter struct{}
func (s *dummyStreamReadWriter) ReadStream(_ string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
panic("dummy")
}
func (s *dummyStreamReadWriter) WriteStream(_ string, _ streamer.ReadChunkStream, _ bool) error {
panic("dummy")
}

View file

@ -0,0 +1,146 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package streamer
import (
"errors"
"fmt"
"io"
"os"
"sync"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/schollz/progressbar/v3"
"github.com/spf13/afero"
)
// FileStreamer handles reading and writing of a file using a stream of chunks.
type FileStreamer struct {
fs afero.Fs
mux sync.RWMutex
}
// ReadChunkStream is abstraction over a gRPC stream that allows us to receive chunks via gRPC.
type ReadChunkStream interface {
Recv() (*pb.Chunk, error)
}
// WriteChunkStream is abstraction over a gRPC stream that allows us to send chunks via gRPC.
type WriteChunkStream interface {
Send(chunk *pb.Chunk) error
}
// New creates a new FileStreamer.
func New(fs afero.Fs) *FileStreamer {
return &FileStreamer{
fs: fs,
mux: sync.RWMutex{},
}
}
// WriteStream opens a file to write to and streams chunks from a gRPC stream into the file.
func (f *FileStreamer) WriteStream(filename string, stream ReadChunkStream, showProgress bool) error {
f.mux.Lock()
defer f.mux.Unlock()
file, err := f.fs.OpenFile(filename, os.O_WRONLY|os.O_CREATE, os.ModePerm)
if err != nil {
return fmt.Errorf("open %v for writing: %w", filename, err)
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("performing stat on %v to get the file size: %w", filename, err)
}
var bar *progressbar.ProgressBar
if showProgress {
bar = newProgressBar(stat.Size())
defer bar.Close()
}
return writeInner(file, stream, bar)
}
// ReadStream opens a file to read from and streams its contents chunkwise over gRPC.
func (f *FileStreamer) ReadStream(filename string, stream WriteChunkStream, chunksize uint, showProgress bool) error {
if chunksize == 0 {
return errors.New("invalid chunksize")
}
// fail if file is currently RW locked
if f.mux.TryRLock() {
defer f.mux.RUnlock()
} else {
return errors.New("a file is opened for writing so cannot read at this time")
}
file, err := f.fs.OpenFile(filename, os.O_RDONLY, 0o755)
if err != nil {
return fmt.Errorf("open %v for reading: %w", filename, err)
}
defer file.Close()
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("performing stat on %v to get the file size: %w", filename, err)
}
var bar *progressbar.ProgressBar
if showProgress {
bar = newProgressBar(stat.Size())
defer bar.Close()
}
return readInner(file, stream, chunksize, bar)
}
// readInner reads from a an io.Reader and sends chunks over a gRPC stream.
func readInner(fp io.Reader, stream WriteChunkStream, chunksize uint, bar *progressbar.ProgressBar) error {
buf := make([]byte, chunksize)
for {
n, readErr := fp.Read(buf)
isLast := errors.Is(readErr, io.EOF)
if readErr != nil && !isLast {
return fmt.Errorf("reading file chunk: %w", readErr)
}
if err := stream.Send(&pb.Chunk{Content: buf[:n], Last: isLast}); err != nil {
return fmt.Errorf("sending chunk: %w", err)
}
if bar != nil {
_ = bar.Add(n)
}
if isLast {
return nil
}
}
}
// writeInner writes chunks from a gRPC stream to an io.Writer.
func writeInner(fp io.Writer, stream ReadChunkStream, bar *progressbar.ProgressBar) error {
for {
chunk, recvErr := stream.Recv()
if recvErr != nil {
return fmt.Errorf("reading stream: %w", recvErr)
}
if _, err := fp.Write(chunk.Content); err != nil {
return fmt.Errorf("writing chunk to disk: %w", err)
}
if bar != nil {
_ = bar.Add(len(chunk.Content))
}
if chunk.Last {
return nil
}
}
}
// newProgressBar creates a new progress bar.
func newProgressBar(size int64) *progressbar.ProgressBar {
return progressbar.NewOptions64(
size,
progressbar.OptionSetDescription("transferring file"),
progressbar.OptionShowBytes(true),
progressbar.OptionClearOnFinish(),
)
}

View file

@ -0,0 +1,217 @@
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package streamer
import (
"errors"
"io"
"testing"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestWriteStream(t *testing.T) {
filename := "testfile"
testCases := map[string]struct {
readChunkStream fakeReadChunkStream
fs afero.Fs
showProgress bool
wantFile []byte
wantErr bool
}{
"stream works": {
readChunkStream: fakeReadChunkStream{
chunks: [][]byte{
[]byte("test"),
},
},
fs: afero.NewMemMapFs(),
wantFile: []byte("test"),
wantErr: false,
},
"chunking works": {
readChunkStream: fakeReadChunkStream{
chunks: [][]byte{
[]byte("te"),
[]byte("st"),
},
},
fs: afero.NewMemMapFs(),
wantFile: []byte("test"),
wantErr: false,
},
"showProgress works": {
readChunkStream: fakeReadChunkStream{
chunks: [][]byte{
[]byte("test"),
},
},
fs: afero.NewMemMapFs(),
showProgress: true,
wantFile: []byte("test"),
wantErr: false,
},
"Open fails": {
fs: afero.NewReadOnlyFs(afero.NewMemMapFs()),
wantErr: true,
},
"recv fails": {
readChunkStream: fakeReadChunkStream{
recvErr: errors.New("someErr"),
},
fs: afero.NewMemMapFs(),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
writer := New(tc.fs)
err := writer.WriteStream(filename, &tc.readChunkStream, tc.showProgress)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
fileContents, err := afero.ReadFile(tc.fs, filename)
require.NoError(err)
assert.Equal(tc.wantFile, fileContents)
})
}
}
func TestReadStream(t *testing.T) {
correctFilename := "testfile"
eof := []byte{}
testCases := map[string]struct {
writeChunkStream stubWriteChunkStream
filename string
chunksize uint
showProgress bool
wantChunks [][]byte
wantErr bool
}{
"stream works": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 4,
wantChunks: [][]byte{
[]byte("test"),
eof,
},
wantErr: false,
},
"chunking works": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 2,
wantChunks: [][]byte{
[]byte("te"),
[]byte("st"),
eof,
},
wantErr: false,
},
"chunksize of 0 detected": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 0,
wantErr: true,
},
"showProgress works": {
writeChunkStream: stubWriteChunkStream{},
filename: correctFilename,
chunksize: 4,
showProgress: true,
wantChunks: [][]byte{
[]byte("test"),
eof,
},
wantErr: false,
},
"Open fails": {
filename: "incorrect-filename",
chunksize: 4,
wantErr: true,
},
"send fails": {
writeChunkStream: stubWriteChunkStream{
sendErr: errors.New("someErr"),
},
filename: correctFilename,
chunksize: 4,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
fs := afero.NewMemMapFs()
assert.NoError(afero.WriteFile(fs, correctFilename, []byte("test"), 0o755))
reader := New(fs)
err := reader.ReadStream(tc.filename, &tc.writeChunkStream, tc.chunksize, tc.showProgress)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantChunks, tc.writeChunkStream.chunks)
})
}
}
type fakeReadChunkStream struct {
chunks [][]byte
pos int
recvErr error
}
func (s *fakeReadChunkStream) Recv() (*pb.Chunk, error) {
if s.recvErr != nil {
return nil, s.recvErr
}
isLastChunk := s.pos == len(s.chunks)-1
if s.pos < len(s.chunks) {
result := &pb.Chunk{Content: s.chunks[s.pos], Last: isLastChunk}
s.pos++
return result, nil
}
return nil, io.EOF
}
type stubWriteChunkStream struct {
chunks [][]byte
sendErr error
}
func (s *stubWriteChunkStream) Send(chunk *pb.Chunk) error {
cpy := make([]byte, len(chunk.Content))
copy(cpy, chunk.Content)
s.chunks = append(s.chunks, cpy)
return s.sendErr
}