/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package cmd

import (
	"bytes"
	"context"
	"testing"

	"github.com/edgelesssys/constellation/v2/cli/internal/cloudcmd"
	"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
	"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
	"github.com/edgelesssys/constellation/v2/internal/config"
	"github.com/edgelesssys/constellation/v2/internal/constants"
	"github.com/edgelesssys/constellation/v2/internal/constellation/helm"
	"github.com/edgelesssys/constellation/v2/internal/constellation/kubecmd"
	"github.com/edgelesssys/constellation/v2/internal/constellation/state"
	"github.com/edgelesssys/constellation/v2/internal/file"
	"github.com/edgelesssys/constellation/v2/internal/kms/uri"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/edgelesssys/constellation/v2/internal/semver"
	"github.com/edgelesssys/constellation/v2/internal/versions"
	"github.com/spf13/afero"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/mock"
	"github.com/stretchr/testify/require"
	apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
)

func TestUpgradeApply(t *testing.T) {
	fsWithStateFileAndTfState := func() file.Handler {
		fh := file.NewHandler(afero.NewMemMapFs())
		require.NoError(t, fh.MkdirAll(constants.TerraformWorkingDir))
		require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultStateFile(cloudprovider.Azure)))
		return fh
	}

	testCases := map[string]struct {
		helmUpgrader      helmApplier
		kubeUpgrader      *stubKubernetesUpgrader
		fh                func() file.Handler
		fhAssertions      func(require *require.Assertions, assert *assert.Assertions, fh file.Handler)
		terraformUpgrader cloudApplier
		fetchImageErr     error
		wantErr           bool
		customK8sVersion  string
		flags             applyFlags
		stdin             string
	}{
		"success": {
			kubeUpgrader:      &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
			fhAssertions: func(require *require.Assertions, assert *assert.Assertions, fh file.Handler) {
				gotState, err := state.ReadFromFile(fh, constants.StateFilename)
				require.NoError(err)
				assert.Equal("v1", gotState.Version)
				assert.Equal(defaultStateFile(cloudprovider.Azure), gotState)
			},
		},
		"id file and state file do not exist": {
			kubeUpgrader:      &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh: func() file.Handler {
				return file.NewHandler(afero.NewMemMapFs())
			},
			wantErr: true,
		},
		"nodeVersion some error": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig:  config.DefaultForAzureSEVSNP(),
				nodeVersionErr: assert.AnError,
			},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			wantErr:           true,
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
		},
		"nodeVersion in progress error": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig:  config.DefaultForAzureSEVSNP(),
				nodeVersionErr: kubecmd.ErrInProgress,
			},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
		},
		"helm other error": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      &stubHelmApplier{err: assert.AnError},
			terraformUpgrader: &stubTerraformUpgrader{},
			wantErr:           true,
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
		},
		"abort": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true},
			wantErr:           true,
			stdin:             "no\n",
			fh:                fsWithStateFileAndTfState,
		},
		"abort, restore terraform err": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{terraformDiff: true, rollbackWorkspaceErr: assert.AnError},
			wantErr:           true,
			stdin:             "no\n",
			fh:                fsWithStateFileAndTfState,
		},
		"plan terraform error": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{planTerraformErr: assert.AnError},
			wantErr:           true,
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
		},
		"apply terraform error": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader: stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{
				applyTerraformErr: assert.AnError,
				terraformDiff:     true,
			},
			wantErr: true,
			flags:   applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:      fsWithStateFileAndTfState,
		},
		"outdated K8s patch version": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			customK8sVersion: func() string {
				v, err := semver.New(versions.SupportedK8sVersions()[0])
				require.NoError(t, err)
				return semver.NewFromInt(v.Major(), v.Minor(), v.Patch()-1, "").String()
			}(),
			flags: applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:    fsWithStateFileAndTfState,
		},
		"outdated K8s version": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			customK8sVersion:  "v1.20.0",
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			wantErr:           true,
			fh:                fsWithStateFileAndTfState,
		},
		"skip all upgrade phases": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      &mockApplier{}, // mocks ensure that no methods are called
			terraformUpgrader: &mockTerraformUpgrader{},
			flags: applyFlags{
				skipPhases: newPhases(skipInfrastructurePhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase, skipImagePhase),
				yes:        true,
			},
			fh: fsWithStateFileAndTfState,
		},
		"skip all phases except node upgrade": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      &mockApplier{}, // mocks ensure that no methods are called
			terraformUpgrader: &mockTerraformUpgrader{},
			flags: applyFlags{
				skipPhases: newPhases(skipInfrastructurePhase, skipAttestationConfigPhase, skipCertSANsPhase, skipHelmPhase, skipK8sPhase),
				yes:        true,
			},
			fh: fsWithStateFileAndTfState,
		},
		"no tf state, infra phase skipped": {
			kubeUpgrader: &stubKubernetesUpgrader{
				currentConfig: config.DefaultForAzureSEVSNP(),
			},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &mockTerraformUpgrader{},
			flags: applyFlags{
				yes:        true,
				skipPhases: newPhases(skipInfrastructurePhase),
			},
			fh: func() file.Handler {
				fh := file.NewHandler(afero.NewMemMapFs())
				require.NoError(t, fh.WriteYAML(constants.StateFilename, defaultStateFile(cloudprovider.Azure)))
				return fh
			},
		},
		"attempt to change attestation variant": {
			kubeUpgrader:      &stubKubernetesUpgrader{currentConfig: &config.AzureTrustedLaunch{}},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
			wantErr:           true,
		},
		"image fetching fails": {
			kubeUpgrader:      &stubKubernetesUpgrader{currentConfig: config.DefaultForAzureSEVSNP()},
			helmUpgrader:      &stubHelmApplier{},
			terraformUpgrader: &stubTerraformUpgrader{},
			fetchImageErr:     assert.AnError,
			flags:             applyFlags{yes: true, skipPhases: skipPhases{skipInitPhase: struct{}{}}},
			fh:                fsWithStateFileAndTfState,
			wantErr:           true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)
			cmd := newUpgradeApplyCmd()
			cmd.SetIn(bytes.NewBufferString(tc.stdin))

			cfg := defaultConfigWithExpectedMeasurements(t, config.Default(), cloudprovider.Azure)
			if tc.customK8sVersion != "" {
				cfg.KubernetesVersion = versions.ValidK8sVersion(tc.customK8sVersion)
			}
			fh := tc.fh()
			require.NoError(fh.Write(constants.AdminConfFilename, []byte{}))
			require.NoError(fh.WriteYAML(constants.ConfigFilename, cfg))
			require.NoError(fh.WriteJSON(constants.MasterSecretFilename, uri.MasterSecret{}))

			upgrader := &applyCmd{
				fileHandler: fh,
				flags:       tc.flags,
				log:         logger.NewTest(t),
				spinner:     &nopSpinner{},
				merger:      &stubMerger{},
				newInfraApplier: func(ctx context.Context) (cloudApplier, func(), error) {
					return tc.terraformUpgrader, func() {}, nil
				},
				applier: &stubConstellApplier{
					stubKubernetesUpgrader: tc.kubeUpgrader,
					helmApplier:            tc.helmUpgrader,
				},
				imageFetcher: &stubImageFetcher{fetchReferenceErr: tc.fetchImageErr},
			}
			err := upgrader.apply(cmd, stubAttestationFetcher{}, "test")
			if tc.wantErr {
				assert.Error(err)
				return
			}
			assert.NoError(err)
			assert.Equal(!tc.flags.skipPhases.contains(skipImagePhase), tc.kubeUpgrader.calledNodeUpgrade,
				"incorrect node upgrade skipping behavior")

			if tc.fhAssertions != nil {
				tc.fhAssertions(require, assert, fh)
			}
		})
	}
}

type stubKubernetesUpgrader struct {
	nodeVersionErr                 error
	kubernetesVersionErr           error
	currentConfig                  config.AttestationCfg
	getClusterAttestationConfigErr error
	calledNodeUpgrade              bool
	calledKubernetesUpgrade        bool
	backupCRDsErr                  error
	backupCRDsCalled               bool
	backupCRsErr                   error
	backupCRsCalled                bool
}

func (u *stubKubernetesUpgrader) BackupCRDs(_ context.Context, _ file.Handler, _ string) ([]apiextensionsv1.CustomResourceDefinition, error) {
	u.backupCRDsCalled = true
	return []apiextensionsv1.CustomResourceDefinition{}, u.backupCRDsErr
}

func (u *stubKubernetesUpgrader) BackupCRs(_ context.Context, _ file.Handler, _ []apiextensionsv1.CustomResourceDefinition, _ string) error {
	u.backupCRsCalled = true
	return u.backupCRsErr
}

func (u *stubKubernetesUpgrader) UpgradeNodeImage(_ context.Context, _ semver.Semver, _ string, _ bool) error {
	u.calledNodeUpgrade = true
	return u.nodeVersionErr
}

func (u *stubKubernetesUpgrader) UpgradeKubernetesVersion(_ context.Context, _ versions.ValidK8sVersion, _ bool) error {
	u.calledKubernetesUpgrade = true
	return u.kubernetesVersionErr
}

func (u *stubKubernetesUpgrader) ApplyJoinConfig(_ context.Context, _ config.AttestationCfg, _ []byte) error {
	return nil
}

func (u *stubKubernetesUpgrader) GetClusterAttestationConfig(_ context.Context, _ variant.Variant) (config.AttestationCfg, error) {
	return u.currentConfig, u.getClusterAttestationConfigErr
}

func (u *stubKubernetesUpgrader) ExtendClusterConfigCertSANs(_ context.Context, _, _ string, _ []string) error {
	return nil
}

type stubTerraformUpgrader struct {
	terraformDiff        bool
	planTerraformErr     error
	applyTerraformErr    error
	rollbackWorkspaceErr error
}

func (u stubTerraformUpgrader) Plan(_ context.Context, _ *config.Config) (bool, error) {
	return u.terraformDiff, u.planTerraformErr
}

func (u stubTerraformUpgrader) Apply(_ context.Context, _ cloudprovider.Provider, _ variant.Variant, _ cloudcmd.RollbackBehavior) (state.Infrastructure, error) {
	return state.Infrastructure{}, u.applyTerraformErr
}

func (u stubTerraformUpgrader) RestoreWorkspace() error {
	return u.rollbackWorkspaceErr
}

func (u stubTerraformUpgrader) WorkingDirIsEmpty() (bool, error) {
	return false, nil
}

type mockTerraformUpgrader struct {
	mock.Mock
}

func (m *mockTerraformUpgrader) Plan(ctx context.Context, conf *config.Config) (bool, error) {
	args := m.Called(ctx, conf)
	return args.Bool(0), args.Error(1)
}

func (m *mockTerraformUpgrader) Apply(ctx context.Context, provider cloudprovider.Provider, variant variant.Variant, rollback cloudcmd.RollbackBehavior) (state.Infrastructure, error) {
	args := m.Called(ctx, provider, variant, rollback)
	return args.Get(0).(state.Infrastructure), args.Error(1)
}

func (m *mockTerraformUpgrader) RestoreWorkspace() error {
	args := m.Called()
	return args.Error(0)
}

func (m *mockTerraformUpgrader) WorkingDirIsEmpty() (bool, error) {
	args := m.Called()
	return args.Bool(0), args.Error(1)
}

type mockApplier struct {
	mock.Mock
}

func (m *mockApplier) PrepareHelmCharts(
	helmOpts helm.Options, stateFile *state.State, str string, masterSecret uri.MasterSecret, openStackCfg *config.OpenStackConfig,
) (helm.Applier, bool, error) {
	args := m.Called(helmOpts, stateFile, helmOpts, str, masterSecret, openStackCfg)
	return args.Get(0).(helm.Applier), args.Bool(1), args.Error(2)
}

type stubImageFetcher struct {
	reference         string
	fetchReferenceErr error
}

func (f *stubImageFetcher) FetchReference(_ context.Context,
	_ cloudprovider.Provider, _ variant.Variant,
	_, _ string, _ bool,
) (string, error) {
	return f.reference, f.fetchReferenceErr
}