/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package trustedlaunch

import (
	"bytes"
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"crypto/x509/pkix"
	"io"
	"net/http"
	"testing"
	"time"

	"github.com/edgelesssys/constellation/v2/internal/attestation/simulator"
	"github.com/edgelesssys/constellation/v2/internal/crypto"
	tpmclient "github.com/google/go-tpm-tools/client"
	"github.com/google/go-tpm/tpm2"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestGetAttestationCert(t *testing.T) {
	require := require.New(t)
	tpm, err := simulator.OpenSimulatedTPM()
	require.NoError(err)
	defer tpm.Close()

	// create key in TPM
	tpmAk, err := tpmclient.NewCachedKey(tpm, tpm2.HandleOwner, tpm2.Public{
		Type:       tpm2.AlgRSA,
		NameAlg:    tpm2.AlgSHA256,
		Attributes: tpm2.FlagFixedTPM | tpm2.FlagFixedParent | tpm2.FlagSensitiveDataOrigin | tpm2.FlagUserWithAuth | tpm2.FlagNoDA | tpm2.FlagRestricted | tpm2.FlagSign,
		RSAParameters: &tpm2.RSAParams{
			Sign: &tpm2.SigScheme{
				Alg:  tpm2.AlgRSASSA,
				Hash: tpm2.AlgSHA256,
			},
			KeyBits: 2048,
		},
	}, tpmAkIdx)
	require.NoError(err)
	defer tpmAk.Close()
	akPub, err := tpmAk.PublicArea().Encode()
	require.NoError(err)

	// root certificate
	rootKey, rootTemplate := fillCertTemplate(t, &x509.Certificate{
		Subject:               pkix.Name{CommonName: "root CA"},
		IsCA:                  true,
		BasicConstraintsValid: true,
		KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
	})
	rootCert := newTestCert(t, rootTemplate, rootTemplate, rootKey.Public(), rootKey)

	// intermediate certificate
	intermediateKey, intermediateTemplate := fillCertTemplate(t, &x509.Certificate{
		Subject:               pkix.Name{CommonName: "intermediate CA"},
		Issuer:                rootTemplate.Subject,
		IsCA:                  true,
		BasicConstraintsValid: true,
		KeyUsage:              x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
	})
	intermediateCert := newTestCert(t, intermediateTemplate, rootTemplate, intermediateKey.Public(), rootKey)

	// define NV index once to avoid the need for fancy error handling later
	require.NoError(tpm2.NVDefineSpace(
		tpm, tpm2.HandleOwner, tpmAkCertIdx, "", "", []byte{},
		tpm2.AttrOwnerWrite|tpm2.AttrOwnerRead|tpm2.AttrAuthRead|tpm2.AttrAuthWrite|tpm2.AttrNoDA, 1,
	))

	defaultAkCertFunc := func(*testing.T) *x509.Certificate {
		t.Helper()
		_, certTemplate := fillCertTemplate(t, &x509.Certificate{
			IssuingCertificateURL: []string{
				"192.0.2.1/ca.crt",
			},
			Subject: pkix.Name{CommonName: "AK Certificate"},
			Issuer:  intermediateCert.Subject,
		})
		return newTestCert(t, certTemplate, intermediateCert, tpmAk.PublicKey(), intermediateKey)
	}

	testCases := map[string]struct {
		crlServer       roundTripFunc
		getAkCert       func(*testing.T) *x509.Certificate
		wantIssueErr    bool
		wantValidateErr bool
	}{
		"success": {
			crlServer: func(req *http.Request) *http.Response {
				return &http.Response{
					StatusCode: http.StatusOK,
					Body:       io.NopCloser(bytes.NewReader(intermediateCert.Raw)),
				}
			},
			getAkCert: defaultAkCertFunc,
		},
		"intermediate cert is fetched from multiple URLs": {
			crlServer: func(req *http.Request) *http.Response {
				if req.URL.String() == "192.0.2.1/ca.crt" {
					return &http.Response{StatusCode: http.StatusNotFound}
				}
				return &http.Response{
					StatusCode: http.StatusOK,
					Body:       io.NopCloser(bytes.NewReader(intermediateCert.Raw)),
				}
			},
			getAkCert: func(*testing.T) *x509.Certificate {
				t.Helper()
				_, certTemplate := fillCertTemplate(t, &x509.Certificate{
					IssuingCertificateURL: []string{
						"192.0.2.1/ca.crt",
						"192.0.2.2/ca.crt",
					},
					Subject: pkix.Name{CommonName: "AK Certificate"},
					Issuer:  intermediateCert.Subject,
				})
				return newTestCert(t, certTemplate, intermediateCert, tpmAk.PublicKey(), intermediateKey)
			},
		},
		"intermediate cert cannot be fetched": {
			crlServer: func(req *http.Request) *http.Response {
				return &http.Response{StatusCode: http.StatusNotFound}
			},
			getAkCert:    defaultAkCertFunc,
			wantIssueErr: true,
		},
		"intermediate cert is not signed by root cert": {
			crlServer: func(req *http.Request) *http.Response {
				return &http.Response{
					StatusCode: http.StatusOK,
					Body:       io.NopCloser(bytes.NewReader(rootCert.Raw)),
				}
			},
			getAkCert:       defaultAkCertFunc,
			wantValidateErr: true,
		},
		"ak does not match ak cert public key": {
			crlServer: func(req *http.Request) *http.Response {
				return &http.Response{
					StatusCode: http.StatusOK,
					Body:       io.NopCloser(bytes.NewReader(intermediateCert.Raw)),
				}
			},
			getAkCert: func(*testing.T) *x509.Certificate {
				t.Helper()
				key, certTemplate := fillCertTemplate(t, &x509.Certificate{
					IssuingCertificateURL: []string{
						"192.0.2.1/ca.crt",
					},
					Subject: pkix.Name{CommonName: "AK Certificate"},
					Issuer:  intermediateCert.Subject,
				})
				return newTestCert(t, certTemplate, intermediateCert, key.Public(), intermediateKey)
			},
			wantValidateErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			akCert := tc.getAkCert(t).Raw

			// write akCert to TPM
			require.NoError(tpm2.NVUndefineSpace(tpm, "", tpm2.HandleOwner, tpmAkCertIdx))
			require.NoError(tpm2.NVDefineSpace(
				tpm, tpm2.HandleOwner, tpmAkCertIdx, "", "", []byte{},
				tpm2.AttrOwnerWrite|tpm2.AttrOwnerRead|tpm2.AttrAuthRead|tpm2.AttrAuthWrite|tpm2.AttrNoDA,
				uint16(len(akCert)),
			))
			require.NoError(tpm2.NVWrite(tpm, tpm2.HandleOwner, tpmAkCertIdx, "", akCert, 0))

			issuer := NewIssuer()
			issuer.hClient = newTestClient(tc.crlServer)

			certs, err := issuer.getAttestationCert(tpm)
			if tc.wantIssueErr {
				assert.Error(err)
				return
			}
			require.NoError(err)

			validator := NewValidator(map[uint32][]byte{}, []uint32{}, nil)
			cert, err := x509.ParseCertificate(rootCert.Raw)
			require.NoError(err)
			roots := x509.NewCertPool()
			roots.AddCert(cert)
			validator.roots = roots

			key, err := validator.verifyAttestationKey(akPub, certs)
			if tc.wantValidateErr {
				assert.Error(err)
				return
			}
			assert.NoError(err)
			rsaKey, ok := key.(*rsa.PublicKey)
			require.True(ok)
			assert.True(rsaKey.Equal(tpmAk.PublicKey()))
		})
	}
}

type roundTripFunc func(req *http.Request) *http.Response

func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	return f(req), nil
}

// newTestClient returns *http.Client with Transport replaced to avoid making real calls.
func newTestClient(fn roundTripFunc) *http.Client {
	return &http.Client{
		Transport: fn,
	}
}

func newTestCert(t *testing.T, template *x509.Certificate, parent *x509.Certificate, pub, priv any) *x509.Certificate {
	t.Helper()
	require := require.New(t)

	certDER, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv)
	require.NoError(err)
	cert, err := x509.ParseCertificate(certDER)
	require.NoError(err)
	return cert
}

func fillCertTemplate(t *testing.T, template *x509.Certificate) (*rsa.PrivateKey, *x509.Certificate) {
	t.Helper()
	require := require.New(t)

	key, err := rsa.GenerateKey(rand.Reader, 2048)
	require.NoError(err, "generating root key failed")

	serialNumber, err := crypto.GenerateCertificateSerialNumber()
	require.NoError(err)
	now := time.Now()

	template.SerialNumber = serialNumber
	template.NotBefore = now.Add(-2 * time.Hour)
	template.NotAfter = now.Add(24 * 365 * time.Hour)
	return key, template
}