diff --git a/cli/internal/cmd/BUILD.bazel b/cli/internal/cmd/BUILD.bazel index 6f29a7965..c2decf328 100644 --- a/cli/internal/cmd/BUILD.bazel +++ b/cli/internal/cmd/BUILD.bazel @@ -111,6 +111,9 @@ go_library( "@io_k8s_sigs_yaml//:yaml", "@org_golang_x_mod//semver", "@org_golang_google_grpc//:grpc", + "@com_github_google_go_tdx_guest//abi", + "@com_github_google_go_tdx_guest//proto/tdx", + "//internal/attestation/azure/tdx", ] + select({ "@io_bazel_rules_go//go/platform:android_amd64": [ "@org_golang_x_sys//unix", diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 5a84c3df0..5bce0af2a 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -21,10 +21,9 @@ import ( "strconv" "strings" - tpmProto "github.com/google/go-tpm-tools/proto/tpm" - "github.com/edgelesssys/constellation/v2/internal/api/attestationconfigapi" "github.com/edgelesssys/constellation/v2/internal/atls" + azuretdx "github.com/edgelesssys/constellation/v2/internal/attestation/azure/tdx" "github.com/edgelesssys/constellation/v2/internal/attestation/choose" "github.com/edgelesssys/constellation/v2/internal/attestation/measurements" "github.com/edgelesssys/constellation/v2/internal/attestation/snp" @@ -38,6 +37,10 @@ import ( "github.com/edgelesssys/constellation/v2/internal/grpc/dialer" "github.com/edgelesssys/constellation/v2/internal/verify" "github.com/edgelesssys/constellation/v2/verify/verifyproto" + + "github.com/google/go-tdx-guest/abi" + "github.com/google/go-tdx-guest/proto/tdx" + tpmProto "github.com/google/go-tpm-tools/proto/tpm" "github.com/spf13/afero" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -106,24 +109,7 @@ func runVerify(cmd *cobra.Command, _ []string) error { dialer: dialer.New(nil, nil, &net.Dialer{}), log: log, } - formatterFactory := func(output string, attestation variant.Variant, log debugLog) (attestationDocFormatter, error) { - if output == "json" && - (!attestation.Equal(variant.AzureSEVSNP{}) && - !attestation.Equal(variant.AWSSEVSNP{}) && - !attestation.Equal(variant.GCPSEVSNP{})) { - return nil, errors.New("json output is only supported for SEV-SNP") - } - switch output { - case "json": - return &jsonAttestationDocFormatter{log}, nil - case "raw": - return &rawAttestationDocFormatter{log}, nil - case "": - return &defaultAttestationDocFormatter{log}, nil - default: - return nil, fmt.Errorf("invalid output value for formatter: %s", output) - } - } + v := &verifyCmd{ fileHandler: fileHandler, log: log, @@ -132,13 +118,12 @@ func runVerify(cmd *cobra.Command, _ []string) error { return err } v.log.Debug("Using flags", "clusterID", v.flags.clusterID, "endpoint", v.flags.endpoint, "ownerID", v.flags.ownerID) + fetcher := attestationconfigapi.NewFetcher() - return v.verify(cmd, verifyClient, formatterFactory, fetcher) + return v.verify(cmd, verifyClient, fetcher) } -type formatterFactory func(output string, attestation variant.Variant, log debugLog) (attestationDocFormatter, error) - -func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error { +func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, configFetcher attestationconfigapi.Fetcher) error { c.log.Debug(fmt.Sprintf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))) conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force) var configValidationErr *config.ValidationError @@ -202,20 +187,21 @@ func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factor return fmt.Errorf("verifying: %w", err) } - // certificates are only available for Azure SEV-SNP and AWS SEV-SNP - formatter, err := factory(c.flags.output, conf.GetAttestationConfig().GetVariant(), c.log) - if err != nil { - return fmt.Errorf("creating formatter: %w", err) + var attDocOutput string + switch c.flags.output { + case "json": + attDocOutput, err = formatJSON(cmd.Context(), rawAttestationDoc, attConfig, c.log) + case "raw": + attDocOutput = fmt.Sprintf("Attestation Document:\n%s\n", rawAttestationDoc) + case "": + attDocOutput, err = formatDefault(cmd.Context(), rawAttestationDoc, attConfig, c.log) + default: + return fmt.Errorf("invalid output value for formatter: %s", c.flags.output) } - attDocOutput, err := formatter.format( - cmd.Context(), - rawAttestationDoc, - (!attConfig.GetVariant().Equal(variant.AzureSEVSNP{}) && !attConfig.GetVariant().Equal(variant.AWSSEVSNP{})), - attConfig, - ) if err != nil { return fmt.Errorf("printing attestation document: %w", err) } + cmd.Println(attDocOutput) cmd.PrintErrln("Verification OK") @@ -255,82 +241,92 @@ func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.St return endpoint, nil } -// an attestationDocFormatter formats the attestation document. -type attestationDocFormatter interface { - // format returns the raw or formatted attestation doc depending on the rawOutput argument. - format(ctx context.Context, docString string, PCRsOnly bool, attestationCfg config.AttestationCfg) (string, error) -} - -type jsonAttestationDocFormatter struct { - log debugLog -} - -// format returns the json formatted attestation doc. -func (f *jsonAttestationDocFormatter) format(ctx context.Context, docString string, _ bool, - attestationCfg config.AttestationCfg, +// formatJSON returns the json formatted attestation doc. +func formatJSON(ctx context.Context, docString string, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { - var doc attestationDoc + var doc vtpm.AttestationDocument if err := json.Unmarshal([]byte(docString), &doc); err != nil { - return "", fmt.Errorf("unmarshal attestation document: %w", err) + return "", fmt.Errorf("unmarshalling attestation document: %w", err) } - instanceInfo, err := extractInstanceInfo(doc) - if err != nil { + switch attestationCfg.GetVariant() { + case variant.AWSSEVSNP{}, variant.AzureSEVSNP{}, variant.GCPSEVSNP{}: + return snpFormatJSON(ctx, doc.InstanceInfo, attestationCfg, log) + case variant.AzureTDX{}: + return tdxFormatJSON(doc.InstanceInfo, attestationCfg) + default: + return "", fmt.Errorf("json output is not supported for variant %s", attestationCfg.GetVariant()) + } +} + +func snpFormatJSON(ctx context.Context, instanceInfoRaw []byte, attestationCfg config.AttestationCfg, log debugLog, +) (string, error) { + var instanceInfo snp.InstanceInfo + if err := json.Unmarshal(instanceInfoRaw, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } - report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, f.log) + report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, log) if err != nil { return "", fmt.Errorf("parsing SNP report: %w", err) } jsonBytes, err := json.Marshal(report) - return string(jsonBytes), err } -type rawAttestationDocFormatter struct { - log debugLog -} +func tdxFormatJSON(instanceInfoRaw []byte, attestationCfg config.AttestationCfg) (string, error) { + var rawQuote []byte -// format returns the raw attestation doc. -func (f *rawAttestationDocFormatter) format(_ context.Context, docString string, _ bool, - _ config.AttestationCfg, -) (string, error) { - b := &strings.Builder{} - b.WriteString("Attestation Document:\n") - b.WriteString(fmt.Sprintf("%s\n", docString)) - return b.String(), nil -} + if attestationCfg.GetVariant().Equal(variant.AzureTDX{}) { + var instanceInfo azuretdx.InstanceInfo + if err := json.Unmarshal(instanceInfoRaw, &instanceInfo); err != nil { + return "", fmt.Errorf("unmarshalling instance info: %w", err) + } + rawQuote = instanceInfo.AttestationReport + } -type defaultAttestationDocFormatter struct { - log debugLog + tdxQuote, err := abi.QuoteToProto(rawQuote) + if err != nil { + return "", fmt.Errorf("converting quote to proto: %w", err) + } + quote, ok := tdxQuote.(*tdx.QuoteV4) + if !ok { + return "", fmt.Errorf("unexpected quote type: %T", tdxQuote) + } + + quoteJSON, err := json.Marshal(quote) + return string(quoteJSON), err } // format returns the formatted attestation doc. -func (f *defaultAttestationDocFormatter) format(ctx context.Context, docString string, PCRsOnly bool, - attestationCfg config.AttestationCfg, +func formatDefault(ctx context.Context, docString string, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { b := &strings.Builder{} b.WriteString("Attestation Document:\n") - var doc attestationDoc + var doc vtpm.AttestationDocument if err := json.Unmarshal([]byte(docString), &doc); err != nil { return "", fmt.Errorf("unmarshal attestation document: %w", err) } - if err := f.parseQuotes(b, doc.Attestation.Quotes, attestationCfg.GetMeasurements()); err != nil { + if err := parseQuotes(b, doc.Attestation.Quotes, attestationCfg.GetMeasurements()); err != nil { return "", fmt.Errorf("parse quote: %w", err) } - if PCRsOnly { + + // If we have a non SNP variant, print only the PCRs + if !(attestationCfg.GetVariant().Equal(variant.AzureSEVSNP{}) || + attestationCfg.GetVariant().Equal(variant.AWSSEVSNP{}) || + attestationCfg.GetVariant().Equal(variant.GCPSEVSNP{})) { return b.String(), nil } - instanceInfo, err := extractInstanceInfo(doc) - if err != nil { + // SNP reports contain extra information that we can print + var instanceInfo snp.InstanceInfo + if err := json.Unmarshal(doc.InstanceInfo, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } - report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, f.log) + report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, log) if err != nil { return "", fmt.Errorf("parsing SNP report: %w", err) } @@ -339,7 +335,7 @@ func (f *defaultAttestationDocFormatter) format(ctx context.Context, docString s } // parseQuotes parses the base64-encoded quotes and writes their details to the output builder. -func (f *defaultAttestationDocFormatter) parseQuotes(b *strings.Builder, quotes []*tpmProto.Quote, expectedPCRs measurements.M) error { +func parseQuotes(b *strings.Builder, quotes []*tpmProto.Quote, expectedPCRs measurements.M) error { writeIndentfln(b, 1, "Quote:") var pcrNumbers []uint32 @@ -366,18 +362,6 @@ func (f *defaultAttestationDocFormatter) parseQuotes(b *strings.Builder, quotes return nil } -// attestationDoc is the attestation document returned by the verifier. -type attestationDoc struct { - Attestation struct { - AkPub string `json:"ak_pub"` - Quotes []*tpmProto.Quote `json:"quotes"` - EventLog string `json:"event_log"` - TeeAttestation interface{} `json:"TeeAttestation"` - } `json:"Attestation"` - InstanceInfo string `json:"InstanceInfo"` - UserData string `json:"UserData"` -} - type constellationVerifier struct { dialer grpcInsecureDialer log debugLog @@ -432,19 +416,6 @@ func writeIndentfln(b *strings.Builder, indentLvl int, format string, args ...an b.WriteString(fmt.Sprintf(format+"\n", args...)) } -func extractInstanceInfo(doc attestationDoc) (snp.InstanceInfo, error) { - instanceInfoString, err := base64.StdEncoding.DecodeString(doc.InstanceInfo) - if err != nil { - return snp.InstanceInfo{}, fmt.Errorf("decode instance info: %w", err) - } - - var instanceInfo snp.InstanceInfo - if err := json.Unmarshal(instanceInfoString, &instanceInfo); err != nil { - return snp.InstanceInfo{}, fmt.Errorf("unmarshal instance info: %w", err) - } - return instanceInfo, nil -} - func addPortIfMissing(endpoint string, defaultPort int) (string, error) { if endpoint == "" { return "", errors.New("endpoint is empty") diff --git a/cli/internal/cmd/verify_test.go b/cli/internal/cmd/verify_test.go index a695a7c2f..9968a4ab4 100644 --- a/cli/internal/cmd/verify_test.go +++ b/cli/internal/cmd/verify_test.go @@ -47,7 +47,6 @@ func TestVerify(t *testing.T) { testCases := map[string]struct { provider cloudprovider.Provider protoClient *stubVerifyClient - formatter *stubAttDocFormatter nodeEndpointFlag string clusterIDFlag string stateFile *state.State @@ -62,7 +61,6 @@ func TestVerify(t *testing.T) { protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{}, }, "azure": { provider: cloudprovider.Azure, @@ -71,7 +69,6 @@ func TestVerify(t *testing.T) { protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.Azure), wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{}, }, "default port": { provider: cloudprovider.GCP, @@ -80,7 +77,6 @@ func TestVerify(t *testing.T) { protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), - formatter: &stubAttDocFormatter{}, }, "endpoint not set": { provider: cloudprovider.GCP, @@ -91,8 +87,7 @@ func TestVerify(t *testing.T) { s.Infrastructure.ClusterEndpoint = "" return s }(), - formatter: &stubAttDocFormatter{}, - wantErr: true, + wantErr: true, }, "endpoint from state file": { provider: cloudprovider.GCP, @@ -104,7 +99,6 @@ func TestVerify(t *testing.T) { return s }(), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), - formatter: &stubAttDocFormatter{}, }, "override endpoint from details file": { provider: cloudprovider.GCP, @@ -117,7 +111,6 @@ func TestVerify(t *testing.T) { return s }(), wantEndpoint: "192.0.2.2:1234", - formatter: &stubAttDocFormatter{}, }, "invalid endpoint": { provider: cloudprovider.GCP, @@ -125,7 +118,6 @@ func TestVerify(t *testing.T) { clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), - formatter: &stubAttDocFormatter{}, wantErr: true, }, "neither owner id nor cluster id set": { @@ -137,7 +129,6 @@ func TestVerify(t *testing.T) { s.ClusterValues.ClusterID = "" return s }(), - formatter: &stubAttDocFormatter{}, protoClient: &stubVerifyClient{}, wantErr: true, }, @@ -151,14 +142,12 @@ func TestVerify(t *testing.T) { return s }(), wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{}, }, "config file not existing": { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, nodeEndpointFlag: "192.0.2.1:1234", stateFile: defaultStateFile(cloudprovider.GCP), - formatter: &stubAttDocFormatter{}, skipConfigCreation: true, wantErr: true, }, @@ -168,7 +157,6 @@ func TestVerify(t *testing.T) { clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")}, stateFile: defaultStateFile(cloudprovider.Azure), - formatter: &stubAttDocFormatter{}, wantErr: true, }, "error protoClient GetState not rpc": { @@ -177,17 +165,6 @@ func TestVerify(t *testing.T) { clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: someErr}, stateFile: defaultStateFile(cloudprovider.Azure), - formatter: &stubAttDocFormatter{}, - wantErr: true, - }, - "format error": { - provider: cloudprovider.Azure, - nodeEndpointFlag: "192.0.2.1:1234", - clusterIDFlag: zeroBase64, - protoClient: &stubVerifyClient{}, - stateFile: defaultStateFile(cloudprovider.Azure), - wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{formatErr: someErr}, wantErr: true, }, } @@ -214,12 +191,10 @@ func TestVerify(t *testing.T) { flags: verifyFlags{ clusterID: tc.clusterIDFlag, endpoint: tc.nodeEndpointFlag, + output: "raw", }, } - formatterFac := func(_ string, _ variant.Variant, _ debugLog) (attestationDocFormatter, error) { - return tc.formatter, nil - } - err := v.verify(cmd, tc.protoClient, formatterFac, stubAttestationFetcher{}) + err := v.verify(cmd, tc.protoClient, stubAttestationFetcher{}) if tc.wantErr { assert.Error(err) } else { @@ -231,36 +206,20 @@ func TestVerify(t *testing.T) { } } -type stubAttDocFormatter struct { - formatErr error -} - -func (f *stubAttDocFormatter) format(_ context.Context, _ string, _ bool, _ config.AttestationCfg) (string, error) { - return "", f.formatErr -} - -func TestFormat(t *testing.T) { - formatter := func() *defaultAttestationDocFormatter { - return &defaultAttestationDocFormatter{ - log: logger.NewTest(t), - } - } - +func TestFormatDefault(t *testing.T) { testCases := map[string]struct { - formatter *defaultAttestationDocFormatter - doc string - wantErr bool + doc string + wantErr bool }{ "invalid doc": { - formatter: formatter(), - doc: "invalid", - wantErr: true, + doc: "invalid", + wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { - _, err := tc.formatter.format(context.Background(), tc.doc, false, nil) + _, err := formatDefault(context.Background(), tc.doc, nil, logger.NewTest(t)) if tc.wantErr { assert.Error(t, err) } else { @@ -502,9 +461,8 @@ func TestParseQuotes(t *testing.T) { assert := assert.New(t) b := &strings.Builder{} - parser := &defaultAttestationDocFormatter{} - err := parser.parseQuotes(b, tc.quotes, tc.expectedPCRs) + err := parseQuotes(b, tc.quotes, tc.expectedPCRs) if tc.wantErr { assert.Error(err) } else { diff --git a/internal/attestation/azure/tdx/issuer.go b/internal/attestation/azure/tdx/issuer.go index e04b066a6..082616635 100644 --- a/internal/attestation/azure/tdx/issuer.go +++ b/internal/attestation/azure/tdx/issuer.go @@ -90,7 +90,7 @@ func (i *Issuer) getInstanceInfo(ctx context.Context, tpm io.ReadWriteCloser, _ return nil, fmt.Errorf("getting quote: %w", err) } - instanceInfo := instanceInfo{ + instanceInfo := InstanceInfo{ AttestationReport: quote, RuntimeData: runtimeData, } diff --git a/internal/attestation/azure/tdx/tdx.go b/internal/attestation/azure/tdx/tdx.go index 815a43ae2..eaee6161a 100644 --- a/internal/attestation/azure/tdx/tdx.go +++ b/internal/attestation/azure/tdx/tdx.go @@ -19,7 +19,8 @@ More specifically: */ package tdx -type instanceInfo struct { +// InstanceInfo wraps the TDX report with additional Azure specific runtime data. +type InstanceInfo struct { AttestationReport []byte RuntimeData []byte } diff --git a/internal/attestation/azure/tdx/validator.go b/internal/attestation/azure/tdx/validator.go index 5b090dae9..02a8d3d6d 100644 --- a/internal/attestation/azure/tdx/validator.go +++ b/internal/attestation/azure/tdx/validator.go @@ -58,7 +58,7 @@ func NewValidator(cfg *config.AzureTDX, log attestation.Logger) *Validator { } func (v *Validator) getTrustedTPMKey(_ context.Context, attDoc vtpm.AttestationDocument, _ []byte) (crypto.PublicKey, error) { - var instanceInfo instanceInfo + var instanceInfo InstanceInfo if err := json.Unmarshal(attDoc.InstanceInfo, &instanceInfo); err != nil { return nil, err } diff --git a/internal/verify/verify.go b/internal/verify/verify.go index d674c4237..27e5db853 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -157,7 +157,7 @@ func getCertChain(cfg config.AttestationCfg) ([]byte, error) { return certChain, nil } -// FormatString builds a string representation of a report that is inteded for console output. +// FormatString builds a string representation of a report that is intended for console output. func (r *Report) FormatString(b *strings.Builder) (string, error) { if len(r.ReportSigner) != 1 { return "", fmt.Errorf("expected exactly one report signing certificate, found %d", len(r.ReportSigner))