Update black, and run auto formatting over the codebase (#9381)

- Update black version to the latest
 - Run black auto formatting over the codebase
    - Run autoformatting according to [`docs/code_style.md
`](80d6dc9783/docs/code_style.md)
 - Update `code_style.md` docs around installing black to use the correct version
This commit is contained in:
Eric Eastwood 2021-02-16 16:32:34 -06:00 committed by GitHub
parent 5636e597c3
commit 0a00b7ff14
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
271 changed files with 2802 additions and 1713 deletions

View file

@ -43,8 +43,7 @@ __all__ = ["Databases", "DataStore"]
class Storage:
"""The high level interfaces for talking to various storage layers.
"""
"""The high level interfaces for talking to various storage layers."""
def __init__(self, hs: "HomeServer", stores: Databases):
# We include the main data store here mainly so that we don't have to

View file

@ -77,7 +77,7 @@ class BackgroundUpdatePerformance:
class BackgroundUpdater:
""" Background updates are updates to the database that run in the
"""Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
process and autotuning the batch size.
@ -158,8 +158,7 @@ class BackgroundUpdater:
return False
async def has_completed_background_update(self, update_name: str) -> bool:
"""Check if the given background update has finished running.
"""
"""Check if the given background update has finished running."""
if self._all_done:
return True
@ -198,7 +197,8 @@ class BackgroundUpdater:
if not self._current_background_update:
all_pending_updates = await self.db_pool.runInteraction(
"background_updates", get_background_updates_txn,
"background_updates",
get_background_updates_txn,
)
if not all_pending_updates:
# no work left to do

View file

@ -85,8 +85,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
def make_pool(
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
) -> adbapi.ConnectionPool:
"""Get the connection pool for the database.
"""
"""Get the connection pool for the database."""
# By default enable `cp_reconnect`. We need to fiddle with db_args in case
# someone has explicitly set `cp_reconnect`.
@ -432,8 +431,7 @@ class DatabasePool:
)
def is_running(self) -> bool:
"""Is the database pool currently running
"""
"""Is the database pool currently running"""
return self._db_pool.running
async def _check_safe_to_upsert(self) -> None:
@ -546,7 +544,11 @@ class DatabasePool:
# This can happen if the database disappears mid
# transaction.
transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
"[TXN OPERROR] {%s} %s %d/%d",
name,
e,
i,
N,
)
if i < N:
i += 1
@ -567,7 +569,9 @@ class DatabasePool:
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1,
"[TXN EROLL] {%s} %s",
name,
e1,
)
continue
raise
@ -1406,7 +1410,10 @@ class DatabasePool:
@staticmethod
def simple_select_onecol_txn(
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: str,
) -> List[Any]:
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
@ -1716,7 +1723,11 @@ class DatabasePool:
desc: description of the transaction, for logging and metrics
"""
await self.runInteraction(
desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
desc,
self.simple_delete_one_txn,
table,
keyvalues,
db_autocommit=True,
)
@staticmethod

View file

@ -56,7 +56,10 @@ class Databases:
database_config.databases,
)
prepare_database(
db_conn, engine, hs.config, databases=database_config.databases,
db_conn,
engine,
hs.config,
databases=database_config.databases,
)
database = DatabasePool(hs, database_config, engine)

View file

@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
return self.services_cache
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
"""Check if the user is one associated with an app service (exclusively)
"""
"""Check if the user is one associated with an app service (exclusively)"""
if self.exclusive_user_regex:
return bool(self.exclusive_user_regex.match(user_id))
else:

View file

@ -280,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return batch_size
async def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
"""Background update to insert last seen info into devices table"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
@ -363,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period.
"""
"""Removes entries in user IPs older than the configured period."""
if self.user_ips_max_age is None:
# Nothing to do
@ -565,7 +563,11 @@ class ClientIpStore(ClientIpWorkerStore):
results = {}
for key in self._batch_row_update:
uid, access_token, ip, = key
(
uid,
access_token,
ip,
) = key
if uid == user_id:
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)

View file

@ -315,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore):
# make sure we go through the devices in stream order
device_ids = sorted(
user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
user_devices.keys(),
key=lambda i: query_map[(user_id, i)][0],
)
for device_id in device_ids:
@ -366,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_as_sent_devices_by_remote(
self, destination: str, stream_id: int
) -> None:
"""Mark that updates have successfully been sent to the destination.
"""
"""Mark that updates have successfully been sent to the destination."""
await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
@ -681,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore):
return results
async def get_user_ids_requiring_device_list_resync(
self, user_ids: Optional[Collection[str]] = None,
self,
user_ids: Optional[Collection[str]] = None,
) -> Set[str]:
"""Given a list of remote users return the list of users that we
should resync the device lists for. If None is given instead of a list,
@ -721,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.
"""
"""Mark that we no longer track device lists for remote user."""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
self.db_pool.simple_delete_txn(
@ -902,7 +902,8 @@ class DeviceWorkerStore(SQLBaseStore):
logger.info("Pruned %d device list outbound pokes", count)
await self.db_pool.runInteraction(
"_prune_old_outbound_device_pokes", _prune_txn,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
@ -943,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
# clear out duplicate device list outbound pokes
self.db_pool.updates.register_background_update_handler(
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
self._remove_duplicate_outbound_pokes,
)
# a pair of background updates that were added during the 1.14 release cycle,
@ -1004,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
row = None
for row in rows:
self.db_pool.simple_delete_txn(
txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
txn,
"device_lists_outbound_pokes",
{x: row[x] for x in KEY_COLS},
)
row["sent"] = False
self.db_pool.simple_insert_txn(
txn, "device_lists_outbound_pokes", row,
txn,
"device_lists_outbound_pokes",
row,
)
if row:
self.db_pool.updates._background_update_progress_txn(
txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
txn,
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
{"last_row": row},
)
return len(rows)
@ -1286,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# we've done a full resync, so we remove the entry that says we need
# to resync
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
txn,
table="device_lists_remote_resync",
keyvalues={"user_id": user_id},
)
async def add_device_change_to_streams(
@ -1336,7 +1346,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[str],
):
txn.call_after(
self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
self._device_list_stream_cache.entity_has_changed,
user_id,
stream_ids[-1],
)
min_stream_id = stream_ids[0]

View file

@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore):
servers: Iterable[str],
creator: Optional[str] = None,
) -> None:
""" Creates an association between a room alias and room_id/servers
"""Creates an association between a room alias and room_id/servers
Args:
room_alias: The alias to create.
@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
self,
old_room_id: str,
new_room_id: str,
creator: Optional[str] = None,
) -> None:
"""Repoint all of the aliases for a given room, to a different room.

View file

@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
"""Count the number of one time keys the server has for a device
Returns:
A mapping from algorithm to number of keys for that algorithm.
"""
@ -494,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, txn: Connection, user_ids: List[str],
self,
txn: Connection,
user_ids: List[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
@ -556,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
return result
def _get_e2e_cross_signing_signatures_txn(
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
self,
txn: Connection,
keys: Dict[str, Dict[str, dict]],
from_user_id: str,
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing signatures made by a user on a set of keys.

View file

@ -71,7 +71,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return await self.get_events_as_list(event_ids)
async def get_auth_chain_ids(
self, event_ids: Collection[str], include_given: bool = False,
self,
event_ids: Collection[str],
include_given: bool = False,
) -> List[str]:
"""Get auth events for given event_ids. The events *must* be state events.
@ -273,7 +275,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# origin chain.
if origin_sequence_number <= chains.get(origin_chain_id, 0):
chains[target_chain_id] = max(
target_sequence_number, chains.get(target_chain_id, 0),
target_sequence_number,
chains.get(target_chain_id, 0),
)
seen_chains.add(target_chain_id)
@ -632,8 +635,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
async def get_min_depth(self, room_id: str) -> int:
"""For the given room, get the minimum depth we have seen for it.
"""
"""For the given room, get the minimum depth we have seen for it."""
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
@ -858,12 +860,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
await self.db_pool.runInteraction(
"_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated
"""Responsible for storing and serving up the various graphs associated
with an event. Including the main event graph and the auth chains for an
event.

View file

@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight):
def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
"""Custom deserializer for actions. This allows us to "compress" common actions"""
if actions:
return db_to_json(actions)
@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id: str, user_id: str, last_read_event_id: Optional[str],
self,
room_id: str,
user_id: str,
last_read_event_id: Optional[str],
) -> Dict[str, int]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id,
self,
txn,
room_id,
user_id,
last_read_event_id,
):
stream_ordering = None
if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn(
txn, last_read_event_id, allow_none=True,
txn,
last_read_event_id,
allow_none=True,
)
if stream_ordering is None:

View file

@ -399,7 +399,9 @@ class PersistEventsStore:
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _persist_event_auth_chain_txn(
self, txn: LoggingTransaction, events: List[EventBase],
self,
txn: LoggingTransaction,
events: List[EventBase],
) -> None:
# We only care about state events, so this if there are no state events.
@ -470,7 +472,11 @@ class PersistEventsStore:
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
txn,
self.db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
@classmethod
@ -517,7 +523,10 @@ class PersistEventsStore:
# simple_select_many, but this case happens rarely and almost always
# with a single row.)
auth_events = db_pool.simple_select_onecol_txn(
txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
txn,
"event_auth",
keyvalues={"event_id": event_id},
retcol="auth_id",
)
events_to_calc_chain_id_for.add(event_id)
@ -550,7 +559,9 @@ class PersistEventsStore:
WHERE
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", missing_auth_chains,
txn.database_engine,
"event_id",
missing_auth_chains,
)
txn.execute(sql + clause, args)
@ -704,7 +715,8 @@ class PersistEventsStore:
if chain_map[a_id][0] != chain_id
}
for start_auth_id, end_auth_id in itertools.permutations(
event_to_auth_chain.get(event_id, []), r=2,
event_to_auth_chain.get(event_id, []),
r=2,
):
if chain_links.exists_path_from(
chain_map[start_auth_id], chain_map[end_auth_id]
@ -888,8 +900,7 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
):
"""Persist the mapping from transaction IDs to event IDs (if defined).
"""
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
to_insert = []
for event, _ in events_and_contexts:
@ -909,7 +920,9 @@ class PersistEventsStore:
if to_insert:
self.db_pool.simple_insert_many_txn(
txn, table="event_txn_id", values=to_insert,
txn,
table="event_txn_id",
values=to_insert,
)
def _update_current_state_txn(
@ -941,7 +954,9 @@ class PersistEventsStore:
txn.execute(sql, (stream_id, self._instance_name, room_id))
self.db_pool.simple_delete_txn(
txn, table="current_state_events", keyvalues={"room_id": room_id},
txn,
table="current_state_events",
keyvalues={"room_id": room_id},
)
else:
# We're still in the room, so we update the current state as normal.
@ -1608,8 +1623,7 @@ class PersistEventsStore:
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
"""Store a room member in the database."""
def str_or_none(val: Any) -> Optional[str]:
return val if isinstance(val, str) else None
@ -2001,8 +2015,7 @@ class PersistEventsStore:
@attr.s(slots=True)
class _LinkMap:
"""A helper type for tracking links between chains.
"""
"""A helper type for tracking links between chains."""
# Stores the set of links as nested maps: source chain ID -> target chain ID
# -> source sequence number -> target sequence number.
@ -2108,7 +2121,9 @@ class _LinkMap:
yield (src_chain, src_seq, target_chain, target_seq)
def exists_path_from(
self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
self,
src_tuple: Tuple[int, int],
target_tuple: Tuple[int, int],
) -> bool:
"""Checks if there is a path between the source chain ID/sequence and
target chain ID/sequence.

View file

@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn.
"""
"""Return value for _calculate_chain_cover_txn."""
# The last room_id/depth/stream processed.
room_id = attr.ib(type=str)
@ -127,11 +126,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
self.db_pool.updates.register_background_update_handler(
"rejected_events_metadata", self._rejected_events_metadata,
"rejected_events_metadata",
self._rejected_events_metadata,
)
self.db_pool.updates.register_background_update_handler(
"chain_cover", self._chain_cover_index,
"chain_cover",
self._chain_cover_index,
)
async def _background_reindex_fields_sender(self, progress, batch_size):
@ -462,8 +463,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return num_handled
async def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions.
"""
"""Handles filling out the `received_ts` column in redactions."""
last_event_id = progress.get("last_event_id", "")
def _redactions_received_ts_txn(txn):
@ -518,8 +518,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return count
async def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
"""Undoes hex encoded censored redacted event JSON."""
def _event_fix_redactions_bytes_txn(txn):
# This update is quite fast due to new index.
@ -642,7 +641,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_event_id, batch_size,))
txn.execute(
sql,
(
last_event_id,
batch_size,
),
)
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
@ -910,7 +915,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
# Annoyingly we need to gut wrench into the persit event store so that
# we can reuse the function to calculate the chain cover for rooms.
PersistEventsStore._add_chain_cover_index(
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
txn,
self.db_pool,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
return _CalculateChainCover(

View file

@ -71,7 +71,9 @@ class EventForwardExtremitiesStore(SQLBaseStore):
if txn.rowcount > 0:
# Invalidate the cache
self._invalidate_cache_and_stream(
txn, self.get_latest_event_ids_in_room, (room_id,),
txn,
self.get_latest_event_ids_in_room,
(room_id,),
)
return txn.rowcount
@ -97,5 +99,6 @@ class EventForwardExtremitiesStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
"get_forward_extremities_for_room",
get_forward_extremities_for_room_txn,
)

View file

@ -120,7 +120,9 @@ class EventsWorkerStore(SQLBaseStore):
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
db_conn,
"events",
"stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@ -140,7 +142,8 @@ class EventsWorkerStore(SQLBaseStore):
if hs.config.run_background_tasks:
# We periodically clean out old transaction ID mappings
self._clock.looping_call(
self._cleanup_old_transaction_ids, 5 * 60 * 1000,
self._cleanup_old_transaction_ids,
5 * 60 * 1000,
)
self._get_event_cache = LruCache(
@ -1325,8 +1328,7 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
"""Returns True if event_id1 is after event_id2 in the stream"""
to_1, so_1 = await self.get_event_ordering(event_id1)
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@ -1428,8 +1430,7 @@ class EventsWorkerStore(SQLBaseStore):
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs.
"""
"""Cleans out transaction id mappings older than 24hrs."""
def _cleanup_old_transaction_ids_txn(txn):
sql = """
@ -1440,5 +1441,6 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (one_day_ago,))
return await self.db_pool.runInteraction(
"_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
"_cleanup_old_transaction_ids",
_cleanup_old_transaction_ids_txn,
)

View file

@ -123,7 +123,9 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
self,
group_id: str,
include_private: bool = False,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
@ -368,8 +370,7 @@ class GroupServerWorkerStore(SQLBaseStore):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
"""Has the group server invited a user?
"""
"""Has the group server invited a user?"""
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -427,8 +428,7 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising
"""
"""Get all groups a user is publicising"""
return await self.db_pool.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
@ -437,8 +437,7 @@ class GroupServerWorkerStore(SQLBaseStore):
)
async def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time
"""
"""Get all attestations that need to be renewed until givent time"""
def _get_attestations_need_renewals_txn(txn):
sql = """
@ -781,8 +780,7 @@ class GroupServerStore(GroupServerWorkerStore):
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/update room category for group
"""
"""Add/update room category for group"""
insertion_values = {}
update_values = {"category_id": category_id} # This cannot be empty
@ -818,8 +816,7 @@ class GroupServerStore(GroupServerWorkerStore):
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/remove user role
"""
"""Add/remove user role"""
insertion_values = {}
update_values = {"role_id": role_id} # This cannot be empty
@ -1012,8 +1009,7 @@ class GroupServerStore(GroupServerWorkerStore):
)
async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user
"""
"""Record that the group server has invited a user"""
await self.db_pool.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
@ -1156,8 +1152,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_group_publicity(
self, group_id: str, user_id: str, publicise: bool
) -> None:
"""Update whether the user is publicising their membership of the group
"""
"""Update whether the user is publicising their membership of the group"""
await self.db_pool.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -1300,8 +1295,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_attestation_renewal(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that we have renewed
"""
"""Update an attestation that we have renewed"""
await self.db_pool.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
@ -1312,8 +1306,7 @@ class GroupServerStore(GroupServerWorkerStore):
async def update_remote_attestion(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that a remote has renewed
"""
"""Update an attestation that a remote has renewed"""
await self.db_pool.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},

View file

@ -33,8 +33,7 @@ db_binary_type = memoryview
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
"""Persistence for signature verification keys"""
@cached()
def _get_server_verify_key(self, server_name_and_key_id):

View file

@ -169,7 +169,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_before(
self, before_ts: int, size_gt: int, keep_profiles: bool,
self,
before_ts: int,
size_gt: int,
keep_profiles: bool,
) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL)
@ -454,10 +457,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_remote_media_thumbnail(
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
self,
origin: str,
media_id: str,
t_width: int,
t_height: int,
t_type: str,
) -> Optional[Dict[str, Any]]:
"""Fetch the thumbnail info of given width, height and type.
"""
"""Fetch the thumbnail info of given width, height and type."""
return await self.db_pool.simple_select_one(
table="remote_media_cache_thumbnails",

View file

@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
)
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(

View file

@ -118,8 +118,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
"""Check whether we are interested in a remote user's profile."""
res = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
@ -145,8 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`
"""
"""Get all users who haven't been checked since `last_checked`"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """

View file

@ -168,7 +168,9 @@ class PushRulesWorkerStore(
)
@cachedList(
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
)
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
@ -195,7 +197,9 @@ class PushRulesWorkerStore(
use_new_defaults = user_id in self._users_new_default_push_rules
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
rules,
enabled_map_by_user.get(user_id, {}),
use_new_defaults,
)
return results

View file

@ -179,7 +179,9 @@ class PusherWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
)
async def get_if_users_have_pushers(
self, user_ids: Iterable[str]
@ -263,7 +265,8 @@ class PusherWorkerStore(SQLBaseStore):
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = ThrottleParams(
row["last_sent_ts"], row["throttle_ms"],
row["last_sent_ts"],
row["throttle_ms"],
)
return params_by_room

View file

@ -208,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""See get_linearized_receipts_for_room
"""
"""See get_linearized_receipts_for_room"""
def f(txn):
if from_key:
@ -304,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
}
return results
@cached(num_args=2,)
@cached(
num_args=2,
)
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:

View file

@ -79,13 +79,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
database.engine,
find_max_generated_user_id_localpart,
"user_id_seq",
)
self._account_validity = hs.config.account_validity
if hs.config.run_background_tasks and self._account_validity.enabled:
self._clock.call_later(
0.0, self._set_expiration_date_when_missing,
0.0,
self._set_expiration_date_when_missing,
)
# Create a background job for culling expired 3PID validity tokens

View file

@ -193,8 +193,7 @@ class RoomWorkerStore(SQLBaseStore):
)
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms.
"""
"""Retrieve the total number of rooms."""
def f(txn):
sql = "SELECT count(*) FROM rooms"
@ -517,7 +516,8 @@ class RoomWorkerStore(SQLBaseStore):
return rooms, room_count[0]
return await self.db_pool.runInteraction(
"get_rooms_paginate", _get_rooms_paginate_txn,
"get_rooms_paginate",
_get_rooms_paginate_txn,
)
@cached(max_entries=10000)
@ -578,7 +578,8 @@ class RoomWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
ret = await self.db_pool.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
"get_retention_policy_for_room",
get_retention_policy_for_room_txn,
)
# If we don't know this room ID, ret will be None, in this case return the default
@ -707,7 +708,10 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
self,
server_name: str,
media_id: str,
quarantined_by: str,
) -> int:
"""quarantines a single local or remote media id
@ -961,7 +965,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self.config = hs.config
self.db_pool.updates.register_background_update_handler(
"insert_room_retention", self._background_insert_retention,
"insert_room_retention",
self._background_insert_retention,
)
self.db_pool.updates.register_background_update_handler(
@ -1033,7 +1038,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return False
end = await self.db_pool.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
"insert_room_retention",
_background_insert_retention_txn,
)
if end:
@ -1588,7 +1594,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
LIMIT ?
OFFSET ?
""".format(
where_clause=where_clause, order=order,
where_clause=where_clause,
order=order,
)
args += [limit, start]

View file

@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
):
self._known_servers_count = 1
self.hs.get_clock().looping_call(
self._count_known_servers, 60 * 1000,
self._count_known_servers,
60 * 1000,
)
self.hs.get_clock().call_later(
1000, self._count_known_servers,
1000,
self._count_known_servers,
)
LaterGauge(
"synapse_federation_known_servers",
@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=100000)
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
""" Get the details of a room roughly suitable for use by the room
"""Get the details of a room roughly suitable for use by the room
summary extension to /sync. Useful when lazy loading room members.
Args:
room_id: The room ID to query
@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_users_who_share_room_with_user(
self, user_id: str, cache_context: _CacheContext
) -> Set[str]:
"""Returns the set of users who share a room with `user_id`
"""
"""Returns the set of users who share a room with `user_id`"""
room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
)
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_membership_from_event_ids(
self, member_event_ids: Iterable[str]
) -> List[dict]:
"""Get user_id and membership of a set of event IDs.
"""
"""Get user_id and membership of a set of event IDs."""
return await self.db_pool.simple_select_many_batch(
table="room_memberships",

View file

@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute(
"UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
"UPDATE remote_media_cache SET last_access_ts = ?",
(int(time.time() * 1000),),
)

View file

@ -52,8 +52,7 @@ class _GetStateGroupDelta(
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
"""
"""The parts of StateGroupStore that can be called from workers."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
num_args=1,
)
async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
"""Returns mapping event_id -> state_group"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
columns=["state_group"],
)
self.db_pool.updates.register_background_update_handler(
self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
self.DELETE_CURRENT_STATE_UPDATE_NAME,
self._background_remove_left_rooms,
)
async def _background_remove_left_rooms(self, progress, batch_size):
@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
"""Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a

View file

@ -1001,7 +1001,9 @@ class StatsStore(StateDeltasStore):
ORDER BY {order_by_column} {order}
LIMIT ? OFFSET ?
""".format(
sql_base=sql_base, order_by_column=order_by_column, order=order,
sql_base=sql_base,
order_by_column=order_by_column,
order=order,
)
args += [limit, start]

View file

@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
AND e.stream_ordering > ? AND e.stream_ordering <= ?
ORDER BY e.stream_ordering ASC
"""
txn.execute(sql, (user_id, min_from_id, max_to_id,))
txn.execute(
sql,
(
user_id,
min_from_id,
max_to_id,
),
)
rows = [
_EventDictReturn(event_id, None, stream_ordering)
@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
return "t%d-%d" % (topo, token)
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
self,
txn: LoggingTransaction,
event_id: str,
allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn=txn,
@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
)
async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
"""Get the persisted position for an event
"""
"""Get the persisted position for an event"""
row = await self.db_pool.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
) -> Tuple[int, List[EventBase]]:
"""Get all new events
Returns all events with from_id < stream_ordering <= current_id.
Returns all events with from_id < stream_ordering <= current_id.
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
Returns:
A tuple of (next_id, events), where `next_id` is the next value to
pass as `from_id` (it will either be the stream_ordering of the
last returned event, or, if fewer than `limit` events were found,
the `current_id`).
"""
Returns:
A tuple of (next_id, events), where `next_id` is the next value to
pass as `from_id` (it will either be the stream_ordering of the
last returned event, or, if fewer than `limit` events were found,
the `current_id`).
"""
def get_all_new_events_stream_txn(txn):
sql = (
@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
@cached()
async def get_id_for_instance(self, instance_name: str) -> int:
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance.
"""
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
def _get_id_for_instance_txn(txn):
instance_id = self.db_pool.simple_select_one_onecol_txn(

View file

@ -64,8 +64,7 @@ class TransactionWorkerStore(SQLBaseStore):
class TransactionStore(TransactionWorkerStore):
"""A collection of queries for handling PDUs.
"""
"""A collection of queries for handling PDUs."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@ -299,7 +298,10 @@ class TransactionStore(TransactionWorkerStore):
)
async def store_destination_rooms_entries(
self, destinations: Iterable[str], room_id: str, stream_ordering: int,
self,
destinations: Iterable[str],
room_id: str,
stream_ordering: int,
) -> None:
"""
Updates or creates `destination_rooms` entries in batch for a single event.
@ -394,7 +396,9 @@ class TransactionStore(TransactionWorkerStore):
)
async def get_catch_up_room_event_ids(
self, destination: str, last_successful_stream_ordering: int,
self,
destination: str,
last_successful_stream_ordering: int,
) -> List[str]:
"""
Returns at most 50 event IDs and their corresponding stream_orderings
@ -418,7 +422,9 @@ class TransactionStore(TransactionWorkerStore):
@staticmethod
def _get_catch_up_room_event_ids_txn(
txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
txn: LoggingTransaction,
destination: str,
last_successful_stream_ordering: int,
) -> List[str]:
q = """
SELECT event_id FROM destination_rooms
@ -429,7 +435,8 @@ class TransactionStore(TransactionWorkerStore):
LIMIT 50
"""
txn.execute(
q, (destination, last_successful_stream_ordering),
q,
(destination, last_successful_stream_ordering),
)
event_ids = [row[0] for row in txn]
return event_ids

View file

@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore):
"""
async def create_ui_auth_session(
self, clientdict: JsonDict, uri: str, method: str, description: str,
self,
clientdict: JsonDict,
uri: str,
method: str,
description: str,
) -> UIAuthSessionData:
"""
Creates a new user interactive authentication session.
@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore):
return UIAuthSessionData(session_id, **result)
async def mark_ui_auth_stage_complete(
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
self,
session_id: str,
stage_type: str,
result: Union[str, bool, JsonDict],
):
"""
Mark a session stage as completed.
@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore):
return serverdict.get(key, default)
async def add_user_agent_ip_to_ui_auth_session(
self, session_id: str, user_agent: str, ip: str,
self,
session_id: str,
user_agent: str,
ip: str,
):
"""Add the given user agent / IP to the tracking table
"""
"""Add the given user agent / IP to the tracking table"""
await self.db_pool.simple_upsert(
table="ui_auth_sessions_ips",
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore):
)
async def get_user_agents_ips_to_ui_auth_session(
self, session_id: str,
self,
session_id: str,
) -> List[Tuple[str, str]]:
"""Get the given user agents / IPs used during the ui auth process

View file

@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return len(users_to_work_on)
async def is_room_world_readable_or_publicly_joinable(self, room_id):
"""Check if the room is either world_readable or publically joinable
"""
"""Check if the room is either world_readable or publically joinable"""
# Create a state filter that only queries join and history state event
types_to_filter = (
@ -516,8 +515,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
"""Delete the entire user directory"""
def _delete_all_from_user_dir_txn(txn):
txn.execute("DELETE FROM user_directory")

View file

@ -48,8 +48,7 @@ class _GetStateGroupDelta(
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""A data store for fetching/storing state groups.
"""
"""A data store for fetching/storing state groups."""
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
50000,
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*", 500000,
"*stateGroupMembersCache*",
500000,
)
def get_max_state_group_txn(txn: Cursor):

View file

@ -94,14 +94,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
@property
@abc.abstractmethod
def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'
"""
"""Gets a string giving the server version. For example: '3.22.0'"""
...
@abc.abstractmethod
def in_transaction(self, conn: Connection) -> bool:
"""Whether the connection is currently in a transaction.
"""
"""Whether the connection is currently in a transaction."""
...
@abc.abstractmethod

View file

@ -138,8 +138,7 @@ class PostgresEngine(BaseDatabaseEngine):
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
"""Do we support using `a = ANY(?)` and passing a list"""
return True
def is_deadlock(self, error):

View file

@ -29,7 +29,10 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
super().__init__(database_module, database_config)
database = database_config.get("args", {}).get("database")
self._is_in_memory = database in (None, ":memory:",)
self._is_in_memory = database in (
None,
":memory:",
)
if platform.python_implementation() == "PyPy":
# pypy's sqlite3 module doesn't handle bytearrays, convert them
@ -63,8 +66,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
"""Do we support using `a = ANY(?)` and passing a list"""
return False
def check_database(self, db_conn, allow_outdated_version: bool = False):

View file

@ -411,8 +411,8 @@ class EventsPersistenceStorage:
)
for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id
latest_event_ids = (
await self.main_store.get_latest_event_ids_in_room(room_id)
)
new_latest_event_ids = await self._calculate_new_extremities(
room_id, ev_ctx_rm, latest_event_ids
@ -889,7 +889,8 @@ class EventsPersistenceStorage:
continue
logger.debug(
"Not dropping as too new and not in new_senders: %s", new_senders,
"Not dropping as too new and not in new_senders: %s",
new_senders,
)
return new_latest_event_ids
@ -1004,7 +1005,10 @@ class EventsPersistenceStorage:
remote_event_ids = [
event_id
for (typ, state_key,), event_id in current_state.items()
for (
typ,
state_key,
), event_id in current_state.items()
if typ == EventTypes.Member and not self.is_mine_id(state_key)
]
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)

View file

@ -425,7 +425,10 @@ def _upgrade_existing_database(
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
"Found multiple delta files with the same name in v%d: %s"
% (v, duplicates,)
% (
v,
duplicates,
)
)
# We sort to ensure that we apply the delta files in a consistent
@ -532,7 +535,8 @@ def _apply_module_schema_files(
names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
"SELECT file FROM applied_module_schemas WHERE module_name = ?",
(modname,),
)
applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:

View file

@ -26,15 +26,13 @@ logger = logging.getLogger(__name__)
class PurgeEventsStorage:
"""High level interface for purging rooms and event history.
"""
"""High level interface for purging rooms and event history."""
def __init__(self, hs: "HomeServer", stores: Databases):
self.stores = stores
async def purge_room(self, room_id: str) -> None:
"""Deletes all record of a room
"""
"""Deletes all record of a room"""
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)

View file

@ -340,8 +340,7 @@ class StateFilter:
class StateGroupStorage:
"""High level interface to fetching state for event.
"""
"""High level interface to fetching state for event."""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self.stores = stores
@ -400,7 +399,7 @@ class StateGroupStorage:
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
"""Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.

View file

@ -277,7 +277,9 @@ class MultiWriterIdGenerator:
self._load_current_ids(db_conn, tables)
def _load_current_ids(
self, db_conn, tables: List[Tuple[str, str, str]],
self,
db_conn,
tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@ -364,7 +366,10 @@ class MultiWriterIdGenerator:
rows.sort()
with self._lock:
for (instance, stream_id,) in rows:
for (
instance,
stream_id,
) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
@ -481,8 +486,7 @@ class MultiWriterIdGenerator:
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
"""Returns the position of the given writer."""
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
@ -581,8 +585,7 @@ class MultiWriterIdGenerator:
break
def _update_stream_positions_table_txn(self, txn: Cursor):
"""Update the `stream_positions` table with newly persisted position.
"""
"""Update the `stream_positions` table with newly persisted position."""
if not self._writers:
return
@ -622,8 +625,7 @@ class _AsyncCtxManagerWrapper:
@attr.s(slots=True)
class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator
"""
"""Async context manager returned by MultiWriterIdGenerator"""
id_gen = attr.ib(type=MultiWriterIdGenerator)
multiple_ids = attr.ib(type=Optional[int], default=None)

View file

@ -124,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
stream_name: Optional[str] = None,
positive: bool = True,
):
"""See SequenceGenerator.check_consistency for docstring.
"""
"""See SequenceGenerator.check_consistency for docstring."""
txn = db_conn.cursor(txn_name="sequence.check_consistency")