attestation: add context to Issue and Validate methods (#1532)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-03-29 09:06:10 +02:00 committed by GitHub
parent 7c27d67953
commit db5660e3d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 43 additions and 34 deletions

View File

@ -464,7 +464,7 @@ type testValidator struct {
pcrs measurements.M pcrs measurements.M
} }
func (v *testValidator) Validate(attDoc []byte, _ []byte) ([]byte, error) { func (v *testValidator) Validate(_ context.Context, attDoc []byte, _ []byte) ([]byte, error) {
var attestation struct { var attestation struct {
UserData []byte UserData []byte
PCRs map[uint32][]byte PCRs map[uint32][]byte
@ -486,7 +486,7 @@ type testIssuer struct {
pcrs map[uint32][]byte pcrs map[uint32][]byte
} }
func (i *testIssuer) Issue(userData []byte, _ []byte) ([]byte, error) { func (i *testIssuer) Issue(_ context.Context, userData []byte, _ []byte) ([]byte, error) {
return json.Marshal( return json.Marshal(
struct { struct {
UserData []byte UserData []byte

View File

@ -232,7 +232,7 @@ func (v *constellationVerifier) Verify(
} }
v.log.Debugf("Verifying attestation") v.log.Debugf("Verifying attestation")
signedData, err := validator.Validate(resp.Attestation, req.Nonce) signedData, err := validator.Validate(ctx, resp.Attestation, req.Nonce)
if err != nil { if err != nil {
return fmt.Errorf("validating attestation: %w", err) return fmt.Errorf("validating attestation: %w", err)
} }

View File

@ -9,6 +9,7 @@ package atls
import ( import (
"bytes" "bytes"
"context"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
@ -28,6 +29,8 @@ import (
"github.com/edgelesssys/constellation/v2/internal/oid" "github.com/edgelesssys/constellation/v2/internal/oid"
) )
const attestationTimeout = 30 * time.Second
// CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document. // CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document.
// Pass a list of validators to enable mutual aTLS. // Pass a list of validators to enable mutual aTLS.
// If issuer is nil, no attestation will be embedded. // If issuer is nil, no attestation will be embedded.
@ -73,13 +76,13 @@ func CreateAttestationClientTLSConfig(issuer Issuer, validators []Validator) (*t
// Issuer issues an attestation document. // Issuer issues an attestation document.
type Issuer interface { type Issuer interface {
oid.Getter oid.Getter
Issue(userData []byte, nonce []byte) (quote []byte, err error) Issue(ctx context.Context, userData []byte, nonce []byte) (quote []byte, err error)
} }
// Validator is able to validate an attestation document. // Validator is able to validate an attestation document.
type Validator interface { type Validator interface {
oid.Getter oid.Getter
Validate(attDoc []byte, nonce []byte) ([]byte, error) Validate(ctx context.Context, attDoc []byte, nonce []byte) ([]byte, error)
} }
// getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server. // getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server.
@ -129,7 +132,7 @@ func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tl
// getCertificate creates a client or server certificate for aTLS connections. // getCertificate creates a client or server certificate for aTLS connections.
// The certificate uses certificate extensions to embed an attestation document generated using nonce. // The certificate uses certificate extensions to embed an attestation document generated using nonce.
func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) { func getCertificate(ctx context.Context, issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) {
serialNumber, err := crypto.GenerateCertificateSerialNumber() serialNumber, err := crypto.GenerateCertificateSerialNumber()
if err != nil { if err != nil {
return nil, err return nil, err
@ -145,7 +148,7 @@ func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificat
} }
// create attestation document using the nonce send by the remote party // create attestation document using the nonce send by the remote party
attDoc, err := issuer.Issue(hash, nonce) attDoc, err := issuer.Issue(ctx, hash, nonce)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -200,7 +203,10 @@ func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, hash,
for _, ex := range cert.Extensions { for _, ex := range cert.Extensions {
for _, validator := range validators { for _, validator := range validators {
if ex.Id.Equal(validator.OID()) { if ex.Id.Equal(validator.OID()) {
userData, err := validator.Validate(ex.Value, nonce) ctx, cancel := context.WithTimeout(context.Background(), attestationTimeout)
defer cancel()
userData, err := validator.Validate(ctx, ex.Value, nonce)
if err != nil { if err != nil {
return err return err
} }
@ -308,7 +314,7 @@ func (c *clientConnection) getCertificate(cri *tls.CertificateRequestInfo) (*tls
return nil, fmt.Errorf("decode nonce: %w", err) return nil, fmt.Errorf("decode nonce: %w", err)
} }
return getCertificate(c.issuer, priv, &priv.PublicKey, serverNonce) return getCertificate(cri.Context(), c.issuer, priv, &priv.PublicKey, serverNonce)
} }
// serverConnection holds state for server to client connections. // serverConnection holds state for server to client connections.
@ -340,7 +346,7 @@ func (c *serverConnection) getCertificate(chi *tls.ClientHelloInfo) (*tls.Certif
} }
// create aTLS certificate using the nonce as extracted from the client-hello message // create aTLS certificate using the nonce as extracted from the client-hello message
return getCertificate(c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce) return getCertificate(chi.Context(), c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce)
} }
// FakeIssuer fakes an issuer and can be used for tests. // FakeIssuer fakes an issuer and can be used for tests.
@ -354,7 +360,7 @@ func NewFakeIssuer(oid oid.Getter) *FakeIssuer {
} }
// Issue marshals the user data and returns it. // Issue marshals the user data and returns it.
func (FakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { func (FakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(FakeAttestationDoc{UserData: userData, Nonce: nonce}) return json.Marshal(FakeAttestationDoc{UserData: userData, Nonce: nonce})
} }
@ -375,7 +381,7 @@ func NewFakeValidators(oid oid.Getter) []Validator {
} }
// Validate unmarshals the attestation document and verifies the nonce. // Validate unmarshals the attestation document and verifies the nonce.
func (v FakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { func (v FakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
var doc FakeAttestationDoc var doc FakeAttestationDoc
if err := json.Unmarshal(attDoc, &doc); err != nil { if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err return nil, err

View File

@ -107,7 +107,7 @@ func NewIssuer(
} }
// Issue generates an attestation document using a TPM. // Issue generates an attestation document using a TPM.
func (i *Issuer) Issue(userData []byte, nonce []byte) (res []byte, err error) { func (i *Issuer) Issue(ctx context.Context, userData []byte, nonce []byte) (res []byte, err error) {
i.log.Infof("Issuing attestation statement") i.log.Infof("Issuing attestation statement")
defer func() { defer func() {
if err != nil { if err != nil {
@ -136,7 +136,7 @@ func (i *Issuer) Issue(userData []byte, nonce []byte) (res []byte, err error) {
} }
// Fetch instance info of the VM // Fetch instance info of the VM
instanceInfo, err := i.getInstanceInfo(context.TODO(), tpm, extraData) // TODO(daniel-weisse): update Issue/Validate to use context instanceInfo, err := i.getInstanceInfo(ctx, tpm, extraData)
if err != nil { if err != nil {
return nil, fmt.Errorf("fetching instance info: %w", err) return nil, fmt.Errorf("fetching instance info: %w", err)
} }
@ -181,7 +181,7 @@ func NewValidator(expected measurements.M, getTrustedKey GetTPMTrustedAttestatio
} }
// Validate a TPM based attestation. // Validate a TPM based attestation.
func (v *Validator) Validate(attDocRaw []byte, nonce []byte) (userData []byte, err error) { func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte) (userData []byte, err error) {
v.log.Infof("Validating attestation document") v.log.Infof("Validating attestation document")
defer func() { defer func() {
if err != nil { if err != nil {
@ -197,7 +197,7 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte) (userData []byte, e
extraData := makeExtraData(attDoc.UserData, nonce) extraData := makeExtraData(attDoc.UserData, nonce)
// Verify and retrieve the trusted attestation public key using the provided instance info // Verify and retrieve the trusted attestation public key using the provided instance info
aKP, err := v.getTrustedKey(context.TODO(), attDoc, extraData) aKP, err := v.getTrustedKey(ctx, attDoc, extraData)
if err != nil { if err != nil {
return nil, fmt.Errorf("validating attestation public key: %w", err) return nil, fmt.Errorf("validating attestation public key: %w", err)
} }

View File

@ -84,7 +84,9 @@ func TestValidate(t *testing.T) {
nonce := []byte{1, 2, 3, 4} nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation") challenge := []byte("Constellation")
attDocRaw, err := issuer.Issue(challenge, nonce) ctx := context.Background()
attDocRaw, err := issuer.Issue(ctx, challenge, nonce)
require.NoError(err) require.NoError(err)
var attDoc AttestationDocument var attDoc AttestationDocument
@ -93,26 +95,26 @@ func TestValidate(t *testing.T) {
require.Equal(challenge, attDoc.UserData) require.Equal(challenge, attDoc.UserData)
// valid test // valid test
out, err := validator.Validate(attDocRaw, nonce) out, err := validator.Validate(ctx, attDocRaw, nonce)
require.NoError(err) require.NoError(err)
require.Equal(challenge, out) require.Equal(challenge, out)
// validation must fail after bootstrapping (change of enforced PCR) // validation must fail after bootstrapping (change of enforced PCR)
require.NoError(MarkNodeAsBootstrapped(tpmOpen, []byte{2})) require.NoError(MarkNodeAsBootstrapped(tpmOpen, []byte{2}))
attDocBootstrappedRaw, err := issuer.Issue(challenge, nonce) attDocBootstrappedRaw, err := issuer.Issue(ctx, challenge, nonce)
require.NoError(err) require.NoError(err)
_, err = validator.Validate(attDocBootstrappedRaw, nonce) _, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce)
require.Error(err) require.Error(err)
// userData must be bound to PCR state // userData must be bound to PCR state
attDocBootstrappedRaw, err = issuer.Issue([]byte{2, 3}, nonce) attDocBootstrappedRaw, err = issuer.Issue(ctx, []byte{2, 3}, nonce)
require.NoError(err) require.NoError(err)
var attDocBootstrapped AttestationDocument var attDocBootstrapped AttestationDocument
require.NoError(json.Unmarshal(attDocBootstrappedRaw, &attDocBootstrapped)) require.NoError(json.Unmarshal(attDocBootstrappedRaw, &attDocBootstrapped))
attDocBootstrapped.Attestation = attDoc.Attestation attDocBootstrapped.Attestation = attDoc.Attestation
attDocBootstrappedRaw, err = json.Marshal(attDocBootstrapped) attDocBootstrappedRaw, err = json.Marshal(attDocBootstrapped)
require.NoError(err) require.NoError(err)
_, err = validator.Validate(attDocBootstrappedRaw, nonce) _, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce)
require.Error(err) require.Error(err)
expectedPCRs := measurements.M{ expectedPCRs := measurements.M{
@ -141,7 +143,7 @@ func TestValidate(t *testing.T) {
fakeValidateCVM, fakeValidateCVM,
warnLog, warnLog,
) )
out, err = warningValidator.Validate(attDocRaw, nonce) out, err = warningValidator.Validate(ctx, attDocRaw, nonce)
require.NoError(err) require.NoError(err)
assert.Equal(t, challenge, out) assert.Equal(t, challenge, out)
assert.Len(t, warnLog.warnings, 4) assert.Len(t, warnLog.warnings, 4)
@ -240,7 +242,7 @@ func TestValidate(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
assert := assert.New(t) assert := assert.New(t)
_, err = tc.validator.Validate(tc.attDoc, tc.nonce) _, err = tc.validator.Validate(ctx, tc.attDoc, tc.nonce)
if tc.wantErr { if tc.wantErr {
assert.Error(err) assert.Error(err)
} else { } else {
@ -316,7 +318,7 @@ func TestFailIssuer(t *testing.T) {
tc.issuer.log = logger.NewTest(t) tc.issuer.log = logger.NewTest(t)
_, err := tc.issuer.Issue(tc.userData, tc.nonce) _, err := tc.issuer.Issue(context.Background(), tc.userData, tc.nonce)
assert.Error(err) assert.Error(err)
}) })
} }

View File

@ -86,7 +86,7 @@ type fakeIssuer struct {
fakeOID fakeOID
} }
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
} }
@ -95,7 +95,7 @@ type fakeValidator struct {
err error err error
} }
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) { func (v fakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
var doc fakeDoc var doc fakeDoc
if err := json.Unmarshal(attDoc, &doc); err != nil { if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err return nil, err

View File

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package watcher package watcher
import ( import (
"context"
"encoding/asn1" "encoding/asn1"
"encoding/json" "encoding/json"
"errors" "errors"
@ -49,10 +50,10 @@ func NewValidator(log *logger.Logger, variant oid.Getter, fileHandler file.Handl
} }
// Validate calls the validators Validate method, and prevents any updates during the call. // Validate calls the validators Validate method, and prevents any updates during the call.
func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) { func (u *Updatable) Validate(ctx context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
u.mux.Lock() u.mux.Lock()
defer u.mux.Unlock() defer u.mux.Unlock()
return u.Validator.Validate(attDoc, nonce) return u.Validator.Validate(ctx, attDoc, nonce)
} }
// OID returns the validators Object Identifier. // OID returns the validators Object Identifier.

View File

@ -270,7 +270,7 @@ type fakeIssuer struct {
oid.Getter oid.Getter
} }
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) { func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce}) return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
} }

View File

@ -111,7 +111,7 @@ func (s *Server) GetAttestation(ctx context.Context, req *verifyproto.GetAttesta
} }
log.Infof("Creating attestation") log.Infof("Creating attestation")
statement, err := s.issuer.Issue([]byte(constants.ConstellationVerifyServiceUserData), req.Nonce) statement, err := s.issuer.Issue(ctx, []byte(constants.ConstellationVerifyServiceUserData), req.Nonce)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "issuing attestation statement: %v", err) return nil, status.Errorf(codes.Internal, "issuing attestation statement: %v", err)
} }
@ -139,7 +139,7 @@ func (s *Server) getAttestationHTTP(w http.ResponseWriter, r *http.Request) {
} }
log.Infof("Creating attestation") log.Infof("Creating attestation")
quote, err := s.issuer.Issue([]byte(constants.ConstellationVerifyServiceUserData), nonce) quote, err := s.issuer.Issue(r.Context(), []byte(constants.ConstellationVerifyServiceUserData), nonce)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("issuing attestation statement: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("issuing attestation statement: %v", err), http.StatusInternalServerError)
return return
@ -154,5 +154,5 @@ func (s *Server) getAttestationHTTP(w http.ResponseWriter, r *http.Request) {
// AttestationIssuer issues an attestation document for the provided userData and nonce. // AttestationIssuer issues an attestation document for the provided userData and nonce.
type AttestationIssuer interface { type AttestationIssuer interface {
Issue(userData []byte, nonce []byte) (quote []byte, err error) Issue(ctx context.Context, userData []byte, nonce []byte) (quote []byte, err error)
} }

View File

@ -197,6 +197,6 @@ type stubIssuer struct {
issueErr error issueErr error
} }
func (i stubIssuer) Issue(_ []byte, _ []byte) ([]byte, error) { func (i stubIssuer) Issue(_ context.Context, _ []byte, _ []byte) ([]byte, error) {
return i.attestation, i.issueErr return i.attestation, i.issueErr
} }