From 574a0812b052a4da26b50d3c0ba28286410fd3d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= Date: Tue, 11 Jun 2024 11:41:11 +0200 Subject: [PATCH] Remove formater factory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- cli/internal/cmd/verify.go | 131 ++++++++++---------------------- cli/internal/cmd/verify_test.go | 62 +++------------ internal/verify/verify.go | 2 +- 3 files changed, 52 insertions(+), 143 deletions(-) diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index 5a84c3df0..9ba5ed546 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -106,24 +106,7 @@ func runVerify(cmd *cobra.Command, _ []string) error { dialer: dialer.New(nil, nil, &net.Dialer{}), log: log, } - formatterFactory := func(output string, attestation variant.Variant, log debugLog) (attestationDocFormatter, error) { - if output == "json" && - (!attestation.Equal(variant.AzureSEVSNP{}) && - !attestation.Equal(variant.AWSSEVSNP{}) && - !attestation.Equal(variant.GCPSEVSNP{})) { - return nil, errors.New("json output is only supported for SEV-SNP") - } - switch output { - case "json": - return &jsonAttestationDocFormatter{log}, nil - case "raw": - return &rawAttestationDocFormatter{log}, nil - case "": - return &defaultAttestationDocFormatter{log}, nil - default: - return nil, fmt.Errorf("invalid output value for formatter: %s", output) - } - } + v := &verifyCmd{ fileHandler: fileHandler, log: log, @@ -132,13 +115,12 @@ func runVerify(cmd *cobra.Command, _ []string) error { return err } v.log.Debug("Using flags", "clusterID", v.flags.clusterID, "endpoint", v.flags.endpoint, "ownerID", v.flags.ownerID) + fetcher := attestationconfigapi.NewFetcher() - return v.verify(cmd, verifyClient, formatterFactory, fetcher) + return v.verify(cmd, verifyClient, fetcher) } -type formatterFactory func(output string, attestation variant.Variant, log debugLog) (attestationDocFormatter, error) - -func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factory formatterFactory, configFetcher attestationconfigapi.Fetcher) error { +func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, configFetcher attestationconfigapi.Fetcher) error { c.log.Debug(fmt.Sprintf("Loading configuration file from %q", c.flags.pathPrefixer.PrefixPrintablePath(constants.ConfigFilename))) conf, err := config.New(c.fileHandler, constants.ConfigFilename, configFetcher, c.flags.force) var configValidationErr *config.ValidationError @@ -202,20 +184,27 @@ func (c *verifyCmd) verify(cmd *cobra.Command, verifyClient verifyClient, factor return fmt.Errorf("verifying: %w", err) } - // certificates are only available for Azure SEV-SNP and AWS SEV-SNP - formatter, err := factory(c.flags.output, conf.GetAttestationConfig().GetVariant(), c.log) - if err != nil { - return fmt.Errorf("creating formatter: %w", err) + var attDocOutput string + switch c.flags.output { + case "json": + if !(attConfig.GetVariant().Equal(variant.AzureSEVSNP{}) || + attConfig.GetVariant().Equal(variant.AWSSEVSNP{}) || + attConfig.GetVariant().Equal(variant.GCPSEVSNP{})) { + return errors.New("json output is only supported for SEV-SNP") + } + + attDocOutput, err = formatJSON(cmd.Context(), rawAttestationDoc, attConfig, c.log) + case "raw": + attDocOutput = fmt.Sprintf("Attestation Document:\n%s\n", rawAttestationDoc) + case "": + attDocOutput, err = formatDefault(cmd.Context(), rawAttestationDoc, attConfig, c.log) + default: + return fmt.Errorf("invalid output value for formatter: %s", c.flags.output) } - attDocOutput, err := formatter.format( - cmd.Context(), - rawAttestationDoc, - (!attConfig.GetVariant().Equal(variant.AzureSEVSNP{}) && !attConfig.GetVariant().Equal(variant.AWSSEVSNP{})), - attConfig, - ) if err != nil { return fmt.Errorf("printing attestation document: %w", err) } + cmd.Println(attDocOutput) cmd.PrintErrln("Verification OK") @@ -255,30 +244,19 @@ func (c *verifyCmd) validateEndpointFlag(cmd *cobra.Command, stateFile *state.St return endpoint, nil } -// an attestationDocFormatter formats the attestation document. -type attestationDocFormatter interface { - // format returns the raw or formatted attestation doc depending on the rawOutput argument. - format(ctx context.Context, docString string, PCRsOnly bool, attestationCfg config.AttestationCfg) (string, error) -} - -type jsonAttestationDocFormatter struct { - log debugLog -} - -// format returns the json formatted attestation doc. -func (f *jsonAttestationDocFormatter) format(ctx context.Context, docString string, _ bool, - attestationCfg config.AttestationCfg, +// formatJSON returns the json formatted attestation doc. +func formatJSON(ctx context.Context, docString string, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { var doc attestationDoc if err := json.Unmarshal([]byte(docString), &doc); err != nil { - return "", fmt.Errorf("unmarshal attestation document: %w", err) + return "", fmt.Errorf("unmarshalling attestation document: %w", err) } - instanceInfo, err := extractInstanceInfo(doc) - if err != nil { + var instanceInfo snp.InstanceInfo + if err := json.Unmarshal(doc.InstanceInfo, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } - report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, f.log) + report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, log) if err != nil { return "", fmt.Errorf("parsing SNP report: %w", err) } @@ -288,27 +266,8 @@ func (f *jsonAttestationDocFormatter) format(ctx context.Context, docString stri return string(jsonBytes), err } -type rawAttestationDocFormatter struct { - log debugLog -} - -// format returns the raw attestation doc. -func (f *rawAttestationDocFormatter) format(_ context.Context, docString string, _ bool, - _ config.AttestationCfg, -) (string, error) { - b := &strings.Builder{} - b.WriteString("Attestation Document:\n") - b.WriteString(fmt.Sprintf("%s\n", docString)) - return b.String(), nil -} - -type defaultAttestationDocFormatter struct { - log debugLog -} - // format returns the formatted attestation doc. -func (f *defaultAttestationDocFormatter) format(ctx context.Context, docString string, PCRsOnly bool, - attestationCfg config.AttestationCfg, +func formatDefault(ctx context.Context, docString string, attestationCfg config.AttestationCfg, log debugLog, ) (string, error) { b := &strings.Builder{} b.WriteString("Attestation Document:\n") @@ -318,19 +277,24 @@ func (f *defaultAttestationDocFormatter) format(ctx context.Context, docString s return "", fmt.Errorf("unmarshal attestation document: %w", err) } - if err := f.parseQuotes(b, doc.Attestation.Quotes, attestationCfg.GetMeasurements()); err != nil { + if err := parseQuotes(b, doc.Attestation.Quotes, attestationCfg.GetMeasurements()); err != nil { return "", fmt.Errorf("parse quote: %w", err) } - if PCRsOnly { + + // If we have a non SNP variant, print only the PCRs + if !(attestationCfg.GetVariant().Equal(variant.AzureSEVSNP{}) || + attestationCfg.GetVariant().Equal(variant.AWSSEVSNP{}) || + attestationCfg.GetVariant().Equal(variant.GCPSEVSNP{})) { return b.String(), nil } - instanceInfo, err := extractInstanceInfo(doc) - if err != nil { + // SNP reports contain extra information that we can print + var instanceInfo snp.InstanceInfo + if err := json.Unmarshal(doc.InstanceInfo, &instanceInfo); err != nil { return "", fmt.Errorf("unmarshalling instance info: %w", err) } - report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, f.log) + report, err := verify.NewReport(ctx, instanceInfo, attestationCfg, log) if err != nil { return "", fmt.Errorf("parsing SNP report: %w", err) } @@ -339,7 +303,7 @@ func (f *defaultAttestationDocFormatter) format(ctx context.Context, docString s } // parseQuotes parses the base64-encoded quotes and writes their details to the output builder. -func (f *defaultAttestationDocFormatter) parseQuotes(b *strings.Builder, quotes []*tpmProto.Quote, expectedPCRs measurements.M) error { +func parseQuotes(b *strings.Builder, quotes []*tpmProto.Quote, expectedPCRs measurements.M) error { writeIndentfln(b, 1, "Quote:") var pcrNumbers []uint32 @@ -374,8 +338,8 @@ type attestationDoc struct { EventLog string `json:"event_log"` TeeAttestation interface{} `json:"TeeAttestation"` } `json:"Attestation"` - InstanceInfo string `json:"InstanceInfo"` - UserData string `json:"UserData"` + InstanceInfo []byte `json:"InstanceInfo"` + UserData []byte `json:"UserData"` } type constellationVerifier struct { @@ -432,19 +396,6 @@ func writeIndentfln(b *strings.Builder, indentLvl int, format string, args ...an b.WriteString(fmt.Sprintf(format+"\n", args...)) } -func extractInstanceInfo(doc attestationDoc) (snp.InstanceInfo, error) { - instanceInfoString, err := base64.StdEncoding.DecodeString(doc.InstanceInfo) - if err != nil { - return snp.InstanceInfo{}, fmt.Errorf("decode instance info: %w", err) - } - - var instanceInfo snp.InstanceInfo - if err := json.Unmarshal(instanceInfoString, &instanceInfo); err != nil { - return snp.InstanceInfo{}, fmt.Errorf("unmarshal instance info: %w", err) - } - return instanceInfo, nil -} - func addPortIfMissing(endpoint string, defaultPort int) (string, error) { if endpoint == "" { return "", errors.New("endpoint is empty") diff --git a/cli/internal/cmd/verify_test.go b/cli/internal/cmd/verify_test.go index a695a7c2f..9968a4ab4 100644 --- a/cli/internal/cmd/verify_test.go +++ b/cli/internal/cmd/verify_test.go @@ -47,7 +47,6 @@ func TestVerify(t *testing.T) { testCases := map[string]struct { provider cloudprovider.Provider protoClient *stubVerifyClient - formatter *stubAttDocFormatter nodeEndpointFlag string clusterIDFlag string stateFile *state.State @@ -62,7 +61,6 @@ func TestVerify(t *testing.T) { protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{}, }, "azure": { provider: cloudprovider.Azure, @@ -71,7 +69,6 @@ func TestVerify(t *testing.T) { protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.Azure), wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{}, }, "default port": { provider: cloudprovider.GCP, @@ -80,7 +77,6 @@ func TestVerify(t *testing.T) { protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), - formatter: &stubAttDocFormatter{}, }, "endpoint not set": { provider: cloudprovider.GCP, @@ -91,8 +87,7 @@ func TestVerify(t *testing.T) { s.Infrastructure.ClusterEndpoint = "" return s }(), - formatter: &stubAttDocFormatter{}, - wantErr: true, + wantErr: true, }, "endpoint from state file": { provider: cloudprovider.GCP, @@ -104,7 +99,6 @@ func TestVerify(t *testing.T) { return s }(), wantEndpoint: "192.0.2.1:" + strconv.Itoa(constants.VerifyServiceNodePortGRPC), - formatter: &stubAttDocFormatter{}, }, "override endpoint from details file": { provider: cloudprovider.GCP, @@ -117,7 +111,6 @@ func TestVerify(t *testing.T) { return s }(), wantEndpoint: "192.0.2.2:1234", - formatter: &stubAttDocFormatter{}, }, "invalid endpoint": { provider: cloudprovider.GCP, @@ -125,7 +118,6 @@ func TestVerify(t *testing.T) { clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{}, stateFile: defaultStateFile(cloudprovider.GCP), - formatter: &stubAttDocFormatter{}, wantErr: true, }, "neither owner id nor cluster id set": { @@ -137,7 +129,6 @@ func TestVerify(t *testing.T) { s.ClusterValues.ClusterID = "" return s }(), - formatter: &stubAttDocFormatter{}, protoClient: &stubVerifyClient{}, wantErr: true, }, @@ -151,14 +142,12 @@ func TestVerify(t *testing.T) { return s }(), wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{}, }, "config file not existing": { provider: cloudprovider.GCP, clusterIDFlag: zeroBase64, nodeEndpointFlag: "192.0.2.1:1234", stateFile: defaultStateFile(cloudprovider.GCP), - formatter: &stubAttDocFormatter{}, skipConfigCreation: true, wantErr: true, }, @@ -168,7 +157,6 @@ func TestVerify(t *testing.T) { clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: rpcStatus.Error(codes.Internal, "failed")}, stateFile: defaultStateFile(cloudprovider.Azure), - formatter: &stubAttDocFormatter{}, wantErr: true, }, "error protoClient GetState not rpc": { @@ -177,17 +165,6 @@ func TestVerify(t *testing.T) { clusterIDFlag: zeroBase64, protoClient: &stubVerifyClient{verifyErr: someErr}, stateFile: defaultStateFile(cloudprovider.Azure), - formatter: &stubAttDocFormatter{}, - wantErr: true, - }, - "format error": { - provider: cloudprovider.Azure, - nodeEndpointFlag: "192.0.2.1:1234", - clusterIDFlag: zeroBase64, - protoClient: &stubVerifyClient{}, - stateFile: defaultStateFile(cloudprovider.Azure), - wantEndpoint: "192.0.2.1:1234", - formatter: &stubAttDocFormatter{formatErr: someErr}, wantErr: true, }, } @@ -214,12 +191,10 @@ func TestVerify(t *testing.T) { flags: verifyFlags{ clusterID: tc.clusterIDFlag, endpoint: tc.nodeEndpointFlag, + output: "raw", }, } - formatterFac := func(_ string, _ variant.Variant, _ debugLog) (attestationDocFormatter, error) { - return tc.formatter, nil - } - err := v.verify(cmd, tc.protoClient, formatterFac, stubAttestationFetcher{}) + err := v.verify(cmd, tc.protoClient, stubAttestationFetcher{}) if tc.wantErr { assert.Error(err) } else { @@ -231,36 +206,20 @@ func TestVerify(t *testing.T) { } } -type stubAttDocFormatter struct { - formatErr error -} - -func (f *stubAttDocFormatter) format(_ context.Context, _ string, _ bool, _ config.AttestationCfg) (string, error) { - return "", f.formatErr -} - -func TestFormat(t *testing.T) { - formatter := func() *defaultAttestationDocFormatter { - return &defaultAttestationDocFormatter{ - log: logger.NewTest(t), - } - } - +func TestFormatDefault(t *testing.T) { testCases := map[string]struct { - formatter *defaultAttestationDocFormatter - doc string - wantErr bool + doc string + wantErr bool }{ "invalid doc": { - formatter: formatter(), - doc: "invalid", - wantErr: true, + doc: "invalid", + wantErr: true, }, } for name, tc := range testCases { t.Run(name, func(t *testing.T) { - _, err := tc.formatter.format(context.Background(), tc.doc, false, nil) + _, err := formatDefault(context.Background(), tc.doc, nil, logger.NewTest(t)) if tc.wantErr { assert.Error(t, err) } else { @@ -502,9 +461,8 @@ func TestParseQuotes(t *testing.T) { assert := assert.New(t) b := &strings.Builder{} - parser := &defaultAttestationDocFormatter{} - err := parser.parseQuotes(b, tc.quotes, tc.expectedPCRs) + err := parseQuotes(b, tc.quotes, tc.expectedPCRs) if tc.wantErr { assert.Error(err) } else { diff --git a/internal/verify/verify.go b/internal/verify/verify.go index d674c4237..27e5db853 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -157,7 +157,7 @@ func getCertChain(cfg config.AttestationCfg) ([]byte, error) { return certChain, nil } -// FormatString builds a string representation of a report that is inteded for console output. +// FormatString builds a string representation of a report that is intended for console output. func (r *Report) FormatString(b *strings.Builder) (string, error) { if len(r.ReportSigner) != 1 { return "", fmt.Errorf("expected exactly one report signing certificate, found %d", len(r.ReportSigner))