/* Copyright (c) Edgeless Systems GmbH SPDX-License-Identifier: AGPL-3.0-only */ package cmd import ( "bytes" "context" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "net" "sort" "strconv" "strings" "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/atls" azuretdx "github.com/edgelesssys/constellation/v2/internal/attestation/azure/tdx" "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/snp" "github.com/edgelesssys/constellation/v2/internal/attestation/variant" "github.com/edgelesssys/constellation/v2/internal/attestation/vtpm" "github.com/edgelesssys/constellation/v2/internal/config" "github.com/edgelesssys/constellation/v2/internal/constants" "github.com/edgelesssys/constellation/v2/internal/constellation/state" "github.com/edgelesssys/constellation/v2/internal/crypto" "github.com/edgelesssys/constellation/v2/internal/file" "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/verify" "github.com/edgelesssys/constellation/v2/verify/verifyproto" "github.com/google/go-sev-guest/proto/sevsnp" "github.com/google/go-tdx-guest/abi" "github.com/google/go-tdx-guest/proto/tdx" "github.com/google/go-tpm-tools/proto/attest" tpmProto "github.com/google/go-tpm-tools/proto/tpm" "github.com/spf13/afero" "github.com/spf13/cobra" "github.com/spf13/pflag" "google.golang.org/grpc" ) // NewVerifyCmd returns a new cobra.Command for the verify command. func NewVerifyCmd() *cobra.Command { cmd := &cobra.Command{ Use: "verify", Short: "Verify the confidential properties of a Constellation cluster", Long: "Verify the confidential properties of a Constellation cluster.\n" + "If arguments aren't specified, values are read from `" + constants.StateFilename + "`.", Args: cobra.ExactArgs(0), RunE: runVerify, } cmd.Flags().String("cluster-id", "", "expected cluster identifier") cmd.Flags().StringP("output", "o", "", "print the attestation document in the output format {json|raw}") cmd.Flags().StringP("node-endpoint", "e", "", "endpoint of the node to verify, passed as HOST[:PORT]") return cmd } type verifyFlags struct { rootFlags endpoint string ownerID string clusterID string output string } func (f *verifyFlags) parse(flags *pflag.FlagSet) error { if err := f.rootFlags.parse(flags); err != nil { return err } var err error f.output, err = flags.GetString("output") if err != nil { return fmt.Errorf("getting 'output' flag: %w", err) } f.endpoint, err = flags.GetString("node-endpoint") if err != nil { return fmt.Errorf("getting 'node-endpoint' flag: %w", err) } f.clusterID, err = flags.GetString("cluster-id") if err != nil { return fmt.Errorf("getting 'cluster-id' flag: %w", err) } return nil } type verifyCmd struct { fileHandler file.Handler flags verifyFlags log debugLog } func runVerify(cmd *cobra.Command, _ []string) error { log, err := newCLILogger(cmd) if err != nil { return fmt.Errorf("creating logger: %w", err) } fileHandler := file.NewHandler(afero.NewOsFs()) verifyClient := &constellationVerifier{ dialer: dialer.New(nil, nil, &net.Dialer{}), log: log, } v := &verifyCmd{ fileHandler: fileHandler, log: log, } if err := v.flags.parse(cmd.Flags()); err != nil { return err } v.log.Debug("Using flags", "clusterID", v.flags.clusterID, "endpoint", v.flags.endpoint, "ownerID", v.flags.ownerID) fetcher := attestationconfigapi.NewFetcher() return v.verify(cmd, verifyClient, fetcher) } func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, configFetcher attestationconfigapi.Fetcher) error { c.log.Debug(fmt.Sprintf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))) conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force) var configValidationErr *config.ValidationError if errors.As(err, &configValidationErr) { cmd.PrintErrln(configValidationErr.LongMessage()) } if err != nil { return fmt.Errorf("loading config file: %w", err) } stateFile, err := state.ReadFromFile(c.fileHandler, constants.StateFilename) if err != nil { stateFile = state.New() // A state file is only required if the user has not provided IP or ID flags } ownerID, clusterID, err := c.validateIDFlags(cmd, stateFile) if err != nil { return err } endpoint, err := c.validateEndpointFlag(cmd, stateFile) if err != nil { return err } var maaURL string if stateFile.Infrastructure.Azure != nil { maaURL = stateFile.Infrastructure.Azure.AttestationURL } conf.UpdateMAAURL(maaURL) c.log.Debug("Updating expected PCRs") attConfig := conf.GetAttestationConfig() if err := updateInitMeasurements(attConfig, ownerID, clusterID); err != nil { return fmt.Errorf("updating expected PCRs: %w", err) } c.log.Debug(fmt.Sprintf("Creating aTLS Validator for %q", conf.GetAttestationConfig().GetVariant())) validator, err := choose.Validator(attConfig, warnLogger{cmd: cmd, log: c.log}) if err != nil { return fmt.Errorf("creating aTLS validator: %w", err) } nonce, err := crypto.GenerateRandomBytes(32) if err != nil { return fmt.Errorf("generating random nonce: %w", err) } c.log.Debug(fmt.Sprintf("Generated random nonce: %x", nonce)) rawAttestationDoc, err := verifyClient.Verify( cmd.Context(), endpoint, &verifyproto.GetAttestationRequest{ Nonce: nonce, }, validator, ) if err != nil { return fmt.Errorf("verifying: %w", err) } var attDocOutput string switch c.flags.output { case "json": attDocOutput, err = formatJSON(cmd.Context(), rawAttestationDoc, attConfig, c.log) if err != nil { return fmt.Errorf("printing attestation document: %w", err) } case "raw": attDocOutput = fmt.Sprintf("Attestation Document:\n%s\n", rawAttestationDoc) case "": attDocOutput, err = formatDefault(cmd.Context(), rawAttestationDoc, attConfig, c.log) if err != nil { return fmt.Errorf("printing attestation document: %w", err) } default: return fmt.Errorf("invalid output value for formatter: %s", c.flags.output) } cmd.Println(attDocOutput) cmd.PrintErrln("Verification OK") return nil } func (c *verifyCmd) validateIDFlags(cmd *cobra.Command, stateFile *state.State) (ownerID, clusterID string, err error) { ownerID, clusterID = c.flags.ownerID, c.flags.clusterID if c.flags.clusterID == "" { cmd.PrintErrf("Using ID from %q. Specify --cluster-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) clusterID = stateFile.ClusterValues.ClusterID } if ownerID == "" { // We don't want to print warnings until this is implemented again // cmd.PrintErrf("Using ID from %q. Specify --owner-id to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) ownerID = stateFile.ClusterValues.OwnerID } // Validate if ownerID == "" && clusterID == "" { return "", "", errors.New("cluster-id not provided to verify the cluster") } return ownerID, clusterID, nil } func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.State) (string, error) { endpoint := c.flags.endpoint if endpoint == "" { cmd.PrintErrf("Using endpoint from %q. Specify --node-endpoint to override this.\n", c.flags.pathPrefixer.PrefixPrintablePath(constants.StateFilename)) endpoint = stateFile.Infrastructure.ClusterEndpoint } endpoint, err := addPortIfMissing(endpoint, constants.VerifyServiceNodePortGRPC) if err != nil { return "", fmt.Errorf("validating endpoint argument: %w", err) } return endpoint, nil } // formatJSON returns the json formatted attestation doc. func formatJSON(ctx context.Context, docString []byte, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { doc, err := unmarshalAttDoc(docString, attestationCfg.GetVariant()) if err != nil { return "", fmt.Errorf("unmarshalling attestation document: %w", err) } switch attestationCfg.GetVariant() { case variant.AWSSEVSNP{}, variant.AzureSEVSNP{}, variant.GCPSEVSNP{}: return snpFormatJSON(ctx, doc.InstanceInfo, attestationCfg, log) case variant.AzureTDX{}: return tdxFormatJSON(doc.InstanceInfo, attestationCfg) default: return "", fmt.Errorf("json output is not supported for variant %s", attestationCfg.GetVariant()) } } func snpFormatJSON(ctx context.Context, instanceInfoRaw []byte, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { var instanceInfo snp.InstanceInfo if err := json.Unmarshal(instanceInfoRaw, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, log) if err != nil { return "", fmt.Errorf("parsing SNP report: %w", err) } jsonBytes, err := json.Marshal(report) return string(jsonBytes), err } func tdxFormatJSON(instanceInfoRaw []byte, attestationCfg config.AttestationCfg) (string, error) { var rawQuote []byte if attestationCfg.GetVariant().Equal(variant.AzureTDX{}) { var instanceInfo azuretdx.InstanceInfo if err := json.Unmarshal(instanceInfoRaw, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } rawQuote = instanceInfo.AttestationReport } tdxQuote, err := abi.QuoteToProto(rawQuote) if err != nil { return "", fmt.Errorf("converting quote to proto: %w", err) } quote, ok := tdxQuote.(*tdx.QuoteV4) if !ok { return "", fmt.Errorf("unexpected quote type: %T", tdxQuote) } quoteJSON, err := json.Marshal(quote) return string(quoteJSON), err } // format returns the formatted attestation doc. func formatDefault(ctx context.Context, docString []byte, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { b := &strings.Builder{} b.WriteString("Attestation Document:\n") doc, err := unmarshalAttDoc(docString, attestationCfg.GetVariant()) if err != nil { return "", fmt.Errorf("unmarshalling attestation document: %w", err) } if err := parseQuotes(b, doc.Attestation.Quotes, attestationCfg.GetMeasurements()); err != nil { return "", fmt.Errorf("parse quote: %w", err) } // If we have a non SNP variant, print only the PCRs if !(attestationCfg.GetVariant().Equal(variant.AzureSEVSNP{}) || attestationCfg.GetVariant().Equal(variant.AWSSEVSNP{}) || attestationCfg.GetVariant().Equal(variant.GCPSEVSNP{})) { return b.String(), nil } // SNP reports contain extra information that we can print var instanceInfo snp.InstanceInfo if err := json.Unmarshal(doc.InstanceInfo, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, log) if err != nil { return "", fmt.Errorf("parsing SNP report: %w", err) } return report.FormatString(b) } // parseQuotes parses the base64-encoded quotes and writes their details to the output builder. func parseQuotes(b *strings.Builder, quotes []*tpmProto.Quote, expectedPCRs measurements.M) error { writeIndentfln(b, 1, "Quote:") var pcrNumbers []uint32 for pcrNum := range expectedPCRs { pcrNumbers = append(pcrNumbers, pcrNum) } sort.Slice(pcrNumbers, func(i, j int) bool { return pcrNumbers[i] < pcrNumbers[j] }) for _, pcrNum := range pcrNumbers { expectedPCR := expectedPCRs[pcrNum] pcrIdx, err := vtpm.GetSHA256QuoteIndex(quotes) if err != nil { return fmt.Errorf("get SHA256 quote index: %w", err) } actualPCR, ok := quotes[pcrIdx].Pcrs.Pcrs[pcrNum] if !ok { return fmt.Errorf("PCR %d not found in quote", pcrNum) } writeIndentfln(b, 2, "PCR %d (Strict: %t):", pcrNum, !expectedPCR.ValidationOpt) writeIndentfln(b, 3, "Expected:\t%x", expectedPCR.Expected) writeIndentfln(b, 3, "Actual:\t\t%x", actualPCR) } return nil } type constellationVerifier struct { dialer grpcInsecureDialer log debugLog } // Verify retrieves an attestation statement from the Constellation and verifies it using the validator. func (v *constellationVerifier) Verify( ctx context.Context, endpoint string, req *verifyproto.GetAttestationRequest, validator atls.Validator, ) ([]byte, error) { v.log.Debug(fmt.Sprintf("Dialing endpoint: %q", endpoint)) conn, err := v.dialer.DialInsecure(endpoint) if err != nil { return nil, fmt.Errorf("dialing init server: %w", err) } defer conn.Close() client := verifyproto.NewAPIClient(conn) v.log.Debug("Sending attestation request") resp, err := client.GetAttestation(ctx, req) if err != nil { return nil, fmt.Errorf("getting attestation: %w", err) } v.log.Debug("Verifying attestation") signedData, err := validator.Validate(ctx, resp.Attestation, req.Nonce) if err != nil { return nil, fmt.Errorf("validating attestation: %w", err) } if !bytes.Equal(signedData, []byte(constants.ConstellationVerifyServiceUserData)) { return nil, errors.New("signed data in attestation does not match expected user data") } return resp.Attestation, nil } type verifyClient interface { Verify(ctx context.Context, endpoint string, req *verifyproto.GetAttestationRequest, validator atls.Validator) ([]byte, error) } type grpcInsecureDialer interface { DialInsecure(endpoint string) (conn *grpc.ClientConn, err error) } // writeIndentfln writes a formatted string to the builder with the given indentation level // and a newline at the end. func writeIndentfln(b *strings.Builder, indentLvl int, format string, args ...any) { for i := 0; i < indentLvl; i++ { b.WriteByte('\t') } b.WriteString(fmt.Sprintf(format+"\n", args...)) } func addPortIfMissing(endpoint string, defaultPort int) (string, error) { if endpoint == "" { return "", errors.New("endpoint is empty") } _, _, err := net.SplitHostPort(endpoint) if err == nil { return endpoint, nil } if strings.Contains(err.Error(), "missing port in address") { return net.JoinHostPort(endpoint, strconv.Itoa(defaultPort)), nil } return "", err } // UpdateInitMeasurements sets the owner and cluster measurement values in the attestation config depending on the // attestation variant. func updateInitMeasurements(config config.AttestationCfg, ownerID, clusterID string) error { m := config.GetMeasurements() switch config.GetVariant() { case variant.AWSNitroTPM{}, variant.AWSSEVSNP{}, variant.AzureTrustedLaunch{}, variant.AzureSEVSNP{}, variant.AzureTDX{}, // AzureTDX also uses a vTPM for measurements variant.GCPSEVES{}, variant.GCPSEVSNP{}, variant.QEMUVTPM{}: if err := updateMeasurementTPM(m, uint32(measurements.PCRIndexOwnerID), ownerID); err != nil { return err } return updateMeasurementTPM(m, uint32(measurements.PCRIndexClusterID), clusterID) case variant.QEMUTDX{}: // Measuring ownerID is currently not implemented for Constellation // Since adding support for measuring ownerID to TDX would require additional code changes, // the current implementation does not support it, but can be changed if we decide to add support in the future return updateMeasurementTDX(m, uint32(measurements.TDXIndexClusterID), clusterID) default: return errors.New("selecting attestation variant: unknown attestation variant") } } // updateMeasurementTDX updates the TDX measurement value in the attestation config for the given measurement index. func updateMeasurementTDX(m measurements.M, measurementIdx uint32, encoded string) error { if encoded == "" { delete(m, measurementIdx) return nil } decoded, err := decodeMeasurement(encoded) if err != nil { return err } // new_measurement_value := hash(old_measurement_value || data_to_extend) // Since we use the DG.MR.RTMR.EXTEND call to extend the register, data_to_extend is the hash of our input hashedInput := sha512.Sum384(decoded) oldExpected := m[measurementIdx].Expected expectedMeasurementSum := sha512.Sum384(append(oldExpected[:], hashedInput[:]...)) m[measurementIdx] = measurements.Measurement{ Expected: expectedMeasurementSum[:], ValidationOpt: m[measurementIdx].ValidationOpt, } return nil } // updateMeasurementTPM updates the TPM measurement value in the attestation config for the given measurement index. func updateMeasurementTPM(m measurements.M, measurementIdx uint32, encoded string) error { if encoded == "" { delete(m, measurementIdx) return nil } decoded, err := decodeMeasurement(encoded) if err != nil { return err } // new_pcr_value := hash(old_pcr_value || data_to_extend) // Since we use the TPM2_PCR_Event call to extend the PCR, data_to_extend is the hash of our input hashedInput := sha256.Sum256(decoded) oldExpected := m[measurementIdx].Expected expectedMeasurement := sha256.Sum256(append(oldExpected[:], hashedInput[:]...)) m[measurementIdx] = measurements.Measurement{ Expected: expectedMeasurement[:], ValidationOpt: m[measurementIdx].ValidationOpt, } return nil } // decodeMeasurement is a utility function that decodes the given string as hex or base64. func decodeMeasurement(encoded string) ([]byte, error) { decoded, err := hex.DecodeString(encoded) if err != nil { hexErr := err decoded, err = base64.StdEncoding.DecodeString(encoded) if err != nil { return nil, fmt.Errorf("input [%s] could neither be hex decoded (%w) nor base64 decoded (%w)", encoded, hexErr, err) } } return decoded, nil } func unmarshalAttDoc(attDocJSON []byte, attestationVariant variant.Variant) (vtpm.AttestationDocument, error) { attDoc := vtpm.AttestationDocument{ Attestation: &attest.Attestation{}, } // Explicitly initialize this struct, as TeeAttestation // is a "oneof" protobuf field, which needs an explicit // type to be set to be unmarshaled correctly. switch attestationVariant { case variant.AzureTDX{}: attDoc.Attestation.TeeAttestation = &attest.Attestation_TdxAttestation{ TdxAttestation: &tdx.QuoteV4{}, } default: attDoc.Attestation.TeeAttestation = &attest.Attestation_SevSnpAttestation{ SevSnpAttestation: &sevsnp.Attestation{}, } } err := json.Unmarshal(attDocJSON, &attDoc) return attDoc, err }