diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index f2b3aceb4..67ad3dfd3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred from synapse.events import FrozenEvent +from synapse.types import get_domain_from_id import synapse.metrics from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination @@ -63,6 +64,7 @@ class FederationClient(FederationBase): self._clock.looping_call( self._clear_tried_cache, 60 * 1000, ) + self.state = hs.get_state_handler() def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" @@ -811,7 +813,8 @@ class FederationClient(FederationBase): if len(signed_events) >= limit: defer.returnValue(signed_events) - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) servers = set(servers) servers.discard(self.server_name) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4bea7f2b1..14352985e 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -19,7 +19,7 @@ from ._base import BaseHandler from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError from synapse.api.constants import EventTypes -from synapse.types import RoomAlias, UserID +from synapse.types import RoomAlias, UserID, get_domain_from_id import logging import string @@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler): # TODO(erikj): Add transactions. # TODO(erikj): Check if there is a current association. if not servers: - servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + servers = set(get_domain_from_id(u) for u in users) if not servers: raise SynapseError(400, "Failed to get server list") @@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler): Codes.NOT_FOUND ) - extra_servers = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + extra_servers = set(get_domain_from_id(u) for u in users) servers = set(extra_servers) | set(servers) # If this server is in the list of servers, return it first. diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 3a3a1257d..d3685fb12 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler): # Send down presence. if event.state_key == auth_user_id: # Send down presence for everyone in the room. - users = yield self.store.get_users_in_room(event.room_id) + users = yield self.state.get_current_user_in_room(event.room_id) states = yield presence_handler.get_states( users, as_event=True, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 6a1fe76c8..73752b2f8 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -88,6 +88,8 @@ class PresenceHandler(object): self.notifier = hs.get_notifier() self.federation = hs.get_replication_layer() + self.state = hs.get_state_handler() + self.federation.register_edu_handler( "m.presence", self.incoming_presence ) @@ -532,7 +534,9 @@ class PresenceHandler(object): if not local_states: continue - hosts = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + hosts = set(get_domain_from_id(u) for u in users) + for host in hosts: hosts_to_states.setdefault(host, []).extend(local_states) @@ -725,13 +729,13 @@ class PresenceHandler(object): # don't need to send to local clients here, as that is done as part # of the event stream/sync. # TODO: Only send to servers not already in the room. + user_ids = yield self.state.get_current_user_in_room(room_id) if self.is_mine(user): state = yield self.current_state_for_user(user.to_string()) - hosts = yield self.store.get_joined_hosts_for_room(room_id) + hosts = set(get_domain_from_id(u) for u in user_ids) self._push_to_remotes({host: (state,) for host in hosts}) else: - user_ids = yield self.store.get_users_in_room(room_id) user_ids = filter(self.is_mine_id, user_ids) states = yield self.current_state_for_users(user_ids) @@ -955,6 +959,7 @@ class PresenceEventSource(object): self.get_presence_handler = hs.get_presence_handler self.clock = hs.get_clock() self.store = hs.get_datastore() + self.state = hs.get_state_handler() @defer.inlineCallbacks @log_function @@ -1017,7 +1022,7 @@ class PresenceEventSource(object): user_ids_to_check = set() for room_id in room_ids: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) user_ids_to_check.update(users) user_ids_to_check.update(friends) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index e62722d78..726f7308d 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -18,6 +18,7 @@ from ._base import BaseHandler from twisted.internet import defer from synapse.util.logcontext import PreserveLoggingContext +from synapse.types import get_domain_from_id import logging @@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler): "m.receipt", self._received_remote_receipt ) self.clock = self.hs.get_clock() + self.state = hs.get_state_handler() @defer.inlineCallbacks def received_client_receipt(self, room_id, receipt_type, user_id, @@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler): event_ids = receipt["event_ids"] data = receipt["data"] - remotedomains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + remotedomains = set(get_domain_from_id(u) for u in users) remotedomains = remotedomains.copy() remotedomains.discard(self.server_name) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 5b9bce7f9..91934b0c8 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -142,6 +142,7 @@ class SyncHandler(object): self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.response_cache = ResponseCache(hs) + self.state = hs.get_state_handler() def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, full_state=False): @@ -670,7 +671,7 @@ class SyncHandler(object): extra_users_ids = set(newly_joined_users) for room_id in newly_joined_rooms: - users = yield self.store.get_users_in_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) extra_users_ids.update(users) extra_users_ids.discard(user.to_string()) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 46181984c..0b530b903 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -20,7 +20,7 @@ from synapse.util.logcontext import ( PreserveLoggingContext, preserve_fn, preserve_context_over_deferred, ) from synapse.util.metrics import Measure -from synapse.types import UserID +from synapse.types import UserID, get_domain_from_id import logging @@ -42,6 +42,7 @@ class TypingHandler(object): self.auth = hs.get_auth() self.is_mine_id = hs.is_mine_id self.notifier = hs.get_notifier() + self.state = hs.get_state_handler() self.clock = hs.get_clock() @@ -166,7 +167,8 @@ class TypingHandler(object): @defer.inlineCallbacks def _push_update(self, room_id, user_id, typing): - domains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) deferreds = [] for domain in domains: @@ -199,7 +201,8 @@ class TypingHandler(object): # Check that the string is a valid user id UserID.from_string(user_id) - domains = yield self.store.get_joined_hosts_for_room(room_id) + users = yield self.state.get_current_user_in_room(room_id) + domains = set(get_domain_from_id(u) for u in users) if self.server_name in domains: self._push_update_local( diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 8d49beaec..51cb21ee9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -87,7 +87,7 @@ class BulkPushRuleEvaluator: ) room_members = yield self.store.get_joined_users_from_context( - event.room_id, context, + event.room_id, context.state_group, context.current_state_ids ) evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 65e982a0c..b14ed4d2d 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -123,6 +123,11 @@ class SlavedEventStore(BaseSlavedStore): get_state_groups_ids = DataStore.get_state_groups_ids.__func__ get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__ get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__ + get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__ + _get_joined_users_from_context = ( + RoomMemberStore.__dict__["_get_joined_users_from_context"] + ) + get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__ get_room_events_stream_for_rooms = ( DataStore.get_room_events_stream_for_rooms.__func__ @@ -216,7 +221,6 @@ class SlavedEventStore(BaseSlavedStore): self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() self.get_users_in_room.invalidate((event.room_id,)) - # self.get_joined_hosts_for_room.invalidate((event.room_id,)) self._invalidate_get_event_cache(event.event_id) @@ -240,7 +244,6 @@ class SlavedEventStore(BaseSlavedStore): if event.type == EventTypes.Member: self.get_rooms_for_user.invalidate((event.state_key,)) - # self.get_joined_hosts_for_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,)) self._membership_stream_cache.entity_has_changed( event.state_key, event.internal_metadata.stream_ordering diff --git a/synapse/state.py b/synapse/state.py index 78461215c..daec983dc 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -124,6 +124,15 @@ class StateHandler(object): defer.returnValue(state) + @defer.inlineCallbacks + def get_current_user_in_room(self, room_id): + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) + group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids) + joined_users = yield self.store.get_joined_users_from_context( + room_id, group, state_ids + ) + defer.returnValue(joined_users) + @defer.inlineCallbacks def compute_event_context(self, event, old_state=None): """ Fills out the context with the `current state` of the graph. The diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 57e500528..5cbe8c597 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -393,7 +393,6 @@ class EventsStore(SQLBaseStore): txn.call_after(self._get_current_state_for_key.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) # Add an entry to the current_state_resets table to record the point # where we clobbered the current state diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 5f15200c2..cab166083 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore): for event in events: txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, @@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore): return results - @cachedInlineCallbacks(max_entries=5000) - def get_joined_hosts_for_room(self, room_id): - user_ids = yield self.get_users_in_room(room_id) - defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids)) - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): where_clause = "c.room_id = ?" where_values = [room_id] @@ -360,8 +354,7 @@ class RoomMemberStore(SQLBaseStore): desc="who_forgot" ) - def get_joined_users_from_context(self, room_id, context): - state_group = context.state_group + def get_joined_users_from_context(self, room_id, state_group, state_ids): if not state_group: # If state_group is None it means it has yet to be assigned a # state group, i.e. we need to make sure that calls with a state_group @@ -370,7 +363,7 @@ class RoomMemberStore(SQLBaseStore): state_group = object() return self._get_joined_users_from_context( - room_id, state_group, context.current_state_ids + room_id, state_group, state_ids ) @cachedInlineCallbacks(num_args=2, cache_context=True) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ab9899b7d..b2957eef9 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -62,6 +62,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.on_new_event = mock_notifier.on_new_event self.auth = Mock(spec=[]) + self.state_handler = Mock() hs = yield setup_test_homeserver( "test", @@ -75,6 +76,7 @@ class TypingNotificationsTestCase(unittest.TestCase): "set_received_txn_response", "get_destination_retry_timings", ]), + state_handler=self.state_handler, handlers=None, notifier=mock_notifier, resource_for_client=Mock(), @@ -113,6 +115,10 @@ class TypingNotificationsTestCase(unittest.TestCase): return set(member.domain for member in self.room_members) self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room + def get_current_user_in_room(room_id): + return set(str(u) for u in self.room_members) + self.state_handler.get_current_user_in_room = get_current_user_in_room + self.auth.check_joined_room = check_joined_room # Some local users to test with diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 27b2b3d12..1be7d932f 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -78,44 +78,3 @@ class RoomMemberStoreTestCase(unittest.TestCase): ) )] ) - - @defer.inlineCallbacks - def test_room_hosts(self): - yield self.inject_room_member(self.room, self.u_alice, Membership.JOIN) - - self.assertEquals( - {"test"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should still have just one host after second join from it - yield self.inject_room_member(self.room, self.u_bob, Membership.JOIN) - - self.assertEquals( - {"test"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should now have two hosts after join from other host - yield self.inject_room_member(self.room, self.u_charlie, Membership.JOIN) - - self.assertEquals( - {"test", "elsewhere"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should still have both hosts - yield self.inject_room_member(self.room, self.u_alice, Membership.LEAVE) - - self.assertEquals( - {"test", "elsewhere"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - ) - - # Should have only one host after other leaves - yield self.inject_room_member(self.room, self.u_charlie, Membership.LEAVE) - - self.assertEquals( - {"test"}, - (yield self.store.get_joined_hosts_for_room(self.room.to_string())) - )