Thomas Tendyck bd63aa3c6b add license headers
sed -i '1i/*\nCopyright (c) Edgeless Systems GmbH\n\nSPDX-License-Identifier: AGPL-3.0-only\n*/\n' `grep -rL --include='*.go' 'DO NOT EDIT'`
gofumpt -w .
2022-09-05 09:17:25 +02:00

401 lines
12 KiB
Go

/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
package atls
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"math/big"
"time"
"github.com/edgelesssys/constellation/internal/crypto"
"github.com/edgelesssys/constellation/internal/oid"
)
// CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document.
// Pass a list of validators to enable mutual aTLS.
// If issuer is nil, no attestation will be embedded.
func CreateAttestationServerTLSConfig(issuer Issuer, validators []Validator) (*tls.Config, error) {
getConfigForClient, err := getATLSConfigForClientFunc(issuer, validators)
if err != nil {
return nil, err
}
return &tls.Config{
GetConfigForClient: getConfigForClient,
}, nil
}
// CreateAttestationClientTLSConfig creates a tls.Config object that verifies a certificate with an embedded attestation document.
//
// ATTENTION: The tls.Config ensures freshness of the server's attestation only for the first connection it is used for.
// If freshness is required, you must create a new tls.Config for each connection or ensure freshness on the protocol level.
// If freshness is not required, you can reuse this tls.Config.
//
// If no validators are set, the server's attestation document will not be verified.
// If issuer is nil, the client will be unable to perform mutual aTLS.
func CreateAttestationClientTLSConfig(issuer Issuer, validators []Validator) (*tls.Config, error) {
clientNonce, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return nil, err
}
clientConn := &clientConnection{
issuer: issuer,
validators: validators,
clientNonce: clientNonce,
}
return &tls.Config{
VerifyPeerCertificate: clientConn.verify,
GetClientCertificate: clientConn.getCertificate, // use custom certificate for mutual aTLS connections
InsecureSkipVerify: true, // disable default verification because we use our own verify func
ServerName: base64.StdEncoding.EncodeToString(clientNonce), // abuse ServerName as a channel to transmit the nonce
MinVersion: tls.VersionTLS12,
}, nil
}
type Issuer interface {
oid.Getter
Issue(userData []byte, nonce []byte) (quote []byte, err error)
}
type Validator interface {
oid.Getter
Validate(attDoc []byte, nonce []byte) ([]byte, error)
}
// getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server.
// This allows for different server configuration for every client.
// In aTLS this is used to generate unique nonces for every client.
func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tls.ClientHelloInfo) (*tls.Config, error), error) {
// generate key for the server
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
// this function will be called once for every client
return func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
// generate nonce for this connection
serverNonce, err := crypto.GenerateRandomBytes(crypto.RNGLengthDefault)
if err != nil {
return nil, err
}
serverConn := &serverConnection{
privKey: priv,
issuer: issuer,
validators: validators,
serverNonce: serverNonce,
}
cfg := &tls.Config{
VerifyPeerCertificate: serverConn.verify,
GetCertificate: serverConn.getCertificate,
MinVersion: tls.VersionTLS12,
}
// enable mutual aTLS if any validators are set
if len(validators) > 0 {
cfg.ClientAuth = tls.RequireAnyClientCert // validity of certificate will be checked by our custom verify function
// ugly hack: abuse acceptable client CAs as a channel to transmit the nonce
if cfg.ClientCAs, err = encodeNonceToCertPool(serverNonce, priv); err != nil {
return nil, fmt.Errorf("encode nonce: %w", err)
}
}
return cfg, nil
}, nil
}
// getCertificate creates a client or server certificate for aTLS connections.
// The certificate uses certificate extensions to embed an attestation document generated using nonce.
func getCertificate(issuer Issuer, priv, pub any, nonce []byte) (*tls.Certificate, error) {
serialNumber, err := crypto.GenerateCertificateSerialNumber()
if err != nil {
return nil, err
}
var extensions []pkix.Extension
// create and embed attestation if quote Issuer is available
if issuer != nil {
hash, err := hashPublicKey(pub)
if err != nil {
return nil, err
}
// create attestation document using the nonce send by the remote party
attDoc, err := issuer.Issue(hash, nonce)
if err != nil {
return nil, err
}
extensions = append(extensions, pkix.Extension{Id: issuer.OID(), Value: attDoc})
}
// create certificate that includes the attestation document as extension
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{CommonName: "Constellation"},
NotBefore: now.Add(-2 * time.Hour),
NotAfter: now.Add(2 * time.Hour),
ExtraExtensions: extensions,
}
cert, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv)
if err != nil {
return nil, err
}
return &tls.Certificate{Certificate: [][]byte{cert}, PrivateKey: priv}, nil
}
// processCertificate parses the certificate and verifies it.
// If successful returns the certificate and its hashed public key, an error otherwise.
func processCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) (*x509.Certificate, []byte, error) {
// parse certificate
if len(rawCerts) == 0 {
return nil, nil, errors.New("rawCerts is empty")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return nil, nil, err
}
// verify self-signed certificate
roots := x509.NewCertPool()
roots.AddCert(cert)
_, err = cert.Verify(x509.VerifyOptions{Roots: roots})
if err != nil {
return nil, nil, err
}
// hash of certificates public key is used as userData in the embedded attestation document
hash, err := hashPublicKey(cert.PublicKey)
return cert, hash, err
}
// verifyEmbeddedReport verifies an aTLS certificate by validating the attestation document embedded in the TLS certificate.
func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, hash, nonce []byte) error {
for _, ex := range cert.Extensions {
for _, validator := range validators {
if ex.Id.Equal(validator.OID()) {
userData, err := validator.Validate(ex.Value, nonce)
if err != nil {
return err
}
if !bytes.Equal(userData, hash) {
return errors.New("certificate hash does not match user data")
}
return nil
}
}
}
return errors.New("certificate does not contain attestation document")
}
func hashPublicKey(pub any) ([]byte, error) {
pubBytes, err := x509.MarshalPKIXPublicKey(pub)
if err != nil {
return nil, err
}
result := sha256.Sum256(pubBytes)
return result[:], nil
}
// encodeNonceToCertPool returns a cert pool that contains a certificate whose CN is the base64-encoded nonce.
func encodeNonceToCertPool(nonce []byte, privKey *ecdsa.PrivateKey) (*x509.CertPool, error) {
template := &x509.Certificate{
SerialNumber: &big.Int{},
Subject: pkix.Name{CommonName: base64.StdEncoding.EncodeToString(nonce)},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
if err != nil {
return nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
pool := x509.NewCertPool()
pool.AddCert(cert)
return pool, nil
}
// decodeNonceFromAcceptableCAs interprets the CN of acceptableCAs[0] as base64-encoded nonce and returns the decoded nonce.
// acceptableCAs should have been received by a client where the server used encodeNonceToCertPool to transmit the nonce.
func decodeNonceFromAcceptableCAs(acceptableCAs [][]byte) ([]byte, error) {
if len(acceptableCAs) != 1 {
return nil, errors.New("unexpected acceptableCAs length")
}
var rdnSeq pkix.RDNSequence
if _, err := asn1.Unmarshal(acceptableCAs[0], &rdnSeq); err != nil {
return nil, err
}
// https://github.com/golang/go/blob/19309779ac5e2f5a2fd3cbb34421dafb2855ac21/src/crypto/x509/pkix/pkix.go#L188
oidCommonName := asn1.ObjectIdentifier{2, 5, 4, 3}
for _, rdnSet := range rdnSeq {
for _, rdn := range rdnSet {
if rdn.Type.Equal(oidCommonName) {
nonce, ok := rdn.Value.(string)
if !ok {
return nil, errors.New("unexpected RDN type")
}
return base64.StdEncoding.DecodeString(nonce)
}
}
}
return nil, errors.New("CN not found")
}
// clientConnection holds state for client to server connections.
type clientConnection struct {
issuer Issuer
validators []Validator
clientNonce []byte
}
// verify the validity of an aTLS server certificate.
func (c *clientConnection) verify(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
cert, hash, err := processCertificate(rawCerts, verifiedChains)
if err != nil {
return err
}
// don't perform verification of attestation document if no validators are set
if len(c.validators) == 0 {
return nil
}
return verifyEmbeddedReport(c.validators, cert, hash, c.clientNonce)
}
// getCertificate generates a client certificate for mutual aTLS connections.
func (c *clientConnection) getCertificate(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
// generate and hash key
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
// ugly hack: abuse acceptable client CAs as a channel to receive the nonce
serverNonce, err := decodeNonceFromAcceptableCAs(cri.AcceptableCAs)
if err != nil {
return nil, fmt.Errorf("decode nonce: %w", err)
}
return getCertificate(c.issuer, priv, &priv.PublicKey, serverNonce)
}
// serverConnection holds state for server to client connections.
type serverConnection struct {
issuer Issuer
validators []Validator
privKey *ecdsa.PrivateKey
serverNonce []byte
}
// verify the validity of a clients aTLS certificate.
// Only needed for mutual aTLS.
func (c *serverConnection) verify(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
cert, hash, err := processCertificate(rawCerts, verifiedChains)
if err != nil {
return err
}
return verifyEmbeddedReport(c.validators, cert, hash, c.serverNonce)
}
// getCertificate generates a client certificate for aTLS connections.
// Can be used for mutual as well as basic aTLS.
func (c *serverConnection) getCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
// abuse ServerName as a channel to receive the nonce
clientNonce, err := base64.StdEncoding.DecodeString(chi.ServerName)
if err != nil {
return nil, err
}
// create aTLS certificate using the nonce as extracted from the client-hello message
return getCertificate(c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce)
}
// FakeIssuer fakes an issuer and can be used for tests.
type FakeIssuer struct {
oid.Getter
}
// NewFakeIssuer creates a new FakeIssuer with the given OID.
func NewFakeIssuer(oid oid.Getter) *FakeIssuer {
return &FakeIssuer{oid}
}
// Issue marshals the user data and returns it.
func (FakeIssuer) Issue(userData []byte, nonce []byte) ([]byte, error) {
return json.Marshal(FakeAttestationDoc{UserData: userData, Nonce: nonce})
}
// FakeValidator fakes a validator and can be used for tests.
type FakeValidator struct {
oid.Getter
err error // used for package internal testing only
}
// NewFakeValidator creates a new FakeValidator with the given OID.
func NewFakeValidator(oid oid.Getter) *FakeValidator {
return &FakeValidator{oid, nil}
}
// NewFakeValidators returns a slice with a single FakeValidator.
func NewFakeValidators(oid oid.Getter) []Validator {
return []Validator{NewFakeValidator(oid)}
}
// Validate unmarshals the attestation document and verifies the nonce.
func (v FakeValidator) Validate(attDoc []byte, nonce []byte) ([]byte, error) {
var doc FakeAttestationDoc
if err := json.Unmarshal(attDoc, &doc); err != nil {
return nil, err
}
if !bytes.Equal(doc.Nonce, nonce) {
return nil, fmt.Errorf("invalid nonce: expected %x, got %x", doc.Nonce, nonce)
}
return doc.UserData, v.err
}
// FakeAttestationDoc is a fake attestation document used for testing.
type FakeAttestationDoc struct {
UserData []byte
Nonce []byte
}
type fakeOID struct {
asn1.ObjectIdentifier
}
func (o fakeOID) OID() asn1.ObjectIdentifier {
return o.ObjectIdentifier
}