/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package snp

import (
	"bytes"
	"context"
	"crypto/sha256"
	"encoding/hex"
	"encoding/json"
	"errors"
	"fmt"
	"os"
	"regexp"
	"testing"

	"github.com/edgelesssys/constellation/v2/internal/attestation"
	"github.com/edgelesssys/constellation/v2/internal/attestation/idkeydigest"
	"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
	"github.com/edgelesssys/constellation/v2/internal/attestation/snp/testdata"
	"github.com/edgelesssys/constellation/v2/internal/attestation/vtpm"
	"github.com/edgelesssys/constellation/v2/internal/config"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/google/go-sev-guest/abi"
	"github.com/google/go-sev-guest/kds"
	spb "github.com/google/go-sev-guest/proto/sevsnp"
	"github.com/google/go-sev-guest/validate"
	"github.com/google/go-sev-guest/verify"
	"github.com/google/go-sev-guest/verify/trust"
	"github.com/google/go-tpm-tools/client"
	"github.com/google/go-tpm-tools/proto/attest"
	"github.com/google/go-tpm/legacy/tpm2"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

// TestNewValidator tests the creation of a new validator.
func TestNewValidator(t *testing.T) {
	require := require.New(t)

	testCases := map[string]struct {
		cfg    *config.AzureSEVSNP
		logger attestation.Logger
	}{
		"success": {
			cfg:    config.DefaultForAzureSEVSNP(),
			logger: logger.NewTest(t),
		},
		"nil logger": {
			cfg:    config.DefaultForAzureSEVSNP(),
			logger: nil,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(_ *testing.T) {
			validator := NewValidator(tc.cfg, tc.logger)
			require.NotNil(validator)
			require.NotNil(validator.log)
		})
	}
}

type stubHTTPSGetter struct {
	urlResponseMatcher *urlResponseMatcher // maps responses to requested URLs
	err                error
}

func newStubHTTPSGetter(urlResponseMatcher *urlResponseMatcher, err error) *stubHTTPSGetter {
	return &stubHTTPSGetter{
		urlResponseMatcher: urlResponseMatcher,
		err:                err,
	}
}

func (s *stubHTTPSGetter) Get(url string) ([]byte, error) {
	if s.err != nil {
		return nil, s.err
	}
	return s.urlResponseMatcher.match(url)
}

type urlResponseMatcher struct {
	certChainResponse    []byte
	wantCertChainRequest bool
	vcekResponse         []byte
	wantVcekRequest      bool
}

func (m *urlResponseMatcher) match(url string) ([]byte, error) {
	switch {
	case url == "https://kdsintf.amd.com/vcek/v1/Milan/cert_chain":
		if !m.wantCertChainRequest {
			return nil, fmt.Errorf("unexpected cert_chain request")
		}
		return m.certChainResponse, nil
	case regexp.MustCompile(`https:\/\/kdsintf.amd.com\/vcek\/v1\/Milan\/.*`).MatchString(url):
		if !m.wantVcekRequest {
			return nil, fmt.Errorf("unexpected VCEK request")
		}
		return m.vcekResponse, nil
	default:
		return nil, fmt.Errorf("unexpected URL: %s", url)
	}
}

// TestCheckIDKeyDigest tests validation of an IDKeyDigest under different enforcement policies.
func TestCheckIDKeyDigest(t *testing.T) {
	cfgWithAcceptedIDKeyDigests := func(enforcementPolicy idkeydigest.Enforcement, digestStrings []string) *config.AzureSEVSNP {
		digests := idkeydigest.List{}
		for _, digest := range digestStrings {
			digests = append(digests, []byte(digest))
		}
		cfg := config.DefaultForAzureSEVSNP()
		cfg.FirmwareSignerConfig.AcceptedKeyDigests = digests
		cfg.FirmwareSignerConfig.EnforcementPolicy = enforcementPolicy
		return cfg
	}
	reportWithIDKeyDigest := func(idKeyDigest string) *spb.Attestation {
		report := &spb.Attestation{}
		report.Report = &spb.Report{}
		report.Report.IdKeyDigest = []byte(idKeyDigest)
		return report
	}
	newTestValidator := func(cfg *config.AzureSEVSNP, validateTokenErr error) *Validator {
		validator := NewValidator(cfg, logger.NewTest(t))
		validator.maa = &stubMaaValidator{
			validateTokenErr: validateTokenErr,
		}
		return validator
	}

	testCases := map[string]struct {
		idKeyDigest          string
		acceptedIDKeyDigests []string
		enforcementPolicy    idkeydigest.Enforcement
		validateMaaTokenErr  error
		wantErr              bool
	}{
		"matching digest": {
			idKeyDigest:          "test",
			acceptedIDKeyDigests: []string{"test"},
		},
		"no accepted digests": {
			idKeyDigest:          "test",
			acceptedIDKeyDigests: []string{},
			wantErr:              true,
		},
		"mismatching digest, enforce": {
			idKeyDigest:          "test",
			acceptedIDKeyDigests: []string{"other"},
			wantErr:              true,
		},
		"mismatching digest, maaFallback": {
			idKeyDigest:          "test",
			acceptedIDKeyDigests: []string{"other"},
			enforcementPolicy:    idkeydigest.MAAFallback,
		},
		"mismatching digest, maaFallback errors": {
			idKeyDigest:          "test",
			acceptedIDKeyDigests: []string{"other"},
			enforcementPolicy:    idkeydigest.MAAFallback,
			validateMaaTokenErr:  errors.New("maa fallback failed"),
			wantErr:              true,
		},
		"mismatching digest, warnOnly": {
			idKeyDigest:          "test",
			acceptedIDKeyDigests: []string{"other"},
			enforcementPolicy:    idkeydigest.WarnOnly,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			require := require.New(t)

			cfg := cfgWithAcceptedIDKeyDigests(tc.enforcementPolicy, tc.acceptedIDKeyDigests)
			report := reportWithIDKeyDigest(tc.idKeyDigest)
			validator := newTestValidator(cfg, tc.validateMaaTokenErr)

			err := validator.checkIDKeyDigest(context.Background(), report, "", nil)
			if tc.wantErr {
				require.Error(err)
			} else {
				require.NoError(err)
			}
		})
	}
}

type stubMaaValidator struct {
	validateTokenErr error
}

func (v *stubMaaValidator) validateToken(_ context.Context, _ string, _ string, _ []byte) error {
	return v.validateTokenErr
}

// TestGetTrustedKey tests the verification and validation of attestation report.
func TestTrustedKeyFromSNP(t *testing.T) {
	cgo := os.Getenv("CGO_ENABLED")
	if cgo == "0" {
		t.Skip("skipping test because CGO is disabled and tpm simulator requires it")
	}

	tpm, err := simulator.OpenSimulatedTPM()
	require.NoError(t, err)
	defer tpm.Close()
	key, err := client.AttestationKeyRSA(tpm)
	require.NoError(t, err)
	defer key.Close()
	akPub, err := key.PublicArea().Encode()
	require.NoError(t, err)

	defaultCfg := config.DefaultForAzureSEVSNP()
	defaultReport := hex.EncodeToString(testdata.AttestationReport)
	defaultRuntimeData := hex.EncodeToString(testdata.RuntimeData)
	defaultIDKeyDigestOld, err := hex.DecodeString("57e229e0ffe5fa92d0faddff6cae0e61c926fc9ef9afd20a8b8cfcf7129db9338cbe5bf3f6987733a2bf65d06dc38fc1")
	require.NoError(t, err)
	defaultIDKeyDigest := idkeydigest.NewList([][]byte{defaultIDKeyDigestOld})
	defaultVerifier := &stubAttestationVerifier{}
	skipVerifier := &stubAttestationVerifier{skipCheck: true}
	defaultValidator := &stubAttestationValidator{}

	// reportTransformer unpacks the hex-encoded report, applies the given transformations and re-encodes it.
	reportTransformer := func(reportHex string, transformations func(*spb.Report)) string {
		rawReport, err := hex.DecodeString(reportHex)
		require.NoError(t, err)
		report, err := abi.ReportToProto(rawReport)
		require.NoError(t, err)
		transformations(report)
		reportBytes, err := abi.ReportToAbiBytes(report)
		require.NoError(t, err)
		return hex.EncodeToString(reportBytes)
	}

	testCases := map[string]struct {
		report               string
		runtimeData          string
		vcek                 []byte
		certChain            []byte
		acceptedIDKeyDigests idkeydigest.List
		enforcementPolicy    idkeydigest.Enforcement
		getter               *stubHTTPSGetter
		verifier             *stubAttestationVerifier
		validator            *stubAttestationValidator
		wantErr              bool
		assertion            func(*assert.Assertions, error)
	}{
		"success": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
		},
		"certificate fetch error": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			getter: newStubHTTPSGetter(
				nil,
				assert.AnError,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "retrieving VCEK certificate from AMD KDS")
			},
		},
		"fetch vcek and certchain": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{
					vcekResponse:         testdata.AmdKdsVCEK,
					wantVcekRequest:      true,
					certChainResponse:    testdata.CertChain,
					wantCertChainRequest: true,
				},
				nil,
			),
		},
		"fetch vcek": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{
					vcekResponse:    testdata.AmdKdsVCEK,
					wantVcekRequest: true,
				},
				nil,
			),
		},
		"fetch certchain": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{
					certChainResponse:    testdata.CertChain,
					wantCertChainRequest: true,
				},
				nil,
			),
		},
		"invalid report signature": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				r.Signature = make([]byte, 512)
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "report signature verification error")
			},
		},
		"invalid vcek": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 []byte("invalid"),
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{
					vcekResponse:    []byte("invalid"),
					wantVcekRequest: true,
				},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "could not interpret VCEK DER bytes: x509: malformed certificate")
			},
		},
		"invalid certchain fall back to embedded": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            []byte("invalid"),
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{
					certChainResponse:    []byte("invalid"),
					wantCertChainRequest: true,
				},
				nil,
			),
			wantErr: true,
		},
		"invalid runtime data": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData[:len(defaultRuntimeData)-10],
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "validating HCLAkPub: unmarshalling json: unexpected end of JSON input")
			},
		},
		"inacceptable idkeydigest (wrong size), enforce": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: idkeydigest.List{[]byte{0x00}},
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "bad hash size in TrustedIDKeyHashes")
			},
		},
		"inacceptable idkeydigest (wrong value), enforce": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: idkeydigest.List{make([]byte, 48)},
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "report ID key not trusted")
			},
		},
		"inacceptable idkeydigest, warn only": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: idkeydigest.List{[]byte{0x00}},
			enforcementPolicy:    idkeydigest.WarnOnly,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
		},
		"launch tcb < minimum launch tcb": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				launchTcb := kds.DecomposeTCBVersion(kds.TCBVersion(r.LaunchTcb))
				defaultCfg.MicrocodeVersion.Value = 10
				launchTcb.UcodeSpl = 9
				newLaunchTcb, err := kds.ComposeTCBParts(launchTcb)
				require.NoError(t, err)
				r.LaunchTcb = uint64(newLaunchTcb)
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.WarnOnly,
			verifier:             skipVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "is lower than the policy minimum launch TCB")
			},
		},
		"reported tcb < minimum tcb": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				reportedTcb := kds.DecomposeTCBVersion(kds.TCBVersion(r.ReportedTcb))
				reportedTcb.UcodeSpl = defaultCfg.MicrocodeVersion.Value - 1
				newReportedTcb, err := kds.ComposeTCBParts(reportedTcb)
				require.NoError(t, err)
				r.ReportedTcb = uint64(newReportedTcb)
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.WarnOnly,
			verifier:             skipVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "is lower than the policy minimum TCB")
			},
		},
		"current tcb < committed tcb": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				r.CurrentTcb = r.CommittedTcb - 1
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.WarnOnly,
			verifier:             skipVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "is lower than the report's COMMITTED_TCB")
			},
		},
		"current tcb < tcb in vcek": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				currentTcb := kds.DecomposeTCBVersion(kds.TCBVersion(r.CurrentTcb))
				currentTcb.UcodeSpl = 0x5c // testdata.AzureThimVCEK has ucode version 0x5d
				newCurrentTcb, err := kds.ComposeTCBParts(currentTcb)
				require.NoError(t, err)
				r.CurrentTcb = uint64(newCurrentTcb)
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.WarnOnly,
			verifier:             skipVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "is lower than the TCB of the V[CL]EK certificate")
			},
		},
		"reported tcb != tcb in vcek": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				r.ReportedTcb = uint64(0)
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.WarnOnly,
			verifier:             skipVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "does not match the TCB of the V[CL]EK certificate")
			},
		},
		"vmpl != 0": {
			report: reportTransformer(defaultReport, func(r *spb.Report) {
				r.Vmpl = 1
			}),
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             skipVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain:            testdata.CertChain,
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "report VMPL 1 is not 0")
			},
		},
		"invalid ASK": {
			report:               defaultReport,
			runtimeData:          defaultRuntimeData,
			acceptedIDKeyDigests: defaultIDKeyDigest,
			enforcementPolicy:    idkeydigest.Equal,
			verifier:             defaultVerifier,
			validator:            defaultValidator,
			vcek:                 testdata.AzureThimVCEK,
			certChain: func() []byte {
				c := make([]byte, len(testdata.CertChain))
				copy(c, testdata.CertChain)
				c[1676] = 0x41 // somewhere in the ASK signature
				return c
			}(),
			getter: newStubHTTPSGetter(
				&urlResponseMatcher{},
				nil,
			),
			wantErr: true,
			assertion: func(assert *assert.Assertions, err error) {
				assert.ErrorContains(err, "crypto/rsa: verification error")
			},
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			// This is important. Without this call, the trust module caches certificates across testcases.
			defer trust.ClearProductCertCache()

			instanceInfo, err := newStubInstanceInfo(tc.vcek, tc.certChain, tc.report, tc.runtimeData)
			require.NoError(err)

			statement, err := json.Marshal(instanceInfo)
			require.NoError(err)

			attDoc := vtpm.AttestationDocument{
				InstanceInfo: statement,
				Attestation: &attest.Attestation{
					AkPub: akPub,
				},
			}

			defaultCfg.FirmwareSignerConfig = config.SNPFirmwareSignerConfig{
				AcceptedKeyDigests: tc.acceptedIDKeyDigests,
				EnforcementPolicy:  tc.enforcementPolicy,
			}

			validator := &Validator{
				hclValidator:         &stubAttestationKey{},
				config:               defaultCfg,
				log:                  logger.NewTest(t),
				getter:               tc.getter,
				attestationVerifier:  tc.verifier,
				attestationValidator: tc.validator,
			}

			key, err := validator.getTrustedKey(context.Background(), attDoc, nil)
			if tc.wantErr {
				assert.Error(err)
				if tc.assertion != nil {
					tc.assertion(assert, err)
				}
			} else {
				assert.NoError(err)
				assert.NotNil(key)
			}
		})
	}
}

type stubAttestationVerifier struct {
	skipCheck bool // whether the verification function should be called
}

// SNPAttestation verifies the VCEK certificate as well as the certificate chain of the attestation report.
func (v *stubAttestationVerifier) SNPAttestation(attestation *spb.Attestation, options *verify.Options) error {
	if v.skipCheck {
		return nil
	}
	return verify.SnpAttestation(attestation, options)
}

type stubAttestationValidator struct {
	skipCheck bool // whether the verification function should be called
}

// SNPAttestation validates the attestation report against the given set of constraints.
func (v *stubAttestationValidator) SNPAttestation(attestation *spb.Attestation, options *validate.Options) error {
	if v.skipCheck {
		return nil
	}
	return validate.SnpAttestation(attestation, options)
}

type stubInstanceInfo struct {
	AttestationReport []byte
	ReportSigner      []byte
	CertChain         []byte
	Azure             *stubAzureInstanceInfo
}

type stubAzureInstanceInfo struct {
	RuntimeData []byte
}

func newStubInstanceInfo(vcek, certChain []byte, report, runtimeData string) (stubInstanceInfo, error) {
	validReport, err := hex.DecodeString(report)
	if err != nil {
		return stubInstanceInfo{}, fmt.Errorf("invalid hex string report: %s", err)
	}

	decodedRuntime, err := hex.DecodeString(runtimeData)
	if err != nil {
		return stubInstanceInfo{}, fmt.Errorf("invalid hex string runtimeData: %s", err)
	}

	return stubInstanceInfo{
		AttestationReport: validReport,
		ReportSigner:      vcek,
		CertChain:         certChain,
		Azure: &stubAzureInstanceInfo{
			RuntimeData: decodedRuntime,
		},
	}, nil
}

type stubAttestationKey struct{}

func (s *stubAttestationKey) Validate(runtimeDataRaw []byte, reportData []byte, _ *tpm2.RSAParams) error {
	if err := json.Unmarshal(runtimeDataRaw, s); err != nil {
		return fmt.Errorf("unmarshalling json: %w", err)
	}

	sum := sha256.Sum256(runtimeDataRaw)
	if !bytes.Equal(sum[:], reportData[:32]) {
		return fmt.Errorf("unexpected runtimeData digest in TPM")
	}

	return nil
}