From 869448c3e19e2572b06b5dbf13ddcef28427c1cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Wei=C3=9Fe?= <66256922+daniel-weisse@users.noreply.github.com> Date: Tue, 24 May 2022 16:33:44 +0200 Subject: [PATCH] Add mutual aTLS support (#176) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Daniel Weiße --- cli/proto/client.go | 2 +- cli/proto/recover.go | 2 +- cli/status/status.go | 2 +- coordinator/atls/atls.go | 334 ++++++++++++++++------- coordinator/atls/atls_test.go | 113 ++++++-- coordinator/cmd/coordinator/run.go | 4 +- coordinator/coordinator_test.go | 4 +- coordinator/core/legacy_test.go | 2 +- coordinator/core/reinitialize_test.go | 2 +- coordinator/oid/oid.go | 3 + coordinator/pubapi/coord_test.go | 4 +- coordinator/pubapi/node_test.go | 4 +- coordinator/util/grpcutil/dialer.go | 4 +- coordinator/util/grpcutil/dialer_test.go | 2 +- hack/pcr-reader/main.go | 2 +- state/keyservice/keyservice.go | 4 +- state/keyservice/keyservice_test.go | 2 +- state/test/integration_test.go | 2 +- test/coordinator_integration_test.go | 8 +- 19 files changed, 354 insertions(+), 146 deletions(-) diff --git a/cli/proto/client.go b/cli/proto/client.go index 7d76658f5..af7cf3a3d 100644 --- a/cli/proto/client.go +++ b/cli/proto/client.go @@ -26,7 +26,7 @@ type Client struct { // called on a client that already has a connection, the old // connection is closed. func (c *Client) Connect(endpoint string, validators []atls.Validator) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) if err != nil { return err } diff --git a/cli/proto/recover.go b/cli/proto/recover.go index dde4cdacf..00843bef6 100644 --- a/cli/proto/recover.go +++ b/cli/proto/recover.go @@ -22,7 +22,7 @@ type KeyClient struct { // called on a client that already has a connection, the old // connection is closed. func (c *KeyClient) Connect(endpoint string, validators []atls.Validator) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) if err != nil { return err } diff --git a/cli/status/status.go b/cli/status/status.go index fbf70fe51..d09f4d408 100644 --- a/cli/status/status.go +++ b/cli/status/status.go @@ -115,7 +115,7 @@ func (w *Waiter) WaitForAll(ctx context.Context, endpoints []string, status ...s // newAttestedConnGenerator creates a function returning a default attested grpc connection. func newAttestedConnGenerator(validators []atls.Validator) func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { return func(ctx context.Context, target string, opts ...grpc.DialOption) (ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig(validators) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, validators) if err != nil { return nil, err } diff --git a/coordinator/atls/atls.go b/coordinator/atls/atls.go index 3734866c5..f5fd4a952 100644 --- a/coordinator/atls/atls.go +++ b/coordinator/atls/atls.go @@ -19,122 +19,42 @@ import ( ) // CreateAttestationServerTLSConfig creates a tls.Config object with a self-signed certificate and an embedded attestation document. -func CreateAttestationServerTLSConfig(issuer Issuer) (*tls.Config, error) { - // generate and hash key - priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) +// Pass a list of validators to enable mutual aTLS. +func CreateAttestationServerTLSConfig(issuer Issuer, validators []Validator) (*tls.Config, error) { + if issuer == nil { + return nil, errors.New("unable to create aTLS server configuration without quote issuer") + } + + getConfigForClient, err := getATLSConfigForClientFunc(issuer, validators) if err != nil { return nil, err } - hash, err := hashPublicKey(&priv.PublicKey) - if err != nil { - return nil, err - } - - getCertificate := func(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { - serialNumber, err := util.GenerateCertificateSerialNumber() - if err != nil { - return nil, err - } - - // abuse ServerName as a channel to receive the nonce - nonce, err := base64.StdEncoding.DecodeString(chi.ServerName) - if err != nil { - return nil, err - } - - attDoc, err := issuer.Issue(hash, nonce) - if err != nil { - return nil, err - } - - // create certficate that includes the attestation document as extension - now := time.Now() - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: pkix.Name{CommonName: "Constellation"}, - NotBefore: now.Add(-2 * time.Hour), - NotAfter: now.Add(2 * time.Hour), - ExtraExtensions: []pkix.Extension{{Id: issuer.OID(), Value: attDoc}}, - } - cert, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) - if err != nil { - return nil, err - } - - return &tls.Certificate{Certificate: [][]byte{cert}, PrivateKey: priv}, nil - } - - return &tls.Config{GetCertificate: getCertificate, MinVersion: tls.VersionTLS12}, nil -} - -// CreateAttestationClientTLSConfig creates a tls.Config object that verifies a certificate with an embedded attestation document. -func CreateAttestationClientTLSConfig(validators []Validator) (*tls.Config, error) { - nonce, err := util.GenerateRandomBytes(config.RNGLengthDefault) - if err != nil { - return nil, err - } - - verify := func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { - // parse certificate - if len(rawCerts) == 0 { - return errors.New("rawCerts is empty") - } - cert, err := x509.ParseCertificate(rawCerts[0]) - if err != nil { - return err - } - - // verify self-signed certificate - roots := x509.NewCertPool() - roots.AddCert(cert) - _, err = cert.Verify(x509.VerifyOptions{Roots: roots}) - if err != nil { - return err - } - - hash, err := hashPublicKey(cert.PublicKey) - if err != nil { - return err - } - - // verify embedded report - for _, ex := range cert.Extensions { - for _, validator := range validators { - if ex.Id.Equal(validator.OID()) { - userData, err := validator.Validate(ex.Value, nonce) - if err != nil { - return err - } - if !bytes.Equal(userData, hash) { - return errors.New("certificate hash does not match user data") - } - return nil - } - } - } - - return errors.New("certificate does not contain attestation document") - } return &tls.Config{ - VerifyPeerCertificate: verify, - InsecureSkipVerify: true, // disable default verification because we use our own verify func - ServerName: base64.StdEncoding.EncodeToString(nonce), // abuse ServerName as a channel to transmit the nonce - MinVersion: tls.VersionTLS12, + GetConfigForClient: getConfigForClient, }, nil } -// CreateUnverifiedClientTLSConfig creates a tls.Config object that skips verification of a certificate with an embedded attestation document. -func CreateUnverifiedClientTLSConfig() (*tls.Config, error) { +// CreateAttestationClientTLSConfig creates a tls.Config object that verifies a certificate with an embedded attestation document. +// If no validators are set, the server's attestation document will not be verified. +// If issuers is nil, the client will be unable to perform mutual aTLS. +func CreateAttestationClientTLSConfig(issuer Issuer, validators []Validator) (*tls.Config, error) { nonce, err := util.GenerateRandomBytes(config.RNGLengthDefault) if err != nil { return nil, err } + clientConn := &clientConnection{ + issuer: issuer, + validators: validators, + clientNonce: nonce, + } return &tls.Config{ - InsecureSkipVerify: true, // disable certificate verification - ServerName: base64.StdEncoding.EncodeToString(nonce), // abuse ServerName as a channel to transmit the nonce - MinVersion: tls.VersionTLS12, + VerifyPeerCertificate: clientConn.verify, + GetClientCertificate: clientConn.getCertificate, // use custom certificate for mutual aTLS connections + InsecureSkipVerify: true, // disable default verification because we use our own verify func + ServerName: base64.StdEncoding.EncodeToString(nonce), // abuse ServerName as a channel to transmit the nonce + MinVersion: tls.VersionTLS12, }, nil } @@ -148,6 +68,134 @@ type Validator interface { Validate(attDoc []byte, nonce []byte) ([]byte, error) } +// getATLSConfigForClientFunc returns a config setup function that is called once for every client connecting to the server. +// This allows for different server configuration for every client. +// In aTLS this is used to generate unique nonces for every client and embed them in the server's certificate. +func getATLSConfigForClientFunc(issuer Issuer, validators []Validator) (func(*tls.ClientHelloInfo) (*tls.Config, error), error) { + // generate key for the server + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + // this function will be called once for every client + return func(chi *tls.ClientHelloInfo) (*tls.Config, error) { + // generate nonce for this connection + nonce, err := util.GenerateRandomBytes(config.RNGLengthDefault) + if err != nil { + return nil, err + } + + serverConn := &serverConnection{ + privKey: priv, + issuer: issuer, + validators: validators, + nonce: nonce, + } + + clientAuth := tls.NoClientCert + // enable mutual aTLS if any validators are set + if len(validators) > 0 { + clientAuth = tls.RequireAnyClientCert // validity of certificate will be checked by our custom verify function + } + + return &tls.Config{ + ClientAuth: clientAuth, + VerifyPeerCertificate: serverConn.verify, + GetCertificate: serverConn.getCertificate, + MinVersion: tls.VersionTLS12, + }, nil + }, nil +} + +// getCertificate creates a client or server certificate for aTLS connections. +// The certificate uses certificate extensions to embed an attestation document generated using remoteNonce. +// If localNonce is set, it is also embedded as a certificate extension. +func getCertificate(issuer Issuer, priv, pub any, remoteNonce, localNonce []byte) (*tls.Certificate, error) { + serialNumber, err := util.GenerateCertificateSerialNumber() + if err != nil { + return nil, err + } + + hash, err := hashPublicKey(pub) + if err != nil { + return nil, err + } + + // create attestation document using the nonce send by the remote party + attDoc, err := issuer.Issue(hash, remoteNonce) + if err != nil { + return nil, err + } + + extensions := []pkix.Extension{{Id: issuer.OID(), Value: attDoc}} + // embed locally generated nonce in certificate + if len(localNonce) > 0 { + extensions = append(extensions, pkix.Extension{Id: oid.ATLSNonce, Value: localNonce}) + } + + // create certificate that includes the attestation document and the server nonce as extension + now := time.Now() + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{CommonName: "Constellation"}, + NotBefore: now.Add(-2 * time.Hour), + NotAfter: now.Add(2 * time.Hour), + ExtraExtensions: extensions, + } + cert, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv) + if err != nil { + return nil, err + } + + return &tls.Certificate{Certificate: [][]byte{cert}, PrivateKey: priv}, nil +} + +// processCertificate parses the certificate and verifies it. +// If successful returns the certificate and its hashed public key, an error otherwise. +func processCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) (*x509.Certificate, []byte, error) { + // parse certificate + if len(rawCerts) == 0 { + return nil, nil, errors.New("rawCerts is empty") + } + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return nil, nil, err + } + + // verify self-signed certificate + roots := x509.NewCertPool() + roots.AddCert(cert) + _, err = cert.Verify(x509.VerifyOptions{Roots: roots}) + if err != nil { + return nil, nil, err + } + + // hash of certificates public key is used as userData in the embedded attestation document + hash, err := hashPublicKey(cert.PublicKey) + return cert, hash, err +} + +// verifyEmbeddedReport verifies an aTLS certificate by validating the attestation document embedded in the TLS certificate. +func verifyEmbeddedReport(validators []Validator, cert *x509.Certificate, hash, nonce []byte) error { + for _, ex := range cert.Extensions { + for _, validator := range validators { + if ex.Id.Equal(validator.OID()) { + userData, err := validator.Validate(ex.Value, nonce) + if err != nil { + return err + } + if !bytes.Equal(userData, hash) { + return errors.New("certificate hash does not match user data") + } + return nil + } + } + } + + return errors.New("certificate does not contain attestation document") +} + func hashPublicKey(pub any) ([]byte, error) { pubBytes, err := x509.MarshalPKIXPublicKey(pub) if err != nil { @@ -156,3 +204,85 @@ func hashPublicKey(pub any) ([]byte, error) { result := sha256.Sum256(pubBytes) return result[:], nil } + +// clientConnection holds state for client to server connections. +type clientConnection struct { + issuer Issuer + validators []Validator + clientNonce []byte + serverNonce []byte +} + +// verify the validity of an aTLS server certificate. +func (c *clientConnection) verify(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + cert, hash, err := processCertificate(rawCerts, verifiedChains) + if err != nil { + return err + } + + // get nonce send by server from cert extensions and save to connection state + for _, ex := range cert.Extensions { + if ex.Id.Equal(oid.ATLSNonce) { + c.serverNonce = ex.Value + } + } + + // don't perform verification of attestation document if no validators are set + if len(c.validators) == 0 { + return nil + } + + return verifyEmbeddedReport(c.validators, cert, hash, c.clientNonce) +} + +// getCertificate generates a client certificate for mutual aTLS connections. +func (c *clientConnection) getCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) { + if c.issuer == nil { + return nil, errors.New("unable to create certificate: no quote issuer available") + } + + // generate and hash key + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + + // create aTLS certificate using the server's nonce as read by clientConnection.verify + // we do not pass a nonce because + // 1. we already received a certificate from the server + // 2. we transmitted the client nonce as our server name in our client-hello message + return getCertificate(c.issuer, priv, &priv.PublicKey, c.serverNonce, nil) +} + +// serverConnection holds state for server to client connections. +type serverConnection struct { + issuer Issuer + validators []Validator + privKey *ecdsa.PrivateKey + nonce []byte +} + +// verify the validity of a clients aTLS certificate. +// Only needed for mutual aTLS. +func (c *serverConnection) verify(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + cert, hash, err := processCertificate(rawCerts, verifiedChains) + if err != nil { + return err + } + + return verifyEmbeddedReport(c.validators, cert, hash, c.nonce) +} + +// getCertificate generates a client certificate for aTLS connections. +// Can be used for mutual as well as basic aTLS. +func (c *serverConnection) getCertificate(chi *tls.ClientHelloInfo) (*tls.Certificate, error) { + // abuse ServerName as a channel to receive the nonce + clientNonce, err := base64.StdEncoding.DecodeString(chi.ServerName) + if err != nil { + return nil, err + } + + // create aTLS certificate using the nonce as extracted from the client-hello message + // we also embed the nonce generated for this connection in case of mutual aTLS + return getCertificate(c.issuer, c.privKey, &c.privKey.PublicKey, clientNonce, c.nonce) +} diff --git a/coordinator/atls/atls_test.go b/coordinator/atls/atls_test.go index 6d8afbe28..02c263e14 100644 --- a/coordinator/atls/atls_test.go +++ b/coordinator/atls/atls_test.go @@ -20,27 +20,102 @@ func TestTLSConfig(t *testing.T) { oid2 := fakeOID{1, 3, 9900, 2} testCases := map[string]struct { - issuer Issuer - validators []Validator - wantErr bool + clientIssuer Issuer + clientValidators []Validator + serverIssuer Issuer + serverValidators []Validator + wantErr bool }{ - "basic": { - issuer: fakeIssuer{fakeOID: oid1}, - validators: []Validator{fakeValidator{fakeOID: oid1}}, + "client->server basic": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, }, - "multiple validators": { - issuer: fakeIssuer{fakeOID: oid2}, - validators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}}, + "client->server multiple validators": { + serverIssuer: fakeIssuer{fakeOID: oid2}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}}, }, - "validate error": { - issuer: fakeIssuer{fakeOID: oid1}, - validators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}}, - wantErr: true, + "client->server validate error": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}}, + wantErr: true, }, - "unknown oid": { - issuer: fakeIssuer{fakeOID: oid1}, - validators: []Validator{fakeValidator{fakeOID: oid2}}, - wantErr: true, + "client->server unknown oid": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid2}}, + wantErr: true, + }, + "client->server client cert is not verified": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, + }, + "server->client basic": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + }, + "server->client multiple validators": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}}, + clientIssuer: fakeIssuer{fakeOID: oid2}, + }, + "server->client validate error": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + wantErr: true, + }, + "server->client unknown oid": { + serverIssuer: fakeIssuer{fakeOID: oid2}, + serverValidators: []Validator{fakeValidator{fakeOID: oid2}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + wantErr: true, + }, + "mutual basic": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, + }, + "mutual multiple validators": { + serverIssuer: fakeIssuer{fakeOID: oid2}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}}, + clientIssuer: fakeIssuer{fakeOID: oid2}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}, fakeValidator{fakeOID: oid2}}, + }, + "mutual fails if client sends no cert": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, + wantErr: true, + }, + "mutual validate error client side": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}}, + wantErr: true, + }, + "mutual validate error server side": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1, err: errors.New("failed")}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, + wantErr: true, + }, + "mutual unknown oid from client": { + serverIssuer: fakeIssuer{fakeOID: oid1}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}}, + clientIssuer: fakeIssuer{fakeOID: oid2}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, + wantErr: true, + }, + "mutual unknown oid from server": { + serverIssuer: fakeIssuer{fakeOID: oid2}, + serverValidators: []Validator{fakeValidator{fakeOID: oid1}}, + clientIssuer: fakeIssuer{fakeOID: oid1}, + clientValidators: []Validator{fakeValidator{fakeOID: oid1}}, + wantErr: true, }, } @@ -53,7 +128,7 @@ func TestTLSConfig(t *testing.T) { // Create server // - serverConfig, err := CreateAttestationServerTLSConfig(tc.issuer) + serverConfig, err := CreateAttestationServerTLSConfig(tc.serverIssuer, tc.serverValidators) require.NoError(err) server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -65,7 +140,7 @@ func TestTLSConfig(t *testing.T) { // Create client // - clientConfig, err := CreateAttestationClientTLSConfig(tc.validators) + clientConfig, err := CreateAttestationClientTLSConfig(tc.clientIssuer, tc.clientValidators) require.NoError(err) client := http.Client{Transport: &http.Transport{TLSClientConfig: clientConfig}} diff --git a/coordinator/cmd/coordinator/run.go b/coordinator/cmd/coordinator/run.go index 7cc24c024..d9d193168 100644 --- a/coordinator/cmd/coordinator/run.go +++ b/coordinator/cmd/coordinator/run.go @@ -36,7 +36,7 @@ func run(issuer core.QuoteIssuer, vpn core.VPN, openTPM vtpm.TPMOpenFunc, getPub defer zapLoggerCore.Sync() zapLoggerCore.Info("starting coordinator", zap.String("version", version)) - tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer, nil) if err != nil { zapLoggerCore.Fatal("failed to create server TLS config", zap.Error(err)) } @@ -117,7 +117,7 @@ func tryJoinClusterOnStartup(getPublicIPAddr func() (string, error), metadata co // We create an client unverified connection, since the node does not need to verify the Coordinator. // ActivateAdditionalNodes triggers the Coordinator to call ActivateAsNode. This rpc lets the Coordinator verify the node. - tlsClientConfig, err := atls.CreateUnverifiedClientTLSConfig() + tlsClientConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) if err != nil { return fmt.Errorf("failed to create client TLS config: %w", err) } diff --git a/coordinator/coordinator_test.go b/coordinator/coordinator_test.go index 53a07f1c5..6f7574157 100644 --- a/coordinator/coordinator_test.go +++ b/coordinator/coordinator_test.go @@ -225,7 +225,7 @@ func spawnPeer(require *require.Assertions, logger *zap.Logger, netDialer *testd papi := pubapi.New(logger, cor, dialer, vapiServer, getPublicAddr, nil) - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) require.NoError(err) server := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubproto.RegisterAPIServer(server, papi) @@ -263,7 +263,7 @@ func activateCoordinator(require *require.Assertions, dialer netDialer, coordina } func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) if err != nil { return nil, err } diff --git a/coordinator/core/legacy_test.go b/coordinator/core/legacy_test.go index c61b5b079..9c61bddc6 100644 --- a/coordinator/core/legacy_test.go +++ b/coordinator/core/legacy_test.go @@ -173,7 +173,7 @@ func (b *bufconnDialer) addListener(endpoint string, listener *bufconn.Listener) } func spawnNode(endpoint string, testNodeCore *pubapi.API, bufDialer *bufconnDialer) (*grpc.Server, error) { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}, nil) if err != nil { return nil, err } diff --git a/coordinator/core/reinitialize_test.go b/coordinator/core/reinitialize_test.go index b9d9ba0c8..8d2098771 100644 --- a/coordinator/core/reinitialize_test.go +++ b/coordinator/core/reinitialize_test.go @@ -258,7 +258,7 @@ func TestGetInitialVPNPeers(t *testing.T) { } func newPubAPIServer() *grpc.Server { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(&MockIssuer{}, nil) if err != nil { panic(err) } diff --git a/coordinator/oid/oid.go b/coordinator/oid/oid.go index 861d9f319..423aee24c 100644 --- a/coordinator/oid/oid.go +++ b/coordinator/oid/oid.go @@ -4,6 +4,9 @@ import ( "encoding/asn1" ) +// ATLSNonce is the ASN.1 object identifier used to transmit a nonce from server to client. +var ATLSNonce = asn1.ObjectIdentifier{1, 3, 9900, 0, 1} + // Getter returns an ASN.1 Object Identifier. type Getter interface { OID() asn1.ObjectIdentifier diff --git a/coordinator/pubapi/coord_test.go b/coordinator/pubapi/coord_test.go index 3f1b878ac..a491b8f28 100644 --- a/coordinator/pubapi/coord_test.go +++ b/coordinator/pubapi/coord_test.go @@ -431,7 +431,7 @@ func (n *stubPeer) GetPeerVPNPublicKey(ctx context.Context, in *pubproto.GetPeer } func (n *stubPeer) newServer() *grpc.Server { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(fakeIssuer{}) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(fakeIssuer{}, nil) if err != nil { panic(err) } @@ -536,7 +536,7 @@ func TestRequestStateDiskKey(t *testing.T) { require.NoError(err) defer listener.Close() - tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(issuer, nil) require.NoError(err) s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) keyproto.RegisterAPIServer(s, stateDiskServer) diff --git a/coordinator/pubapi/node_test.go b/coordinator/pubapi/node_test.go index 8ec9a7132..fd9154c3c 100644 --- a/coordinator/pubapi/node_test.go +++ b/coordinator/pubapi/node_test.go @@ -162,7 +162,7 @@ func TestActivateAsNode(t *testing.T) { go vserver.Serve(netDialer.GetListener(net.JoinHostPort("10.118.0.1", vpnAPIPort))) defer vserver.GracefulStop() - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) require.NoError(err) pubserver := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubproto.RegisterAPIServer(pubserver, api) @@ -432,7 +432,7 @@ func activateNode(require *require.Assertions, dialer netDialer, messageSequence } func dialGRPC(ctx context.Context, dialer netDialer, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) if err != nil { return nil, err } diff --git a/coordinator/util/grpcutil/dialer.go b/coordinator/util/grpcutil/dialer.go index 76e9d37fa..60bafbbbc 100644 --- a/coordinator/util/grpcutil/dialer.go +++ b/coordinator/util/grpcutil/dialer.go @@ -26,7 +26,7 @@ func NewDialer(validator atls.Validator, netDialer NetDialer) *Dialer { // Dial creates a new grpc client connection to the given target using the atls validator. func (d *Dialer) Dial(ctx context.Context, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{d.validator}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{d.validator}) if err != nil { return nil, err } @@ -48,7 +48,7 @@ func (d *Dialer) DialInsecure(ctx context.Context, target string) (*grpc.ClientC // DialNoVerify creates a new grpc client connection to the given target without verifying the server's attestation. func (d *Dialer) DialNoVerify(ctx context.Context, target string) (*grpc.ClientConn, error) { - tlsConfig, err := atls.CreateUnverifiedClientTLSConfig() + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) if err != nil { return nil, err } diff --git a/coordinator/util/grpcutil/dialer_test.go b/coordinator/util/grpcutil/dialer_test.go index f64f7778a..052ae083f 100644 --- a/coordinator/util/grpcutil/dialer_test.go +++ b/coordinator/util/grpcutil/dialer_test.go @@ -88,7 +88,7 @@ func TestDial(t *testing.T) { func newServer(tls bool) *grpc.Server { if tls { - tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(&core.MockIssuer{}, nil) if err != nil { panic(err) } diff --git a/hack/pcr-reader/main.go b/hack/pcr-reader/main.go index 42e5bcc2c..9418989ec 100644 --- a/hack/pcr-reader/main.go +++ b/hack/pcr-reader/main.go @@ -56,7 +56,7 @@ func main() { } attDocRaw := []byte{} - tlsConfig, err := atls.CreateUnverifiedClientTLSConfig() + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) if err != nil { log.Fatal(err) } diff --git a/state/keyservice/keyservice.go b/state/keyservice/keyservice.go index b095666f9..0ab8cc18d 100644 --- a/state/keyservice/keyservice.go +++ b/state/keyservice/keyservice.go @@ -63,7 +63,7 @@ func (a *KeyAPI) WaitForDecryptionKey(uuid, listenAddr string) ([]byte, error) { return nil, errors.New("received no disk UUID") } - tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(a.issuer, nil) if err != nil { return nil, err } @@ -95,7 +95,7 @@ func (a *KeyAPI) ResetKey() { func (a *KeyAPI) requestKeyLoop(uuid string, opts ...grpc.DialOption) error { // we do not perform attestation, since the restarting node does not need to care about notifying the correct Coordinator // if an incorrect key is pushed by a malicious actor, decrypting the disk will fail, and the node will not start - tlsClientConfig, err := atls.CreateUnverifiedClientTLSConfig() + tlsClientConfig, err := atls.CreateAttestationClientTLSConfig(nil, nil) if err != nil { return err } diff --git a/state/keyservice/keyservice_test.go b/state/keyservice/keyservice_test.go index 2ede90feb..dcc6a34e0 100644 --- a/state/keyservice/keyservice_test.go +++ b/state/keyservice/keyservice_test.go @@ -75,7 +75,7 @@ func TestRequestKeyLoop(t *testing.T) { listener := bufconn.Listen(1) defer listener.Close() - tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer()) + tlsConfig, err := atls.CreateAttestationServerTLSConfig(core.NewMockIssuer(), nil) require.NoError(err) s := grpc.NewServer(grpc.Creds(credentials.NewTLS(tlsConfig))) pubproto.RegisterAPIServer(s, tc.server) diff --git a/state/test/integration_test.go b/state/test/integration_test.go index 9061e8ff1..b6cc08059 100644 --- a/state/test/integration_test.go +++ b/state/test/integration_test.go @@ -91,7 +91,7 @@ func TestKeyAPI(t *testing.T) { // wait 2 seconds before sending the key time.Sleep(2 * time.Second) - clientCfg, err := atls.CreateUnverifiedClientTLSConfig() + clientCfg, err := atls.CreateAttestationClientTLSConfig(nil, nil) require.NoError(err) conn, err := grpc.Dial(apiAddr, grpc.WithTransportCredentials(credentials.NewTLS(clientCfg))) require.NoError(err) diff --git a/test/coordinator_integration_test.go b/test/coordinator_integration_test.go index 76dc0e72a..7a3d6d444 100644 --- a/test/coordinator_integration_test.go +++ b/test/coordinator_integration_test.go @@ -249,7 +249,7 @@ func TestMain(t *testing.T) { // helper methods func startCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) if err != nil { return err } @@ -299,7 +299,7 @@ func createTempDir() error { } func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, oldCoordinatorAddr string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) if err != nil { return err } @@ -322,7 +322,7 @@ func addNewCoordinatorToCoordinator(ctx context.Context, newCoordinatorAddr, old } func addNewNodesToCoordinator(ctx context.Context, coordinatorAddr string, ips []string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) if err != nil { return err } @@ -545,7 +545,7 @@ func awaitPeerResponse(ctx context.Context, ip string, tlsConfig *tls.Config) er } func blockUntilUp(ctx context.Context, peerIPs []string) error { - tlsConfig, err := atls.CreateAttestationClientTLSConfig([]atls.Validator{&core.MockValidator{}}) + tlsConfig, err := atls.CreateAttestationClientTLSConfig(nil, []atls.Validator{&core.MockValidator{}}) if err != nil { return err }