mirror of
https://github.com/edgelesssys/constellation.git
synced 2025-02-25 01:10:16 -05:00
debugd: send requests over lb (#2346)
This commit is contained in:
parent
49c37b3969
commit
548bb2dfa6
@ -88,8 +88,8 @@ func (f *Fetcher) DiscoverDebugdIPs(ctx context.Context) ([]string, error) {
|
|||||||
return ips, nil
|
return ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DiscoverLoadbalancerIP gets load balancer IP from metadata API.
|
// DiscoverLoadBalancerIP gets load balancer IP from metadata API.
|
||||||
func (f *Fetcher) DiscoverLoadbalancerIP(ctx context.Context) (string, error) {
|
func (f *Fetcher) DiscoverLoadBalancerIP(ctx context.Context) (string, error) {
|
||||||
lbHost, _, err := f.metaAPI.GetLoadBalancerEndpoint(ctx)
|
lbHost, _, err := f.metaAPI.GetLoadBalancerEndpoint(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("retrieving load balancer endpoint: %w", err)
|
return "", fmt.Errorf("retrieving load balancer endpoint: %w", err)
|
||||||
|
@ -121,7 +121,7 @@ func TestDiscoverDebugIPs(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDiscoverLoadbalancerIP(t *testing.T) {
|
func TestDiscoverLoadBalancerIP(t *testing.T) {
|
||||||
ip := "192.0.2.1"
|
ip := "192.0.2.1"
|
||||||
someErr := errors.New("failed")
|
someErr := errors.New("failed")
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ func TestDiscoverLoadbalancerIP(t *testing.T) {
|
|||||||
metaAPI: tc.metaAPI,
|
metaAPI: tc.metaAPI,
|
||||||
}
|
}
|
||||||
|
|
||||||
ip, err := fetcher.DiscoverLoadbalancerIP(context.Background())
|
ip, err := fetcher.DiscoverLoadBalancerIP(context.Background())
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
assert.Error(err)
|
assert.Error(err)
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
// Fetcher retrieves other debugd IPs from cloud provider metadata.
|
// Fetcher retrieves other debugd IPs from cloud provider metadata.
|
||||||
type Fetcher interface {
|
type Fetcher interface {
|
||||||
DiscoverDebugdIPs(ctx context.Context) ([]string, error)
|
DiscoverDebugdIPs(ctx context.Context) ([]string, error)
|
||||||
|
DiscoverLoadBalancerIP(ctx context.Context) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scheduler schedules fetching of metadata using timers.
|
// Scheduler schedules fetching of metadata using timers.
|
||||||
@ -51,23 +52,35 @@ func (s *Scheduler) Start(ctx context.Context, wg *sync.WaitGroup) {
|
|||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
ips, err := s.fetcher.DiscoverDebugdIPs(ctx)
|
|
||||||
if err != nil {
|
|
||||||
s.log.With(zap.Error(err)).Warnf("Discovering debugd IPs failed")
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances")
|
|
||||||
s.download(ctx, ips)
|
|
||||||
if s.deploymentDone && s.infoDone {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ips, err := s.fetcher.DiscoverDebugdIPs(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.log.With(zap.Error(err)).Warnf("Discovering debugd IPs failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
lbip, err := s.fetcher.DiscoverLoadBalancerIP(ctx)
|
||||||
|
if err != nil {
|
||||||
|
s.log.With(zap.Error(err)).Warnf("Discovering load balancer IP failed")
|
||||||
|
} else {
|
||||||
|
ips = append(ips, lbip)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ips) == 0 {
|
||||||
|
s.log.With(zap.Error(err)).Warnf("No debugd IPs discovered")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
s.log.With(zap.Strings("ips", ips)).Infof("Discovered instances")
|
||||||
|
s.download(ctx, ips)
|
||||||
|
if s.deploymentDone && s.infoDone {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
@ -33,32 +33,47 @@ func TestSchedulerStart(t *testing.T) {
|
|||||||
wantInfoDownloads []string
|
wantInfoDownloads []string
|
||||||
}{
|
}{
|
||||||
"no errors occur": {
|
"no errors occur": {
|
||||||
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}},
|
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}, loadBalancerIP: "192.0.2.3"},
|
||||||
downloader: stubDownloader{},
|
downloader: stubDownloader{},
|
||||||
wantDiscoverCount: 1,
|
wantDiscoverCount: 2,
|
||||||
wantDeploymentDownloads: []string{"192.0.2.1"},
|
wantDeploymentDownloads: []string{"192.0.2.1"},
|
||||||
wantInfoDownloads: []string{"192.0.2.1"},
|
wantInfoDownloads: []string{"192.0.2.1"},
|
||||||
},
|
},
|
||||||
"download deployment fails": {
|
"no load balancer is discovered": {
|
||||||
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}},
|
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}},
|
||||||
downloader: stubDownloader{downloadDeploymentErrs: []error{someErr, someErr}},
|
downloader: stubDownloader{},
|
||||||
wantDiscoverCount: 2,
|
wantDiscoverCount: 2,
|
||||||
wantDeploymentDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.1"},
|
wantDeploymentDownloads: []string{"192.0.2.1"},
|
||||||
|
wantInfoDownloads: []string{"192.0.2.1"},
|
||||||
|
},
|
||||||
|
"no nodes are discovered": {
|
||||||
|
fetcher: stubFetcher{loadBalancerIP: "192.0.2.3"},
|
||||||
|
downloader: stubDownloader{},
|
||||||
|
wantDiscoverCount: 2,
|
||||||
|
wantDeploymentDownloads: []string{"192.0.2.3"},
|
||||||
|
wantInfoDownloads: []string{"192.0.2.3"},
|
||||||
|
},
|
||||||
|
"download deployment fails": {
|
||||||
|
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}, loadBalancerIP: "192.0.2.3"},
|
||||||
|
downloader: stubDownloader{downloadDeploymentErrs: []error{someErr, someErr, someErr}},
|
||||||
|
wantDiscoverCount: 4,
|
||||||
|
wantDeploymentDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.3", "192.0.2.1"},
|
||||||
wantInfoDownloads: []string{"192.0.2.1"},
|
wantInfoDownloads: []string{"192.0.2.1"},
|
||||||
},
|
},
|
||||||
"download info fails": {
|
"download info fails": {
|
||||||
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}},
|
fetcher: stubFetcher{ips: []string{"192.0.2.1", "192.0.2.2"}, loadBalancerIP: "192.0.2.3"},
|
||||||
downloader: stubDownloader{downloadInfoErrs: []error{someErr, someErr}},
|
downloader: stubDownloader{downloadInfoErrs: []error{someErr, someErr, someErr}},
|
||||||
wantDiscoverCount: 2,
|
wantDiscoverCount: 4,
|
||||||
wantDeploymentDownloads: []string{"192.0.2.1"},
|
wantDeploymentDownloads: []string{"192.0.2.1"},
|
||||||
wantInfoDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.1"},
|
wantInfoDownloads: []string{"192.0.2.1", "192.0.2.2", "192.0.2.3", "192.0.2.1"},
|
||||||
},
|
},
|
||||||
"endpoint discovery fails": {
|
"endpoint discovery fails": {
|
||||||
fetcher: stubFetcher{
|
fetcher: stubFetcher{
|
||||||
discoverErrs: []error{someErr, someErr, someErr},
|
discoverErrs: []error{someErr, someErr, someErr},
|
||||||
|
discoverLoadBalancerIPErr: someErr,
|
||||||
ips: []string{"192.0.2.1", "192.0.2.2"},
|
ips: []string{"192.0.2.1", "192.0.2.2"},
|
||||||
},
|
},
|
||||||
wantDiscoverCount: 4,
|
wantDiscoverCount: 8,
|
||||||
wantDeploymentDownloads: []string{"192.0.2.1"},
|
wantDeploymentDownloads: []string{"192.0.2.1"},
|
||||||
wantInfoDownloads: []string{"192.0.2.1"},
|
wantInfoDownloads: []string{"192.0.2.1"},
|
||||||
},
|
},
|
||||||
@ -90,7 +105,11 @@ type stubFetcher struct {
|
|||||||
ips []string
|
ips []string
|
||||||
discoverErrs []error
|
discoverErrs []error
|
||||||
discoverErrIdx int
|
discoverErrIdx int
|
||||||
|
|
||||||
discoverCalls int
|
discoverCalls int
|
||||||
|
|
||||||
|
loadBalancerIP string
|
||||||
|
discoverLoadBalancerIPErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *stubFetcher) DiscoverDebugdIPs(_ context.Context) ([]string, error) {
|
func (s *stubFetcher) DiscoverDebugdIPs(_ context.Context) ([]string, error) {
|
||||||
@ -104,6 +123,11 @@ func (s *stubFetcher) DiscoverDebugdIPs(_ context.Context) ([]string, error) {
|
|||||||
return s.ips, nil
|
return s.ips, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *stubFetcher) DiscoverLoadBalancerIP(_ context.Context) (string, error) {
|
||||||
|
s.discoverCalls++
|
||||||
|
return s.loadBalancerIP, s.discoverLoadBalancerIPErr
|
||||||
|
}
|
||||||
|
|
||||||
type stubDownloader struct {
|
type stubDownloader struct {
|
||||||
downloadDeploymentErrs []error
|
downloadDeploymentErrs []error
|
||||||
downloadDeploymentErrIdx int
|
downloadDeploymentErrIdx int
|
||||||
|
@ -133,9 +133,6 @@ func (s *debugdServer) UploadFiles(stream pb.Debugd_UploadFilesServer) error {
|
|||||||
// DownloadFiles streams the previously received files to other instances.
|
// DownloadFiles streams the previously received files to other instances.
|
||||||
func (s *debugdServer) DownloadFiles(_ *pb.DownloadFilesRequest, stream pb.Debugd_DownloadFilesServer) error {
|
func (s *debugdServer) DownloadFiles(_ *pb.DownloadFilesRequest, stream pb.Debugd_DownloadFilesServer) error {
|
||||||
s.log.Infof("Sending files to other instance")
|
s.log.Infof("Sending files to other instance")
|
||||||
if !s.transfer.CanSend() {
|
|
||||||
return errors.New("cannot send files at this time")
|
|
||||||
}
|
|
||||||
return s.transfer.SendFiles(stream)
|
return s.transfer.SendFiles(stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -185,5 +182,4 @@ type fileTransferer interface {
|
|||||||
RecvFiles(stream filetransfer.RecvFilesStream) error
|
RecvFiles(stream filetransfer.RecvFilesStream) error
|
||||||
SendFiles(stream filetransfer.SendFilesStream) error
|
SendFiles(stream filetransfer.SendFilesStream) error
|
||||||
GetFiles() []filetransfer.FileStat
|
GetFiles() []filetransfer.FileStat
|
||||||
CanSend() bool
|
|
||||||
}
|
}
|
||||||
|
@ -228,10 +228,6 @@ func TestDownloadFiles(t *testing.T) {
|
|||||||
canSend: true,
|
canSend: true,
|
||||||
wantSendFileCalls: 1,
|
wantSendFileCalls: 1,
|
||||||
},
|
},
|
||||||
"transfer is not ready for sending": {
|
|
||||||
request: &pb.DownloadFilesRequest{},
|
|
||||||
wantRecvErr: true,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, tc := range testCases {
|
for name, tc := range testCases {
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
"github.com/edgelesssys/constellation/v2/debugd/internal/debugd"
|
||||||
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
|
"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
|
||||||
@ -33,10 +34,10 @@ type SendFilesStream interface {
|
|||||||
|
|
||||||
// FileTransferer manages sending and receiving of files.
|
// FileTransferer manages sending and receiving of files.
|
||||||
type FileTransferer struct {
|
type FileTransferer struct {
|
||||||
mux sync.RWMutex
|
fileMux sync.RWMutex
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
receiveStarted bool
|
receiveStarted bool
|
||||||
receiveFinished bool
|
receiveFinished atomic.Bool
|
||||||
files []FileStat
|
files []FileStat
|
||||||
streamer streamReadWriter
|
streamer streamReadWriter
|
||||||
showProgress bool
|
showProgress bool
|
||||||
@ -52,12 +53,15 @@ func New(log *logger.Logger, streamer streamReadWriter, showProgress bool) *File
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendFiles sends files to the given stream.
|
// SendFiles sends files to the given stream.
|
||||||
|
// If the FileTransferer has not received any files to send, an error is returned.
|
||||||
func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
|
func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
|
||||||
s.mux.RLock()
|
if !s.receiveFinished.Load() {
|
||||||
defer s.mux.RUnlock()
|
|
||||||
if !s.receiveFinished {
|
|
||||||
return errors.New("cannot send files before receiving them")
|
return errors.New("cannot send files before receiving them")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.fileMux.RLock()
|
||||||
|
defer s.fileMux.RUnlock()
|
||||||
|
|
||||||
for _, file := range s.files {
|
for _, file := range s.files {
|
||||||
if err := s.handleFileSend(stream, file); err != nil {
|
if err := s.handleFileSend(stream, file); err != nil {
|
||||||
return err
|
return err
|
||||||
@ -68,8 +72,8 @@ func (s *FileTransferer) SendFiles(stream SendFilesStream) error {
|
|||||||
|
|
||||||
// RecvFiles receives files from the given stream.
|
// RecvFiles receives files from the given stream.
|
||||||
func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
|
func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
|
||||||
s.mux.Lock()
|
s.fileMux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.fileMux.Unlock()
|
||||||
if err := s.startRecv(); err != nil {
|
if err := s.startRecv(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -89,30 +93,23 @@ func (s *FileTransferer) RecvFiles(stream RecvFilesStream) (err error) {
|
|||||||
|
|
||||||
// GetFiles returns the a copy of the list of files that have been received.
|
// GetFiles returns the a copy of the list of files that have been received.
|
||||||
func (s *FileTransferer) GetFiles() []FileStat {
|
func (s *FileTransferer) GetFiles() []FileStat {
|
||||||
s.mux.RLock()
|
s.fileMux.RLock()
|
||||||
defer s.mux.RUnlock()
|
defer s.fileMux.RUnlock()
|
||||||
res := make([]FileStat, len(s.files))
|
res := make([]FileStat, len(s.files))
|
||||||
copy(res, s.files)
|
copy(res, s.files)
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetFiles sets the list of files that can be sent.
|
// SetFiles sets the list of files that can be sent.
|
||||||
|
// This function is used for a sender which has not received any files through
|
||||||
|
// this FileTransferer i.e. the CLI.
|
||||||
func (s *FileTransferer) SetFiles(files []FileStat) {
|
func (s *FileTransferer) SetFiles(files []FileStat) {
|
||||||
s.mux.Lock()
|
s.fileMux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.fileMux.Unlock()
|
||||||
res := make([]FileStat, len(files))
|
res := make([]FileStat, len(files))
|
||||||
copy(res, files)
|
copy(res, files)
|
||||||
s.files = res
|
s.files = res
|
||||||
s.receiveFinished = true
|
s.receiveFinished.Store(true)
|
||||||
}
|
|
||||||
|
|
||||||
// CanSend returns true if the file receive has finished.
|
|
||||||
// This is called to determine if a debugd instance can request files from this server.
|
|
||||||
func (s *FileTransferer) CanSend() bool {
|
|
||||||
s.mux.RLock()
|
|
||||||
defer s.mux.RUnlock()
|
|
||||||
ret := s.receiveFinished
|
|
||||||
return ret
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error {
|
func (s *FileTransferer) handleFileSend(stream SendFilesStream, file FileStat) error {
|
||||||
@ -173,7 +170,7 @@ func (s *FileTransferer) handleFileRecv(stream RecvFilesStream) (bool, error) {
|
|||||||
// startRecv marks the file receive as started. It returns an error if receiving has already started.
|
// startRecv marks the file receive as started. It returns an error if receiving has already started.
|
||||||
func (s *FileTransferer) startRecv() error {
|
func (s *FileTransferer) startRecv() error {
|
||||||
switch {
|
switch {
|
||||||
case s.receiveFinished:
|
case s.receiveFinished.Load():
|
||||||
return ErrReceiveFinished
|
return ErrReceiveFinished
|
||||||
case s.receiveStarted:
|
case s.receiveStarted:
|
||||||
return ErrReceiveRunning
|
return ErrReceiveRunning
|
||||||
@ -193,7 +190,7 @@ func (s *FileTransferer) abortRecv() {
|
|||||||
// This allows other debugd instances to request files from this server.
|
// This allows other debugd instances to request files from this server.
|
||||||
func (s *FileTransferer) finishRecv() {
|
func (s *FileTransferer) finishRecv() {
|
||||||
s.receiveStarted = false
|
s.receiveStarted = false
|
||||||
s.receiveFinished = true
|
s.receiveFinished.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// addFile adds a file to the list of received files.
|
// addFile adds a file to the list of received files.
|
||||||
|
@ -26,6 +26,7 @@ func TestMain(m *testing.M) {
|
|||||||
func TestSendFiles(t *testing.T) {
|
func TestSendFiles(t *testing.T) {
|
||||||
testCases := map[string]struct {
|
testCases := map[string]struct {
|
||||||
files *[]FileStat
|
files *[]FileStat
|
||||||
|
receiveFinished bool
|
||||||
sendErr error
|
sendErr error
|
||||||
readStreamErr error
|
readStreamErr error
|
||||||
wantHeaders []*pb.FileTransferMessage
|
wantHeaders []*pb.FileTransferMessage
|
||||||
@ -44,6 +45,7 @@ func TestSendFiles(t *testing.T) {
|
|||||||
OverrideServiceUnit: "somesvcB",
|
OverrideServiceUnit: "somesvcB",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
receiveFinished: true,
|
||||||
wantHeaders: []*pb.FileTransferMessage{
|
wantHeaders: []*pb.FileTransferMessage{
|
||||||
{
|
{
|
||||||
Kind: &pb.FileTransferMessage_Header{
|
Kind: &pb.FileTransferMessage_Header{
|
||||||
@ -65,7 +67,20 @@ func TestSendFiles(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"no files set": {
|
"not finished receiving": {
|
||||||
|
files: &[]FileStat{
|
||||||
|
{
|
||||||
|
TargetPath: "testfileA",
|
||||||
|
Mode: 0o644,
|
||||||
|
OverrideServiceUnit: "somesvcA",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
TargetPath: "testfileB",
|
||||||
|
Mode: 0o644,
|
||||||
|
OverrideServiceUnit: "somesvcB",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
receiveFinished: false,
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
"send fails": {
|
"send fails": {
|
||||||
@ -76,6 +91,7 @@ func TestSendFiles(t *testing.T) {
|
|||||||
OverrideServiceUnit: "somesvcA",
|
OverrideServiceUnit: "somesvcA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
receiveFinished: true,
|
||||||
sendErr: errors.New("send failed"),
|
sendErr: errors.New("send failed"),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -87,6 +103,7 @@ func TestSendFiles(t *testing.T) {
|
|||||||
OverrideServiceUnit: "somesvcA",
|
OverrideServiceUnit: "somesvcA",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
receiveFinished: true,
|
||||||
readStreamErr: errors.New("read stream failed"),
|
readStreamErr: errors.New("read stream failed"),
|
||||||
wantErr: true,
|
wantErr: true,
|
||||||
},
|
},
|
||||||
@ -99,10 +116,16 @@ func TestSendFiles(t *testing.T) {
|
|||||||
|
|
||||||
streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr}
|
streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr}
|
||||||
stream := &stubSendFilesStream{sendErr: tc.sendErr}
|
stream := &stubSendFilesStream{sendErr: tc.sendErr}
|
||||||
transfer := New(logger.NewTest(t), streamer, false)
|
transfer := &FileTransferer{
|
||||||
if tc.files != nil {
|
log: logger.NewTest(t),
|
||||||
transfer.SetFiles(*tc.files)
|
streamer: streamer,
|
||||||
|
showProgress: false,
|
||||||
}
|
}
|
||||||
|
if tc.files != nil {
|
||||||
|
transfer.files = *tc.files
|
||||||
|
}
|
||||||
|
transfer.receiveFinished.Store(tc.receiveFinished)
|
||||||
|
|
||||||
err := transfer.SendFiles(stream)
|
err := transfer.SendFiles(stream)
|
||||||
|
|
||||||
if tc.wantErr {
|
if tc.wantErr {
|
||||||
@ -236,7 +259,7 @@ func TestRecvFiles(t *testing.T) {
|
|||||||
transfer.receiveStarted = true
|
transfer.receiveStarted = true
|
||||||
}
|
}
|
||||||
if tc.recvAlreadyFinished {
|
if tc.recvAlreadyFinished {
|
||||||
transfer.receiveFinished = true
|
transfer.receiveFinished.Store(true)
|
||||||
}
|
}
|
||||||
err := transfer.RecvFiles(stream)
|
err := transfer.RecvFiles(stream)
|
||||||
|
|
||||||
@ -290,34 +313,11 @@ func TestGetSetFiles(t *testing.T) {
|
|||||||
}
|
}
|
||||||
gotFiles := transfer.GetFiles()
|
gotFiles := transfer.GetFiles()
|
||||||
assert.Equal(tc.wantFiles, gotFiles)
|
assert.Equal(tc.wantFiles, gotFiles)
|
||||||
assert.Equal(tc.setFiles != nil, transfer.receiveFinished)
|
assert.Equal(tc.setFiles != nil, transfer.receiveFinished.Load())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCanSend(t *testing.T) {
|
|
||||||
assert := assert.New(t)
|
|
||||||
|
|
||||||
streamer := &stubStreamReadWriter{}
|
|
||||||
stream := &stubRecvFilesStream{recvErr: io.EOF}
|
|
||||||
transfer := New(logger.NewTest(t), streamer, false)
|
|
||||||
assert.False(transfer.CanSend())
|
|
||||||
|
|
||||||
// manual set
|
|
||||||
transfer.SetFiles(nil)
|
|
||||||
assert.True(transfer.CanSend())
|
|
||||||
|
|
||||||
// reset
|
|
||||||
transfer.receiveStarted = false
|
|
||||||
transfer.receiveFinished = false
|
|
||||||
transfer.files = nil
|
|
||||||
assert.False(transfer.CanSend())
|
|
||||||
|
|
||||||
// receive files (empty)
|
|
||||||
assert.NoError(transfer.RecvFiles(stream))
|
|
||||||
assert.True(transfer.CanSend())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConcurrency(t *testing.T) {
|
func TestConcurrency(t *testing.T) {
|
||||||
ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false)
|
ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false)
|
||||||
|
|
||||||
@ -337,10 +337,6 @@ func TestConcurrency(t *testing.T) {
|
|||||||
ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}})
|
ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}})
|
||||||
}
|
}
|
||||||
|
|
||||||
canSend := func() {
|
|
||||||
_ = ft.CanSend()
|
|
||||||
}
|
|
||||||
|
|
||||||
go sendFiles()
|
go sendFiles()
|
||||||
go sendFiles()
|
go sendFiles()
|
||||||
go sendFiles()
|
go sendFiles()
|
||||||
@ -357,10 +353,6 @@ func TestConcurrency(t *testing.T) {
|
|||||||
go setFiles()
|
go setFiles()
|
||||||
go setFiles()
|
go setFiles()
|
||||||
go setFiles()
|
go setFiles()
|
||||||
go canSend()
|
|
||||||
go canSend()
|
|
||||||
go canSend()
|
|
||||||
go canSend()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type stubStreamReadWriter struct {
|
type stubStreamReadWriter struct {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user