mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-08-01 15:36:04 -04:00
Update black, and run auto formatting over the codebase (#9381)
- Update black version to the latest
- Run black auto formatting over the codebase
- Run autoformatting according to [`docs/code_style.md
`](80d6dc9783/docs/code_style.md
)
- Update `code_style.md` docs around installing black to use the correct version
This commit is contained in:
parent
5636e597c3
commit
0a00b7ff14
271 changed files with 2802 additions and 1713 deletions
|
@ -43,8 +43,7 @@ __all__ = ["Databases", "DataStore"]
|
|||
|
||||
|
||||
class Storage:
|
||||
"""The high level interfaces for talking to various storage layers.
|
||||
"""
|
||||
"""The high level interfaces for talking to various storage layers."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
# We include the main data store here mainly so that we don't have to
|
||||
|
|
|
@ -77,7 +77,7 @@ class BackgroundUpdatePerformance:
|
|||
|
||||
|
||||
class BackgroundUpdater:
|
||||
""" Background updates are updates to the database that run in the
|
||||
"""Background updates are updates to the database that run in the
|
||||
background. Each update processes a batch of data at once. We attempt to
|
||||
limit the impact of each update by monitoring how long each batch takes to
|
||||
process and autotuning the batch size.
|
||||
|
@ -158,8 +158,7 @@ class BackgroundUpdater:
|
|||
return False
|
||||
|
||||
async def has_completed_background_update(self, update_name: str) -> bool:
|
||||
"""Check if the given background update has finished running.
|
||||
"""
|
||||
"""Check if the given background update has finished running."""
|
||||
if self._all_done:
|
||||
return True
|
||||
|
||||
|
@ -198,7 +197,8 @@ class BackgroundUpdater:
|
|||
|
||||
if not self._current_background_update:
|
||||
all_pending_updates = await self.db_pool.runInteraction(
|
||||
"background_updates", get_background_updates_txn,
|
||||
"background_updates",
|
||||
get_background_updates_txn,
|
||||
)
|
||||
if not all_pending_updates:
|
||||
# no work left to do
|
||||
|
|
|
@ -85,8 +85,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
|||
def make_pool(
|
||||
reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
|
||||
) -> adbapi.ConnectionPool:
|
||||
"""Get the connection pool for the database.
|
||||
"""
|
||||
"""Get the connection pool for the database."""
|
||||
|
||||
# By default enable `cp_reconnect`. We need to fiddle with db_args in case
|
||||
# someone has explicitly set `cp_reconnect`.
|
||||
|
@ -432,8 +431,7 @@ class DatabasePool:
|
|||
)
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Is the database pool currently running
|
||||
"""
|
||||
"""Is the database pool currently running"""
|
||||
return self._db_pool.running
|
||||
|
||||
async def _check_safe_to_upsert(self) -> None:
|
||||
|
@ -546,7 +544,11 @@ class DatabasePool:
|
|||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
transaction_logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name,
|
||||
e,
|
||||
i,
|
||||
N,
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
|
@ -567,7 +569,9 @@ class DatabasePool:
|
|||
conn.rollback()
|
||||
except self.engine.module.Error as e1:
|
||||
transaction_logger.warning(
|
||||
"[TXN EROLL] {%s} %s", name, e1,
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name,
|
||||
e1,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
|
@ -1406,7 +1410,10 @@ class DatabasePool:
|
|||
|
||||
@staticmethod
|
||||
def simple_select_onecol_txn(
|
||||
txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: str,
|
||||
) -> List[Any]:
|
||||
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
|
||||
|
||||
|
@ -1716,7 +1723,11 @@ class DatabasePool:
|
|||
desc: description of the transaction, for logging and metrics
|
||||
"""
|
||||
await self.runInteraction(
|
||||
desc, self.simple_delete_one_txn, table, keyvalues, db_autocommit=True,
|
||||
desc,
|
||||
self.simple_delete_one_txn,
|
||||
table,
|
||||
keyvalues,
|
||||
db_autocommit=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -56,7 +56,10 @@ class Databases:
|
|||
database_config.databases,
|
||||
)
|
||||
prepare_database(
|
||||
db_conn, engine, hs.config, databases=database_config.databases,
|
||||
db_conn,
|
||||
engine,
|
||||
hs.config,
|
||||
databases=database_config.databases,
|
||||
)
|
||||
|
||||
database = DatabasePool(hs, database_config, engine)
|
||||
|
|
|
@ -73,8 +73,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
|
|||
return self.services_cache
|
||||
|
||||
def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
|
||||
"""Check if the user is one associated with an app service (exclusively)
|
||||
"""
|
||||
"""Check if the user is one associated with an app service (exclusively)"""
|
||||
if self.exclusive_user_regex:
|
||||
return bool(self.exclusive_user_regex.match(user_id))
|
||||
else:
|
||||
|
|
|
@ -280,8 +280,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
|||
return batch_size
|
||||
|
||||
async def _devices_last_seen_update(self, progress, batch_size):
|
||||
"""Background update to insert last seen info into devices table
|
||||
"""
|
||||
"""Background update to insert last seen info into devices table"""
|
||||
|
||||
last_user_id = progress.get("last_user_id", "")
|
||||
last_device_id = progress.get("last_device_id", "")
|
||||
|
@ -363,8 +362,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
|||
|
||||
@wrap_as_background_process("prune_old_user_ips")
|
||||
async def _prune_old_user_ips(self):
|
||||
"""Removes entries in user IPs older than the configured period.
|
||||
"""
|
||||
"""Removes entries in user IPs older than the configured period."""
|
||||
|
||||
if self.user_ips_max_age is None:
|
||||
# Nothing to do
|
||||
|
@ -565,7 +563,11 @@ class ClientIpStore(ClientIpWorkerStore):
|
|||
results = {}
|
||||
|
||||
for key in self._batch_row_update:
|
||||
uid, access_token, ip, = key
|
||||
(
|
||||
uid,
|
||||
access_token,
|
||||
ip,
|
||||
) = key
|
||||
if uid == user_id:
|
||||
user_agent, _, last_seen = self._batch_row_update[key]
|
||||
results[(access_token, ip)] = (user_agent, last_seen)
|
||||
|
|
|
@ -315,7 +315,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
# make sure we go through the devices in stream order
|
||||
device_ids = sorted(
|
||||
user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
|
||||
user_devices.keys(),
|
||||
key=lambda i: query_map[(user_id, i)][0],
|
||||
)
|
||||
|
||||
for device_id in device_ids:
|
||||
|
@ -366,8 +367,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
async def mark_as_sent_devices_by_remote(
|
||||
self, destination: str, stream_id: int
|
||||
) -> None:
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
"""Mark that updates have successfully been sent to the destination."""
|
||||
await self.db_pool.runInteraction(
|
||||
"mark_as_sent_devices_by_remote",
|
||||
self._mark_as_sent_devices_by_remote_txn,
|
||||
|
@ -681,7 +681,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return results
|
||||
|
||||
async def get_user_ids_requiring_device_list_resync(
|
||||
self, user_ids: Optional[Collection[str]] = None,
|
||||
self,
|
||||
user_ids: Optional[Collection[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""Given a list of remote users return the list of users that we
|
||||
should resync the device lists for. If None is given instead of a list,
|
||||
|
@ -721,8 +722,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
|
||||
"""Mark that we no longer track device lists for remote user.
|
||||
"""
|
||||
"""Mark that we no longer track device lists for remote user."""
|
||||
|
||||
def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
|
@ -902,7 +902,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
logger.info("Pruned %d device list outbound pokes", count)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"_prune_old_outbound_device_pokes", _prune_txn,
|
||||
"_prune_old_outbound_device_pokes",
|
||||
_prune_txn,
|
||||
)
|
||||
|
||||
|
||||
|
@ -943,7 +944,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
|
||||
# clear out duplicate device list outbound pokes
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
|
||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
|
||||
self._remove_duplicate_outbound_pokes,
|
||||
)
|
||||
|
||||
# a pair of background updates that were added during the 1.14 release cycle,
|
||||
|
@ -1004,17 +1006,23 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
row = None
|
||||
for row in rows:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
|
||||
txn,
|
||||
"device_lists_outbound_pokes",
|
||||
{x: row[x] for x in KEY_COLS},
|
||||
)
|
||||
|
||||
row["sent"] = False
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn, "device_lists_outbound_pokes", row,
|
||||
txn,
|
||||
"device_lists_outbound_pokes",
|
||||
row,
|
||||
)
|
||||
|
||||
if row:
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
|
||||
txn,
|
||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
|
||||
{"last_row": row},
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
@ -1286,7 +1294,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
# we've done a full resync, so we remove the entry that says we need
|
||||
# to resync
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
|
||||
txn,
|
||||
table="device_lists_remote_resync",
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
|
||||
async def add_device_change_to_streams(
|
||||
|
@ -1336,7 +1346,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
stream_ids: List[str],
|
||||
):
|
||||
txn.call_after(
|
||||
self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
|
||||
self._device_list_stream_cache.entity_has_changed,
|
||||
user_id,
|
||||
stream_ids[-1],
|
||||
)
|
||||
|
||||
min_stream_id = stream_ids[0]
|
||||
|
|
|
@ -85,7 +85,7 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||
servers: Iterable[str],
|
||||
creator: Optional[str] = None,
|
||||
) -> None:
|
||||
""" Creates an association between a room alias and room_id/servers
|
||||
"""Creates an association between a room alias and room_id/servers
|
||||
|
||||
Args:
|
||||
room_alias: The alias to create.
|
||||
|
@ -160,7 +160,10 @@ class DirectoryStore(DirectoryWorkerStore):
|
|||
return room_id
|
||||
|
||||
async def update_aliases_for_room(
|
||||
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
|
||||
self,
|
||||
old_room_id: str,
|
||||
new_room_id: str,
|
||||
creator: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Repoint all of the aliases for a given room, to a different room.
|
||||
|
||||
|
|
|
@ -361,7 +361,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
async def count_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Dict[str, int]:
|
||||
""" Count the number of one time keys the server has for a device
|
||||
"""Count the number of one time keys the server has for a device
|
||||
Returns:
|
||||
A mapping from algorithm to number of keys for that algorithm.
|
||||
"""
|
||||
|
@ -494,7 +494,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
)
|
||||
|
||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||
self, txn: Connection, user_ids: List[str],
|
||||
self,
|
||||
txn: Connection,
|
||||
user_ids: List[str],
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
|
@ -556,7 +558,10 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
|
|||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_txn(
|
||||
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
|
||||
self,
|
||||
txn: Connection,
|
||||
keys: Dict[str, Dict[str, dict]],
|
||||
from_user_id: str,
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing signatures made by a user on a set of keys.
|
||||
|
||||
|
|
|
@ -71,7 +71,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
return await self.get_events_as_list(event_ids)
|
||||
|
||||
async def get_auth_chain_ids(
|
||||
self, event_ids: Collection[str], include_given: bool = False,
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
include_given: bool = False,
|
||||
) -> List[str]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
|
@ -273,7 +275,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
# origin chain.
|
||||
if origin_sequence_number <= chains.get(origin_chain_id, 0):
|
||||
chains[target_chain_id] = max(
|
||||
target_sequence_number, chains.get(target_chain_id, 0),
|
||||
target_sequence_number,
|
||||
chains.get(target_chain_id, 0),
|
||||
)
|
||||
|
||||
seen_chains.add(target_chain_id)
|
||||
|
@ -632,8 +635,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
async def get_min_depth(self, room_id: str) -> int:
|
||||
"""For the given room, get the minimum depth we have seen for it.
|
||||
"""
|
||||
"""For the given room, get the minimum depth we have seen for it."""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||
)
|
||||
|
@ -858,12 +860,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"_delete_old_forward_extrem_cache", _delete_old_forward_extrem_cache_txn,
|
||||
"_delete_old_forward_extrem_cache",
|
||||
_delete_old_forward_extrem_cache_txn,
|
||||
)
|
||||
|
||||
|
||||
class EventFederationStore(EventFederationWorkerStore):
|
||||
""" Responsible for storing and serving up the various graphs associated
|
||||
"""Responsible for storing and serving up the various graphs associated
|
||||
with an event. Including the main event graph and the auth chains for an
|
||||
event.
|
||||
|
||||
|
|
|
@ -54,8 +54,7 @@ def _serialize_action(actions, is_highlight):
|
|||
|
||||
|
||||
def _deserialize_action(actions, is_highlight):
|
||||
"""Custom deserializer for actions. This allows us to "compress" common actions
|
||||
"""
|
||||
"""Custom deserializer for actions. This allows us to "compress" common actions"""
|
||||
if actions:
|
||||
return db_to_json(actions)
|
||||
|
||||
|
@ -91,7 +90,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
|
||||
@cached(num_args=3, tree=True, max_entries=5000)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id: str, user_id: str, last_read_event_id: Optional[str],
|
||||
self,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> Dict[str, int]:
|
||||
"""Get the notification count, the highlight count and the unread message count
|
||||
for a given user in a given room after the given read receipt.
|
||||
|
@ -120,13 +122,19 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
def _get_unread_counts_by_receipt_txn(
|
||||
self, txn, room_id, user_id, last_read_event_id,
|
||||
self,
|
||||
txn,
|
||||
room_id,
|
||||
user_id,
|
||||
last_read_event_id,
|
||||
):
|
||||
stream_ordering = None
|
||||
|
||||
if last_read_event_id is not None:
|
||||
stream_ordering = self.get_stream_id_for_event_txn(
|
||||
txn, last_read_event_id, allow_none=True,
|
||||
txn,
|
||||
last_read_event_id,
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if stream_ordering is None:
|
||||
|
|
|
@ -399,7 +399,9 @@ class PersistEventsStore:
|
|||
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
|
||||
|
||||
def _persist_event_auth_chain_txn(
|
||||
self, txn: LoggingTransaction, events: List[EventBase],
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events: List[EventBase],
|
||||
) -> None:
|
||||
|
||||
# We only care about state events, so this if there are no state events.
|
||||
|
@ -470,7 +472,11 @@ class PersistEventsStore:
|
|||
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
|
||||
|
||||
self._add_chain_cover_index(
|
||||
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
|
||||
txn,
|
||||
self.db_pool,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -517,7 +523,10 @@ class PersistEventsStore:
|
|||
# simple_select_many, but this case happens rarely and almost always
|
||||
# with a single row.)
|
||||
auth_events = db_pool.simple_select_onecol_txn(
|
||||
txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
|
||||
txn,
|
||||
"event_auth",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="auth_id",
|
||||
)
|
||||
|
||||
events_to_calc_chain_id_for.add(event_id)
|
||||
|
@ -550,7 +559,9 @@ class PersistEventsStore:
|
|||
WHERE
|
||||
"""
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "event_id", missing_auth_chains,
|
||||
txn.database_engine,
|
||||
"event_id",
|
||||
missing_auth_chains,
|
||||
)
|
||||
txn.execute(sql + clause, args)
|
||||
|
||||
|
@ -704,7 +715,8 @@ class PersistEventsStore:
|
|||
if chain_map[a_id][0] != chain_id
|
||||
}
|
||||
for start_auth_id, end_auth_id in itertools.permutations(
|
||||
event_to_auth_chain.get(event_id, []), r=2,
|
||||
event_to_auth_chain.get(event_id, []),
|
||||
r=2,
|
||||
):
|
||||
if chain_links.exists_path_from(
|
||||
chain_map[start_auth_id], chain_map[end_auth_id]
|
||||
|
@ -888,8 +900,7 @@ class PersistEventsStore:
|
|||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
):
|
||||
"""Persist the mapping from transaction IDs to event IDs (if defined).
|
||||
"""
|
||||
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
|
||||
|
||||
to_insert = []
|
||||
for event, _ in events_and_contexts:
|
||||
|
@ -909,7 +920,9 @@ class PersistEventsStore:
|
|||
|
||||
if to_insert:
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn, table="event_txn_id", values=to_insert,
|
||||
txn,
|
||||
table="event_txn_id",
|
||||
values=to_insert,
|
||||
)
|
||||
|
||||
def _update_current_state_txn(
|
||||
|
@ -941,7 +954,9 @@ class PersistEventsStore:
|
|||
txn.execute(sql, (stream_id, self._instance_name, room_id))
|
||||
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="current_state_events", keyvalues={"room_id": room_id},
|
||||
txn,
|
||||
table="current_state_events",
|
||||
keyvalues={"room_id": room_id},
|
||||
)
|
||||
else:
|
||||
# We're still in the room, so we update the current state as normal.
|
||||
|
@ -1608,8 +1623,7 @@ class PersistEventsStore:
|
|||
)
|
||||
|
||||
def _store_room_members_txn(self, txn, events, backfilled):
|
||||
"""Store a room member in the database.
|
||||
"""
|
||||
"""Store a room member in the database."""
|
||||
|
||||
def str_or_none(val: Any) -> Optional[str]:
|
||||
return val if isinstance(val, str) else None
|
||||
|
@ -2001,8 +2015,7 @@ class PersistEventsStore:
|
|||
|
||||
@attr.s(slots=True)
|
||||
class _LinkMap:
|
||||
"""A helper type for tracking links between chains.
|
||||
"""
|
||||
"""A helper type for tracking links between chains."""
|
||||
|
||||
# Stores the set of links as nested maps: source chain ID -> target chain ID
|
||||
# -> source sequence number -> target sequence number.
|
||||
|
@ -2108,7 +2121,9 @@ class _LinkMap:
|
|||
yield (src_chain, src_seq, target_chain, target_seq)
|
||||
|
||||
def exists_path_from(
|
||||
self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
|
||||
self,
|
||||
src_tuple: Tuple[int, int],
|
||||
target_tuple: Tuple[int, int],
|
||||
) -> bool:
|
||||
"""Checks if there is a path between the source chain ID/sequence and
|
||||
target chain ID/sequence.
|
||||
|
|
|
@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class _CalculateChainCover:
|
||||
"""Return value for _calculate_chain_cover_txn.
|
||||
"""
|
||||
"""Return value for _calculate_chain_cover_txn."""
|
||||
|
||||
# The last room_id/depth/stream processed.
|
||||
room_id = attr.ib(type=str)
|
||||
|
@ -127,11 +126,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"rejected_events_metadata", self._rejected_events_metadata,
|
||||
"rejected_events_metadata",
|
||||
self._rejected_events_metadata,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"chain_cover", self._chain_cover_index,
|
||||
"chain_cover",
|
||||
self._chain_cover_index,
|
||||
)
|
||||
|
||||
async def _background_reindex_fields_sender(self, progress, batch_size):
|
||||
|
@ -462,8 +463,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
return num_handled
|
||||
|
||||
async def _redactions_received_ts(self, progress, batch_size):
|
||||
"""Handles filling out the `received_ts` column in redactions.
|
||||
"""
|
||||
"""Handles filling out the `received_ts` column in redactions."""
|
||||
last_event_id = progress.get("last_event_id", "")
|
||||
|
||||
def _redactions_received_ts_txn(txn):
|
||||
|
@ -518,8 +518,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
return count
|
||||
|
||||
async def _event_fix_redactions_bytes(self, progress, batch_size):
|
||||
"""Undoes hex encoded censored redacted event JSON.
|
||||
"""
|
||||
"""Undoes hex encoded censored redacted event JSON."""
|
||||
|
||||
def _event_fix_redactions_bytes_txn(txn):
|
||||
# This update is quite fast due to new index.
|
||||
|
@ -642,7 +641,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_event_id, batch_size,))
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
last_event_id,
|
||||
batch_size,
|
||||
),
|
||||
)
|
||||
|
||||
return [(row[0], row[1], db_to_json(row[2]), row[3], row[4]) for row in txn] # type: ignore
|
||||
|
||||
|
@ -910,7 +915,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
|||
# Annoyingly we need to gut wrench into the persit event store so that
|
||||
# we can reuse the function to calculate the chain cover for rooms.
|
||||
PersistEventsStore._add_chain_cover_index(
|
||||
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
|
||||
txn,
|
||||
self.db_pool,
|
||||
event_to_room_id,
|
||||
event_to_types,
|
||||
event_to_auth_chain,
|
||||
)
|
||||
|
||||
return _CalculateChainCover(
|
||||
|
|
|
@ -71,7 +71,9 @@ class EventForwardExtremitiesStore(SQLBaseStore):
|
|||
if txn.rowcount > 0:
|
||||
# Invalidate the cache
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_latest_event_ids_in_room, (room_id,),
|
||||
txn,
|
||||
self.get_latest_event_ids_in_room,
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
return txn.rowcount
|
||||
|
@ -97,5 +99,6 @@ class EventForwardExtremitiesStore(SQLBaseStore):
|
|||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
|
||||
"get_forward_extremities_for_room",
|
||||
get_forward_extremities_for_room_txn,
|
||||
)
|
||||
|
|
|
@ -120,7 +120,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
# SQLite).
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
|
@ -140,7 +142,8 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
if hs.config.run_background_tasks:
|
||||
# We periodically clean out old transaction ID mappings
|
||||
self._clock.looping_call(
|
||||
self._cleanup_old_transaction_ids, 5 * 60 * 1000,
|
||||
self._cleanup_old_transaction_ids,
|
||||
5 * 60 * 1000,
|
||||
)
|
||||
|
||||
self._get_event_cache = LruCache(
|
||||
|
@ -1325,8 +1328,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
return rows, to_token, True
|
||||
|
||||
async def is_event_after(self, event_id1, event_id2):
|
||||
"""Returns True if event_id1 is after event_id2 in the stream
|
||||
"""
|
||||
"""Returns True if event_id1 is after event_id2 in the stream"""
|
||||
to_1, so_1 = await self.get_event_ordering(event_id1)
|
||||
to_2, so_2 = await self.get_event_ordering(event_id2)
|
||||
return (to_1, so_1) > (to_2, so_2)
|
||||
|
@ -1428,8 +1430,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
|
||||
@wrap_as_background_process("_cleanup_old_transaction_ids")
|
||||
async def _cleanup_old_transaction_ids(self):
|
||||
"""Cleans out transaction id mappings older than 24hrs.
|
||||
"""
|
||||
"""Cleans out transaction id mappings older than 24hrs."""
|
||||
|
||||
def _cleanup_old_transaction_ids_txn(txn):
|
||||
sql = """
|
||||
|
@ -1440,5 +1441,6 @@ class EventsWorkerStore(SQLBaseStore):
|
|||
txn.execute(sql, (one_day_ago,))
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"_cleanup_old_transaction_ids", _cleanup_old_transaction_ids_txn,
|
||||
"_cleanup_old_transaction_ids",
|
||||
_cleanup_old_transaction_ids_txn,
|
||||
)
|
||||
|
|
|
@ -123,7 +123,9 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def get_rooms_for_summary_by_category(
|
||||
self, group_id: str, include_private: bool = False,
|
||||
self,
|
||||
group_id: str,
|
||||
include_private: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""Get the rooms and categories that should be included in a summary request
|
||||
|
||||
|
@ -368,8 +370,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
async def is_user_invited_to_local_group(
|
||||
self, group_id: str, user_id: str
|
||||
) -> Optional[bool]:
|
||||
"""Has the group server invited a user?
|
||||
"""
|
||||
"""Has the group server invited a user?"""
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
|
@ -427,8 +428,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
|
||||
"""Get all groups a user is publicising
|
||||
"""
|
||||
"""Get all groups a user is publicising"""
|
||||
return await self.db_pool.simple_select_onecol(
|
||||
table="local_group_membership",
|
||||
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
|
||||
|
@ -437,8 +437,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def get_attestations_need_renewals(self, valid_until_ms):
|
||||
"""Get all attestations that need to be renewed until givent time
|
||||
"""
|
||||
"""Get all attestations that need to be renewed until givent time"""
|
||||
|
||||
def _get_attestations_need_renewals_txn(txn):
|
||||
sql = """
|
||||
|
@ -781,8 +780,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
profile: Optional[JsonDict],
|
||||
is_public: Optional[bool],
|
||||
) -> None:
|
||||
"""Add/update room category for group
|
||||
"""
|
||||
"""Add/update room category for group"""
|
||||
insertion_values = {}
|
||||
update_values = {"category_id": category_id} # This cannot be empty
|
||||
|
||||
|
@ -818,8 +816,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
profile: Optional[JsonDict],
|
||||
is_public: Optional[bool],
|
||||
) -> None:
|
||||
"""Add/remove user role
|
||||
"""
|
||||
"""Add/remove user role"""
|
||||
insertion_values = {}
|
||||
update_values = {"role_id": role_id} # This cannot be empty
|
||||
|
||||
|
@ -1012,8 +1009,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
)
|
||||
|
||||
async def add_group_invite(self, group_id: str, user_id: str) -> None:
|
||||
"""Record that the group server has invited a user
|
||||
"""
|
||||
"""Record that the group server has invited a user"""
|
||||
await self.db_pool.simple_insert(
|
||||
table="group_invites",
|
||||
values={"group_id": group_id, "user_id": user_id},
|
||||
|
@ -1156,8 +1152,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
async def update_group_publicity(
|
||||
self, group_id: str, user_id: str, publicise: bool
|
||||
) -> None:
|
||||
"""Update whether the user is publicising their membership of the group
|
||||
"""
|
||||
"""Update whether the user is publicising their membership of the group"""
|
||||
await self.db_pool.simple_update_one(
|
||||
table="local_group_membership",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
|
@ -1300,8 +1295,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
async def update_attestation_renewal(
|
||||
self, group_id: str, user_id: str, attestation: dict
|
||||
) -> None:
|
||||
"""Update an attestation that we have renewed
|
||||
"""
|
||||
"""Update an attestation that we have renewed"""
|
||||
await self.db_pool.simple_update_one(
|
||||
table="group_attestations_renewals",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
|
@ -1312,8 +1306,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
|||
async def update_remote_attestion(
|
||||
self, group_id: str, user_id: str, attestation: dict
|
||||
) -> None:
|
||||
"""Update an attestation that a remote has renewed
|
||||
"""
|
||||
"""Update an attestation that a remote has renewed"""
|
||||
await self.db_pool.simple_update_one(
|
||||
table="group_attestations_remote",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
|
|
|
@ -33,8 +33,7 @@ db_binary_type = memoryview
|
|||
|
||||
|
||||
class KeyStore(SQLBaseStore):
|
||||
"""Persistence for signature verification keys
|
||||
"""
|
||||
"""Persistence for signature verification keys"""
|
||||
|
||||
@cached()
|
||||
def _get_server_verify_key(self, server_name_and_key_id):
|
||||
|
|
|
@ -169,7 +169,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
async def get_local_media_before(
|
||||
self, before_ts: int, size_gt: int, keep_profiles: bool,
|
||||
self,
|
||||
before_ts: int,
|
||||
size_gt: int,
|
||||
keep_profiles: bool,
|
||||
) -> List[str]:
|
||||
|
||||
# to find files that have never been accessed (last_access_ts IS NULL)
|
||||
|
@ -454,10 +457,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
async def get_remote_media_thumbnail(
|
||||
self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
|
||||
self,
|
||||
origin: str,
|
||||
media_id: str,
|
||||
t_width: int,
|
||||
t_height: int,
|
||||
t_type: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch the thumbnail info of given width, height and type.
|
||||
"""
|
||||
"""Fetch the thumbnail info of given width, height and type."""
|
||||
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="remote_media_cache_thumbnails",
|
||||
|
|
|
@ -130,7 +130,9 @@ class PresenceStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
|
||||
cached_method_name="_get_presence_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def get_presence_for_users(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
|
|
|
@ -118,8 +118,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def is_subscribed_remote_profile_for_user(self, user_id):
|
||||
"""Check whether we are interested in a remote user's profile.
|
||||
"""
|
||||
"""Check whether we are interested in a remote user's profile."""
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
|
@ -145,8 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
|||
async def get_remote_profile_cache_entries_that_expire(
|
||||
self, last_checked: int
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Get all users who haven't been checked since `last_checked`
|
||||
"""
|
||||
"""Get all users who haven't been checked since `last_checked`"""
|
||||
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||
sql = """
|
||||
|
|
|
@ -168,7 +168,9 @@ class PushRulesWorkerStore(
|
|||
)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
|
||||
cached_method_name="get_push_rules_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def bulk_get_push_rules(self, user_ids):
|
||||
if not user_ids:
|
||||
|
@ -195,7 +197,9 @@ class PushRulesWorkerStore(
|
|||
use_new_defaults = user_id in self._users_new_default_push_rules
|
||||
|
||||
results[user_id] = _load_rules(
|
||||
rules, enabled_map_by_user.get(user_id, {}), use_new_defaults,
|
||||
rules,
|
||||
enabled_map_by_user.get(user_id, {}),
|
||||
use_new_defaults,
|
||||
)
|
||||
|
||||
return results
|
||||
|
|
|
@ -179,7 +179,9 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
||||
cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
async def get_if_users_have_pushers(
|
||||
self, user_ids: Iterable[str]
|
||||
|
@ -263,7 +265,8 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
params_by_room = {}
|
||||
for row in res:
|
||||
params_by_room[row["room_id"]] = ThrottleParams(
|
||||
row["last_sent_ts"], row["throttle_ms"],
|
||||
row["last_sent_ts"],
|
||||
row["throttle_ms"],
|
||||
)
|
||||
|
||||
return params_by_room
|
||||
|
|
|
@ -208,8 +208,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
async def _get_linearized_receipts_for_room(
|
||||
self, room_id: str, to_key: int, from_key: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""See get_linearized_receipts_for_room
|
||||
"""
|
||||
"""See get_linearized_receipts_for_room"""
|
||||
|
||||
def f(txn):
|
||||
if from_key:
|
||||
|
@ -304,7 +303,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||
}
|
||||
return results
|
||||
|
||||
@cached(num_args=2,)
|
||||
@cached(
|
||||
num_args=2,
|
||||
)
|
||||
async def get_linearized_receipts_for_all_rooms(
|
||||
self, to_key: int, from_key: Optional[int] = None
|
||||
) -> Dict[str, JsonDict]:
|
||||
|
|
|
@ -79,13 +79,16 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
|||
# call `find_max_generated_user_id_localpart` each time, which is
|
||||
# expensive if there are many entries.
|
||||
self._user_id_seq = build_sequence_generator(
|
||||
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
|
||||
database.engine,
|
||||
find_max_generated_user_id_localpart,
|
||||
"user_id_seq",
|
||||
)
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
if hs.config.run_background_tasks and self._account_validity.enabled:
|
||||
self._clock.call_later(
|
||||
0.0, self._set_expiration_date_when_missing,
|
||||
0.0,
|
||||
self._set_expiration_date_when_missing,
|
||||
)
|
||||
|
||||
# Create a background job for culling expired 3PID validity tokens
|
||||
|
|
|
@ -193,8 +193,7 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def get_room_count(self) -> int:
|
||||
"""Retrieve the total number of rooms.
|
||||
"""
|
||||
"""Retrieve the total number of rooms."""
|
||||
|
||||
def f(txn):
|
||||
sql = "SELECT count(*) FROM rooms"
|
||||
|
@ -517,7 +516,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
return rooms, room_count[0]
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_paginate", _get_rooms_paginate_txn,
|
||||
"get_rooms_paginate",
|
||||
_get_rooms_paginate_txn,
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
|
@ -578,7 +578,8 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
ret = await self.db_pool.runInteraction(
|
||||
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
|
||||
"get_retention_policy_for_room",
|
||||
get_retention_policy_for_room_txn,
|
||||
)
|
||||
|
||||
# If we don't know this room ID, ret will be None, in this case return the default
|
||||
|
@ -707,7 +708,10 @@ class RoomWorkerStore(SQLBaseStore):
|
|||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
async def quarantine_media_by_id(
|
||||
self, server_name: str, media_id: str, quarantined_by: str,
|
||||
self,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
quarantined_by: str,
|
||||
) -> int:
|
||||
"""quarantines a single local or remote media id
|
||||
|
||||
|
@ -961,7 +965,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
self.config = hs.config
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"insert_room_retention", self._background_insert_retention,
|
||||
"insert_room_retention",
|
||||
self._background_insert_retention,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
|
@ -1033,7 +1038,8 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
|
|||
return False
|
||||
|
||||
end = await self.db_pool.runInteraction(
|
||||
"insert_room_retention", _background_insert_retention_txn,
|
||||
"insert_room_retention",
|
||||
_background_insert_retention_txn,
|
||||
)
|
||||
|
||||
if end:
|
||||
|
@ -1588,7 +1594,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
|||
LIMIT ?
|
||||
OFFSET ?
|
||||
""".format(
|
||||
where_clause=where_clause, order=order,
|
||||
where_clause=where_clause,
|
||||
order=order,
|
||||
)
|
||||
|
||||
args += [limit, start]
|
||||
|
|
|
@ -70,10 +70,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
):
|
||||
self._known_servers_count = 1
|
||||
self.hs.get_clock().looping_call(
|
||||
self._count_known_servers, 60 * 1000,
|
||||
self._count_known_servers,
|
||||
60 * 1000,
|
||||
)
|
||||
self.hs.get_clock().call_later(
|
||||
1000, self._count_known_servers,
|
||||
1000,
|
||||
self._count_known_servers,
|
||||
)
|
||||
LaterGauge(
|
||||
"synapse_federation_known_servers",
|
||||
|
@ -174,7 +176,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
|
||||
@cached(max_entries=100000)
|
||||
async def get_room_summary(self, room_id: str) -> Dict[str, MemberSummary]:
|
||||
""" Get the details of a room roughly suitable for use by the room
|
||||
"""Get the details of a room roughly suitable for use by the room
|
||||
summary extension to /sync. Useful when lazy loading room members.
|
||||
Args:
|
||||
room_id: The room ID to query
|
||||
|
@ -488,8 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
async def get_users_who_share_room_with_user(
|
||||
self, user_id: str, cache_context: _CacheContext
|
||||
) -> Set[str]:
|
||||
"""Returns the set of users who share a room with `user_id`
|
||||
"""
|
||||
"""Returns the set of users who share a room with `user_id`"""
|
||||
room_ids = await self.get_rooms_for_user(
|
||||
user_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
@ -618,7 +619,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
|
||||
cached_method_name="_get_joined_profile_from_event_id",
|
||||
list_name="event_ids",
|
||||
)
|
||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
|
@ -802,8 +804,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||
async def get_membership_from_event_ids(
|
||||
self, member_event_ids: Iterable[str]
|
||||
) -> List[dict]:
|
||||
"""Get user_id and membership of a set of event IDs.
|
||||
"""
|
||||
"""Get user_id and membership of a set of event IDs."""
|
||||
|
||||
return await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
|
|
|
@ -23,5 +23,6 @@ def run_create(cur, database_engine, *args, **kwargs):
|
|||
|
||||
def run_upgrade(cur, database_engine, *args, **kwargs):
|
||||
cur.execute(
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?", (int(time.time() * 1000),),
|
||||
"UPDATE remote_media_cache SET last_access_ts = ?",
|
||||
(int(time.time() * 1000),),
|
||||
)
|
||||
|
|
|
@ -52,8 +52,7 @@ class _GetStateGroupDelta(
|
|||
|
||||
# this inherits from EventsWorkerStore because it calls self.get_events
|
||||
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"""The parts of StateGroupStore that can be called from workers.
|
||||
"""
|
||||
"""The parts of StateGroupStore that can be called from workers."""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -276,8 +275,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||
num_args=1,
|
||||
)
|
||||
async def _get_state_group_for_events(self, event_ids):
|
||||
"""Returns mapping event_id -> state_group
|
||||
"""
|
||||
"""Returns mapping event_id -> state_group"""
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="event_id",
|
||||
|
@ -338,7 +336,8 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||
columns=["state_group"],
|
||||
)
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
|
||||
self.DELETE_CURRENT_STATE_UPDATE_NAME,
|
||||
self._background_remove_left_rooms,
|
||||
)
|
||||
|
||||
async def _background_remove_left_rooms(self, progress, batch_size):
|
||||
|
@ -487,7 +486,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
|||
|
||||
|
||||
class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
|
||||
""" Keeps track of the state at a given event.
|
||||
"""Keeps track of the state at a given event.
|
||||
|
||||
This is done by the concept of `state groups`. Every event is a assigned
|
||||
a state group (identified by an arbitrary string), which references a
|
||||
|
|
|
@ -1001,7 +1001,9 @@ class StatsStore(StateDeltasStore):
|
|||
ORDER BY {order_by_column} {order}
|
||||
LIMIT ? OFFSET ?
|
||||
""".format(
|
||||
sql_base=sql_base, order_by_column=order_by_column, order=order,
|
||||
sql_base=sql_base,
|
||||
order_by_column=order_by_column,
|
||||
order=order,
|
||||
)
|
||||
|
||||
args += [limit, start]
|
||||
|
|
|
@ -565,7 +565,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
AND e.stream_ordering > ? AND e.stream_ordering <= ?
|
||||
ORDER BY e.stream_ordering ASC
|
||||
"""
|
||||
txn.execute(sql, (user_id, min_from_id, max_to_id,))
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
user_id,
|
||||
min_from_id,
|
||||
max_to_id,
|
||||
),
|
||||
)
|
||||
|
||||
rows = [
|
||||
_EventDictReturn(event_id, None, stream_ordering)
|
||||
|
@ -695,7 +702,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
return "t%d-%d" % (topo, token)
|
||||
|
||||
def get_stream_id_for_event_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, allow_none=False,
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
event_id: str,
|
||||
allow_none=False,
|
||||
) -> int:
|
||||
return self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
|
@ -706,8 +716,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
)
|
||||
|
||||
async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
|
||||
"""Get the persisted position for an event
|
||||
"""
|
||||
"""Get the persisted position for an event"""
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
|
@ -897,19 +906,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
) -> Tuple[int, List[EventBase]]:
|
||||
"""Get all new events
|
||||
|
||||
Returns all events with from_id < stream_ordering <= current_id.
|
||||
Returns all events with from_id < stream_ordering <= current_id.
|
||||
|
||||
Args:
|
||||
from_id: the stream_ordering of the last event we processed
|
||||
current_id: the stream_ordering of the most recently processed event
|
||||
limit: the maximum number of events to return
|
||||
Args:
|
||||
from_id: the stream_ordering of the last event we processed
|
||||
current_id: the stream_ordering of the most recently processed event
|
||||
limit: the maximum number of events to return
|
||||
|
||||
Returns:
|
||||
A tuple of (next_id, events), where `next_id` is the next value to
|
||||
pass as `from_id` (it will either be the stream_ordering of the
|
||||
last returned event, or, if fewer than `limit` events were found,
|
||||
the `current_id`).
|
||||
"""
|
||||
Returns:
|
||||
A tuple of (next_id, events), where `next_id` is the next value to
|
||||
pass as `from_id` (it will either be the stream_ordering of the
|
||||
last returned event, or, if fewer than `limit` events were found,
|
||||
the `current_id`).
|
||||
"""
|
||||
|
||||
def get_all_new_events_stream_txn(txn):
|
||||
sql = (
|
||||
|
@ -1238,8 +1247,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
|
|||
|
||||
@cached()
|
||||
async def get_id_for_instance(self, instance_name: str) -> int:
|
||||
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance.
|
||||
"""
|
||||
"""Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
|
||||
|
||||
def _get_id_for_instance_txn(txn):
|
||||
instance_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
|
|
|
@ -64,8 +64,7 @@ class TransactionWorkerStore(SQLBaseStore):
|
|||
|
||||
|
||||
class TransactionStore(TransactionWorkerStore):
|
||||
"""A collection of queries for handling PDUs.
|
||||
"""
|
||||
"""A collection of queries for handling PDUs."""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -299,7 +298,10 @@ class TransactionStore(TransactionWorkerStore):
|
|||
)
|
||||
|
||||
async def store_destination_rooms_entries(
|
||||
self, destinations: Iterable[str], room_id: str, stream_ordering: int,
|
||||
self,
|
||||
destinations: Iterable[str],
|
||||
room_id: str,
|
||||
stream_ordering: int,
|
||||
) -> None:
|
||||
"""
|
||||
Updates or creates `destination_rooms` entries in batch for a single event.
|
||||
|
@ -394,7 +396,9 @@ class TransactionStore(TransactionWorkerStore):
|
|||
)
|
||||
|
||||
async def get_catch_up_room_event_ids(
|
||||
self, destination: str, last_successful_stream_ordering: int,
|
||||
self,
|
||||
destination: str,
|
||||
last_successful_stream_ordering: int,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns at most 50 event IDs and their corresponding stream_orderings
|
||||
|
@ -418,7 +422,9 @@ class TransactionStore(TransactionWorkerStore):
|
|||
|
||||
@staticmethod
|
||||
def _get_catch_up_room_event_ids_txn(
|
||||
txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
|
||||
txn: LoggingTransaction,
|
||||
destination: str,
|
||||
last_successful_stream_ordering: int,
|
||||
) -> List[str]:
|
||||
q = """
|
||||
SELECT event_id FROM destination_rooms
|
||||
|
@ -429,7 +435,8 @@ class TransactionStore(TransactionWorkerStore):
|
|||
LIMIT 50
|
||||
"""
|
||||
txn.execute(
|
||||
q, (destination, last_successful_stream_ordering),
|
||||
q,
|
||||
(destination, last_successful_stream_ordering),
|
||||
)
|
||||
event_ids = [row[0] for row in txn]
|
||||
return event_ids
|
||||
|
|
|
@ -44,7 +44,11 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
|
||||
async def create_ui_auth_session(
|
||||
self, clientdict: JsonDict, uri: str, method: str, description: str,
|
||||
self,
|
||||
clientdict: JsonDict,
|
||||
uri: str,
|
||||
method: str,
|
||||
description: str,
|
||||
) -> UIAuthSessionData:
|
||||
"""
|
||||
Creates a new user interactive authentication session.
|
||||
|
@ -123,7 +127,10 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
return UIAuthSessionData(session_id, **result)
|
||||
|
||||
async def mark_ui_auth_stage_complete(
|
||||
self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
|
||||
self,
|
||||
session_id: str,
|
||||
stage_type: str,
|
||||
result: Union[str, bool, JsonDict],
|
||||
):
|
||||
"""
|
||||
Mark a session stage as completed.
|
||||
|
@ -261,10 +268,12 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
return serverdict.get(key, default)
|
||||
|
||||
async def add_user_agent_ip_to_ui_auth_session(
|
||||
self, session_id: str, user_agent: str, ip: str,
|
||||
self,
|
||||
session_id: str,
|
||||
user_agent: str,
|
||||
ip: str,
|
||||
):
|
||||
"""Add the given user agent / IP to the tracking table
|
||||
"""
|
||||
"""Add the given user agent / IP to the tracking table"""
|
||||
await self.db_pool.simple_upsert(
|
||||
table="ui_auth_sessions_ips",
|
||||
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
|
||||
|
@ -273,7 +282,8 @@ class UIAuthWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
async def get_user_agents_ips_to_ui_auth_session(
|
||||
self, session_id: str,
|
||||
self,
|
||||
session_id: str,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Get the given user agents / IPs used during the ui auth process
|
||||
|
||||
|
|
|
@ -336,8 +336,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
return len(users_to_work_on)
|
||||
|
||||
async def is_room_world_readable_or_publicly_joinable(self, room_id):
|
||||
"""Check if the room is either world_readable or publically joinable
|
||||
"""
|
||||
"""Check if the room is either world_readable or publically joinable"""
|
||||
|
||||
# Create a state filter that only queries join and history state event
|
||||
types_to_filter = (
|
||||
|
@ -516,8 +515,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
|||
)
|
||||
|
||||
async def delete_all_from_user_dir(self) -> None:
|
||||
"""Delete the entire user directory
|
||||
"""
|
||||
"""Delete the entire user directory"""
|
||||
|
||||
def _delete_all_from_user_dir_txn(txn):
|
||||
txn.execute("DELETE FROM user_directory")
|
||||
|
|
|
@ -48,8 +48,7 @@ class _GetStateGroupDelta(
|
|||
|
||||
|
||||
class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
"""A data store for fetching/storing state groups.
|
||||
"""
|
||||
"""A data store for fetching/storing state groups."""
|
||||
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
@ -89,7 +88,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
|||
50000,
|
||||
)
|
||||
self._state_group_members_cache = DictionaryCache(
|
||||
"*stateGroupMembersCache*", 500000,
|
||||
"*stateGroupMembersCache*",
|
||||
500000,
|
||||
)
|
||||
|
||||
def get_max_state_group_txn(txn: Cursor):
|
||||
|
|
|
@ -94,14 +94,12 @@ class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
|
|||
@property
|
||||
@abc.abstractmethod
|
||||
def server_version(self) -> str:
|
||||
"""Gets a string giving the server version. For example: '3.22.0'
|
||||
"""
|
||||
"""Gets a string giving the server version. For example: '3.22.0'"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def in_transaction(self, conn: Connection) -> bool:
|
||||
"""Whether the connection is currently in a transaction.
|
||||
"""
|
||||
"""Whether the connection is currently in a transaction."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
@ -138,8 +138,7 @@ class PostgresEngine(BaseDatabaseEngine):
|
|||
|
||||
@property
|
||||
def supports_using_any_list(self):
|
||||
"""Do we support using `a = ANY(?)` and passing a list
|
||||
"""
|
||||
"""Do we support using `a = ANY(?)` and passing a list"""
|
||||
return True
|
||||
|
||||
def is_deadlock(self, error):
|
||||
|
|
|
@ -29,7 +29,10 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
super().__init__(database_module, database_config)
|
||||
|
||||
database = database_config.get("args", {}).get("database")
|
||||
self._is_in_memory = database in (None, ":memory:",)
|
||||
self._is_in_memory = database in (
|
||||
None,
|
||||
":memory:",
|
||||
)
|
||||
|
||||
if platform.python_implementation() == "PyPy":
|
||||
# pypy's sqlite3 module doesn't handle bytearrays, convert them
|
||||
|
@ -63,8 +66,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
|
|||
|
||||
@property
|
||||
def supports_using_any_list(self):
|
||||
"""Do we support using `a = ANY(?)` and passing a list
|
||||
"""
|
||||
"""Do we support using `a = ANY(?)` and passing a list"""
|
||||
return False
|
||||
|
||||
def check_database(self, db_conn, allow_outdated_version: bool = False):
|
||||
|
|
|
@ -411,8 +411,8 @@ class EventsPersistenceStorage:
|
|||
)
|
||||
|
||||
for room_id, ev_ctx_rm in events_by_room.items():
|
||||
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
|
||||
room_id
|
||||
latest_event_ids = (
|
||||
await self.main_store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
new_latest_event_ids = await self._calculate_new_extremities(
|
||||
room_id, ev_ctx_rm, latest_event_ids
|
||||
|
@ -889,7 +889,8 @@ class EventsPersistenceStorage:
|
|||
continue
|
||||
|
||||
logger.debug(
|
||||
"Not dropping as too new and not in new_senders: %s", new_senders,
|
||||
"Not dropping as too new and not in new_senders: %s",
|
||||
new_senders,
|
||||
)
|
||||
|
||||
return new_latest_event_ids
|
||||
|
@ -1004,7 +1005,10 @@ class EventsPersistenceStorage:
|
|||
|
||||
remote_event_ids = [
|
||||
event_id
|
||||
for (typ, state_key,), event_id in current_state.items()
|
||||
for (
|
||||
typ,
|
||||
state_key,
|
||||
), event_id in current_state.items()
|
||||
if typ == EventTypes.Member and not self.is_mine_id(state_key)
|
||||
]
|
||||
rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
|
||||
|
|
|
@ -425,7 +425,10 @@ def _upgrade_existing_database(
|
|||
# We don't support using the same file name in the same delta version.
|
||||
raise PrepareDatabaseException(
|
||||
"Found multiple delta files with the same name in v%d: %s"
|
||||
% (v, duplicates,)
|
||||
% (
|
||||
v,
|
||||
duplicates,
|
||||
)
|
||||
)
|
||||
|
||||
# We sort to ensure that we apply the delta files in a consistent
|
||||
|
@ -532,7 +535,8 @@ def _apply_module_schema_files(
|
|||
names_and_streams: the names and streams of schemas to be applied
|
||||
"""
|
||||
cur.execute(
|
||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
|
||||
"SELECT file FROM applied_module_schemas WHERE module_name = ?",
|
||||
(modname,),
|
||||
)
|
||||
applied_deltas = {d for d, in cur}
|
||||
for (name, stream) in names_and_streams:
|
||||
|
|
|
@ -26,15 +26,13 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class PurgeEventsStorage:
|
||||
"""High level interface for purging rooms and event history.
|
||||
"""
|
||||
"""High level interface for purging rooms and event history."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
||||
self.stores = stores
|
||||
|
||||
async def purge_room(self, room_id: str) -> None:
|
||||
"""Deletes all record of a room
|
||||
"""
|
||||
"""Deletes all record of a room"""
|
||||
|
||||
state_groups_to_delete = await self.stores.main.purge_room(room_id)
|
||||
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
||||
|
|
|
@ -340,8 +340,7 @@ class StateFilter:
|
|||
|
||||
|
||||
class StateGroupStorage:
|
||||
"""High level interface to fetching state for event.
|
||||
"""
|
||||
"""High level interface to fetching state for event."""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self.stores = stores
|
||||
|
@ -400,7 +399,7 @@ class StateGroupStorage:
|
|||
async def get_state_groups(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[int, List[EventBase]]:
|
||||
""" Get the state groups for the given list of event_ids
|
||||
"""Get the state groups for the given list of event_ids
|
||||
|
||||
Args:
|
||||
room_id: ID of the room for these events.
|
||||
|
|
|
@ -277,7 +277,9 @@ class MultiWriterIdGenerator:
|
|||
self._load_current_ids(db_conn, tables)
|
||||
|
||||
def _load_current_ids(
|
||||
self, db_conn, tables: List[Tuple[str, str, str]],
|
||||
self,
|
||||
db_conn,
|
||||
tables: List[Tuple[str, str, str]],
|
||||
):
|
||||
cur = db_conn.cursor(txn_name="_load_current_ids")
|
||||
|
||||
|
@ -364,7 +366,10 @@ class MultiWriterIdGenerator:
|
|||
rows.sort()
|
||||
|
||||
with self._lock:
|
||||
for (instance, stream_id,) in rows:
|
||||
for (
|
||||
instance,
|
||||
stream_id,
|
||||
) in rows:
|
||||
stream_id = self._return_factor * stream_id
|
||||
self._add_persisted_position(stream_id)
|
||||
|
||||
|
@ -481,8 +486,7 @@ class MultiWriterIdGenerator:
|
|||
return self.get_persisted_upto_position()
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
"""
|
||||
"""Returns the position of the given writer."""
|
||||
|
||||
# If we don't have an entry for the given instance name, we assume it's a
|
||||
# new writer.
|
||||
|
@ -581,8 +585,7 @@ class MultiWriterIdGenerator:
|
|||
break
|
||||
|
||||
def _update_stream_positions_table_txn(self, txn: Cursor):
|
||||
"""Update the `stream_positions` table with newly persisted position.
|
||||
"""
|
||||
"""Update the `stream_positions` table with newly persisted position."""
|
||||
|
||||
if not self._writers:
|
||||
return
|
||||
|
@ -622,8 +625,7 @@ class _AsyncCtxManagerWrapper:
|
|||
|
||||
@attr.s(slots=True)
|
||||
class _MultiWriterCtxManager:
|
||||
"""Async context manager returned by MultiWriterIdGenerator
|
||||
"""
|
||||
"""Async context manager returned by MultiWriterIdGenerator"""
|
||||
|
||||
id_gen = attr.ib(type=MultiWriterIdGenerator)
|
||||
multiple_ids = attr.ib(type=Optional[int], default=None)
|
||||
|
|
|
@ -124,8 +124,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
|||
stream_name: Optional[str] = None,
|
||||
positive: bool = True,
|
||||
):
|
||||
"""See SequenceGenerator.check_consistency for docstring.
|
||||
"""
|
||||
"""See SequenceGenerator.check_consistency for docstring."""
|
||||
|
||||
txn = db_conn.cursor(txn_name="sequence.check_consistency")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue