/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package initserver

import (
	"bytes"
	"context"
	"errors"
	"io"
	"net"
	"strings"
	"sync"
	"testing"
	"time"

	"github.com/edgelesssys/constellation/v2/bootstrapper/initproto"
	"github.com/edgelesssys/constellation/v2/internal/atls"
	"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
	"github.com/edgelesssys/constellation/v2/internal/crypto/testvector"
	"github.com/edgelesssys/constellation/v2/internal/file"
	kmssetup "github.com/edgelesssys/constellation/v2/internal/kms/setup"
	"github.com/edgelesssys/constellation/v2/internal/kms/uri"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/edgelesssys/constellation/v2/internal/versions/components"
	"github.com/spf13/afero"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
	"golang.org/x/crypto/bcrypt"
	"google.golang.org/grpc"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m,
		// https://github.com/census-instrumentation/opencensus-go/issues/1262
		goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"),
		goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"),
	)
}

func TestNew(t *testing.T) {
	fh := file.NewHandler(afero.NewMemMapFs())

	testCases := map[string]struct {
		metadata stubMetadata
		wantErr  bool
	}{
		"success": {
			metadata: stubMetadata{initSecretHashVal: []byte("hash")},
		},
		"empty init secret hash": {
			metadata: stubMetadata{initSecretHashVal: nil},
			wantErr:  true,
		},
		"metadata error": {
			metadata: stubMetadata{initSecretHashErr: errors.New("error")},
			wantErr:  true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			server, err := New(
				context.TODO(), newFakeLock(), &stubClusterInitializer{}, atls.NewFakeIssuer(variant.Dummy{}),
				&stubDisk{}, fh, &tc.metadata, logger.NewTest(t),
			)
			if tc.wantErr {
				assert.Error(err)
				return
			}
			assert.NoError(err)
			assert.NotNil(server)
			assert.NotEmpty(server.initSecretHash)
			assert.NotNil(server.log)
			assert.NotNil(server.nodeLock)
			assert.NotNil(server.initializer)
			assert.NotNil(server.grpcServer)
			assert.NotNil(server.fileHandler)
			assert.NotNil(server.disk)
		})
	}
}

func TestInit(t *testing.T) {
	someErr := errors.New("failed")
	lockedLock := newFakeLock()
	aqcuiredLock, lockErr := lockedLock.TryLockOnce(nil)
	require.True(t, aqcuiredLock)
	require.Nil(t, lockErr)

	initSecret := []byte("password")
	initSecretHash, err := bcrypt.GenerateFromPassword(initSecret, bcrypt.DefaultCost)
	require.NoError(t, err)

	masterSecret := uri.MasterSecret{Key: []byte("secret"), Salt: []byte("salt")}

	testCases := map[string]struct {
		nodeLock       *fakeLock
		initializer    ClusterInitializer
		disk           encryptedDisk
		fileHandler    file.Handler
		req            *initproto.InitRequest
		stream         stubStream
		logCollector   stubJournaldCollector
		initSecretHash []byte
		wantErr        bool
		wantShutdown   bool
	}{
		"successful init": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			initSecretHash: initSecretHash,
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			wantShutdown:   true,
		},
		"node locked": {
			nodeLock:       lockedLock,
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			initSecretHash: initSecretHash,
			wantErr:        true,
		},
		"disk open error": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{openErr: someErr},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			initSecretHash: initSecretHash,
			wantErr:        true,
			wantShutdown:   true,
		},
		"disk uuid error": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{uuidErr: someErr},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			initSecretHash: initSecretHash,
			wantErr:        true,
			wantShutdown:   true,
		},
		"disk update passphrase error": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{updatePassphraseErr: someErr},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			initSecretHash: initSecretHash,
			wantErr:        true,
			wantShutdown:   true,
		},
		"write state file error": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{},
			fileHandler:    file.NewHandler(afero.NewReadOnlyFs(afero.NewMemMapFs())),
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			initSecretHash: initSecretHash,
			wantErr:        true,
			wantShutdown:   true,
		},
		"initialize cluster error": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{initClusterErr: someErr},
			disk:           &stubDisk{},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			req:            &initproto.InitRequest{InitSecret: initSecret, KmsUri: masterSecret.EncodeToURI(), StorageUri: uri.NoStoreURI},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			initSecretHash: initSecretHash,
			wantErr:        true,
			wantShutdown:   true,
		},
		"wrong initSecret": {
			nodeLock:       newFakeLock(),
			initializer:    &stubClusterInitializer{},
			disk:           &stubDisk{},
			fileHandler:    file.NewHandler(afero.NewMemMapFs()),
			initSecretHash: initSecretHash,
			req:            &initproto.InitRequest{InitSecret: []byte("wrongpassword")},
			stream:         stubStream{},
			logCollector:   stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte{})}},
			wantErr:        true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			serveStopper := newStubServeStopper()
			server := &Server{
				nodeLock:          tc.nodeLock,
				initializer:       tc.initializer,
				disk:              tc.disk,
				fileHandler:       tc.fileHandler,
				log:               logger.NewTest(t),
				grpcServer:        serveStopper,
				cleaner:           &fakeCleaner{serveStopper: serveStopper},
				initSecretHash:    tc.initSecretHash,
				journaldCollector: &tc.logCollector,
			}

			err := server.Init(tc.req, &tc.stream)

			if tc.wantErr {
				assert.Error(err)

				if tc.wantShutdown {
					select {
					case <-serveStopper.shutdownCalled:
					case <-time.After(time.Second):
						t.Fatal("grpc server did not shut down")
					}
				}

				return
			}

			for _, res := range tc.stream.res {
				assert.NotNil(res.GetInitSuccess())
			}
			assert.NoError(err)
			assert.False(server.nodeLock.TryLockOnce(nil)) // lock should be locked
		})
	}
}

func TestSendLogsWithMessage(t *testing.T) {
	someError := errors.New("failed")

	testCases := map[string]struct {
		logCollector           journaldCollection
		stream                 stubStream
		failureMessage         string
		expectedResult         string
		expectedFailureMessage string
		wantErr                bool
	}{
		"success": {
			logCollector:           &stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte("asdf"))}},
			stream:                 stubStream{},
			failureMessage:         "fdsa",
			expectedResult:         "asdf",
			expectedFailureMessage: "fdsa",
		},
		"fail collection": {
			logCollector:           &stubJournaldCollector{collectErr: someError},
			failureMessage:         "fdsa",
			wantErr:                true,
			expectedFailureMessage: "fdsa",
		},
		"fail to send": {
			logCollector:           &stubJournaldCollector{logPipe: &stubReadCloser{reader: bytes.NewReader([]byte("asdf"))}},
			stream:                 stubStream{sendError: someError},
			failureMessage:         "fdsa",
			wantErr:                true,
			expectedFailureMessage: "fdsa",
		},
		"fail to read": {
			logCollector:           &stubJournaldCollector{logPipe: &stubReadCloser{readErr: someError}},
			failureMessage:         "fdsa",
			wantErr:                true,
			expectedFailureMessage: "fdsa",
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			serverStopper := newStubServeStopper()
			server := &Server{
				grpcServer:        serverStopper,
				journaldCollector: tc.logCollector,
			}

			err := server.sendLogsWithMessage(&tc.stream, errors.New(tc.failureMessage))

			if tc.wantErr {
				assert.Error(err)

				return
			}

			assert.Equal(tc.stream.res[0].GetInitFailure().GetError(), tc.expectedFailureMessage)

			assert.NoError(err)
			for _, res := range tc.stream.res[1:] {
				assert.Equal(tc.expectedResult, string(res.GetLog().Log))
			}
		})
	}
}

func TestSetupDisk(t *testing.T) {
	testCases := map[string]struct {
		uuid      string
		masterKey []byte
		salt      []byte
		wantKey   []byte
	}{
		"lower case uuid": {
			uuid:      strings.ToLower(testvector.HKDF0xFF.Info),
			masterKey: testvector.HKDF0xFF.Secret,
			salt:      testvector.HKDF0xFF.Salt,
			wantKey:   testvector.HKDF0xFF.Output,
		},
		"upper case uuid": {
			uuid:      strings.ToUpper(testvector.HKDF0xFF.Info),
			masterKey: testvector.HKDF0xFF.Secret,
			salt:      testvector.HKDF0xFF.Salt,
			wantKey:   testvector.HKDF0xFF.Output,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			disk := &fakeDisk{
				uuid:    tc.uuid,
				wantKey: tc.wantKey,
			}
			server := &Server{
				disk: disk,
			}

			masterSecret := uri.MasterSecret{Key: tc.masterKey, Salt: tc.salt}

			cloudKms, err := kmssetup.KMS(context.Background(), uri.NoStoreURI, masterSecret.EncodeToURI())
			require.NoError(err)
			assert.NoError(server.setupDisk(context.Background(), cloudKms))
		})
	}
}

type fakeDisk struct {
	uuid    string
	wantKey []byte
}

func (d *fakeDisk) Open() (func(), error) {
	return func() {}, nil
}

func (d *fakeDisk) Close() error {
	return nil
}

func (d *fakeDisk) UUID() (string, error) {
	return d.uuid, nil
}

func (d *fakeDisk) UpdatePassphrase(passphrase string) error {
	if passphrase != string(d.wantKey) {
		return errors.New("wrong passphrase")
	}
	return nil
}

func (d *fakeDisk) MarkDiskForReset() error {
	return nil
}

type stubDisk struct {
	openErr                error
	uuid                   string
	uuidErr                error
	updatePassphraseErr    error
	updatePassphraseCalled bool
}

func (d *stubDisk) Open() (func(), error) {
	return func() {}, d.openErr
}

func (d *stubDisk) UUID() (string, error) {
	return d.uuid, d.uuidErr
}

func (d *stubDisk) UpdatePassphrase(string) error {
	d.updatePassphraseCalled = true
	return d.updatePassphraseErr
}

func (d *stubDisk) MarkDiskForReset() error {
	return nil
}

type stubClusterInitializer struct {
	initClusterKubeconfig []byte
	initClusterErr        error
}

func (i *stubClusterInitializer) InitCluster(
	context.Context, string, string,
	bool, components.Components, []string, string,
) ([]byte, error) {
	return i.initClusterKubeconfig, i.initClusterErr
}

type stubServeStopper struct {
	shutdownCalled chan struct{}
}

func newStubServeStopper() *stubServeStopper {
	return &stubServeStopper{shutdownCalled: make(chan struct{}, 1)}
}

func (s *stubServeStopper) Serve(net.Listener) error {
	panic("should not be called in a test")
}

func (s *stubServeStopper) GracefulStop() {
	s.shutdownCalled <- struct{}{}
}

type fakeLock struct {
	state *sync.Mutex
}

func newFakeLock() *fakeLock {
	return &fakeLock{
		state: &sync.Mutex{},
	}
}

func (l *fakeLock) TryLockOnce(_ []byte) (bool, error) {
	return l.state.TryLock(), nil
}

type fakeCleaner struct {
	serveStopper
}

func (f *fakeCleaner) Clean() {
	go f.serveStopper.GracefulStop() // this is not the correct way to do this, but it's fine for testing
}

type stubMetadata struct {
	initSecretHashVal []byte
	initSecretHashErr error
}

func (m *stubMetadata) InitSecretHash(context.Context) ([]byte, error) {
	return m.initSecretHashVal, m.initSecretHashErr
}

type stubStream struct {
	res       []*initproto.InitResponse
	sendError error
	grpc.ServerStream
}

func (s *stubStream) Send(m *initproto.InitResponse) error {
	if s.sendError == nil {
		// we append here since we don't receive anything
		// if that if doesn't trigger
		s.res = append(s.res, m)
	}
	return s.sendError
}

func (s *stubStream) Context() context.Context {
	return context.Background()
}

type stubJournaldCollector struct {
	logPipe    io.ReadCloser
	collectErr error
}

func (s *stubJournaldCollector) Start() (io.ReadCloser, error) {
	return s.logPipe, s.collectErr
}

type stubReadCloser struct {
	reader   io.Reader
	readErr  error
	closeErr error
}

func (s *stubReadCloser) Read(p []byte) (n int, err error) {
	if s.readErr != nil {
		return 0, s.readErr
	}
	return s.reader.Read(p)
}

func (s *stubReadCloser) Close() error {
	return s.closeErr
}