Limit the length of state chains

This commit is contained in:
Erik Johnston 2016-09-02 10:41:38 +01:00
parent 9e25443db8
commit 598317927c
2 changed files with 106 additions and 43 deletions

View file

@ -25,6 +25,9 @@ import logging
logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class StateStore(SQLBaseStore):
""" Keeps track of the state at a given event.
@ -104,7 +107,6 @@ class StateStore(SQLBaseStore):
state_groups[event.event_id] = context.state_group
if self._have_persisted_state_group_txn(txn, context.state_group):
logger.info("Already persisted state_group: %r", context.state_group)
continue
state_event_ids = dict(context.current_state_ids)
@ -120,29 +122,48 @@ class StateStore(SQLBaseStore):
)
if context.prev_group:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"prev_state_group": context.prev_group,
},
potential_hops = self._count_state_group_hops_txn(
txn, context.prev_group
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
if potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.items()
],
)
"prev_state_group": context.prev_group,
},
)
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.delta_ids.items()
],
)
else:
self._simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
{
"state_group": context.state_group,
"room_id": event.room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
for key, state_id in context.current_state_ids.items()
],
)
else:
self._simple_insert_many_txn(
txn,
@ -171,6 +192,41 @@ class StateStore(SQLBaseStore):
],
)
def _count_state_group_hops_txn(self, txn, state_group):
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None: