Run black on the rest of the storage module (#4996)

This commit is contained in:
Amber Brown 2019-04-03 20:07:29 +11:00 committed by Richard van der Hoff
parent 3039d61baf
commit 7efd1d87c2
42 changed files with 2129 additions and 2453 deletions

View file

@ -30,13 +30,11 @@ logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple(
"OpsLevel",
("ban_level", "kick_level", "redact_level",)
"OpsLevel", ("ban_level", "kick_level", "redact_level")
)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride",
("messages_per_second", "burst_count",)
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore):
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
keyvalues={
"is_public": True,
},
keyvalues={"is_public": True},
retcol="room_id",
desc="get_public_room_ids",
)
@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore):
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple
stream_id,
network_tuple=network_tuple,
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
network_tuple):
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(
@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore):
if network_tuple:
# We want to get from a particular list. No aggregation required.
sql = ("""
sql = """
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
"""
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
(stream_id, network_tuple.appservice_id, network_tuple.network_id),
)
else:
txn.execute(
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list")
sql = ("""
sql = """
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
""")
"""
txn.execute(
sql,
(stream_id,)
)
txn.execute(sql, (stream_id,))
results = {}
# A room is visible if its visible on any list.
@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore):
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id,
network_tuple):
def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore):
def is_room_blocked(self, room_id):
return self._simple_select_one_onecol(
table="blocked_rooms",
keyvalues={
"room_id": room_id,
},
keyvalues={"room_id": room_id},
retcol="1",
allow_none=True,
desc="is_room_blocked",
@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore):
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
defer.returnValue(
RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
)
else:
defer.returnValue(None)
class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
StoreError if the room could not be stored.
"""
try:
def store_room_txn(txn, next_id):
self._simple_insert_txn(
txn,
@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
}
},
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"store_room_txn",
store_room_txn, next_id,
)
yield self.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public,
"appservice_id": None,
"network_id": None,
}
},
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public",
set_room_is_public_txn, next_id,
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
is_public):
def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public
):
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
is_public (bool): Whether to publish or unpublish the room from the
list.
"""
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
values={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
"room_id": room_id,
},
)
except self.database_engine.module.IntegrityError:
@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
keyvalues={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
"room_id": room_id,
},
)
@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public,
"appservice_id": appservice_id,
"network_id": network_id,
}
},
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn, next_id,
set_room_is_public_appservice_txn,
next_id,
)
self.hs.get_notifier().on_new_replication_data()
@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
return self.runInteraction(
"get_rooms", f
)
return self.runInteraction("get_rooms", f)
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
)
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"],
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event):
@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore):
"event_id": event.event_id,
"room_id": event.room_id,
"name": event.content["name"],
}
},
)
self.store_event_search_txn(
txn, event, "content.name", event.content["name"],
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self.store_event_search_txn(
txn, event, "content.body", event.content["body"],
txn, event, "content.body", event.content["body"]
)
def _store_history_visibility_txn(self, txn, event):
@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
" (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" % {"key": key}
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content[key]
))
txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next()
return self._simple_insert(
table="event_reports",
@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
"reason": reason,
"content": json.dumps(content),
},
desc="add_event_report"
desc="add_event_report",
)
def get_current_public_room_stream_id(self):
@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
""")
"""
txn.execute(sql, (prev_id, current_id, limit,))
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore):
"""
yield self._simple_upsert(
table="blocked_rooms",
keyvalues={
"room_id": room_id,
},
keyvalues={"room_id": room_id},
values={},
insertion_values={
"user_id": user_id,
},
insertion_values={"user_id": user_id},
desc="block_room",
)
yield self.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked, (room_id,),
self.is_room_blocked,
(room_id,),
)
def get_media_mxcs_in_room(self, room_id):
@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_mxcs))
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
)
),
)
total_media_quarantined += len(local_mxcs)
@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
return total_media_quarantined
return self.runInteraction(
"quarantine_media_in_room",
_quarantine_media_in_room_txn,
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):