diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 90f96209f..e83adc833 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -88,9 +88,13 @@ class BaseHandler(object): current_state = yield self.store.get_events( context.current_state_ids.values() ) - current_state = current_state.values() else: - current_state = yield self.store.get_current_state(event.room_id) + current_state = yield self.state_handler.get_current_state( + event.room_id + ) + + current_state = current_state.values() + logger.info("maybe_kick_guest_users %r", current_state) yield self.kick_guest_users(current_state) diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py index 64f18bbb3..b3f3bf748 100644 --- a/synapse/replication/slave/storage/events.py +++ b/synapse/replication/slave/storage/events.py @@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore): get_latest_event_ids_in_room = EventFederationStore.__dict__[ "get_latest_event_ids_in_room" ] - _get_current_state_for_key = StateStore.__dict__[ - "_get_current_state_for_key" - ] get_invited_rooms_for_user = RoomMemberStore.__dict__[ "get_invited_rooms_for_user" ] @@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore): ) get_event = DataStore.get_event.__func__ get_events = DataStore.get_events.__func__ - get_current_state = DataStore.get_current_state.__func__ - get_current_state_for_key = DataStore.get_current_state_for_key.__func__ get_rooms_for_user_where_membership_is = ( DataStore.get_rooms_for_user_where_membership_is.__func__ ) @@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore): def invalidate_caches_for_event(self, event, backfilled, reset_state): if reset_state: - self._get_current_state_for_key.invalidate_all() self.get_rooms_for_user.invalidate_all() self.get_users_in_room.invalidate((event.room_id,)) @@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore): if (not event.internal_metadata.is_invite_from_remote() and event.internal_metadata.is_outlier()): return - - self._get_current_state_for_key.invalidate(( - event.room_id, event.type, event.state_key - )) diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6160949f3..599db4c9f 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -284,71 +284,37 @@ class EventsStore(SQLBaseStore): new_forward_extremeties = {} current_state_for_room = {} if not backfilled: - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in events_by_room.items(): - # Work out new extremities by recursively adding and removing - # the new events. - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremeties( - room_id, [ev for ev, _ in ev_ctx_rm] - ) - - if new_latest_event_ids == set(latest_event_ids): - # No change in extremities, so no change in state - continue - - new_forward_extremeties[room_id] = new_latest_event_ids - - # Now we need to work out the different state sets for - # each state extremities - state_sets = [] - missing_event_ids = [] - was_updated = False - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that - for ev, ctx in ev_ctx_rm: - if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - state_sets.append(ctx.current_state_ids) - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - was_updated = True - missing_event_ids.append(event_id) - - if missing_event_ids: - # Now pull out the state for any missing events from DB - event_to_groups = yield self._get_state_group_for_events( - missing_event_ids, + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) ) - groups = set(event_to_groups.values()) - group_to_state = yield self._get_state_for_groups(groups) - - state_sets.extend(group_to_state.values()) - - if not new_latest_event_ids or was_updated: - current_state_for_room[room_id] = yield resolve_events( - state_sets, - state_map_factory=lambda ev_ids: self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), + for room_id, ev_ctx_rm in events_by_room.items(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id ) + new_latest_event_ids = yield self._calculate_new_extremeties( + room_id, [ev for ev, _ in ev_ctx_rm] + ) + + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue + + new_forward_extremeties[room_id] = new_latest_event_ids + + state = yield self._calculate_state_delta( + room_id, ev_ctx_rm, new_latest_event_ids + ) + if state: + current_state_for_room[room_id] = state yield self.runInteraction( "persist_events", @@ -405,6 +371,91 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, i.e. + (type, state_key) -> event_id. `to_delete` are the entreis to + first be deleted from current_state_events, `to_insert` are entries + to insert. + May return None if there are no changes to be applied. + """ + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in events_context: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.values()) + group_to_state = yield self._get_state_for_groups(groups) + + state_sets.extend(group_to_state.values()) + + if not new_latest_event_ids: + current_state = {} + elif was_updated: + current_state = yield resolve_events( + state_sets, + state_map_factory=lambda ev_ids: self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ), + ) + else: + return + + existing_state_rows = yield self._simple_select_list( + table="current_state_events", + keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], + desc="_calculate_state_delta", + ) + + existing_events = set(row["event_id"] for row in existing_state_rows) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + + if not changed_events: + return + + to_delete = { + (row["type"], row["state_key"]): row["event_id"] + for row in existing_state_rows + if row["event_id"] in changed_events + } + events_to_insert = (new_events - existing_events) + to_insert = { + key: ev_id for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + } + + defer.returnValue((to_delete, to_insert)) + @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, @@ -475,38 +526,55 @@ class EventsStore(SQLBaseStore): database before insertion. This is useful when retrying due to IntegrityError. """ max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - for room_id, current_state in current_state_for_room.iteritems(): - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (room_id,)) + for room_id, current_state_tuple in current_state_for_room.iteritems(): + to_delete, to_insert = current_state_tuple + txn.executemany( + "DELETE FROM current_state_events WHERE event_id = ?", + [(ev_id,) for ev_id in to_delete.itervalues()], + ) - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in to_insert.iteritems() + ], + ) - self._simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": room_id}, - ) + # Invalidate the various caches - self._simple_insert_many_txn( - txn, - table="current_state_events", - values=[ - { - "event_id": ev_id, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - } - for key, ev_id in current_state.iteritems() - ], - ) + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = set( + state_key for ev_type, state_key in to_delete.iterkeys() + if ev_type == EventTypes.Member + ) + members_changed.update( + state_key for ev_type, state_key in to_insert.iterkeys() + if ev_type == EventTypes.Member + ) + + for member in members_changed: + txn.call_after(self.get_rooms_for_user.invalidate, (member,)) + + txn.call_after(self.get_users_in_room.invalidate, (room_id,)) + + # Add an entry to the current_state_resets table to record the point + # where we clobbered the current state + self._simple_insert_txn( + txn, + table="current_state_resets", + values={"event_stream_ordering": max_stream_order} + ) for room_id, new_extrem in new_forward_extremeties.items(): self._simple_delete_txn( diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7d34dd03b..d1d653327 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -232,58 +232,6 @@ class StateStore(SQLBaseStore): return count - @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type and state_key is not None: - result = yield self.get_current_state_for_key( - room_id, event_type, state_key - ) - defer.returnValue(result) - - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? " - ) - - if event_type and state_key is not None: - sql += " AND type = ? AND state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - txn.execute(sql, args) - results = txn.fetchall() - - return [r[0] for r in results] - - event_ids = yield self.runInteraction("get_current_state", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_current_state_for_key(self, room_id, event_type, state_key): - event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @cached(num_args=3) - def _get_current_state_for_key(self, room_id, event_type, state_key): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?" - ) - - args = (room_id, event_type, state_key) - txn.execute(sql, args) - results = txn.fetchall() - return [r[0] for r in results] - return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index 38fedfe69..6acb8ab75 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -118,35 +118,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] ) - @defer.inlineCallbacks - def test_get_current_state(self): - # Create the room. - yield self.persist(type="m.room.create", key="", creator=USER_ID) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), [] - ) - - # Join the room. - join1 = yield self.persist( - type="m.room.member", key=USER_ID, membership="join", - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), - [join1] - ) - - # Add some other user to the room. - join2 = yield self.persist( - type="m.room.member", key=USER_ID_2, membership="join", - ) - yield self.replicate() - yield self.check( - "get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2), - [join2] - ) - @defer.inlineCallbacks def test_redactions(self): yield self.persist(type="m.room.create", key="", creator=USER_ID)