diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index fd0939722..7a57a69bd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -24,7 +24,7 @@ from synapse.push.action_generator import ActionGenerator from synapse.types import ( UserID, RoomAlias, RoomStreamToken, ) -from synapse.util.async import run_on_reactor, ReadWriteLock +from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter from synapse.util.logcontext import preserve_fn from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client @@ -50,6 +50,10 @@ class MessageHandler(BaseHandler): self.pagination_lock = ReadWriteLock() + # We arbitrarily limit concurrent event creation for a room to 5. + # This is to stop us from diverging history *too* much. + self.limiter = Limiter(max_count=5) + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -191,36 +195,38 @@ class MessageHandler(BaseHandler): """ builder = self.event_builder_factory.new(event_dict) - self.validator.validate_new(builder) + with (yield self.limiter.queue(builder.room_id)): + self.validator.validate_new(builder) - if builder.type == EventTypes.Member: - membership = builder.content.get("membership", None) - target = UserID.from_string(builder.state_key) + if builder.type == EventTypes.Member: + membership = builder.content.get("membership", None) + target = UserID.from_string(builder.state_key) - if membership in {Membership.JOIN, Membership.INVITE}: - # If event doesn't include a display name, add one. - profile = self.hs.get_handlers().profile_handler - content = builder.content + if membership in {Membership.JOIN, Membership.INVITE}: + # If event doesn't include a display name, add one. + profile = self.hs.get_handlers().profile_handler + content = builder.content - try: - content["displayname"] = yield profile.get_displayname(target) - content["avatar_url"] = yield profile.get_avatar_url(target) - except Exception as e: - logger.info( - "Failed to get profile information for %r: %s", - target, e - ) + try: + content["displayname"] = yield profile.get_displayname(target) + content["avatar_url"] = yield profile.get_avatar_url(target) + except Exception as e: + logger.info( + "Failed to get profile information for %r: %s", + target, e + ) - if token_id is not None: - builder.internal_metadata.token_id = token_id + if token_id is not None: + builder.internal_metadata.token_id = token_id - if txn_id is not None: - builder.internal_metadata.txn_id = txn_id + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id + + event, context = yield self._create_new_client_event( + builder=builder, + prev_event_ids=prev_event_ids, + ) - event, context = yield self._create_new_client_event( - builder=builder, - prev_event_ids=prev_event_ids, - ) defer.returnValue((event, context)) @defer.inlineCallbacks diff --git a/synapse/util/async.py b/synapse/util/async.py index 347fb1e38..16ed183d4 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -197,6 +197,64 @@ class Linearizer(object): defer.returnValue(_ctx_manager()) +class Limiter(object): + """Limits concurrent access to resources based on a key. Useful to ensure + only a few thing happen at a time on a given resource. + + Example: + + with (yield limiter.queue("test_key")): + # do some work. + + """ + def __init__(self, max_count): + """ + Args: + max_count(int): The maximum number of concurrent access + """ + self.max_count = max_count + + # key_to_defer is a map from the key to a 2 element list where + # the first element is the number of things executing + # the second element is a list of deferreds for the things blocked from + # executing. + self.key_to_defer = {} + + @defer.inlineCallbacks + def queue(self, key): + entry = self.key_to_defer.setdefault(key, [0, []]) + + # If the number of things executing is greater than the maximum + # then add a deferred to the list of blocked items + # When on of the things currently executing finishes it will callback + # this item so that it can continue executing. + if entry[0] >= self.max_count: + new_defer = defer.Deferred() + entry[1].append(new_defer) + with PreserveLoggingContext(): + yield new_defer + + entry[0] += 1 + + @contextmanager + def _ctx_manager(): + try: + yield + finally: + # We've finished executing so check if there are any things + # blocked waiting to execute and start one of them + entry[0] -= 1 + try: + entry[1].pop(0).callback(None) + except IndexError: + # If nothing else is executing for this key then remove it + # from the map + if entry[0] == 0: + self.key_to_defer.pop(key, None) + + defer.returnValue(_ctx_manager()) + + class ReadWriteLock(object): """A deferred style read write lock. diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py new file mode 100644 index 000000000..9c795d9fd --- /dev/null +++ b/tests/util/test_limiter.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +# Copyright 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tests import unittest + +from twisted.internet import defer + +from synapse.util.async import Limiter + + +class LimiterTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def test_limiter(self): + limiter = Limiter(3) + + key = object() + + d1 = limiter.queue(key) + cm1 = yield d1 + + d2 = limiter.queue(key) + cm2 = yield d2 + + d3 = limiter.queue(key) + cm3 = yield d3 + + d4 = limiter.queue(key) + self.assertFalse(d4.called) + + d5 = limiter.queue(key) + self.assertFalse(d5.called) + + with cm1: + self.assertFalse(d4.called) + self.assertFalse(d5.called) + + self.assertTrue(d4.called) + self.assertFalse(d5.called) + + with cm3: + self.assertFalse(d5.called) + + self.assertTrue(d5.called) + + with cm2: + pass + + with (yield d4): + pass + + with (yield d5): + pass + + d6 = limiter.queue(key) + with (yield d6): + pass