Don't pull out the full state when storing state (#13274)

This commit is contained in:
Erik Johnston 2022-07-15 13:59:45 +01:00 committed by GitHub
parent 3343035a06
commit 0731e0829c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 130 additions and 69 deletions

View file

@ -298,12 +298,18 @@ class StateHandler:
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
state_ids_before_event = None
# We make sure that we have a state group assigned to the state.
if entry.state_group is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
# store_state_group requires us to have either a previous state group
# (with deltas) or the complete state map. So, if we don't have a
# previous state group, load the complete state map now.
if state_group_before_event_prev_group is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
@ -316,7 +322,6 @@ class StateHandler:
entry.state_group = state_group_before_event
else:
state_group_before_event = entry.state_group
state_ids_before_event = None
#
# now if it's not a state event, we're done
@ -336,19 +341,20 @@ class StateHandler:
#
# otherwise, we'll need to create a new state group for after the event
#
if state_ids_before_event is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
key = (event.type, event.state_key)
if key in state_ids_before_event:
replaces = state_ids_before_event[key]
if replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
state_ids_after_event = dict(state_ids_before_event)
state_ids_after_event[key] = event.event_id
if state_ids_before_event is not None:
replaces = state_ids_before_event.get(key)
else:
replaces_state_map = await entry.get_state(
self._state_storage_controller, StateFilter.from_types([key])
)
replaces = replaces_state_map.get(key)
if replaces and replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
delta_ids = {key: event.event_id}
state_group_after_event = (
@ -357,7 +363,7 @@ class StateHandler:
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
current_state_ids=state_ids_after_event,
current_state_ids=None,
)
)