/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package server

import (
	"context"
	"encoding/base64"
	"encoding/json"
	"errors"
	"io"
	"net"
	"net/http"
	"net/http/httptest"
	"sync"
	"testing"

	"github.com/edgelesssys/constellation/v2/internal/grpc/testdialer"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/edgelesssys/constellation/v2/verify/verifyproto"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m, goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"))
}

func TestRun(t *testing.T) {
	assert := assert.New(t)
	closedErr := errors.New("closed")

	var err error
	var wg sync.WaitGroup
	s := &Server{
		log:    logger.NewTest(t),
		issuer: stubIssuer{attestation: []byte("quote")},
	}

	httpListener, grpcListener := setUpTestListeners()
	wg.Add(1)
	go func() {
		defer wg.Done()
		err = s.Run(httpListener, grpcListener)
	}()
	assert.NoError(httpListener.Close())
	wg.Wait()
	assert.Equal(err, closedErr)

	httpListener, grpcListener = setUpTestListeners()
	wg.Add(1)
	go func() {
		defer wg.Done()
		err = s.Run(httpListener, grpcListener)
	}()
	assert.NoError(grpcListener.Close())
	wg.Wait()
	assert.Equal(err, closedErr)

	httpListener, grpcListener = setUpTestListeners()
	wg.Add(1)
	go func() {
		defer wg.Done()
		err = s.Run(httpListener, grpcListener)
	}()
	go assert.NoError(grpcListener.Close())
	go assert.NoError(httpListener.Close())
	wg.Wait()
	assert.Equal(err, closedErr)
}

func TestGetAttestationGRPC(t *testing.T) {
	testCases := map[string]struct {
		issuer  stubIssuer
		request *verifyproto.GetAttestationRequest
		wantErr bool
	}{
		"success": {
			issuer: stubIssuer{attestation: []byte("quote")},
			request: &verifyproto.GetAttestationRequest{
				Nonce: []byte("nonce"),
			},
		},
		"issuer fails": {
			issuer: stubIssuer{issueErr: errors.New("issuer error")},
			request: &verifyproto.GetAttestationRequest{
				Nonce: []byte("nonce"),
			},
			wantErr: true,
		},
		"no nonce": {
			issuer:  stubIssuer{attestation: []byte("quote")},
			request: &verifyproto.GetAttestationRequest{},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)

			server := &Server{
				log:    logger.NewTest(t),
				issuer: tc.issuer,
			}

			resp, err := server.GetAttestation(context.Background(), tc.request)
			if tc.wantErr {
				assert.Error(err)
			} else {
				assert.NoError(err)
				assert.Equal(tc.issuer.attestation, resp.Attestation)
			}
		})
	}
}

func TestGetAttestationHTTP(t *testing.T) {
	testCases := map[string]struct {
		request string
		issuer  stubIssuer
		wantErr bool
	}{
		"success": {
			request: "?nonce=" + base64.URLEncoding.EncodeToString([]byte("nonce")),
			issuer:  stubIssuer{attestation: []byte("quote")},
		},
		"invalid nonce in query": {
			request: "?nonce=not-base-64",
			issuer:  stubIssuer{attestation: []byte("quote")},
			wantErr: true,
		},
		"no nonce in query": {
			request: "?foo=bar",
			issuer:  stubIssuer{attestation: []byte("quote")},
			wantErr: true,
		},
		"empty nonce in query": {
			request: "?nonce=",
			issuer:  stubIssuer{attestation: []byte("quote")},
			wantErr: true,
		},
		"issuer fails": {
			request: "?nonce=" + base64.URLEncoding.EncodeToString([]byte("nonce")),
			issuer:  stubIssuer{issueErr: errors.New("errors")},
			wantErr: true,
		},
	}

	for name, tc := range testCases {
		t.Run(name, func(t *testing.T) {
			assert := assert.New(t)
			require := require.New(t)

			server := &Server{
				log:    logger.NewTest(t),
				issuer: tc.issuer,
			}

			httpServer := httptest.NewServer(http.HandlerFunc(server.getAttestationHTTP))
			defer httpServer.Close()

			req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, httpServer.URL+tc.request, nil)
			require.NoError(err)
			resp, err := http.DefaultClient.Do(req)
			require.NoError(err)
			defer resp.Body.Close()

			if tc.wantErr {
				assert.NotEqual(http.StatusOK, resp.StatusCode)
				return
			}
			assert.Equal(http.StatusOK, resp.StatusCode)
			quote, err := io.ReadAll(resp.Body)
			require.NoError(err)

			var rawQuote attestation
			require.NoError(json.Unmarshal(quote, &rawQuote))

			assert.Equal(tc.issuer.attestation, rawQuote.Data)
		})
	}
}

func setUpTestListeners() (net.Listener, net.Listener) {
	httpListener := testdialer.NewBufconnDialer().GetListener(net.JoinHostPort("192.0.2.1", "8080"))
	grpcListener := testdialer.NewBufconnDialer().GetListener(net.JoinHostPort("192.0.2.1", "8081"))
	return httpListener, grpcListener
}

type stubIssuer struct {
	attestation []byte
	issueErr    error
}

func (i stubIssuer) Issue(_ context.Context, _ []byte, _ []byte) ([]byte, error) {
	return i.attestation, i.issueErr
}