/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

/*
Package rejoinclient handles the automatic rejoining of a restarting node.

It does so by continuously sending rejoin requests to the JoinService of available control-plane endpoints.
*/
package rejoinclient

import (
	"context"
	"errors"
	"fmt"
	"log/slog"
	"net"
	"strconv"
	"time"

	"github.com/edgelesssys/constellation/v2/internal/cloud/metadata"
	"github.com/edgelesssys/constellation/v2/internal/constants"
	"github.com/edgelesssys/constellation/v2/internal/role"
	"github.com/edgelesssys/constellation/v2/joinservice/joinproto"
	"google.golang.org/grpc"
	"k8s.io/utils/clock"
)

const (
	interval = 30 * time.Second
	timeout  = 30 * time.Second
)

// RejoinClient is a client for requesting the needed information
// for rejoining a cluster as a restarting worker or control-plane node.
type RejoinClient struct {
	diskUUID string
	nodeInfo metadata.InstanceMetadata

	timeout  time.Duration
	interval time.Duration
	clock    clock.WithTicker

	dialer      grpcDialer
	metadataAPI metadataAPI

	log *slog.Logger
}

// New returns a new RejoinClient.
func New(dial grpcDialer, nodeInfo metadata.InstanceMetadata,
	meta metadataAPI, log *slog.Logger,
) *RejoinClient {
	return &RejoinClient{
		nodeInfo:    nodeInfo,
		timeout:     timeout,
		interval:    interval,
		clock:       clock.RealClock{},
		dialer:      dial,
		metadataAPI: meta,
		log:         log,
	}
}

// Start starts the rejoin client.
// The client will continuously request available control-plane endpoints
// from the metadata API and send rejoin requests to them.
// The function returns after a successful rejoin request has been performed.
func (c *RejoinClient) Start(ctx context.Context, diskUUID string) (diskKey, measurementSecret []byte) {
	c.log.Info("Starting RejoinClient")
	c.diskUUID = diskUUID
	ticker := c.clock.NewTicker(c.interval)

	defer ticker.Stop()
	defer c.log.Info("RejoinClient stopped")

	for {
		endpoints, err := c.getJoinEndpoints()
		if err != nil {
			c.log.With(slog.Any("error", err)).Error("Failed to get control-plane endpoints")
		} else {
			c.log.With(slog.Any("endpoints", endpoints)).Info("Received list with JoinService endpoints")
			diskKey, measurementSecret, err = c.tryRejoinWithAvailableServices(ctx, endpoints)
			if err == nil {
				c.log.Info("Successfully retrieved rejoin ticket")
				return diskKey, measurementSecret
			}
		}

		select {
		case <-ctx.Done():
			return nil, nil
		case <-ticker.C():
		}
	}
}

// tryRejoinWithAvailableServices tries sending rejoin requests to the available endpoints.
func (c *RejoinClient) tryRejoinWithAvailableServices(ctx context.Context, endpoints []string) (diskKey, measurementSecret []byte, err error) {
	for _, endpoint := range endpoints {
		c.log.With(slog.String("endpoint", endpoint)).Info("Requesting rejoin ticket")
		rejoinTicket, err := c.requestRejoinTicket(endpoint)
		if err == nil {
			return rejoinTicket.StateDiskKey, rejoinTicket.MeasurementSecret, nil
		}
		c.log.With(slog.Any("error", err), slog.String("endpoint", endpoint)).Warn("Failed to rejoin on endpoint")

		// stop requesting additional endpoints if the context is done
		select {
		case <-ctx.Done():
			return nil, nil, ctx.Err()
		default:
		}
	}
	c.log.Error("Failed to rejoin on all endpoints")
	return nil, nil, errors.New("failed to join on all endpoints")
}

// requestRejoinTicket requests a rejoin ticket from the endpoint.
func (c *RejoinClient) requestRejoinTicket(endpoint string) (*joinproto.IssueRejoinTicketResponse, error) {
	ctx, cancel := c.timeoutCtx()
	defer cancel()

	conn, err := c.dialer.Dial(endpoint)
	if err != nil {
		return nil, err
	}
	defer conn.Close()

	return joinproto.NewAPIClient(conn).IssueRejoinTicket(ctx, &joinproto.IssueRejoinTicketRequest{DiskUuid: c.diskUUID})
}

// getJoinEndpoints requests the available control-plane endpoints from the metadata API.
// The list is filtered to remove *this* node if it is a restarting control-plane node.
// Furthermore, the load balancer's endpoint is added.
func (c *RejoinClient) getJoinEndpoints() ([]string, error) {
	ctx, cancel := c.timeoutCtx()
	defer cancel()

	joinEndpoints := []string{}

	lbEndpoint, _, err := c.metadataAPI.GetLoadBalancerEndpoint(ctx)
	if err != nil {
		return nil, fmt.Errorf("retrieving load balancer endpoint from cloud provider: %w", err)
	}
	joinEndpoints = append(joinEndpoints, net.JoinHostPort(lbEndpoint, strconv.Itoa(constants.JoinServiceNodePort)))

	instances, err := c.metadataAPI.List(ctx)
	if err != nil {
		return nil, fmt.Errorf("retrieving instances list from cloud provider: %w", err)
	}

	for _, instance := range instances {
		if instance.Role == role.ControlPlane {
			if instance.VPCIP != "" {
				joinEndpoints = append(joinEndpoints, net.JoinHostPort(instance.VPCIP, strconv.Itoa(constants.JoinServiceNodePort)))
			}
		}
	}

	if c.nodeInfo.Role == role.ControlPlane {
		return removeSelfFromEndpoints(c.nodeInfo.VPCIP, joinEndpoints), nil
	}

	return joinEndpoints, nil
}

// removeSelfFromEndpoints removes *this* node from the list of endpoints.
// If an error occurs, the entry is removed from the list of endpoints.
func removeSelfFromEndpoints(self string, endpoints []string) []string {
	var result []string
	for _, endpoint := range endpoints {
		host, _, err := net.SplitHostPort(endpoint)
		if err == nil && host != self {
			result = append(result, endpoint)
		}
	}
	return result
}

func (c *RejoinClient) timeoutCtx() (context.Context, context.CancelFunc) {
	return context.WithTimeout(context.Background(), c.timeout)
}

type grpcDialer interface {
	Dial(target string) (*grpc.ClientConn, error)
}

type metadataAPI interface {
	// List retrieves all instances belonging to the current constellation.
	List(ctx context.Context) ([]metadata.InstanceMetadata, error)
	// GetLoadBalancerEndpoint retrieves the load balancer endpoint.
	GetLoadBalancerEndpoint(ctx context.Context) (host, port string, err error)
}