constellation/debugd/internal/filetransfer/filetransfer_test.go
2023-09-19 16:10:22 +02:00

399 lines
9.3 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package filetransfer
import (
"errors"
"io"
"testing"
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
pb "github.com/edgelesssys/constellation/v2/debugd/service"
"github.com/edgelesssys/constellation/v2/internal/logger"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestSendFiles(t *testing.T) {
testCases := map[string]struct {
files *[]FileStat
receiveFinished bool
sendErr error
readStreamErr error
wantHeaders []*pb.FileTransferMessage
wantErr bool
}{
"can send files": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
{
TargetPath: "testfileB",
Mode: 0o644,
OverrideServiceUnit: "somesvcB",
},
},
receiveFinished: true,
wantHeaders: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileB",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcB"; return &s }(),
},
},
},
},
},
"not finished receiving": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
{
TargetPath: "testfileB",
Mode: 0o644,
OverrideServiceUnit: "somesvcB",
},
},
receiveFinished: false,
wantErr: true,
},
"send fails": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
receiveFinished: true,
sendErr: errors.New("send failed"),
wantErr: true,
},
"read stream fails": {
files: &[]FileStat{
{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
receiveFinished: true,
readStreamErr: errors.New("read stream failed"),
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr}
stream := &stubSendFilesStream{sendErr: tc.sendErr}
transfer := &FileTransferer{
log: logger.NewTest(t),
streamer: streamer,
showProgress: false,
}
if tc.files != nil {
transfer.files = *tc.files
}
transfer.receiveFinished.Store(tc.receiveFinished)
err := transfer.SendFiles(stream)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantHeaders, stream.msgs)
})
}
}
func TestRecvFiles(t *testing.T) {
testCases := map[string]struct {
msgs []*pb.FileTransferMessage
recvAlreadyStarted bool
recvAlreadyFinished bool
recvErr error
writeStreamErr error
wantFiles []FileStat
wantErr bool
}{
"can recv files": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileB",
Mode: 0o644,
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
wantFiles: []FileStat{
{
SourcePath: "testfileA",
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
{
SourcePath: "testfileB",
TargetPath: "testfileB",
Mode: 0o644,
},
},
},
"no messages": {},
"recv fails": {
recvErr: errors.New("recv failed"),
wantErr: true,
},
"first recv does not yield file header": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Chunk{},
},
},
wantErr: true,
},
"write stream fails": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
writeStreamErr: errors.New("write stream failed"),
wantErr: true,
},
"recv has already started": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
recvAlreadyStarted: true,
wantErr: true,
},
"recv has already finished": {
msgs: []*pb.FileTransferMessage{
{
Kind: &pb.FileTransferMessage_Header{
Header: &pb.FileTransferHeader{
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
},
},
},
// Chunk messages left out since they would be consumed by the streamReadWriter
},
recvAlreadyFinished: true,
wantErr: true,
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
streamer := &stubStreamReadWriter{writeStreamErr: tc.writeStreamErr}
stream := &fakeRecvFilesStream{msgs: tc.msgs, recvErr: tc.recvErr}
transfer := New(logger.NewTest(t), streamer, false)
if tc.recvAlreadyStarted {
transfer.receiveStarted = true
}
if tc.recvAlreadyFinished {
transfer.receiveFinished.Store(true)
}
err := transfer.RecvFiles(stream)
if tc.wantErr {
assert.Error(err)
return
}
require.NoError(err)
assert.Equal(tc.wantFiles, transfer.files)
})
}
}
func TestGetSetFiles(t *testing.T) {
testCases := map[string]struct {
setFiles *[]FileStat
wantFiles []FileStat
wantErr bool
}{
"no files": {
wantFiles: []FileStat{},
},
"files": {
setFiles: &[]FileStat{
{
SourcePath: "testfileA",
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
wantFiles: []FileStat{
{
SourcePath: "testfileA",
TargetPath: "testfileA",
Mode: 0o644,
OverrideServiceUnit: "somesvcA",
},
},
},
}
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
streamer := &dummyStreamReadWriter{}
transfer := New(logger.NewTest(t), streamer, false)
if tc.setFiles != nil {
transfer.SetFiles(*tc.setFiles)
}
gotFiles := transfer.GetFiles()
assert.Equal(tc.wantFiles, gotFiles)
assert.Equal(tc.setFiles != nil, transfer.receiveFinished.Load())
})
}
}
func TestConcurrency(t *testing.T) {
ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false)
sendFiles := func() {
_ = ft.SendFiles(&stubSendFilesStream{})
}
recvFiles := func() {
_ = ft.RecvFiles(&stubRecvFilesStream{})
}
getFiles := func() {
_ = ft.GetFiles()
}
setFiles := func() {
ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}})
}
go sendFiles()
go sendFiles()
go sendFiles()
go sendFiles()
go recvFiles()
go recvFiles()
go recvFiles()
go recvFiles()
go getFiles()
go getFiles()
go getFiles()
go getFiles()
go setFiles()
go setFiles()
go setFiles()
go setFiles()
}
type stubStreamReadWriter struct {
readStreamErr error
writeStreamErr error
}
func (s *stubStreamReadWriter) ReadStream(_ string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
return s.readStreamErr
}
func (s *stubStreamReadWriter) WriteStream(_ string, _ streamer.ReadChunkStream, _ bool) error {
return s.writeStreamErr
}
type fakeRecvFilesStream struct {
msgs []*pb.FileTransferMessage
pos int
recvErr error
}
func (s *fakeRecvFilesStream) Recv() (*pb.FileTransferMessage, error) {
if s.recvErr != nil {
return nil, s.recvErr
}
if s.pos < len(s.msgs) {
s.pos++
return s.msgs[s.pos-1], nil
}
return nil, io.EOF
}
type dummyStreamReadWriter struct{}
func (s *dummyStreamReadWriter) ReadStream(_ string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
panic("dummy")
}
func (s *dummyStreamReadWriter) WriteStream(_ string, _ streamer.ReadChunkStream, _ bool) error {
panic("dummy")
}