Run black on the rest of the storage module (#4996)

This commit is contained in:
Amber Brown 2019-04-03 20:07:29 +11:00 committed by Richard van der Hoff
parent 3039d61baf
commit 7efd1d87c2
42 changed files with 2129 additions and 2453 deletions

View file

@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
class _GetStateGroupDelta(
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
__slots__ = []
def __len__(self):
@ -70,10 +73,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
self.types = {
k: v for k, v in iteritems(self.types)
if v is not None
}
self.types = {k: v for k, v in iteritems(self.types) if v is not None}
@staticmethod
def all():
@ -130,10 +130,7 @@ class StateFilter(object):
Returns:
StateFilter
"""
return StateFilter(
types={EventTypes.Member: set(members)},
include_others=True,
)
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
"""Creates a new StateFilter where type wild cards have been removed
@ -243,9 +240,7 @@ class StateFilter(object):
if where_clause:
where_clause += " OR "
where_clause += "type NOT IN (%s)" % (
",".join(["?"] * len(self.types)),
)
where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
where_args.extend(self.types)
return where_clause, where_args
@ -305,12 +300,8 @@ class StateFilter(object):
bool
"""
return (
self.include_others
or any(
state_keys is None
for state_keys in itervalues(self.types)
)
return self.include_others or any(
state_keys is None for state_keys in itervalues(self.types)
)
def concrete_types(self):
@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._state_group_cache = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000 * get_cache_factor_for("stateGroupCache")
50000 * get_cache_factor_for("stateGroupCache"),
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*",
500000 * get_cache_factor_for("stateGroupMembersCache")
500000 * get_cache_factor_for("stateGroupMembersCache"),
)
@defer.inlineCallbacks
@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: dict of (type, state_key) -> event_id
"""
def _get_current_state_ids_txn(txn):
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
""",
(room_id,)
(room_id,),
)
return {
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
return self.runInteraction(
"get_current_state_ids",
_get_current_state_ids_txn,
)
return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
return self.runInteraction(
"get_filtered_current_state_ids",
_get_filtered_current_state_ids_txn,
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@defer.inlineCallbacks
@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[str|None]: The canonical alias, if any
"""
state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types(
[(EventTypes.CanonicalAlias, "")]
))
state = yield self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
(prev_group, delta_ids), where both may be None.
"""
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
},
keyvalues={"state_group": state_group},
retcol="prev_state_group",
allow_none=True,
)
@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
},
retcols=("type", "state_key", "event_id",)
keyvalues={"state_group": state_group},
retcols=("type", "state_key", "event_id"),
)
return _GetStateGroupDelta(prev_group, {
(row["type"], row["state_key"]): row["event_id"]
for row in delta_ids
})
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
)
return _GetStateGroupDelta(
prev_group,
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_ids:
defer.returnValue({})
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_event_map = yield self.get_events(
[
ev_id for group_ids in itervalues(group_to_ids)
ev_id
for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids)
],
get_prev_content=False
get_prev_content=False,
)
defer.returnValue({
group: [
state_event_map[v] for v in itervalues(event_id_map)
if v in state_event_map
]
for group, event_id_map in iteritems(group_to_ids)
})
defer.returnValue(
{
group: [
state_event_map[v]
for v in itervalues(event_id_map)
if v in state_event_map
]
for group, event_id_map in iteritems(group_to_ids)
}
)
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = {}
chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, state_filter,
self._get_state_groups_from_groups_txn,
chunk,
state_filter,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all(),
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args
args,
)
results[group].update(
((typ, state_key), event_id)
@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None and
len(results[group]) == max_entries_returned
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False
get_prev_content=False,
)
event_to_state = {
@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
desc="_get_state_group_for_event",
)
@cachedList(cached_method_name="_get_state_group_for_event",
list_name="event_ids", num_args=1, inlineCallbacks=True)
@cachedList(
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id", "state_group",),
retcols=("event_id", "state_group"),
desc="_get_state_group_for_events",
)
@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Now we look them up in the member and non-member caches
non_member_state, incomplete_groups_nm, = (
yield self._get_state_for_groups_using_cache(
groups, self._state_group_cache,
state_filter=non_member_filter,
groups, self._state_group_cache, state_filter=non_member_filter
)
)
member_state, incomplete_groups_m, = (
yield self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache,
state_filter=member_filter,
groups, self._state_group_members_cache, state_filter=member_filter
)
)
@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups(
list(incomplete_groups),
state_filter=db_state_filter,
list(incomplete_groups), state_filter=db_state_filter
)
# Now lets update the caches
@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(state)
def _get_state_for_groups_using_cache(
self, groups, cache, state_filter,
):
def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results, incomplete_groups
def _insert_into_cache(self, group_to_state_dict, state_filter,
cache_seq_num_members, cache_seq_num_non_members):
def _insert_into_cache(
self,
group_to_state_dict,
state_filter,
cache_seq_num_members,
cache_seq_num_non_members,
):
"""Inserts results from querying the database into the relevant cache.
Args:
@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
fetched_keys=non_member_types,
)
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
"""Store a new set of state, returning a newly assigned state group.
Args:
@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[int]: The state group ID
"""
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": state_group,
"room_id": room_id,
"event_id": event_id,
},
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
# We persist as a delta if we can, while also ensuring the chain
@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
% (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": state_group,
"prev_state_group": prev_group,
},
values={"state_group": state_group, "prev_state_group": prev_group},
)
self._simple_insert_many_txn(
@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks
@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
if max_group is None:
rows = yield self._execute(
"_background_deduplicate_state", None,
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC"
" LIMIT 1",
(new_last_state_group, max_group,)
(new_last_state_group, max_group),
)
row = txn.fetchone()
if row:
@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT state_group FROM state_group_edges"
" WHERE state_group = ?",
(state_group,)
(state_group,),
)
# If we reach a point where we've already started inserting
@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT coalesce(max(id), 0) FROM state_groups"
" WHERE id < ? AND room_id = ?",
(state_group, room_id,)
(state_group, room_id),
)
prev_group, = txn.fetchone()
new_last_state_group = state_group
if prev_group:
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if potential_hops >= MAX_STATE_DELTA_HOPS:
# We want to ensure chains are at most this long,#
# otherwise read performance degrades.
continue
prev_state = self._get_state_groups_from_groups_txn(
txn, [prev_group],
txn, [prev_group]
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
txn, [state_group],
txn, [state_group]
)
curr_state = curr_state[state_group]
@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
# of keys
delta_state = {
key: value for key, value in iteritems(curr_state)
key: value
for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value
}
self._simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
}
keyvalues={"state_group": state_group},
)
self._simple_insert_txn(
@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
values={
"state_group": state_group,
"prev_state_group": prev_group,
}
},
)
self._simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
}
keyvalues={"state_group": state_group},
)
self._simple_insert_many_txn(
@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
)
if finished:
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
yield self._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
finally:
conn.set_session(autocommit=False)
else:
@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.runWithConnection(reindex_txn)