Run black on the rest of the storage module (#4996)

This commit is contained in:
Amber Brown 2019-04-03 20:07:29 +11:00 committed by Richard van der Hoff
parent 3039d61baf
commit 7efd1d87c2
42 changed files with 2129 additions and 2453 deletions

View file

@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple("_EventDictReturn", (
"event_id", "topological_ordering", "stream_ordering",
))
_EventDictReturn = namedtuple(
"_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
)
def lower_bound(token, engine, inclusive=False):
@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False):
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, inclusive, "stream_ordering",
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True):
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, inclusive, "stream_ordering",
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
@ -116,9 +130,7 @@ def filter_to_clause(event_filter):
args = []
if event_filter.types:
clauses.append(
"(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
)
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
args.extend(event_filter.types)
for typ in event_filter.not_types:
@ -126,9 +138,7 @@ def filter_to_clause(event_filter):
args.append(typ)
if event_filter.senders:
clauses.append(
"(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
@ -136,9 +146,7 @@ def filter_to_clause(event_filter):
args.append(sender)
if event_filter.rooms:
clauses.append(
"(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
)
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
db_conn,
"events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
"EventsRoomStreamChangeCache",
min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
"MembershipStreamChangeCache", events_max
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotImplementedError()
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
order='DESC'):
def get_room_events_stream_for_rooms(
self, room_ids, from_key, to_key, limit=0, order='DESC'
):
"""Get new room events in stream ordering since `from_key`.
Args:
@ -221,14 +232,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(defer.gatherResults([
run_in_background(
self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order,
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_room_events_stream_for_room,
room_id,
from_key,
to_key,
limit,
order=order,
)
for room_id in rm_ids
],
consumeErrors=True,
)
for room_id in rm_ids
], consumeErrors=True))
)
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set(
room_id for room_id in room_ids
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):
def get_room_events_stream_for_room(
self, room_id, from_key, to_key, limit=0, order='DESC'
):
"""Get new room events in stream ordering since `from_key`.
@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
txn.execute(sql, (user_id, from_id, to_id,))
txn.execute(sql, (user_id, from_id, to_id))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
self._set_before_and_after(ret, rows, topo_order=False)
@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token,
room_id, limit, end_token
)
logger.debug("stream before")
events = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
[r.event_id for r in rows], get_prev_content=True
)
logger.debug("stream after")
@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.runInteraction(
"get_recent_event_ids_for_room", self._paginate_room_events_txn,
room_id, from_token=end_token, limit=limit,
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=end_token,
limit=limit,
)
# We want to return the results in ascending order.
@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id)
"""
def _f(txn):
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" ORDER BY stream_ordering"
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering, ))
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
return self.runInteraction(
"get_room_event_after_stream_ordering", _f,
)
return self.runInteraction("get_room_event_after_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue("s%d" % (token,))
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn,
room_id,
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
defer.returnValue("t%d-%d" % (topo, token))
@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred "s%d" stream token.
"""
return self._simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id):
@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
).addCallback(
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
def get_max_topological_token(self, room_id, stream_key):
@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
"get_max_topological_token", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
)
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
" WHERE room_id = ?",
(room_id,)
"SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
(room_id,),
)
rows = txn.fetchall()
@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (
int(topo) if topo else 0,
int(stream),
)
internal.order = (int(topo) if topo else 0, int(stream))
@defer.inlineCallbacks
def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None,
self, room_id, event_id, before_limit, after_limit, event_filter=None
):
"""Retrieve events and pagination tokens around a given event in a
room.
@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn,
room_id, event_id, before_limit, after_limit, event_filter,
"get_events_around",
self._get_events_around_txn,
room_id,
event_id,
before_limit,
after_limit,
event_filter,
)
events_before = yield self._get_events(
[e for e in results["before"]["event_ids"]],
get_prev_content=True
[e for e in results["before"]["event_ids"]], get_prev_content=True
)
events_after = yield self._get_events(
[e for e in results["after"]["event_ids"]],
get_prev_content=True
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
defer.returnValue({
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
})
defer.returnValue(
{
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
}
)
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter,
self, txn, room_id, event_id, before_limit, after_limit, event_filter
):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = self._simple_select_one_txn(
txn,
"events",
keyvalues={
"event_id": event_id,
"room_id": room_id,
},
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
)
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
results["topological_ordering"] - 1,
results["stream_ordering"],
results["topological_ordering"] - 1, results["stream_ordering"]
)
after_token = RoomStreamToken(
results["topological_ordering"],
results["stream_ordering"],
results["topological_ordering"], results["stream_ordering"]
)
rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit,
txn,
room_id,
before_token,
direction='b',
limit=before_limit,
event_filter=event_filter,
)
events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit,
txn,
room_id,
after_token,
direction='f',
limit=after_limit,
event_filter=event_filter,
)
events_after = [r.event_id for r in rows]
return {
"before": {
"event_ids": events_before,
"token": start_token,
},
"after": {
"event_ids": events_after,
"token": end_token,
},
"before": {"event_ids": events_before, "token": start_token},
"after": {"event_ids": events_after, "token": end_token},
}
@defer.inlineCallbacks
@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn,
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self._get_events(event_ids)
@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
desc="get_federation_out_pos"
desc="get_federation_out_pos",
)
def update_federation_out_pos(self, typ, stream_id):
@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
direction='b', limit=-1, event_filter=None):
def _paginate_room_events_txn(
self,
txn,
room_id,
from_token,
to_token=None,
direction='b',
limit=-1,
event_filter=None,
):
"""Returns list of events before or after a given token.
Args:
@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
bounds = upper_bound(
from_token, self.database_engine
)
bounds = upper_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (bounds, lower_bound(
to_token, self.database_engine
))
bounds = "%s AND %s" % (
bounds,
lower_bound(to_token, self.database_engine),
)
else:
order = "ASC"
bounds = lower_bound(
from_token, self.database_engine
)
bounds = lower_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (bounds, upper_bound(
to_token, self.database_engine
))
bounds = "%s AND %s" % (
bounds,
upper_bound(to_token, self.database_engine),
)
filter_clause, filter_args = filter_to_clause(event_filter)
@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?"
) % {
"bounds": bounds,
"order": order,
}
) % {"bounds": bounds, "order": order}
txn.execute(sql, args)
@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
return rows, str(next_token),
return rows, str(next_token)
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, event_filter=None):
def paginate_room_events(
self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
):
"""Returns list of events before or after a given token.
Args:
@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction(
"paginate_room_events", self._paginate_room_events_txn,
room_id, from_key, to_key, direction, limit, event_filter,
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
from_key,
to_key,
direction,
limit,
event_filter,
)
events = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
[r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(events, rows)