/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package filetransfer

import (
	"errors"
	"io"
	"testing"

	"github.com/edgelesssys/constellation/v2/debugd/internal/filetransfer/streamer"
	pb "github.com/edgelesssys/constellation/v2/debugd/service"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m)
}

func TestSendFiles(t *testing.T) {
	testCases := map[string]struct {
		files           *[]FileStat
		receiveFinished bool
		sendErr         error
		readStreamErr   error
		wantHeaders     []*pb.FileTransferMessage
		wantErr         bool
	}{
		"can send files": {
			files: &[]FileStat{
				{
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
				{
					TargetPath:          "testfileB",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcB",
				},
			},
			receiveFinished: true,
			wantHeaders: []*pb.FileTransferMessage{
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath:          "testfileA",
							Mode:                0o644,
							OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
						},
					},
				},
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath:          "testfileB",
							Mode:                0o644,
							OverrideServiceUnit: func() *string { s := "somesvcB"; return &s }(),
						},
					},
				},
			},
		},
		"not finished receiving": {
			files: &[]FileStat{
				{
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
				{
					TargetPath:          "testfileB",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcB",
				},
			},
			receiveFinished: false,
			wantErr:         true,
		},
		"send fails": {
			files: &[]FileStat{
				{
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
			},
			receiveFinished: true,
			sendErr:         errors.New("send failed"),
			wantErr:         true,
		},
		"read stream fails": {
			files: &[]FileStat{
				{
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
			},
			receiveFinished: true,
			readStreamErr:   errors.New("read stream failed"),
			wantErr:         true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			streamer := &stubStreamReadWriter{readStreamErr: tc.readStreamErr}
			stream := &stubSendFilesStream{sendErr: tc.sendErr}
			transfer := &FileTransferer{
				log:          logger.NewTest(t),
				streamer:     streamer,
				showProgress: false,
			}
			if tc.files != nil {
				transfer.files = *tc.files
			}
			transfer.receiveFinished.Store(tc.receiveFinished)

			err := transfer.SendFiles(stream)

			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(tc.wantHeaders, stream.msgs)
		})
	}
}

func TestRecvFiles(t *testing.T) {
	testCases := map[string]struct {
		msgs                []*pb.FileTransferMessage
		recvAlreadyStarted  bool
		recvAlreadyFinished bool
		recvErr             error
		writeStreamErr      error
		wantFiles           []FileStat
		wantErr             bool
	}{
		"can recv files": {
			msgs: []*pb.FileTransferMessage{
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath:          "testfileA",
							Mode:                0o644,
							OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
						},
					},
				},
				// Chunk messages left out since they would be consumed by the streamReadWriter
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath: "testfileB",
							Mode:       0o644,
						},
					},
				},
				// Chunk messages left out since they would be consumed by the streamReadWriter
			},
			wantFiles: []FileStat{
				{
					SourcePath:          "testfileA",
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
				{
					SourcePath: "testfileB",
					TargetPath: "testfileB",
					Mode:       0o644,
				},
			},
		},
		"no messages": {},
		"recv fails": {
			recvErr: errors.New("recv failed"),
			wantErr: true,
		},
		"first recv does not yield file header": {
			msgs: []*pb.FileTransferMessage{
				{
					Kind: &pb.FileTransferMessage_Chunk{},
				},
			},
			wantErr: true,
		},
		"write stream fails": {
			msgs: []*pb.FileTransferMessage{
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath:          "testfileA",
							Mode:                0o644,
							OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
						},
					},
				},
				// Chunk messages left out since they would be consumed by the streamReadWriter
			},
			writeStreamErr: errors.New("write stream failed"),
			wantErr:        true,
		},
		"recv has already started": {
			msgs: []*pb.FileTransferMessage{
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath:          "testfileA",
							Mode:                0o644,
							OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
						},
					},
				},
				// Chunk messages left out since they would be consumed by the streamReadWriter
			},
			recvAlreadyStarted: true,
			wantErr:            true,
		},
		"recv has already finished": {
			msgs: []*pb.FileTransferMessage{
				{
					Kind: &pb.FileTransferMessage_Header{
						Header: &pb.FileTransferHeader{
							TargetPath:          "testfileA",
							Mode:                0o644,
							OverrideServiceUnit: func() *string { s := "somesvcA"; return &s }(),
						},
					},
				},
				// Chunk messages left out since they would be consumed by the streamReadWriter
			},
			recvAlreadyFinished: true,
			wantErr:             true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			streamer := &stubStreamReadWriter{writeStreamErr: tc.writeStreamErr}
			stream := &fakeRecvFilesStream{msgs: tc.msgs, recvErr: tc.recvErr}
			transfer := New(logger.NewTest(t), streamer, false)
			if tc.recvAlreadyStarted {
				transfer.receiveStarted = true
			}
			if tc.recvAlreadyFinished {
				transfer.receiveFinished.Store(true)
			}
			err := transfer.RecvFiles(stream)

			if tc.wantErr {
				assert.Error(err)
				return
			}
			require.NoError(err)
			assert.Equal(tc.wantFiles, transfer.files)
		})
	}
}

func TestGetSetFiles(t *testing.T) {
	testCases := map[string]struct {
		setFiles  *[]FileStat
		wantFiles []FileStat
		wantErr   bool
	}{
		"no files": {
			wantFiles: []FileStat{},
		},
		"files": {
			setFiles: &[]FileStat{
				{
					SourcePath:          "testfileA",
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
			},
			wantFiles: []FileStat{
				{
					SourcePath:          "testfileA",
					TargetPath:          "testfileA",
					Mode:                0o644,
					OverrideServiceUnit: "somesvcA",
				},
			},
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			streamer := &dummyStreamReadWriter{}
			transfer := New(logger.NewTest(t), streamer, false)
			if tc.setFiles != nil {
				transfer.SetFiles(*tc.setFiles)
			}
			gotFiles := transfer.GetFiles()
			assert.Equal(tc.wantFiles, gotFiles)
			assert.Equal(tc.setFiles != nil, transfer.receiveFinished.Load())
		})
	}
}

func TestConcurrency(t *testing.T) {
	ft := New(logger.NewTest(t), &stubStreamReadWriter{}, false)

	sendFiles := func() {
		_ = ft.SendFiles(&stubSendFilesStream{})
	}

	recvFiles := func() {
		_ = ft.RecvFiles(&stubRecvFilesStream{})
	}

	getFiles := func() {
		_ = ft.GetFiles()
	}

	setFiles := func() {
		ft.SetFiles([]FileStat{{SourcePath: "file", TargetPath: "file", Mode: 0o644}})
	}

	go sendFiles()
	go sendFiles()
	go sendFiles()
	go sendFiles()
	go recvFiles()
	go recvFiles()
	go recvFiles()
	go recvFiles()
	go getFiles()
	go getFiles()
	go getFiles()
	go getFiles()
	go setFiles()
	go setFiles()
	go setFiles()
	go setFiles()
}

type stubStreamReadWriter struct {
	readStreamErr  error
	writeStreamErr error
}

func (s *stubStreamReadWriter) ReadStream(_ string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
	return s.readStreamErr
}

func (s *stubStreamReadWriter) WriteStream(_ string, _ streamer.ReadChunkStream, _ bool) error {
	return s.writeStreamErr
}

type fakeRecvFilesStream struct {
	msgs    []*pb.FileTransferMessage
	pos     int
	recvErr error
}

func (s *fakeRecvFilesStream) Recv() (*pb.FileTransferMessage, error) {
	if s.recvErr != nil {
		return nil, s.recvErr
	}

	if s.pos < len(s.msgs) {
		s.pos++
		return s.msgs[s.pos-1], nil
	}

	return nil, io.EOF
}

type dummyStreamReadWriter struct{}

func (s *dummyStreamReadWriter) ReadStream(_ string, _ streamer.WriteChunkStream, _ uint, _ bool) error {
	panic("dummy")
}

func (s *dummyStreamReadWriter) WriteStream(_ string, _ streamer.ReadChunkStream, _ bool) error {
	panic("dummy")
}