Add a caching layer to .well-known responses (#4516)

This commit is contained in:
Richard van der Hoff 2019-01-30 10:55:25 +00:00 committed by GitHub
parent 3f189c902e
commit bc5f6e1797
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 493 additions and 10 deletions

1
changelog.d/4516.feature Normal file
View File

@ -0,0 +1 @@
Implement MSC1708 (.well-known routing for server-server federation)

View File

@ -14,6 +14,8 @@
# limitations under the License. # limitations under the License.
import json import json
import logging import logging
import random
import time
import attr import attr
from netaddr import IPAddress from netaddr import IPAddress
@ -22,13 +24,29 @@ from zope.interface import implementer
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import URI, Agent, HTTPConnectionPool, readBody from twisted.web.client import URI, Agent, HTTPConnectionPool, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent from twisted.web.iweb import IAgent
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import make_deferred_yieldable from synapse.util.logcontext import make_deferred_yieldable
# period to cache .well-known results for by default
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
# jitter to add to the .well-known default cache ttl
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
# period to cache failure to fetch .well-known for
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
# cap for .well-known cache period
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
well_known_cache = TTLCache('well-known')
@implementer(IAgent) @implementer(IAgent)
@ -57,6 +75,7 @@ class MatrixFederationAgent(object):
self, reactor, tls_client_options_factory, self, reactor, tls_client_options_factory,
_well_known_tls_policy=None, _well_known_tls_policy=None,
_srv_resolver=None, _srv_resolver=None,
_well_known_cache=well_known_cache,
): ):
self._reactor = reactor self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory self._tls_client_options_factory = tls_client_options_factory
@ -77,6 +96,8 @@ class MatrixFederationAgent(object):
_well_known_agent = Agent(self._reactor, pool=self._pool, **agent_args) _well_known_agent = Agent(self._reactor, pool=self._pool, **agent_args)
self._well_known_agent = _well_known_agent self._well_known_agent = _well_known_agent
self._well_known_cache = _well_known_cache
@defer.inlineCallbacks @defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None): def request(self, method, uri, headers=None, bodyProducer=None):
""" """
@ -259,7 +280,14 @@ class MatrixFederationAgent(object):
Deferred[bytes|None]: either the new server name, from the .well-known, or Deferred[bytes|None]: either the new server name, from the .well-known, or
None if there was no .well-known file. None if there was no .well-known file.
""" """
# FIXME: add a cache try:
cached = self._well_known_cache[server_name]
defer.returnValue(cached)
except KeyError:
pass
# TODO: should we linearise so that we don't end up doing two .well-known requests
# for the same server in parallel?
uri = b"https://%s/.well-known/matrix/server" % (server_name, ) uri = b"https://%s/.well-known/matrix/server" % (server_name, )
uri_str = uri.decode("ascii") uri_str = uri.decode("ascii")
@ -270,12 +298,14 @@ class MatrixFederationAgent(object):
) )
except Exception as e: except Exception as e:
logger.info("Connection error fetching %s: %s", uri_str, e) logger.info("Connection error fetching %s: %s", uri_str, e)
self._well_known_cache.set(server_name, None, WELL_KNOWN_INVALID_CACHE_PERIOD)
defer.returnValue(None) defer.returnValue(None)
body = yield make_deferred_yieldable(readBody(response)) body = yield make_deferred_yieldable(readBody(response))
if response.code != 200: if response.code != 200:
logger.info("Error response %i from %s", response.code, uri_str) logger.info("Error response %i from %s", response.code, uri_str)
self._well_known_cache.set(server_name, None, WELL_KNOWN_INVALID_CACHE_PERIOD)
defer.returnValue(None) defer.returnValue(None)
try: try:
@ -287,7 +317,63 @@ class MatrixFederationAgent(object):
raise Exception("Missing key 'm.server'") raise Exception("Missing key 'm.server'")
except Exception as e: except Exception as e:
raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,)) raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,))
defer.returnValue(parsed_body["m.server"].encode("ascii"))
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers(
response.headers,
time_now=self._reactor.seconds,
)
if cache_period is None:
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
# add some randomness to the TTL to avoid a stampeding herd every hour after
# startup
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
defer.returnValue(result)
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
if b'no-store' in cache_controls:
return 0
if b'max-age' in cache_controls:
try:
max_age = int(cache_controls[b'max-age'])
return max_age
except ValueError:
pass
expires = headers.getRawHeaders(b'expires')
if expires is not None:
try:
expires_date = stringToDatetime(expires[-1])
return expires_date - time_now()
except ValueError:
# RFC7234 says 'A cache recipient MUST interpret invalid date formats,
# especially the value "0", as representing a time in the past (i.e.,
# "already expired").
return 0
return None
def _parse_cache_control(headers):
cache_controls = {}
for hdr in headers.getRawHeaders(b'cache-control', []):
for directive in hdr.split(b','):
splits = [x.strip() for x in directive.split(b'=', 1)]
k = splits[0].lower()
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v
return cache_controls
@attr.s @attr.s

View File

@ -0,0 +1,161 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import attr
from sortedcontainers import SortedList
from synapse.util.caches import register_cache
logger = logging.getLogger(__name__)
SENTINEL = object()
class TTLCache(object):
"""A key/value cache implementation where each entry has its own TTL"""
def __init__(self, cache_name, timer=time.time):
# map from key to _CacheEntry
self._data = {}
# the _CacheEntries, sorted by expiry time
self._expiry_list = SortedList()
self._timer = timer
self._metrics = register_cache("ttl", cache_name, self)
def set(self, key, value, ttl):
"""Add/update an entry in the cache
Args:
key: key for this entry
value: value for this entry
ttl (float): TTL for this entry, in seconds
"""
expiry = self._timer() + ttl
self.expire()
e = self._data.pop(key, SENTINEL)
if e != SENTINEL:
self._expiry_list.remove(e)
entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
def get(self, key, default=SENTINEL):
"""Get a value from the cache
Args:
key: key to look up
default: default value to return, if key is not found. If not set, and the
key is not found, a KeyError will be raised
Returns:
value from the cache, or the default
"""
self.expire()
e = self._data.get(key, SENTINEL)
if e == SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
raise KeyError(key)
return default
self._metrics.inc_hits()
return e.value
def get_with_expiry(self, key):
"""Get a value, and its expiry time, from the cache
Args:
key: key to look up
Returns:
Tuple[Any, float]: the value from the cache, and the expiry time
Raises:
KeyError if the entry is not found
"""
self.expire()
try:
e = self._data[key]
except KeyError:
self._metrics.inc_misses()
raise
self._metrics.inc_hits()
return e.value, e.expiry_time
def pop(self, key, default=SENTINEL):
"""Remove a value from the cache
If key is in the cache, remove it and return its value, else return default.
If default is not given and key is not in the cache, a KeyError is raised.
Args:
key: key to look up
default: default value to return, if key is not found. If not set, and the
key is not found, a KeyError will be raised
Returns:
value from the cache, or the default
"""
self.expire()
e = self._data.pop(key, SENTINEL)
if e == SENTINEL:
self._metrics.inc_misses()
if default == SENTINEL:
raise KeyError(key)
return default
self._expiry_list.remove(e)
self._metrics.inc_hits()
return e.value
def __getitem__(self, key):
return self.get(key)
def __delitem__(self, key):
self.pop(key)
def __contains__(self, key):
return key in self._data
def __len__(self):
self.expire()
return len(self._data)
def expire(self):
"""Run the expiry on the cache. Any entries whose expiry times are due will
be removed
"""
now = self._timer()
while self._expiry_list:
first_entry = self._expiry_list[0]
if first_entry.expiry_time - now > 0.0:
break
del self._data[first_entry.key]
del self._expiry_list[0]
@attr.s(frozen=True, slots=True)
class _CacheEntry(object):
"""TTLCache entry"""
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
key = attr.ib()
value = attr.ib()

View File

@ -24,11 +24,16 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
from twisted.protocols.tls import TLSMemoryBIOFactory from twisted.protocols.tls import TLSMemoryBIOFactory
from twisted.web.http import HTTPChannel from twisted.web.http import HTTPChannel
from twisted.web.http_headers import Headers
from twisted.web.iweb import IPolicyForHTTPS from twisted.web.iweb import IPolicyForHTTPS
from synapse.crypto.context_factory import ClientTLSOptionsFactory from synapse.crypto.context_factory import ClientTLSOptionsFactory
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import (
MatrixFederationAgent,
_cache_period_from_headers,
)
from synapse.http.federation.srv_resolver import Server from synapse.http.federation.srv_resolver import Server
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from tests.http import ServerTLSContext from tests.http import ServerTLSContext
@ -44,11 +49,14 @@ class MatrixFederationAgentTests(TestCase):
self.mock_resolver = Mock() self.mock_resolver = Mock()
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds)
self.agent = MatrixFederationAgent( self.agent = MatrixFederationAgent(
reactor=self.reactor, reactor=self.reactor,
tls_client_options_factory=ClientTLSOptionsFactory(None), tls_client_options_factory=ClientTLSOptionsFactory(None),
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(), _well_known_tls_policy=TrustingTLSPolicyForHTTPS(),
_srv_resolver=self.mock_resolver, _srv_resolver=self.mock_resolver,
_well_known_cache=self.well_known_cache,
) )
def _make_connection(self, client_factory, expected_sni): def _make_connection(self, client_factory, expected_sni):
@ -115,7 +123,9 @@ class MatrixFederationAgentTests(TestCase):
finally: finally:
_check_logcontext(context) _check_logcontext(context)
def _handle_well_known_connection(self, client_factory, expected_sni, target_server): def _handle_well_known_connection(
self, client_factory, expected_sni, target_server, response_headers={},
):
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the """Handle an outgoing HTTPs connection: wire it up to a server, check that the
request is for a .well-known, and send the response. request is for a .well-known, and send the response.
@ -124,6 +134,8 @@ class MatrixFederationAgentTests(TestCase):
expected_sni (bytes): SNI that we expect the outgoing connection to send expected_sni (bytes): SNI that we expect the outgoing connection to send
target_server (bytes): target server that we should redirect to in the target_server (bytes): target server that we should redirect to in the
.well-known response. .well-known response.
Returns:
HTTPChannel: server impl
""" """
# make the connection for .well-known # make the connection for .well-known
well_known_server = self._make_connection( well_known_server = self._make_connection(
@ -133,9 +145,10 @@ class MatrixFederationAgentTests(TestCase):
# check the .well-known request and send a response # check the .well-known request and send a response
self.assertEqual(len(well_known_server.requests), 1) self.assertEqual(len(well_known_server.requests), 1)
request = well_known_server.requests[0] request = well_known_server.requests[0]
self._send_well_known_response(request, target_server) self._send_well_known_response(request, target_server, headers=response_headers)
return well_known_server
def _send_well_known_response(self, request, target_server): def _send_well_known_response(self, request, target_server, headers={}):
"""Check that an incoming request looks like a valid .well-known request, and """Check that an incoming request looks like a valid .well-known request, and
send back the response. send back the response.
""" """
@ -146,6 +159,8 @@ class MatrixFederationAgentTests(TestCase):
[b'testserv'], [b'testserv'],
) )
# send back a response # send back a response
for k, v in headers.items():
request.setHeader(k, v)
request.write(b'{ "m.server": "%s" }' % (target_server,)) request.write(b'{ "m.server": "%s" }' % (target_server,))
request.finish() request.finish()
@ -448,6 +463,13 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server")
# check the cache expires
self.reactor.pump((25 * 3600,))
self.well_known_cache.expire()
self.assertNotIn(b"testserv", self.well_known_cache)
def test_get_hostname_srv(self): def test_get_hostname_srv(self):
""" """
Test the behaviour when there is a single SRV record Test the behaviour when there is a single SRV record
@ -661,6 +683,126 @@ class MatrixFederationAgentTests(TestCase):
self.reactor.pump((0.1,)) self.reactor.pump((0.1,))
self.successResultOf(test_d) self.successResultOf(test_d)
@defer.inlineCallbacks
def do_get_well_known(self, serv):
try:
result = yield self.agent._get_well_known(serv)
logger.info("Result from well-known fetch: %s", result)
except Exception as e:
logger.warning("Error fetching well-known: %s", e)
raise
defer.returnValue(result)
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = self.do_get_well_known(b'testserv')
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
well_known_server = self._handle_well_known_connection(
client_factory,
expected_sni=b"testserv",
response_headers={b'Cache-Control': b'max-age=10'},
target_server=b"target-server",
)
r = self.successResultOf(fetch_d)
self.assertEqual(r, b'target-server')
# close the tcp connection
well_known_server.loseConnection()
# repeat the request: it should hit the cache
fetch_d = self.do_get_well_known(b'testserv')
r = self.successResultOf(fetch_d)
self.assertEqual(r, b'target-server')
# expire the cache
self.reactor.pump((10.0,))
# now it should connect again
fetch_d = self.do_get_well_known(b'testserv')
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, '1.2.3.4')
self.assertEqual(port, 443)
self._handle_well_known_connection(
client_factory,
expected_sni=b"testserv",
target_server=b"other-server",
)
r = self.successResultOf(fetch_d)
self.assertEqual(r, b'other-server')
class TestCachePeriodFromHeaders(TestCase):
def test_cache_control(self):
# uppercase
self.assertEqual(
_cache_period_from_headers(
Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}),
), 100,
)
# missing value
self.assertIsNone(_cache_period_from_headers(
Headers({b'Cache-Control': [b'max-age=, bar']}),
))
# hackernews: bogus due to semicolon
self.assertIsNone(_cache_period_from_headers(
Headers({b'Cache-Control': [b'private; max-age=0']}),
))
# github
self.assertEqual(
_cache_period_from_headers(
Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}),
), 0,
)
# google
self.assertEqual(
_cache_period_from_headers(
Headers({b'cache-control': [b'private, max-age=0']}),
), 0,
)
def test_expires(self):
self.assertEqual(
_cache_period_from_headers(
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}),
time_now=lambda: 1548833700
), 33,
)
# cache-control overrides expires
self.assertEqual(
_cache_period_from_headers(
Headers({
b'cache-control': [b'max-age=10'],
b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']
}),
time_now=lambda: 1548833700
), 10,
)
# invalid expires means immediate expiry
self.assertEqual(
_cache_period_from_headers(
Headers({b'Expires': [b'0']}),
), 0,
)
def _check_logcontext(context): def _check_logcontext(context):
current = LoggingContext.current_context() current = LoggingContext.current_context()

View File

@ -360,6 +360,7 @@ class FakeTransport(object):
""" """
disconnecting = False disconnecting = False
disconnected = False
buffer = attr.ib(default=b'') buffer = attr.ib(default=b'')
producer = attr.ib(default=None) producer = attr.ib(default=None)
@ -370,14 +371,16 @@ class FakeTransport(object):
return None return None
def loseConnection(self, reason=None): def loseConnection(self, reason=None):
logger.info("FakeTransport: loseConnection(%s)", reason)
if not self.disconnecting: if not self.disconnecting:
logger.info("FakeTransport: loseConnection(%s)", reason)
self.disconnecting = True self.disconnecting = True
if self._protocol: if self._protocol:
self._protocol.connectionLost(reason) self._protocol.connectionLost(reason)
self.disconnected = True
def abortConnection(self): def abortConnection(self):
self.disconnecting = True logger.info("FakeTransport: abortConnection()")
self.loseConnection()
def pauseProducing(self): def pauseProducing(self):
if not self.producer: if not self.producer:
@ -416,9 +419,16 @@ class FakeTransport(object):
# TLSMemoryBIOProtocol # TLSMemoryBIOProtocol
return return
if self.disconnected:
return
logger.info("%s->%s: %s", self._protocol, self.other, self.buffer)
if getattr(self.other, "transport") is not None: if getattr(self.other, "transport") is not None:
try:
self.other.dataReceived(self.buffer) self.other.dataReceived(self.buffer)
self.buffer = b"" self.buffer = b""
except Exception as e:
logger.warning("Exception writing to protocol: %s", e)
return return
self._reactor.callLater(0.0, _write) self._reactor.callLater(0.0, _write)

View File

@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
class CacheTestCase(unittest.TestCase):
def setUp(self):
self.mock_timer = Mock(side_effect=lambda: 100.0)
self.cache = TTLCache("test_cache", self.mock_timer)
def test_get(self):
"""simple set/get tests"""
self.cache.set('one', '1', 10)
self.cache.set('two', '2', 20)
self.cache.set('three', '3', 30)
self.assertEqual(len(self.cache), 3)
self.assertTrue('one' in self.cache)
self.assertEqual(self.cache.get('one'), '1')
self.assertEqual(self.cache['one'], '1')
self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110))
self.assertEqual(self.cache._metrics.hits, 3)
self.assertEqual(self.cache._metrics.misses, 0)
self.cache.set('two', '2.5', 20)
self.assertEqual(self.cache['two'], '2.5')
self.assertEqual(self.cache._metrics.hits, 4)
# non-existent-item tests
self.assertEqual(self.cache.get('four', '4'), '4')
self.assertIs(self.cache.get('four', None), None)
with self.assertRaises(KeyError):
self.cache['four']
with self.assertRaises(KeyError):
self.cache.get('four')
with self.assertRaises(KeyError):
self.cache.get_with_expiry('four')
self.assertEqual(self.cache._metrics.hits, 4)
self.assertEqual(self.cache._metrics.misses, 5)
def test_expiry(self):
self.cache.set('one', '1', 10)
self.cache.set('two', '2', 20)
self.cache.set('three', '3', 30)
self.assertEqual(len(self.cache), 3)
self.assertEqual(self.cache['one'], '1')
self.assertEqual(self.cache['two'], '2')
# enough for the first entry to expire, but not the rest
self.mock_timer.side_effect = lambda: 110.0
self.assertEqual(len(self.cache), 2)
self.assertFalse('one' in self.cache)
self.assertEqual(self.cache['two'], '2')
self.assertEqual(self.cache['three'], '3')
self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120))
self.assertEqual(self.cache._metrics.hits, 5)
self.assertEqual(self.cache._metrics.misses, 0)