/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package server

import (
	"context"
	"errors"
	"testing"

	"github.com/edgelesssys/constellation/v2/internal/kms/kms"
	"github.com/edgelesssys/constellation/v2/internal/logger"
	"github.com/edgelesssys/constellation/v2/keyservice/keyserviceproto"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"go.uber.org/goleak"
)

func TestMain(m *testing.M) {
	goleak.VerifyTestMain(m, goleak.IgnoreAnyFunction("github.com/bazelbuild/rules_go/go/tools/bzltestutil.RegisterTimeoutHandler.func1"))
}

func TestGetDataKey(t *testing.T) {
	assert := assert.New(t)
	require := require.New(t)

	log := logger.NewTest(t)

	kms := &stubKMS{derivedKey: []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5}}
	api := New(log, kms)

	res, err := api.GetDataKey(context.Background(), &keyserviceproto.GetDataKeyRequest{DataKeyId: "1", Length: 32})
	require.NoError(err)
	assert.Equal(kms.derivedKey, res.DataKey)

	// Test no data key id
	res, err = api.GetDataKey(context.Background(), &keyserviceproto.GetDataKeyRequest{Length: 32})
	require.Error(err)
	assert.Nil(res)

	// Test no / zero key length
	res, err = api.GetDataKey(context.Background(), &keyserviceproto.GetDataKeyRequest{DataKeyId: "1"})
	require.Error(err)
	assert.Nil(res)

	// Test derive key error
	api = New(log, &stubKMS{deriveKeyErr: errors.New("error")})
	res, err = api.GetDataKey(context.Background(), &keyserviceproto.GetDataKeyRequest{DataKeyId: "1", Length: 32})
	assert.Error(err)
	assert.Nil(res)
}

type stubKMS struct {
	kms.CloudKMS
	masterKey    []byte
	derivedKey   []byte
	deriveKeyErr error
}

func (c *stubKMS) CreateKEK(_ context.Context, _ string, kek []byte) error {
	c.masterKey = kek
	return nil
}

func (c *stubKMS) GetDEK(_ context.Context, _ string, _ int) ([]byte, error) {
	if c.deriveKeyErr != nil {
		return nil, c.deriveKeyErr
	}
	return c.derivedKey, nil
}