/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: AGPL-3.0-only
*/

package watcher

import (
	"encoding/asn1"
	"encoding/hex"
	"fmt"
	"path/filepath"
	"strconv"
	"sync"

	"github.com/edgelesssys/constellation/v2/internal/atls"
	"github.com/edgelesssys/constellation/v2/internal/attestation/aws"
	"github.com/edgelesssys/constellation/v2/internal/attestation/azure/snp"
	"github.com/edgelesssys/constellation/v2/internal/attestation/azure/trustedlaunch"
	"github.com/edgelesssys/constellation/v2/internal/attestation/gcp"
	"github.com/edgelesssys/constellation/v2/internal/attestation/qemu"
	"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
	"github.com/edgelesssys/constellation/v2/internal/constants"
	"github.com/edgelesssys/constellation/v2/internal/file"
	"github.com/edgelesssys/constellation/v2/internal/logger"
)

// Updatable implements an updatable atls.Validator.
type Updatable struct {
	log          *logger.Logger
	mux          sync.Mutex
	newValidator newValidatorFunc
	fileHandler  file.Handler
	csp          cloudprovider.Provider
	azureCVM     bool
	atls.Validator
}

// NewValidator initializes a new updatable validator.
func NewValidator(log *logger.Logger, csp string, fileHandler file.Handler, azureCVM bool) (*Updatable, error) {
	var newValidator newValidatorFunc
	switch cloudprovider.FromString(csp) {
	case cloudprovider.AWS:
		newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator {
			return aws.NewValidator(m, e, log)
		}
	case cloudprovider.Azure:
		if azureCVM {
			newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator {
				return snp.NewValidator(m, e, idkeydigest, enforceIdKeyDigest, log)
			}
		} else {
			newValidator = func(m map[uint32][]byte, e []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator {
				return trustedlaunch.NewValidator(m, e, log)
			}
		}
	case cloudprovider.GCP:
		newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator {
			return gcp.NewValidator(m, e, log)
		}
	case cloudprovider.QEMU:
		newValidator = func(m map[uint32][]byte, e []uint32, _ []byte, _ bool, log *logger.Logger) atls.Validator {
			return qemu.NewValidator(m, e, log)
		}
	default:
		return nil, fmt.Errorf("unknown cloud service provider: %q", csp)
	}

	u := &Updatable{
		log:          log,
		newValidator: newValidator,
		fileHandler:  fileHandler,
		csp:          cloudprovider.FromString(csp),
		azureCVM:     azureCVM,
	}

	if err := u.Update(); err != nil {
		return nil, err
	}
	return u, nil
}

// Validate calls the validators Validate method, and prevents any updates during the call.
func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
	u.mux.Lock()
	defer u.mux.Unlock()
	return u.Validator.Validate(attDoc, nonce)
}

// OID returns the validators Object Identifier.
func (u *Updatable) OID() asn1.ObjectIdentifier {
	return u.Validator.OID()
}

// Update switches out the underlying validator.
func (u *Updatable) Update() error {
	u.mux.Lock()
	defer u.mux.Unlock()

	u.log.Infof("Updating expected measurements")

	var measurements map[uint32][]byte
	if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.MeasurementsFilename), &measurements); err != nil {
		return err
	}
	u.log.Debugf("New measurements: %v", measurements)

	var enforced []uint32
	if err := u.fileHandler.ReadJSON(filepath.Join(constants.ServiceBasePath, constants.EnforcedPCRsFilename), &enforced); err != nil {
		return err
	}
	u.log.Debugf("Enforced PCRs: %v", enforced)

	var idkeydigest []byte
	var enforceIDKeyDigest bool
	if u.csp == cloudprovider.Azure && u.azureCVM {
		u.log.Infof("Updating encforceIdKeyDigest value")
		enforceRaw, err := u.fileHandler.Read(filepath.Join(constants.ServiceBasePath, constants.EnforceIDKeyDigestFilename))
		if err != nil {
			return err
		}
		enforceIDKeyDigest, err = strconv.ParseBool(string(enforceRaw))
		if err != nil {
			return fmt.Errorf("parsing content of EnforceIdKeyDigestFilename: %s: %w", enforceRaw, err)
		}
		u.log.Debugf("New encforceIdKeyDigest value: %v", enforceIDKeyDigest)

		u.log.Infof("Updating expected idkeydigest")
		idkeydigestRaw, err := u.fileHandler.Read(filepath.Join(constants.ServiceBasePath, constants.IDKeyDigestFilename))
		if err != nil {
			return err
		}
		idkeydigest, err = hex.DecodeString(string(idkeydigestRaw))
		if err != nil {
			return fmt.Errorf("parsing hexstring: %s: %w", idkeydigestRaw, err)
		}
		u.log.Debugf("New idkeydigest: %x", idkeydigest)
	}

	u.Validator = u.newValidator(measurements, enforced, idkeydigest, enforceIDKeyDigest, u.log)

	return nil
}

type newValidatorFunc func(measurements map[uint32][]byte, enforcedPCRs []uint32, idkeydigest []byte, enforceIdKeyDigest bool, log *logger.Logger) atls.Validator