diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a390a1b8b..e62722d78 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -29,6 +29,8 @@ class ReceiptsHandler(BaseHandler): def __init__(self, hs): super(ReceiptsHandler, self).__init__(hs) + self.server_name = hs.config.server_name + self.store = hs.get_datastore() self.hs = hs self.federation = hs.get_replication_layer() self.federation.register_edu_handler( @@ -131,12 +133,9 @@ class ReceiptsHandler(BaseHandler): event_ids = receipt["event_ids"] data = receipt["data"] - remotedomains = set() - - rm_handler = self.hs.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=None, remotedomains=remotedomains - ) + remotedomains = yield self.store.get_joined_hosts_for_room(room_id) + remotedomains = remotedomains.copy() + remotedomains.discard(self.server_name) logger.debug("Sending receipt to: %r", remotedomains) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b44e52a51..b785a8fa9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -55,35 +55,6 @@ class RoomMemberHandler(BaseHandler): self.distributor.declare("user_joined_room") self.distributor.declare("user_left_room") - @defer.inlineCallbacks - def get_room_members(self, room_id): - users = yield self.store.get_users_in_room(room_id) - - defer.returnValue([UserID.from_string(u) for u in users]) - - @defer.inlineCallbacks - def fetch_room_distributions_into(self, room_id, localusers=None, - remotedomains=None, ignore_user=None): - """Fetch the distribution of a room, adding elements to either - 'localusers' or 'remotedomains', which should be a set() if supplied. - If ignore_user is set, ignore that user. - - This function returns nothing; its result is performed by the - side-effect on the two passed sets. This allows easy accumulation of - member lists of multiple rooms at once if required. - """ - members = yield self.get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if self.hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - @defer.inlineCallbacks def _local_membership_update( self, requester, target, room_id, membership, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 8ce27f49e..92eff534d 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -39,7 +39,8 @@ class TypingNotificationHandler(BaseHandler): def __init__(self, hs): super(TypingNotificationHandler, self).__init__(hs) - self.homeserver = hs + self.store = hs.get_datastore() + self.server_name = hs.config.server_name self.clock = hs.get_clock() @@ -157,32 +158,26 @@ class TypingNotificationHandler(BaseHandler): @defer.inlineCallbacks def _push_update(self, room_id, user, typing): - localusers = set() - remotedomains = set() - - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers, remotedomains=remotedomains - ) - - if localusers: - self._push_update_local( - room_id=room_id, - user=user, - typing=typing - ) + domains = yield self.store.get_joined_hosts_for_room(room_id) deferreds = [] - for domain in remotedomains: - deferreds.append(self.federation.send_edu( - destination=domain, - edu_type="m.typing", - content={ - "room_id": room_id, - "user_id": user.to_string(), - "typing": typing, - }, - )) + for domain in domains: + if domain == self.server_name: + self._push_update_local( + room_id=room_id, + user=user, + typing=typing + ) + else: + deferreds.append(self.federation.send_edu( + destination=domain, + edu_type="m.typing", + content={ + "room_id": room_id, + "user_id": user.to_string(), + "typing": typing, + }, + )) yield defer.DeferredList(deferreds, consumeErrors=True) @@ -191,14 +186,9 @@ class TypingNotificationHandler(BaseHandler): room_id = content["room_id"] user = UserID.from_string(content["user_id"]) - localusers = set() + domains = yield self.store.get_joined_hosts_for_room(room_id) - rm_handler = self.homeserver.get_handlers().room_member_handler - yield rm_handler.fetch_room_distributions_into( - room_id, localusers=localusers - ) - - if localusers: + if self.server_name in domains: self._push_update_local( room_id=room_id, user=user, diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 3955e7f5b..d38ca37d6 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -71,6 +71,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.auth = Mock(spec=[]) hs = yield setup_test_homeserver( + "test", auth=self.auth, clock=self.clock, datastore=Mock(spec=[ @@ -110,56 +111,16 @@ class TypingNotificationsTestCase(unittest.TestCase): self.room_id = "a-room" - # Mock the RoomMemberHandler - hs.handlers.room_member_handler = Mock(spec=[]) - self.room_member_handler = hs.handlers.room_member_handler - self.room_members = [] - def get_rooms_for_user(user): - if user in self.room_members: - return defer.succeed([self.room_id]) - else: - return defer.succeed([]) - self.room_member_handler.get_rooms_for_user = get_rooms_for_user - - def get_room_members(room_id): - if room_id == self.room_id: - return defer.succeed(self.room_members) - else: - return defer.succeed([]) - self.room_member_handler.get_room_members = get_room_members - - def get_joined_rooms_for_user(user): - if user in self.room_members: - return defer.succeed([self.room_id]) - else: - return defer.succeed([]) - self.room_member_handler.get_joined_rooms_for_user = get_joined_rooms_for_user - - @defer.inlineCallbacks - def fetch_room_distributions_into( - room_id, localusers=None, remotedomains=None, ignore_user=None - ): - members = yield get_room_members(room_id) - for member in members: - if ignore_user is not None and member == ignore_user: - continue - - if hs.is_mine(member): - if localusers is not None: - localusers.add(member) - else: - if remotedomains is not None: - remotedomains.add(member.domain) - self.room_member_handler.fetch_room_distributions_into = ( - fetch_room_distributions_into - ) - def check_joined_room(room_id, user_id): if user_id not in [u.to_string() for u in self.room_members]: raise AuthError(401, "User is not in the room") + def get_joined_hosts_for_room(room_id): + return set(member.domain for member in self.room_members) + self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room + self.auth.check_joined_room = check_joined_room # Some local users to test with diff --git a/tests/utils.py b/tests/utils.py index 9d7978a64..59d985b5f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -50,7 +50,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.enable_registration = True config.macaroon_secret_key = "not even a little secret" config.expire_access_token = False - config.server_name = "server.under.test" + config.server_name = name config.trusted_third_party_id_servers = [] config.room_invite_state_types = []