/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package nitrotpm

import (
	"context"
	"errors"
	"os"
	"testing"

	"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
	"github.com/aws/smithy-go/middleware"
	"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
	tpmclient "github.com/google/go-tpm-tools/client"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestGetAttestationKey(t *testing.T) {
	cgo := os.Getenv("CGO_ENABLED")
	if cgo == "0" {
		t.Skip("skipping test because CGO is disabled")
	}
	require := require.New(t)
	assert := assert.New(t)

	tpm, err := simulator.OpenSimulatedTPM()
	require.NoError(err)
	defer tpm.Close()

	// create the attestation ket in RSA format
	tpmAk, err := tpmclient.AttestationKeyRSA(tpm)
	assert.NoError(err)
	assert.NotNil(tpmAk)

	// get the cached, already created key
	getAk, err := getAttestationKey(tpm)
	assert.NoError(err)
	assert.NotNil(getAk)

	// if everything worked fine, tpmAk and getAk are the same key
	assert.Equal(tpmAk, getAk)
}

func TestGetInstanceInfo(t *testing.T) {
	cgo := os.Getenv("CGO_ENABLED")
	if cgo == "0" {
		t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
	}
	testCases := map[string]struct {
		client  stubMetadataAPI
		wantErr bool
	}{
		"invalid region": {
			client: stubMetadataAPI{
				instanceDoc: imds.InstanceIdentityDocument{
					Region: "invalid-region",
				},
				instanceErr: errors.New("failed"),
			},
			wantErr: true,
		},
		"valid region": {
			client: stubMetadataAPI{
				instanceDoc: imds.InstanceIdentityDocument{
					Region: "us-east-2",
				},
			},
		},
		"invalid imageID": {
			client: stubMetadataAPI{
				instanceDoc: imds.InstanceIdentityDocument{
					ImageID: "ami-fail",
				},
				instanceErr: errors.New("failed"),
			},
			wantErr: true,
		},
		"valid imageID": {
			client: stubMetadataAPI{
				instanceDoc: imds.InstanceIdentityDocument{
					ImageID: "ami-09e7c7f5617a47830",
				},
			},
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			tpm, err := simulator.OpenSimulatedTPM()
			assert.NoError(err)
			defer tpm.Close()

			instanceInfoFunc := getInstanceInfo(&tc.client)
			assert.NotNil(instanceInfoFunc)

			info, err := instanceInfoFunc(context.Background(), tpm, nil)
			if tc.wantErr {
				assert.Error(err)
				assert.Nil(info)
			} else {
				assert.Nil(err)
				assert.NotNil(info)
			}
		})
	}
}

type stubMetadataAPI struct {
	instanceDoc imds.InstanceIdentityDocument
	instanceErr error
}

func (c *stubMetadataAPI) GetInstanceIdentityDocument(context.Context, *imds.GetInstanceIdentityDocumentInput, ...func(*imds.Options)) (*imds.GetInstanceIdentityDocumentOutput, error) {
	output := &imds.InstanceIdentityDocument{}

	return &imds.GetInstanceIdentityDocumentOutput{
		InstanceIdentityDocument: *output,
		ResultMetadata:           middleware.Metadata{},
	}, c.instanceErr
}