/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package cmd

import (
	"bytes"
	"context"
	"errors"
	"net"
	"strconv"
	"testing"
	"time"

	"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
	"github.com/edgelesssys/constellation/v2/cli/internal/clusterid"
	"github.com/edgelesssys/constellation/v2/disk-mapper/recoverproto"
	"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
	"github.com/edgelesssys/constellation/v2/internal/config"
	"github.com/edgelesssys/constellation/v2/internal/constants"
	"github.com/edgelesssys/constellation/v2/internal/crypto/testvector"
	"github.com/edgelesssys/constellation/v2/internal/file"
	"github.com/edgelesssys/constellation/v2/internal/grpc/atlscredentials"
	"github.com/edgelesssys/constellation/v2/internal/grpc/dialer"
	"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
	"github.com/spf13/afero"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

func TestRecoverCmdArgumentValidation(t *testing.T) {
	testCases := map[string]struct {
		args    []string
		wantErr bool
	}{
		"no args":            {[]string{}, false},
		"too many arguments": {[]string{"abc"}, true},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			cmd := NewRecoverCmd()
			err := cmd.ValidateArgs(tc.args)

			if tc.wantErr {
				assert.Error(err)
			} else {
				assert.NoError(err)
			}
		})
	}
}

func TestRecover(t *testing.T) {
	someErr := errors.New("error")
	unavailableErr := status.Error(codes.Unavailable, "unavailable")
	lbErr := status.Error(codes.Unavailable, `connection error: desc = "transport: authentication handshake failed: read tcp`)

	testCases := map[string]struct {
		doer            *stubDoer
		masterSecret    testvector.HKDF
		endpoint        string
		configFlag      string
		successfulCalls int
		wantErr         bool
	}{
		"works": {
			doer:            &stubDoer{returns: []error{nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 1,
		},
		"missing config": {
			doer:         &stubDoer{returns: []error{nil}},
			endpoint:     "192.0.2.89",
			masterSecret: testvector.HKDFZero,
			configFlag:   "nonexistent-config",
			wantErr:      true,
		},
		"success multiple nodes": {
			doer:            &stubDoer{returns: []error{nil, nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 2,
		},
		"no nodes to recover does not error": {
			doer:            &stubDoer{returns: []error{unavailableErr}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 0,
		},
		"error on first node": {
			doer:            &stubDoer{returns: []error{someErr, nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 0,
			wantErr:         true,
		},
		"unavailable error is retried once": {
			doer:            &stubDoer{returns: []error{unavailableErr, nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 1,
		},
		"unavailable error is not retried twice": {
			doer:            &stubDoer{returns: []error{unavailableErr, unavailableErr, nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 0,
		},
		"unavailable error is not retried twice after success": {
			doer:            &stubDoer{returns: []error{nil, unavailableErr, unavailableErr, nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 1,
		},
		"transient LB errors are retried": {
			doer:            &stubDoer{returns: []error{lbErr, lbErr, lbErr, nil}},
			endpoint:        "192.0.2.90",
			masterSecret:    testvector.HKDFZero,
			successfulCalls: 1,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			cmd := NewRecoverCmd()
			cmd.SetContext(context.Background())
			cmd.Flags().String("config", constants.ConfigFilename, "") // register persistent flag manually
			out := &bytes.Buffer{}
			cmd.SetOut(out)
			cmd.SetErr(out)
			require.NoError(cmd.Flags().Set("endpoint", tc.endpoint))

			if tc.configFlag != "" {
				require.NoError(cmd.Flags().Set("config", tc.configFlag))
			}

			fs := afero.NewMemMapFs()
			fileHandler := file.NewHandler(fs)

			config := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.GCP)
			require.NoError(fileHandler.WriteYAML(constants.ConfigFilename, config))

			require.NoError(fileHandler.WriteJSON(
				"constellation-mastersecret.json",
				masterSecret{Key: tc.masterSecret.Secret, Salt: tc.masterSecret.Salt},
				file.OptNone,
			))

			newDialer := func(*cloudcmd.Validator) *dialer.Dialer { return nil }

			err := recover(cmd, fileHandler, time.Millisecond, tc.doer, newDialer)
			if tc.wantErr {
				assert.Error(err)
				if tc.successfulCalls > 0 {
					assert.Contains(out.String(), strconv.Itoa(tc.successfulCalls))
				}
				return
			}

			assert.NoError(err)
			if tc.successfulCalls > 0 {
				assert.Contains(out.String(), "Pushed recovery key.")
				assert.Contains(out.String(), strconv.Itoa(tc.successfulCalls))
			} else {
				assert.Contains(out.String(), "No control-plane nodes in need of recovery found.")
			}
		})
	}
}

func TestParseRecoverFlags(t *testing.T) {
	testCases := map[string]struct {
		args        []string
		wantFlags   recoverFlags
		writeIDFile bool
		wantErr     bool
	}{
		"no flags": {
			wantFlags: recoverFlags{
				endpoint:   "192.0.2.42:9999",
				secretPath: "constellation-mastersecret.json",
			},
			writeIDFile: true,
		},
		"no flags, no ID file": {
			wantFlags: recoverFlags{
				endpoint:   "192.0.2.42:9999",
				secretPath: "constellation-mastersecret.json",
			},
			wantErr: true,
		},
		"invalid endpoint": {
			args:    []string{"-e", "192.0.2.42:2:2"},
			wantErr: true,
		},
		"all args set": {
			args: []string{"-e", "192.0.2.42:2", "--config", "config-path", "--master-secret", "/path/super-secret.json"},
			wantFlags: recoverFlags{
				endpoint:   "192.0.2.42:2",
				secretPath: "/path/super-secret.json",
				configPath: "config-path",
			},
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			cmd := NewRecoverCmd()
			cmd.Flags().String("config", "", "") // register persistent flag manually
			require.NoError(cmd.ParseFlags(tc.args))

			fileHandler := file.NewHandler(afero.NewMemMapFs())
			if tc.writeIDFile {
				require.NoError(fileHandler.WriteJSON(constants.ClusterIDsFileName, &clusterid.File{IP: "192.0.2.42"}))
			}

			flags, err := parseRecoverFlags(cmd, fileHandler)

			if tc.wantErr {
				assert.Error(err)
				return
			}
			assert.NoError(err)
			assert.Equal(tc.wantFlags, flags)
		})
	}
}

func TestDoRecovery(t *testing.T) {
	someErr := errors.New("error")
	testCases := map[string]struct {
		recoveryServer *stubRecoveryServer
		wantErr        bool
	}{
		"success": {
			recoveryServer: &stubRecoveryServer{
				actions: [][]func(stream recoverproto.API_RecoverServer) error{{
					func(stream recoverproto.API_RecoverServer) error {
						_, err := stream.Recv()
						return err
					},
					func(stream recoverproto.API_RecoverServer) error {
						return stream.Send(&recoverproto.RecoverResponse{
							DiskUuid: "00000000-0000-0000-0000-000000000000",
						})
					},
					func(stream recoverproto.API_RecoverServer) error {
						_, err := stream.Recv()
						return err
					},
				}},
			},
		},
		"error on first recv": {
			recoveryServer: &stubRecoveryServer{
				actions: [][]func(stream recoverproto.API_RecoverServer) error{
					{
						func(stream recoverproto.API_RecoverServer) error {
							return someErr
						},
					},
				},
			},
			wantErr: true,
		},
		"error on send": {
			recoveryServer: &stubRecoveryServer{
				actions: [][]func(stream recoverproto.API_RecoverServer) error{
					{
						func(stream recoverproto.API_RecoverServer) error {
							_, err := stream.Recv()
							return err
						},
						func(stream recoverproto.API_RecoverServer) error {
							return someErr
						},
					},
				},
			},
			wantErr: true,
		},
		"error on second recv": {
			recoveryServer: &stubRecoveryServer{
				actions: [][]func(stream recoverproto.API_RecoverServer) error{
					{
						func(stream recoverproto.API_RecoverServer) error {
							_, err := stream.Recv()
							return err
						},
						func(stream recoverproto.API_RecoverServer) error {
							return stream.Send(&recoverproto.RecoverResponse{
								DiskUuid: "00000000-0000-0000-0000-000000000000",
							})
						},
						func(stream recoverproto.API_RecoverServer) error {
							return someErr
						},
					},
				},
			},
			wantErr: true,
		},
		"final message is an error": {
			recoveryServer: &stubRecoveryServer{
				actions: [][]func(stream recoverproto.API_RecoverServer) error{{
					func(stream recoverproto.API_RecoverServer) error {
						_, err := stream.Recv()
						return err
					},
					func(stream recoverproto.API_RecoverServer) error {
						return stream.Send(&recoverproto.RecoverResponse{
							DiskUuid: "00000000-0000-0000-0000-000000000000",
						})
					},
					func(stream recoverproto.API_RecoverServer) error {
						_, err := stream.Recv()
						return err
					},
					func(stream recoverproto.API_RecoverServer) error {
						return someErr
					},
				}},
			},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			netDialer := testdialer.NewBufconnDialer()
			serverCreds := atlscredentials.New(nil, nil)
			recoverServer := grpc.NewServer(grpc.Creds(serverCreds))
			recoverproto.RegisterAPIServer(recoverServer, tc.recoveryServer)
			addr := net.JoinHostPort("192.0.42.42", strconv.Itoa(constants.RecoveryPort))
			listener := netDialer.GetListener(addr)
			go recoverServer.Serve(listener)
			defer recoverServer.GracefulStop()

			recoverDoer := &recoverDoer{
				dialer:            dialer.New(nil, nil, netDialer),
				endpoint:          addr,
				measurementSecret: []byte("measurement-secret"),
				getDiskKey: func(string) ([]byte, error) {
					return []byte("disk-key"), nil
				},
			}

			err := recoverDoer.Do(context.Background())
			if tc.wantErr {
				assert.Error(err)
			} else {
				assert.NoError(err)
			}
		})
	}
}

func TestDeriveStateDiskKey(t *testing.T) {
	testCases := map[string]struct {
		masterSecret testvector.HKDF
	}{
		"all zero": {
			masterSecret: testvector.HKDFZero,
		},
		"all 0xff": {
			masterSecret: testvector.HKDF0xFF,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			getKeyFunc := getStateDiskKeyFunc(tc.masterSecret.Secret, tc.masterSecret.Salt)
			stateDiskKey, err := getKeyFunc(tc.masterSecret.Info)

			assert.NoError(err)
			assert.Equal(tc.masterSecret.Output, stateDiskKey)
		})
	}
}

type stubRecoveryServer struct {
	actions [][]func(recoverproto.API_RecoverServer) error
	calls   int
	recoverproto.UnimplementedAPIServer
}

func (s *stubRecoveryServer) Recover(stream recoverproto.API_RecoverServer) error {
	if s.calls >= len(s.actions) {
		return status.Error(codes.Unavailable, "server is unavailable")
	}
	s.calls++

	for _, action := range s.actions[s.calls-1] {
		if err := action(stream); err != nil {
			return err
		}
	}
	return nil
}

type stubDoer struct {
	returns []error
}

func (d *stubDoer) Do(context.Context) error {
	err := d.returns[0]
	if len(d.returns) > 1 {
		d.returns = d.returns[1:]
	} else {
		d.returns = []error{status.Error(codes.Unavailable, "unavailable")}
	}
	return err
}

func (d *stubDoer) setDialer(grpcDialer, string) {}

func (d *stubDoer) setSecrets(func(string) ([]byte, error), []byte) {}