From db5660e3d619cda5d2ea5bbf2262a74f7c159621 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Wed, 29 Mar 2023 09:06:10 +0200 Subject: [PATCH] attestation: add context to Issue and Validate methods (#1532) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- cli/internal/cmd/init_test.go | 4 ++-- cli/internal/cmd/verify.go | 2 +- internal/atls/atls.go | 24 ++++++++++++------- internal/attestation/vtpm/attestation.go | 8 +++---- internal/attestation/vtpm/attestation_test.go | 20 +++++++++------- .../atlscredentials/atlscredentials_test.go | 4 ++-- internal/watcher/validator.go | 5 ++-- internal/watcher/validator_test.go | 2 +- verify/server/server.go | 6 ++--- verify/server/server_test.go | 2 +- 10 files changed, 43 insertions(+), 34 deletions(-) diff --git a/cli/internal/cmd/init_test.go b/cli/internal/cmd/init_test.go index a2694414c..86a44f040 100644 --- a/cli/internal/cmd/init_test.go +++ b/cli/internal/cmd/init_test.go @@ -464,7 +464,7 @@ type testValidator struct { pcrs measurements.M } -func (v *testValidator) Validate(attDoc []byte, _ []byte) ([]byte, error) { +func (v *testValidator) Validate(_ context.Context, attDoc []byte, _ []byte) ([]byte, error) { var attestation struct { UserData []byte PCRs map[uint32][]byte @@ -486,7 +486,7 @@ type testIssuer struct { pcrs map[uint32][]byte } -func (i *testIssuer) Issue(userData []byte, _ []byte) ([]byte, error) { +func (i *testIssuer) Issue(_ context.Context, userData []byte, _ []byte) ([]byte, error) { return json.Marshal( struct { UserData []byte diff --git a/cli/internal/cmd/verify.go b/cli/internal/cmd/verify.go index f1d0eb197..c6812be8e 100644 --- a/cli/internal/cmd/verify.go +++ b/cli/internal/cmd/verify.go @@ -232,7 +232,7 @@ func (v *constellationVerifier) Verify( } v.log.Debugf("Verifying attestation") - signedData, err := validator.Validate(resp.Attestation, req.Nonce) + signedData, err := validator.Validate(ctx, resp.Attestation, req.Nonce) if err != nil { return fmt.Errorf("validating attestation: %w", err) } diff --git a/internal/atls/atls.go b/internal/atls/atls.go index a1a5d589d..b52a2417f 100644 --- a/internal/atls/atls.go +++ b/internal/atls/atls.go @@ -9,6 +9,7 @@ package atls import ( "bytes" + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -28,6 +29,8 @@ import ( "github.com/edgelesssys/constellation/v2/internal/oid" ) +const attestationTimeout = 30 * time.Second + // CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document. // Pass a list of validators to enable mutual aTLS. // If issuer is nil, no attestation will be embedded. @@ -73,13 +76,13 @@ func CreateAttestationClientTLSConfig(issuer Issuer, validators []Validator) (*t // Issuer issues an attestation document. type Issuer interface { oid.Getter - Issue(userData []byte, nonce []byte) (quote []byte, err error) + Issue(ctx context.Context, userData []byte, nonce []byte) (quote []byte, err error) } // Validator is able to validate an attestation document. type Validator interface { oid.Getter - Validate(attDoc []byte, nonce []byte) ([]byte, error) + Validate(ctx context.Context, attDoc []byte, nonce []byte) ([]byte, error) } // getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server. @@ -129,7 +132,7 @@ func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tl // getCertificate creates a client or server certificate for aTLS connections. // The certificate uses certificate extensions to embed an attestation document generated using nonce. -func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) { +func getCertificate(ctx context.Context, issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) { serialNumber, err := crypto.GenerateCertificateSerialNumber() if err != nil { return nil, err @@ -145,7 +148,7 @@ func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificat } // create attestation document using the nonce send by the remote party - attDoc, err := issuer.Issue(hash, nonce) + attDoc, err := issuer.Issue(ctx, hash, nonce) if err != nil { return nil, err } @@ -200,7 +203,10 @@ func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, hash, for _, ex := range cert.Extensions { for _, validator := range validators { if ex.Id.Equal(validator.OID()) { - userData, err := validator.Validate(ex.Value, nonce) + ctx, cancel := context.WithTimeout(context.Background(), attestationTimeout) + defer cancel() + + userData, err := validator.Validate(ctx, ex.Value, nonce) if err != nil { return err } @@ -308,7 +314,7 @@ func (c *clientConnection) getCertificate(cri *tls.CertificateRequestInfo) (*tls return nil, fmt.Errorf("decode nonce: %w", err) } - return getCertificate(c.issuer, priv, &priv.PublicKey, serverNonce) + return getCertificate(cri.Context(), c.issuer, priv, &priv.PublicKey, serverNonce) } // serverConnection holds state for server to client connections. @@ -340,7 +346,7 @@ func (c *serverConnection) getCertificate(chi *tls.ClientHelloInfo) (*tls.Certif } // create aTLS certificate using the nonce as extracted from the client-hello message - return getCertificate(c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce) + return getCertificate(chi.Context(), c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce) } // FakeIssuer fakes an issuer and can be used for tests. @@ -354,7 +360,7 @@ func NewFakeIssuer(oid oid.Getter) *FakeIssuer { } // Issue marshals the user data and returns it. -func (FakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { +func (FakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) { return json.Marshal(FakeAttestationDoc{UserData: userData, Nonce: nonce}) } @@ -375,7 +381,7 @@ func NewFakeValidators(oid oid.Getter) []Validator { } // Validate unmarshals the attestation document and verifies the nonce. -func (v FakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { +func (v FakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) { var doc FakeAttestationDoc if err := json.Unmarshal(attDoc, &doc); err != nil { return nil, err diff --git a/internal/attestation/vtpm/attestation.go b/internal/attestation/vtpm/attestation.go index fed14529a..96a1894a0 100644 --- a/internal/attestation/vtpm/attestation.go +++ b/internal/attestation/vtpm/attestation.go @@ -107,7 +107,7 @@ func NewIssuer( } // Issue generates an attestation document using a TPM. -func (i *Issuer) Issue(userData []byte, nonce []byte) (res []byte, err error) { +func (i *Issuer) Issue(ctx context.Context, userData []byte, nonce []byte) (res []byte, err error) { i.log.Infof("Issuing attestation statement") defer func() { if err != nil { @@ -136,7 +136,7 @@ func (i *Issuer) Issue(userData []byte, nonce []byte) (res []byte, err error) { } // Fetch instance info of the VM - instanceInfo, err := i.getInstanceInfo(context.TODO(), tpm, extraData) // TODO(daniel-weisse): update Issue/Validate to use context + instanceInfo, err := i.getInstanceInfo(ctx, tpm, extraData) if err != nil { return nil, fmt.Errorf("fetching instance info: %w", err) } @@ -181,7 +181,7 @@ func NewValidator(expected measurements.M, getTrustedKey GetTPMTrustedAttestatio } // Validate a TPM based attestation. -func (v *Validator) Validate(attDocRaw []byte, nonce []byte) (userData []byte, err error) { +func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte) (userData []byte, err error) { v.log.Infof("Validating attestation document") defer func() { if err != nil { @@ -197,7 +197,7 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte) (userData []byte, e extraData := makeExtraData(attDoc.UserData, nonce) // Verify and retrieve the trusted attestation public key using the provided instance info - aKP, err := v.getTrustedKey(context.TODO(), attDoc, extraData) + aKP, err := v.getTrustedKey(ctx, attDoc, extraData) if err != nil { return nil, fmt.Errorf("validating attestation public key: %w", err) } diff --git a/internal/attestation/vtpm/attestation_test.go b/internal/attestation/vtpm/attestation_test.go index 5dd64247e..db5238880 100644 --- a/internal/attestation/vtpm/attestation_test.go +++ b/internal/attestation/vtpm/attestation_test.go @@ -84,7 +84,9 @@ func TestValidate(t *testing.T) { nonce := []byte{1, 2, 3, 4} challenge := []byte("Constellation") - attDocRaw, err := issuer.Issue(challenge, nonce) + ctx := context.Background() + + attDocRaw, err := issuer.Issue(ctx, challenge, nonce) require.NoError(err) var attDoc AttestationDocument @@ -93,26 +95,26 @@ func TestValidate(t *testing.T) { require.Equal(challenge, attDoc.UserData) // valid test - out, err := validator.Validate(attDocRaw, nonce) + out, err := validator.Validate(ctx, attDocRaw, nonce) require.NoError(err) require.Equal(challenge, out) // validation must fail after bootstrapping (change of enforced PCR) require.NoError(MarkNodeAsBootstrapped(tpmOpen, []byte{2})) - attDocBootstrappedRaw, err := issuer.Issue(challenge, nonce) + attDocBootstrappedRaw, err := issuer.Issue(ctx, challenge, nonce) require.NoError(err) - _, err = validator.Validate(attDocBootstrappedRaw, nonce) + _, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce) require.Error(err) // userData must be bound to PCR state - attDocBootstrappedRaw, err = issuer.Issue([]byte{2, 3}, nonce) + attDocBootstrappedRaw, err = issuer.Issue(ctx, []byte{2, 3}, nonce) require.NoError(err) var attDocBootstrapped AttestationDocument require.NoError(json.Unmarshal(attDocBootstrappedRaw, &attDocBootstrapped)) attDocBootstrapped.Attestation = attDoc.Attestation attDocBootstrappedRaw, err = json.Marshal(attDocBootstrapped) require.NoError(err) - _, err = validator.Validate(attDocBootstrappedRaw, nonce) + _, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce) require.Error(err) expectedPCRs := measurements.M{ @@ -141,7 +143,7 @@ func TestValidate(t *testing.T) { fakeValidateCVM, warnLog, ) - out, err = warningValidator.Validate(attDocRaw, nonce) + out, err = warningValidator.Validate(ctx, attDocRaw, nonce) require.NoError(err) assert.Equal(t, challenge, out) assert.Len(t, warnLog.warnings, 4) @@ -240,7 +242,7 @@ func TestValidate(t *testing.T) { t.Run(name, func(t *testing.T) { assert := assert.New(t) - _, err = tc.validator.Validate(tc.attDoc, tc.nonce) + _, err = tc.validator.Validate(ctx, tc.attDoc, tc.nonce) if tc.wantErr { assert.Error(err) } else { @@ -316,7 +318,7 @@ func TestFailIssuer(t *testing.T) { tc.issuer.log = logger.NewTest(t) - _, err := tc.issuer.Issue(tc.userData, tc.nonce) + _, err := tc.issuer.Issue(context.Background(), tc.userData, tc.nonce) assert.Error(err) }) } diff --git a/internal/grpc/atlscredentials/atlscredentials_test.go b/internal/grpc/atlscredentials/atlscredentials_test.go index f52ecdfa9..5c8a072ef 100644 --- a/internal/grpc/atlscredentials/atlscredentials_test.go +++ b/internal/grpc/atlscredentials/atlscredentials_test.go @@ -86,7 +86,7 @@ type fakeIssuer struct { fakeOID } -func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { +func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) { return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) } @@ -95,7 +95,7 @@ type fakeValidator struct { err error } -func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { +func (v fakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) { var doc fakeDoc if err := json.Unmarshal(attDoc, &doc); err != nil { return nil, err diff --git a/internal/watcher/validator.go b/internal/watcher/validator.go index 1ebced7a5..5113f47f7 100644 --- a/internal/watcher/validator.go +++ b/internal/watcher/validator.go @@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only package watcher import ( + "context" "encoding/asn1" "encoding/json" "errors" @@ -49,10 +50,10 @@ func NewValidator(log *logger.Logger, variant oid.Getter, fileHandler file.Handl } // Validate calls the validators Validate method, and prevents any updates during the call. -func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) { +func (u *Updatable) Validate(ctx context.Context, attDoc []byte, nonce []byte) ([]byte, error) { u.mux.Lock() defer u.mux.Unlock() - return u.Validator.Validate(attDoc, nonce) + return u.Validator.Validate(ctx, attDoc, nonce) } // OID returns the validators Object Identifier. diff --git a/internal/watcher/validator_test.go b/internal/watcher/validator_test.go index d8755a034..4264cd6d2 100644 --- a/internal/watcher/validator_test.go +++ b/internal/watcher/validator_test.go @@ -270,7 +270,7 @@ type fakeIssuer struct { oid.Getter } -func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { +func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) { return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) } diff --git a/verify/server/server.go b/verify/server/server.go index 85c26ff69..39790b947 100644 --- a/verify/server/server.go +++ b/verify/server/server.go @@ -111,7 +111,7 @@ func (s *Server) GetAttestation(ctx context.Context, req *verifyproto.GetAttesta } log.Infof("Creating attestation") - statement, err := s.issuer.Issue([]byte(constants.ConstellationVerifyServiceUserData), req.Nonce) + statement, err := s.issuer.Issue(ctx, []byte(constants.ConstellationVerifyServiceUserData), req.Nonce) if err != nil { return nil, status.Errorf(codes.Internal, "issuing attestation statement: %v", err) } @@ -139,7 +139,7 @@ func (s *Server) getAttestationHTTP(w http.ResponseWriter, r *http.Request) { } log.Infof("Creating attestation") - quote, err := s.issuer.Issue([]byte(constants.ConstellationVerifyServiceUserData), nonce) + quote, err := s.issuer.Issue(r.Context(), []byte(constants.ConstellationVerifyServiceUserData), nonce) if err != nil { http.Error(w, fmt.Sprintf("issuing attestation statement: %v", err), http.StatusInternalServerError) return @@ -154,5 +154,5 @@ func (s *Server) getAttestationHTTP(w http.ResponseWriter, r *http.Request) { // AttestationIssuer issues an attestation document for the provided userData and nonce. type AttestationIssuer interface { - Issue(userData []byte, nonce []byte) (quote []byte, err error) + Issue(ctx context.Context, userData []byte, nonce []byte) (quote []byte, err error) } diff --git a/verify/server/server_test.go b/verify/server/server_test.go index d170db612..16f84d5c5 100644 --- a/verify/server/server_test.go +++ b/verify/server/server_test.go @@ -197,6 +197,6 @@ type stubIssuer struct { issueErr error } -func (i stubIssuer) Issue(_ []byte, _ []byte) ([]byte, error) { +func (i stubIssuer) Issue(_ context.Context, _ []byte, _ []byte) ([]byte, error) { return i.attestation, i.issueErr }