/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package proto

import (
	"context"
	"errors"
	"io"

	"github.com/edgelesssys/constellation/disk-mapper/recoverproto"
	"github.com/edgelesssys/constellation/internal/atls"
	"github.com/edgelesssys/constellation/internal/attestation"
	"github.com/edgelesssys/constellation/internal/crypto"
	"github.com/edgelesssys/constellation/internal/grpc/atlscredentials"
	"go.uber.org/multierr"
	"google.golang.org/grpc"
)

// RecoverClient wraps a recoverAPI client and the connection to it.
type RecoverClient struct {
	conn       *grpc.ClientConn
	recoverapi recoverproto.APIClient
}

// Connect connects the client to a given server, using the handed
// Validators for the attestation of the connection.
// The connection must be closed using Close(). If connect is
// called on a client that already has a connection, the old
// connection is closed.
func (c *RecoverClient) Connect(endpoint string, validators atls.Validator) error {
	creds := atlscredentials.New(nil, []atls.Validator{validators})

	conn, err := grpc.Dial(endpoint, grpc.WithTransportCredentials(creds))
	if err != nil {
		return err
	}
	_ = c.Close()
	c.conn = conn
	c.recoverapi = recoverproto.NewAPIClient(conn)
	return nil
}

// Close closes the grpc connection of the client.
// Close is idempotent and can be called on non connected clients
// without returning an error.
func (c *RecoverClient) Close() error {
	if c.conn == nil {
		return nil
	}
	if err := c.conn.Close(); err != nil {
		return err
	}
	c.conn = nil
	return nil
}

// PushStateDiskKey pushes the state disk key to a constellation instance in recovery mode.
func (c *RecoverClient) Recover(ctx context.Context, masterSecret, salt []byte) (retErr error) {
	if c.recoverapi == nil {
		return errors.New("client is not connected")
	}

	measurementSecret, err := attestation.DeriveMeasurementSecret(masterSecret, salt)
	if err != nil {
		return err
	}

	recoverclient, err := c.recoverapi.Recover(ctx)
	if err != nil {
		return err
	}
	defer func() {
		if err := recoverclient.CloseSend(); err != nil {
			multierr.AppendInto(&retErr, err)
		}
	}()

	if err := recoverclient.Send(&recoverproto.RecoverMessage{
		Request: &recoverproto.RecoverMessage_MeasurementSecret{
			MeasurementSecret: measurementSecret,
		},
	}); err != nil {
		return err
	}
	res, err := recoverclient.Recv()
	if err != nil {
		return err
	}

	stateDiskKey, err := deriveStateDiskKey(masterSecret, salt, res.DiskUuid)
	if err != nil {
		return err
	}

	if err := recoverclient.Send(&recoverproto.RecoverMessage{
		Request: &recoverproto.RecoverMessage_StateDiskKey{
			StateDiskKey: stateDiskKey,
		},
	}); err != nil {
		return err
	}

	if _, err := recoverclient.Recv(); err != nil && !errors.Is(err, io.EOF) {
		return err
	}
	return nil
}

// deriveStateDiskKey derives a state disk key from a master key, a salt, and a disk UUID.
func deriveStateDiskKey(masterKey, salt []byte, diskUUID string) ([]byte, error) {
	return crypto.DeriveKey(masterKey, salt, []byte(crypto.HKDFInfoPrefix+diskUUID), crypto.StateDiskKeyLength)
}