attestation: add context to Issue and Validate methods (#1532)

Signed-off-by: Daniel Weiße <dw@edgeless.systems>
This commit is contained in:
Daniel Weiße 2023-03-29 09:06:10 +02:00 committed by GitHub
parent 7c27d67953
commit db5660e3d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 43 additions and 34 deletions

View File

@ -464,7 +464,7 @@ type testValidator struct {
pcrs measurements.M
}
func (v *testValidator) Validate(attDoc []byte, _ []byte) ([]byte, error) {
func (v *testValidator) Validate(_ context.Context, attDoc []byte, _ []byte) ([]byte, error) {
var attestation struct {
UserData []byte
PCRs map[uint32][]byte
@ -486,7 +486,7 @@ type testIssuer struct {
pcrs map[uint32][]byte
}
func (i *testIssuer) Issue(userData []byte, _ []byte) ([]byte, error) {
func (i *testIssuer) Issue(_ context.Context, userData []byte, _ []byte) ([]byte, error) {
return json.Marshal(
struct {
UserData []byte

View File

@ -232,7 +232,7 @@ func (v *constellationVerifier) Verify(
}
v.log.Debugf("Verifying attestation")
signedData, err := validator.Validate(resp.Attestation, req.Nonce)
signedData, err := validator.Validate(ctx, resp.Attestation, req.Nonce)
if err != nil {
return fmt.Errorf("validating attestation: %w", err)
}

View File

@ -9,6 +9,7 @@ package atls
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
@ -28,6 +29,8 @@ import (
"github.com/edgelesssys/constellation/v2/internal/oid"
)
const attestationTimeout = 30 * time.Second
// CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document.
// Pass a list of validators to enable mutual aTLS.
// If issuer is nil, no attestation will be embedded.
@ -73,13 +76,13 @@ func CreateAttestationClientTLSConfig(issuer Issuer, validators []Validator) (*t
// Issuer issues an attestation document.
type Issuer interface {
oid.Getter
Issue(userData []byte, nonce []byte) (quote []byte, err error)
Issue(ctx context.Context, userData []byte, nonce []byte) (quote []byte, err error)
}
// Validator is able to validate an attestation document.
type Validator interface {
oid.Getter
Validate(attDoc []byte, nonce []byte) ([]byte, error)
Validate(ctx context.Context, attDoc []byte, nonce []byte) ([]byte, error)
}
// getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server.
@ -129,7 +132,7 @@ func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tl
// getCertificate creates a client or server certificate for aTLS connections.
// The certificate uses certificate extensions to embed an attestation document generated using nonce.
func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) {
func getCertificate(ctx context.Context, issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) {
serialNumber, err := crypto.GenerateCertificateSerialNumber()
if err != nil {
return nil, err
@ -145,7 +148,7 @@ func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificat
}
// create attestation document using the nonce send by the remote party
attDoc, err := issuer.Issue(hash, nonce)
attDoc, err := issuer.Issue(ctx, hash, nonce)
if err != nil {
return nil, err
}
@ -200,7 +203,10 @@ func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, hash,
for _, ex := range cert.Extensions {
for _, validator := range validators {
if ex.Id.Equal(validator.OID()) {
userData, err := validator.Validate(ex.Value, nonce)
ctx, cancel := context.WithTimeout(context.Background(), attestationTimeout)
defer cancel()
userData, err := validator.Validate(ctx, ex.Value, nonce)
if err != nil {
return err
}
@ -308,7 +314,7 @@ func (c *clientConnection) getCertificate(cri *tls.CertificateRequestInfo) (*tls
return nil, fmt.Errorf("decode nonce: %w", err)
}
return getCertificate(c.issuer, priv, &priv.PublicKey, serverNonce)
return getCertificate(cri.Context(), c.issuer, priv, &priv.PublicKey, serverNonce)
}
// serverConnection holds state for server to client connections.
@ -340,7 +346,7 @@ func (c *serverConnection) getCertificate(chi *tls.ClientHelloInfo) (*tls.Certif
}
// create aTLS certificate using the nonce as extracted from the client-hello message
return getCertificate(c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce)
return getCertificate(chi.Context(), c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce)
}
// FakeIssuer fakes an issuer and can be used for tests.
@ -354,7 +360,7 @@ func NewFakeIssuer(oid oid.Getter) *FakeIssuer {
}
// Issue marshals the user data and returns it.
func (FakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
func (FakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(FakeAttestationDoc{UserData: userData, Nonce: nonce})
}
@ -375,7 +381,7 @@ func NewFakeValidators(oid oid.Getter) []Validator {
}
// Validate unmarshals the attestation document and verifies the nonce.
func (v FakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
func (v FakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
var doc FakeAttestationDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err

View File

@ -107,7 +107,7 @@ func NewIssuer(
}
// Issue generates an attestation document using a TPM.
func (i *Issuer) Issue(userData []byte, nonce []byte) (res []byte, err error) {
func (i *Issuer) Issue(ctx context.Context, userData []byte, nonce []byte) (res []byte, err error) {
i.log.Infof("Issuing attestation statement")
defer func() {
if err != nil {
@ -136,7 +136,7 @@ func (i *Issuer) Issue(userData []byte, nonce []byte) (res []byte, err error) {
}
// Fetch instance info of the VM
instanceInfo, err := i.getInstanceInfo(context.TODO(), tpm, extraData) // TODO(daniel-weisse): update Issue/Validate to use context
instanceInfo, err := i.getInstanceInfo(ctx, tpm, extraData)
if err != nil {
return nil, fmt.Errorf("fetching instance info: %w", err)
}
@ -181,7 +181,7 @@ func NewValidator(expected measurements.M, getTrustedKey GetTPMTrustedAttestatio
}
// Validate a TPM based attestation.
func (v *Validator) Validate(attDocRaw []byte, nonce []byte) (userData []byte, err error) {
func (v *Validator) Validate(ctx context.Context, attDocRaw []byte, nonce []byte) (userData []byte, err error) {
v.log.Infof("Validating attestation document")
defer func() {
if err != nil {
@ -197,7 +197,7 @@ func (v *Validator) Validate(attDocRaw []byte, nonce []byte) (userData []byte, e
extraData := makeExtraData(attDoc.UserData, nonce)
// Verify and retrieve the trusted attestation public key using the provided instance info
aKP, err := v.getTrustedKey(context.TODO(), attDoc, extraData)
aKP, err := v.getTrustedKey(ctx, attDoc, extraData)
if err != nil {
return nil, fmt.Errorf("validating attestation public key: %w", err)
}

View File

@ -84,7 +84,9 @@ func TestValidate(t *testing.T) {
nonce := []byte{1, 2, 3, 4}
challenge := []byte("Constellation")
attDocRaw, err := issuer.Issue(challenge, nonce)
ctx := context.Background()
attDocRaw, err := issuer.Issue(ctx, challenge, nonce)
require.NoError(err)
var attDoc AttestationDocument
@ -93,26 +95,26 @@ func TestValidate(t *testing.T) {
require.Equal(challenge, attDoc.UserData)
// valid test
out, err := validator.Validate(attDocRaw, nonce)
out, err := validator.Validate(ctx, attDocRaw, nonce)
require.NoError(err)
require.Equal(challenge, out)
// validation must fail after bootstrapping (change of enforced PCR)
require.NoError(MarkNodeAsBootstrapped(tpmOpen, []byte{2}))
attDocBootstrappedRaw, err := issuer.Issue(challenge, nonce)
attDocBootstrappedRaw, err := issuer.Issue(ctx, challenge, nonce)
require.NoError(err)
_, err = validator.Validate(attDocBootstrappedRaw, nonce)
_, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce)
require.Error(err)
// userData must be bound to PCR state
attDocBootstrappedRaw, err = issuer.Issue([]byte{2, 3}, nonce)
attDocBootstrappedRaw, err = issuer.Issue(ctx, []byte{2, 3}, nonce)
require.NoError(err)
var attDocBootstrapped AttestationDocument
require.NoError(json.Unmarshal(attDocBootstrappedRaw, &attDocBootstrapped))
attDocBootstrapped.Attestation = attDoc.Attestation
attDocBootstrappedRaw, err = json.Marshal(attDocBootstrapped)
require.NoError(err)
_, err = validator.Validate(attDocBootstrappedRaw, nonce)
_, err = validator.Validate(ctx, attDocBootstrappedRaw, nonce)
require.Error(err)
expectedPCRs := measurements.M{
@ -141,7 +143,7 @@ func TestValidate(t *testing.T) {
fakeValidateCVM,
warnLog,
)
out, err = warningValidator.Validate(attDocRaw, nonce)
out, err = warningValidator.Validate(ctx, attDocRaw, nonce)
require.NoError(err)
assert.Equal(t, challenge, out)
assert.Len(t, warnLog.warnings, 4)
@ -240,7 +242,7 @@ func TestValidate(t *testing.T) {
t.Run(name, func(t *testing.T) {
assert := assert.New(t)
_, err = tc.validator.Validate(tc.attDoc, tc.nonce)
_, err = tc.validator.Validate(ctx, tc.attDoc, tc.nonce)
if tc.wantErr {
assert.Error(err)
} else {
@ -316,7 +318,7 @@ func TestFailIssuer(t *testing.T) {
tc.issuer.log = logger.NewTest(t)
_, err := tc.issuer.Issue(tc.userData, tc.nonce)
_, err := tc.issuer.Issue(context.Background(), tc.userData, tc.nonce)
assert.Error(err)
})
}

View File

@ -86,7 +86,7 @@ type fakeIssuer struct {
fakeOID
}
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}
@ -95,7 +95,7 @@ type fakeValidator struct {
err error
}
func (v fakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
func (v fakeValidator) Validate(_ context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
var doc fakeDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err

View File

@ -7,6 +7,7 @@ SPDX-License-Identifier: AGPL-3.0-only
package watcher
import (
"context"
"encoding/asn1"
"encoding/json"
"errors"
@ -49,10 +50,10 @@ func NewValidator(log *logger.Logger, variant oid.Getter, fileHandler file.Handl
}
// Validate calls the validators Validate method, and prevents any updates during the call.
func (u *Updatable) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
func (u *Updatable) Validate(ctx context.Context, attDoc []byte, nonce []byte) ([]byte, error) {
u.mux.Lock()
defer u.mux.Unlock()
return u.Validator.Validate(attDoc, nonce)
return u.Validator.Validate(ctx, attDoc, nonce)
}
// OID returns the validators Object Identifier.

View File

@ -270,7 +270,7 @@ type fakeIssuer struct {
oid.Getter
}
func (fakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
func (fakeIssuer) Issue(_ context.Context, userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(fakeDoc{UserData: userData, Nonce: nonce})
}

View File

@ -111,7 +111,7 @@ func (s *Server) GetAttestation(ctx context.Context, req *verifyproto.GetAttesta
}
log.Infof("Creating attestation")
statement, err := s.issuer.Issue([]byte(constants.ConstellationVerifyServiceUserData), req.Nonce)
statement, err := s.issuer.Issue(ctx, []byte(constants.ConstellationVerifyServiceUserData), req.Nonce)
if err != nil {
return nil, status.Errorf(codes.Internal, "issuing attestation statement: %v", err)
}
@ -139,7 +139,7 @@ func (s *Server) getAttestationHTTP(w http.ResponseWriter, r *http.Request) {
}
log.Infof("Creating attestation")
quote, err := s.issuer.Issue([]byte(constants.ConstellationVerifyServiceUserData), nonce)
quote, err := s.issuer.Issue(r.Context(), []byte(constants.ConstellationVerifyServiceUserData), nonce)
if err != nil {
http.Error(w, fmt.Sprintf("issuing attestation statement: %v", err), http.StatusInternalServerError)
return
@ -154,5 +154,5 @@ func (s *Server) getAttestationHTTP(w http.ResponseWriter, r *http.Request) {
// AttestationIssuer issues an attestation document for the provided userData and nonce.
type AttestationIssuer interface {
Issue(userData []byte, nonce []byte) (quote []byte, err error)
Issue(ctx context.Context, userData []byte, nonce []byte) (quote []byte, err error)
}

View File

@ -197,6 +197,6 @@ type stubIssuer struct {
issueErr error
}
func (i stubIssuer) Issue(_ []byte, _ []byte) ([]byte, error) {
func (i stubIssuer) Issue(_ context.Context, _ []byte, _ []byte) ([]byte, error) {
return i.attestation, i.issueErr
}