diff --git a/CHANGES.rst b/CHANGES.rst index 2a46af52a..6659c6671 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,47 @@ +Changes in synapse v0.20.0-rc1 (2017-03-30) +=========================================== + +Features: + +* Add delete_devices API (PR #1993) +* Add phone number registration/login support (PR #1994, #2055) + + +Changes: + +* Use JSONSchema for validation of filters. Thanks @pik! (PR #1783) +* Reread log config on SIGHUP (PR #1982) +* Speed up public room list (PR #1989) +* Add helpful texts to logger config options (PR #1990) +* Minor ``/sync`` performance improvements. (PR #2002, #2013, #2022) +* Add some debug to help diagnose weird federation issue (PR #2035) +* Correctly limit retries for all federation requests (PR #2050, #2061) +* Don't lock table when persisting new one time keys (PR #2053) +* Reduce some CPU work on DB threads (PR #2054) +* Cache hosts in room (PR #2060) +* Batch sending of device list pokes (PR #2063) +* Speed up persist event path in certain edge cases (PR #2070) + + +Bug fixes: + +* Fix bug where current_state_events renamed to current_state_ids (PR #1849) +* Fix routing loop when fetching remote media (PR #1992) +* Fix current_state_events table to not lie (PR #1996) +* Fix CAS login to handle PartialDownloadError (PR #1997) +* Fix assertion to stop transaction queue getting wedged (PR #2010) +* Fix presence to fallback to last_active_ts if it beats the last sync time. + Thanks @Half-Shot! (PR #2014) +* Fix bug when federation received a PDU while a room join is in progress (PR + #2016) +* Fix resetting state on rejected events (PR #2025) +* Fix installation issues in readme. Thanks @ricco386 (PR #2037) +* Fix caching of remote servers' signature keys (PR #2042) +* Fix some leaking log context (PR #2048, #2049, #2057, #2058) +* Fix rejection of invites not reaching sync (PR #2056) + + + Changes in synapse v0.19.3 (2017-03-20) ======================================= diff --git a/README.rst b/README.rst index b9c854ad4..759197f5f 100644 --- a/README.rst +++ b/README.rst @@ -108,10 +108,10 @@ Installing prerequisites on ArchLinux:: sudo pacman -S base-devel python2 python-pip \ python-setuptools python-virtualenv sqlite3 -Installing prerequisites on CentOS 7:: +Installing prerequisites on CentOS 7 or Fedora 25:: sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ - lcms2-devel libwebp-devel tcl-devel tk-devel \ + lcms2-devel libwebp-devel tcl-devel tk-devel redhat-rpm-config \ python-virtualenv libffi-devel openssl-devel sudo yum groupinstall "Development Tools" diff --git a/docs/log_contexts.rst b/docs/log_contexts.rst index 8d04a973d..eb1784e70 100644 --- a/docs/log_contexts.rst +++ b/docs/log_contexts.rst @@ -204,9 +204,14 @@ That doesn't follow the rules, but we can fix it by wrapping it with This technique works equally for external functions which return deferreds, or deferreds we have made ourselves. -XXX: think this is what ``preserve_context_over_deferred`` is supposed to do, -though it is broken, in that it only restores the logcontext for the duration -of the callbacks, which doesn't comply with the logcontext rules. +You can also use ``logcontext.make_deferred_yieldable``, which just does the +boilerplate for you, so the above could be written: + +.. code:: python + + def sleep(seconds): + return logcontext.make_deferred_yieldable(get_sleep_deferred(seconds)) + Fire-and-forget --------------- diff --git a/docs/postgres.rst b/docs/postgres.rst index 402ff9a4d..b592801e9 100644 --- a/docs/postgres.rst +++ b/docs/postgres.rst @@ -112,9 +112,9 @@ script one last time, e.g. if the SQLite database is at ``homeserver.db`` run:: synapse_port_db --sqlite-database homeserver.db \ - --postgres-config database_config.yaml + --postgres-config homeserver-postgres.yaml Once that has completed, change the synapse config to point at the PostgreSQL -database configuration file using the ``database_config`` parameter (see -`Synapse Config`_) and restart synapse. Synapse should now be running against +database configuration file ``homeserver-postgres.yaml`` (i.e. rename it to +``homeserver.yaml``) and restart synapse. Synapse should now be running against PostgreSQL. diff --git a/docs/turn-howto.rst b/docs/turn-howto.rst index 04c010071..e48628ce6 100644 --- a/docs/turn-howto.rst +++ b/docs/turn-howto.rst @@ -50,14 +50,37 @@ You may be able to setup coturn via your package manager, or set it up manually pwgen -s 64 1 - 5. Ensure youe firewall allows traffic into the TURN server on - the ports you've configured it to listen on (remember to allow - both TCP and UDP if you've enabled both). + 5. Consider your security settings. TURN lets users request a relay + which will connect to arbitrary IP addresses and ports. At the least + we recommend: - 6. If you've configured coturn to support TLS/DTLS, generate or + # VoIP traffic is all UDP. There is no reason to let users connect to arbitrary TCP endpoints via the relay. + no-tcp-relay + + # don't let the relay ever try to connect to private IP address ranges within your network (if any) + # given the turn server is likely behind your firewall, remember to include any privileged public IPs too. + denied-peer-ip=10.0.0.0-10.255.255.255 + denied-peer-ip=192.168.0.0-192.168.255.255 + denied-peer-ip=172.16.0.0-172.31.255.255 + + # special case the turn server itself so that client->TURN->TURN->client flows work + allowed-peer-ip=10.0.0.1 + + # consider whether you want to limit the quota of relayed streams per user (or total) to avoid risk of DoS. + user-quota=12 # 4 streams per video call, so 12 streams = 3 simultaneous relayed calls per user. + total-quota=1200 + + Ideally coturn should refuse to relay traffic which isn't SRTP; + see https://github.com/matrix-org/synapse/issues/2009 + + 6. Ensure your firewall allows traffic into the TURN server on + the ports you've configured it to listen on (remember to allow + both TCP and UDP TURN traffic) + + 7. If you've configured coturn to support TLS/DTLS, generate or import your private key and certificate. - 7. Start the turn server:: + 8. Start the turn server:: bin/turnserver -o @@ -83,12 +106,19 @@ Your home server configuration file needs the following extra keys: to refresh credentials. The TURN REST API specification recommends one day (86400000). + 4. "turn_allow_guests": Whether to allow guest users to use the TURN + server. This is enabled by default, as otherwise VoIP will not + work reliably for guests. However, it does introduce a security risk + as it lets guests connect to arbitrary endpoints without having gone + through a CAPTCHA or similar to register a real account. + As an example, here is the relevant section of the config file for matrix.org:: turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ] turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons turn_user_lifetime: 86400000 + turn_allow_guests: True Now, restart synapse:: diff --git a/scripts-dev/nuke-room-from-db.sh b/scripts-dev/nuke-room-from-db.sh index 58c036c89..1201d176c 100755 --- a/scripts-dev/nuke-room-from-db.sh +++ b/scripts-dev/nuke-room-from-db.sh @@ -9,16 +9,39 @@ ROOMID="$1" sqlite3 homeserver.db < regex for push. See _glob_matches +regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +register_cache("regex_push_cache", regex_cache) + + def _glob_matches(glob, value, word_boundary=False): """Tests if value matches glob. @@ -137,46 +143,63 @@ def _glob_matches(glob, value, word_boundary=False): Returns: bool """ + try: - if IS_GLOB.search(glob): - r = re.escape(glob) - - r = r.replace(r'\*', '.*?') - r = r.replace(r'\?', '.') - - # handle [abc], [a-z] and [!a-z] style ranges. - r = GLOB_REGEX.sub( - lambda x: ( - '[%s%s]' % ( - x.group(1) and '^' or '', - x.group(2).replace(r'\\\-', '-') - ) - ), - r, - ) - if word_boundary: - r = r"\b%s\b" % (r,) - r = _compile_regex(r) - - return r.search(value) - else: - r = r + "$" - r = _compile_regex(r) - - return r.match(value) - elif word_boundary: - r = re.escape(glob) - r = r"\b%s\b" % (r,) - r = _compile_regex(r) - - return r.search(value) - else: - return value.lower() == glob.lower() + r = regex_cache.get((glob, word_boundary), None) + if not r: + r = _glob_to_re(glob, word_boundary) + regex_cache[(glob, word_boundary)] = r + return r.search(value) except re.error: logger.warn("Failed to parse glob to regex: %r", glob) return False +def _glob_to_re(glob, word_boundary): + """Generates regex for a given glob. + + Args: + glob (string) + word_boundary (bool): Whether to match against word boundaries or entire + string. Defaults to False. + + Returns: + regex object + """ + if IS_GLOB.search(glob): + r = re.escape(glob) + + r = r.replace(r'\*', '.*?') + r = r.replace(r'\?', '.') + + # handle [abc], [a-z] and [!a-z] style ranges. + r = GLOB_REGEX.sub( + lambda x: ( + '[%s%s]' % ( + x.group(1) and '^' or '', + x.group(2).replace(r'\\\-', '-') + ) + ), + r, + ) + if word_boundary: + r = r"\b%s\b" % (r,) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + r + "$" + + return re.compile(r, flags=re.IGNORECASE) + elif word_boundary: + r = re.escape(glob) + r = r"\b%s\b" % (r,) + + return re.compile(r, flags=re.IGNORECASE) + else: + r = "^" + re.escape(glob) + "$" + return re.compile(r, flags=re.IGNORECASE) + + def _flatten_dict(d, prefix=[], result={}): for key, value in d.items(): if isinstance(value, basestring): @@ -185,16 +208,3 @@ def _flatten_dict(d, prefix=[], result={}): _flatten_dict(value, prefix=(prefix + [key]), result=result) return result - - -regex_cache = LruCache(5000) - - -def _compile_regex(regex_str): - r = regex_cache.get(regex_str, None) - if r: - return r - - r = re.compile(regex_str, flags=re.IGNORECASE) - regex_cache[regex_str] = r - return r diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 287df94b4..6835f54e9 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -17,15 +17,12 @@ from twisted.internet import defer from synapse.push.presentable_names import ( calculate_room_name, name_from_member_event ) -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @defer.inlineCallbacks def get_badge_count(store, user_id): - invites, joins = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.get_invited_rooms_for_user)(user_id), - preserve_fn(store.get_rooms_for_user)(user_id), - ], consumeErrors=True)) + invites = yield store.get_invited_rooms_for_user(user_id) + joins = yield store.get_rooms_for_user(user_id) my_receipts_by_room = yield store.get_receipts_for_user( user_id, "m.read", diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 03141c623..c43b30b73 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,10 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req( + request, + self.hs.config.turn_allow_guests + ) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index 31f94bc6e..6fceb23e2 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -36,7 +36,7 @@ class ThirdPartyProtocolsServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols() defer.returnValue((200, protocols)) @@ -54,7 +54,7 @@ class ThirdPartyProtocolServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) protocols = yield self.appservice_handler.get_3pe_protocols( only_protocol=protocol, @@ -77,7 +77,7 @@ class ThirdPartyUserServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) @@ -101,7 +101,7 @@ class ThirdPartyLocationServlet(RestServlet): @defer.inlineCallbacks def on_GET(self, request, protocol): - yield self.auth.get_user_by_req(request) + yield self.auth.get_user_by_req(request, allow_guest=True) fields = request.args fields.pop("access_token", None) diff --git a/synapse/types.py b/synapse/types.py index 9666f9d73..c87ed813b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -216,9 +216,7 @@ class StreamToken( return self def copy_and_replace(self, key, new_value): - d = self._asdict() - d[key] = new_value - return StreamToken(**d) + return self._replace(**{key: new_value}) StreamToken.START = StreamToken( diff --git a/synapse/util/async.py b/synapse/util/async.py index 35380bf8e..1453faf0e 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -89,6 +89,11 @@ class ObservableDeferred(object): deferred.addCallbacks(callback, errback) def observe(self): + """Observe the underlying deferred. + + Can return either a deferred if the underlying deferred is still pending + (or has failed), or the actual value. Callers may need to use maybeDeferred. + """ if not self._result: d = defer.Deferred() @@ -101,7 +106,7 @@ class ObservableDeferred(object): return d else: success, res = self._result - return defer.succeed(res) if success else defer.fail(res) + return res if success else defer.fail(res) def observers(self): return self._observers diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 19595df42..9d0d0be1f 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -15,12 +15,9 @@ import logging from synapse.util.async import ObservableDeferred -from synapse.util import unwrapFirstError +from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry -from synapse.util.logcontext import ( - PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn -) from . import DEBUG_CACHES, register_cache @@ -227,8 +224,20 @@ class _CacheDescriptorBase(object): ) self.num_args = num_args + + # list of the names of the args used as the cache key self.arg_names = all_args[1:num_args + 1] + # self.arg_defaults is a map of arg name to its default value for each + # argument that has a default value + if arg_spec.defaults: + self.arg_defaults = dict(zip( + all_args[-len(arg_spec.defaults):], + arg_spec.defaults + )) + else: + self.arg_defaults = {} + if "cache_context" in self.arg_names: raise Exception( "cache_context arg cannot be included among the cache keys" @@ -292,18 +301,31 @@ class CacheDescriptor(_CacheDescriptorBase): iterable=self.iterable, ) + def get_cache_key(args, kwargs): + """Given some args/kwargs return a generator that resolves into + the cache_key. + + We loop through each arg name, looking up if its in the `kwargs`, + otherwise using the next argument in `args`. If there are no more + args then we try looking the arg name up in the defaults + """ + pos = 0 + for nm in self.arg_names: + if nm in kwargs: + yield kwargs[nm] + elif pos < len(args): + yield args[pos] + pos += 1 + else: + yield self.arg_defaults[nm] + @functools.wraps(self.orig) def wrapped(*args, **kwargs): # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) - # Add temp cache_context so inspect.getcallargs doesn't explode - if self.add_cache_context: - kwargs["cache_context"] = None - - arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs) - cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names) + cache_key = tuple(get_cache_key(args, kwargs)) # Add our own `cache_context` to argument list if the wrapped function # has asked for one @@ -328,11 +350,9 @@ class CacheDescriptor(_CacheDescriptorBase): defer.returnValue(cached_result) observer.addCallback(check_result) - return preserve_context_over_deferred(observer) except KeyError: ret = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs ) @@ -342,10 +362,14 @@ class CacheDescriptor(_CacheDescriptorBase): ret.addErrback(onErr) - ret = ObservableDeferred(ret, consumeErrors=True) - cache.set(cache_key, ret, callback=invalidate_callback) + result_d = ObservableDeferred(ret, consumeErrors=True) + cache.set(cache_key, result_d, callback=invalidate_callback) + observer = result_d.observe() - return preserve_context_over_deferred(ret.observe()) + if isinstance(observer, defer.Deferred): + return logcontext.make_deferred_yieldable(observer) + else: + return observer wrapped.invalidate = cache.invalidate wrapped.invalidate_all = cache.invalidate_all @@ -362,7 +386,11 @@ class CacheListDescriptor(_CacheDescriptorBase): """Wraps an existing cache to support bulk fetching of keys. Given a list of keys it looks in the cache to find any hits, then passes - the list of missing keys to the wrapped fucntion. + the list of missing keys to the wrapped function. + + Once wrapped, the function returns either a Deferred which resolves to + the list of results, or (if all results were cached), just the list of + results. """ def __init__(self, orig, cached_method_name, list_name, num_args=None, @@ -433,8 +461,7 @@ class CacheListDescriptor(_CacheDescriptorBase): args_to_call[self.list_name] = missing ret_d = defer.maybeDeferred( - preserve_context_over_fn, - self.function_to_call, + logcontext.preserve_fn(self.function_to_call), **args_to_call ) @@ -443,8 +470,7 @@ class CacheListDescriptor(_CacheDescriptorBase): # We need to create deferreds for each arg in the list so that # we can insert the new deferred into the cache. for arg in missing: - with PreserveLoggingContext(): - observer = ret_d.observe() + observer = ret_d.observe() observer.addCallback(lambda r, arg: r.get(arg, None), arg) observer = ObservableDeferred(observer) @@ -471,7 +497,7 @@ class CacheListDescriptor(_CacheDescriptorBase): results.update(res) return results - return preserve_context_over_deferred(defer.gatherResults( + return logcontext.make_deferred_yieldable(defer.gatherResults( cached_defers.values(), consumeErrors=True, ).addCallback(update_results_dict).addErrback( diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py index ff67b1d79..990216145 100644 --- a/synapse/util/logcontext.py +++ b/synapse/util/logcontext.py @@ -310,6 +310,10 @@ def preserve_context_over_fn(fn, *args, **kwargs): def preserve_context_over_deferred(deferred, context=None): """Given a deferred wrap it such that any callbacks added later to it will be invoked with the current context. + + Deprecated: this almost certainly doesn't do want you want, ie make + the deferred follow the synapse logcontext rules: try + ``make_deferred_yieldable`` instead. """ if context is None: context = LoggingContext.current_context() @@ -330,12 +334,8 @@ def preserve_fn(f): LoggingContext.set_current_context(LoggingContext.sentinel) return result - # XXX: why is this here rather than inside g? surely we want to preserve - # the context from the time the function was called, not when it was - # wrapped? - current = LoggingContext.current_context() - def g(*args, **kwargs): + current = LoggingContext.current_context() res = f(*args, **kwargs) if isinstance(res, defer.Deferred) and not res.called: # The function will have reset the context before returning, so @@ -359,6 +359,25 @@ def preserve_fn(f): return g +@defer.inlineCallbacks +def make_deferred_yieldable(deferred): + """Given a deferred, make it follow the Synapse logcontext rules: + + If the deferred has completed (or is not actually a Deferred), essentially + does nothing (just returns another completed deferred with the + result/failure). + + If the deferred has not yet completed, resets the logcontext before + returning a deferred. Then, when the deferred completes, restores the + current logcontext before running callbacks/errbacks. + + (This is more-or-less the opposite operation to preserve_fn.) + """ + with PreserveLoggingContext(): + r = yield deferred + defer.returnValue(r) + + # modules to ignore in `logcontext_tracer` _to_ignore = [ "synapse.util.logcontext", diff --git a/synapse/visibility.py b/synapse/visibility.py index 31659156a..c4dd9ae2c 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state): events ([synapse.events.EventBase]): list of events to filter """ forgotten = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(store.who_forgot_in_room)( + defer.maybeDeferred( + preserve_fn(store.who_forgot_in_room), room_id, ) for room_id in frozenset(e.room_id for e in events) diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index aa8cc5055..7586ea905 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -19,10 +19,12 @@ from twisted.internet import defer from mock import Mock from tests import unittest +import re + def _regex(regex, exclusive=True): return { - "regex": regex, + "regex": re.compile(regex), "exclusive": exclusive } diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 8361dd8ce..281eb1625 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -199,7 +199,7 @@ class CacheDecoratorTestCase(unittest.TestCase): a.func.prefill(("foo",), ObservableDeferred(d)) - self.assertEquals(a.func("foo").result, d.result) + self.assertEquals(a.func("foo"), d.result) self.assertEquals(callcount[0], 0) @defer.inlineCallbacks diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 419281054..3f14ab503 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -12,11 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging + import mock +from synapse.api.errors import SynapseError +from synapse.util import async +from synapse.util import logcontext from twisted.internet import defer from synapse.util.caches import descriptors from tests import unittest +logger = logging.getLogger(__name__) + class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks @@ -84,3 +91,125 @@ class DescriptorTestCase(unittest.TestCase): r = yield obj.fn(2, 5) self.assertEqual(r, 'chips') obj.mock.assert_not_called() + + def test_cache_logcontexts(self): + """Check that logcontexts are set and restored correctly when + using the cache.""" + + complete_lookup = defer.Deferred() + + class Cls(object): + @descriptors.cached() + def fn(self, arg1): + @defer.inlineCallbacks + def inner_fn(): + with logcontext.PreserveLoggingContext(): + yield complete_lookup + defer.returnValue(1) + + return inner_fn() + + @defer.inlineCallbacks + def do_lookup(): + with logcontext.LoggingContext() as c1: + c1.name = "c1" + r = yield obj.fn(1) + self.assertEqual(logcontext.LoggingContext.current_context(), + c1) + defer.returnValue(r) + + def check_result(r): + self.assertEqual(r, 1) + + obj = Cls() + + # set off a deferred which will do a cache lookup + d1 = do_lookup() + self.assertEqual(logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel) + d1.addCallback(check_result) + + # and another + d2 = do_lookup() + self.assertEqual(logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel) + d2.addCallback(check_result) + + # let the lookup complete + complete_lookup.callback(None) + + return defer.gatherResults([d1, d2]) + + def test_cache_logcontexts_with_exception(self): + """Check that the cache sets and restores logcontexts correctly when + the lookup function throws an exception""" + + class Cls(object): + @descriptors.cached() + def fn(self, arg1): + @defer.inlineCallbacks + def inner_fn(): + yield async.run_on_reactor() + raise SynapseError(400, "blah") + + return inner_fn() + + @defer.inlineCallbacks + def do_lookup(): + with logcontext.LoggingContext() as c1: + c1.name = "c1" + try: + yield obj.fn(1) + self.fail("No exception thrown") + except SynapseError: + pass + + self.assertEqual(logcontext.LoggingContext.current_context(), + c1) + + obj = Cls() + + # set off a deferred which will do a cache lookup + d1 = do_lookup() + self.assertEqual(logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel) + + return d1 + + @defer.inlineCallbacks + def test_cache_default_args(self): + class Cls(object): + def __init__(self): + self.mock = mock.Mock() + + @descriptors.cached() + def fn(self, arg1, arg2=2, arg3=3): + return self.mock(arg1, arg2, arg3) + + obj = Cls() + + obj.mock.return_value = 'fish' + r = yield obj.fn(1, 2, 3) + self.assertEqual(r, 'fish') + obj.mock.assert_called_once_with(1, 2, 3) + obj.mock.reset_mock() + + # a call with same params shouldn't call the mock again + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + obj.mock.assert_not_called() + obj.mock.reset_mock() + + # a call with different params should call the mock again + obj.mock.return_value = 'chips' + r = yield obj.fn(2, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_called_once_with(2, 3, 3) + obj.mock.reset_mock() + + # the two values should now be cached + r = yield obj.fn(1, 2) + self.assertEqual(r, 'fish') + r = yield obj.fn(2, 3) + self.assertEqual(r, 'chips') + obj.mock.assert_not_called() diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index 7e289715b..d3a8630c2 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -53,7 +53,9 @@ class SnapshotCacheTestCase(unittest.TestCase): # before the cache expires returns a resolved deferred. get_result_at_11 = self.cache.get(11, "key") self.assertIsNotNone(get_result_at_11) - self.assertTrue(get_result_at_11.called) + if isinstance(get_result_at_11, Deferred): + # The cache may return the actual result rather than a deferred + self.assertTrue(get_result_at_11.called) # Check that getting the key after the deferred has resolved # after the cache expires returns None