Add some type hints to datastore (#12485)

This commit is contained in:
Dirk Klimpel 2022-04-27 14:05:00 +02:00 committed by GitHub
parent 63ba9ba38b
commit b76f1a4d5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 188 additions and 84 deletions

1
changelog.d/12485.misc Normal file
View File

@ -0,0 +1 @@
Add some type hints to datastore.

View File

@ -15,12 +15,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import (
IdGenerator, IdGenerator,
MultiWriterIdGenerator, MultiWriterIdGenerator,
@ -266,7 +271,9 @@ class DataStore(
A tuple of a list of mappings from user to information and a count of total users. A tuple of a list of mappings from user to information and a count of total users.
""" """
def get_users_paginate_txn(txn): def get_users_paginate_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
filters = [] filters = []
args = [self.hs.config.server.server_name] args = [self.hs.config.server.server_name]
@ -301,7 +308,7 @@ class DataStore(
""" """
sql = "SELECT COUNT(*) as total_users " + sql_base sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args) txn.execute(sql, args)
count = txn.fetchone()[0] count = cast(Tuple[int], txn.fetchone())[0]
sql = f""" sql = f"""
SELECT name, user_type, is_guest, admin, deactivated, shadow_banned, SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
@ -338,7 +345,9 @@ class DataStore(
) )
def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig): def check_database_before_upgrade(
cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
) -> None:
"""Called before upgrading an existing database to check that it is broadly sane """Called before upgrading an existing database to check that it is broadly sane
compared with the configuration. compared with the configuration.
""" """

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
from synapse.appservice import ( from synapse.appservice import (
ApplicationService, ApplicationService,
@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
txn.execute( txn.execute(
"SELECT COALESCE(max(txn_id), 0) FROM application_services_txns" "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
) )
return txn.fetchone()[0] # type: ignore return cast(Tuple[int], txn.fetchone())[0]
self._as_txn_seq_gen = build_sequence_generator( self._as_txn_seq_gen = build_sequence_generator(
db_conn, db_conn,

View File

@ -14,7 +14,17 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
from synapse.logging import issue9533_logger from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
) -> None:
if stream_name == ToDeviceStream.NAME: if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used. # If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator) assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
) )
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def get_to_device_stream_token(self): def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
async def get_messages_for_user_devices( async def get_messages_for_user_devices(
@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not user_ids_to_query: if not user_ids_to_query:
return {}, to_stream_id return {}, to_stream_id
def get_device_messages_txn(txn: LoggingTransaction): def get_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
# Build a query to select messages from any of the given devices that # Build a query to select messages from any of the given devices that
# are between the given stream id bounds. # are between the given stream id bounds.
@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"}) log_kv({"message": "No changes in cache since last check"})
return 0 return 0
def delete_messages_for_device_txn(txn): def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = ( sql = (
"DELETE FROM device_inbox" "DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?" " WHERE user_id = ? AND device_id = ?"
@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace @trace
async def get_new_device_msgs_for_remote( async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
) -> Tuple[List[dict], int]: ) -> Tuple[List[JsonDict], int]:
""" """
Args: Args:
destination(str): The name of the remote server. destination: The name of the remote server.
last_stream_id(int|long): The last position of the device message stream last_stream_id: The last position of the device message stream
that the server sent up to. that the server sent up to.
current_stream_id(int|long): The current position of the device current_stream_id: The current position of the device message stream.
message stream.
Returns: Returns:
A list of messages for the device and where in the stream the messages got to. A list of messages for the device and where in the stream the messages got to.
""" """
@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return [], last_stream_id return [], last_stream_id
@trace @trace
def get_new_messages_for_remote_destination_txn(txn): def get_new_messages_for_remote_destination_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
sql = ( sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox" "SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?" " WHERE destination = ?"
@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
up_to_stream_id: Where to delete messages up to. up_to_stream_id: Where to delete messages up to.
""" """
def delete_messages_for_remote_destination_txn(txn): def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
sql = ( sql = (
"DELETE FROM device_federation_outbox" "DELETE FROM device_federation_outbox"
" WHERE destination = ?" " WHERE destination = ?"
@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_new_device_messages_txn(txn): def get_all_new_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We limit like this as we might have multiple rows per stream_id, and # We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id # we want to make sure we always get all entries for any stream_id
# we return. # we return.
@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace @trace
async def add_messages_to_device_inbox( async def add_messages_to_device_inbox(
self, self,
local_messages_by_user_then_device: dict, local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
remote_messages_by_destination: dict, remote_messages_by_destination: Dict[str, JsonDict],
) -> int: ) -> int:
"""Used to send messages from this server. """Used to send messages from this server.
@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
assert self._can_write_to_device assert self._can_write_to_device
def add_messages_txn(txn, now_ms, stream_id): def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Add the local messages directly to the local inbox. # Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn( self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device txn, stream_id, local_messages_by_user_then_device
@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return self._device_inbox_id_gen.get_current_token() return self._device_inbox_id_gen.get_current_token()
async def add_messages_from_remote_to_device_inbox( async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict self,
origin: str,
message_id: str,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> int: ) -> int:
assert self._can_write_to_device assert self._can_write_to_device
def add_messages_txn(txn, now_ms, stream_id): def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Check if we've already inserted a matching message_id for that # Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our # origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message. # acknowledgement from the first time we received the message.
@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return stream_id return stream_id
def _add_messages_to_local_device_inbox_txn( def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device self,
): txn: LoggingTransaction,
stream_id: int,
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> None:
assert self._can_write_to_device assert self._can_write_to_device
local_by_user_then_device = {} local_by_user_then_device = {}
@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_dead_devices_from_device_inbox, self._remove_dead_devices_from_device_inbox,
) )
async def _background_drop_index_device_inbox(self, progress, batch_size): async def _background_drop_index_device_inbox(
def reindex_txn(conn): self, progress: JsonDict, batch_size: int
) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor() txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close() txn.close()

View File

@ -25,6 +25,7 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
cast,
) )
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
Number of devices of this users. Number of devices of this users.
""" """
def count_devices_by_users_txn(txn, user_ids): def count_devices_by_users_txn(
txn: LoggingTransaction, user_ids: List[str]
) -> int:
sql = """ sql = """
SELECT count(*) SELECT count(*)
FROM devices FROM devices
@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
txn.execute(sql + clause, args) txn.execute(sql + clause, args)
return txn.fetchone()[0] return cast(Tuple[int], txn.fetchone())[0]
if not user_ids: if not user_ids:
return 0 return 0
@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
txn.execute(sql, (destination, from_stream_id, now_stream_id, limit)) txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
return list(txn) return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
async def _get_device_update_edus_by_remote( async def _get_device_update_edus_by_remote(
self, self,
@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
async def _get_last_device_update_for_remote_user( async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int self, destination: str, user_id: str, from_stream_id: int
) -> int: ) -> int:
def f(txn): def f(txn: LoggingTransaction) -> int:
prev_sent_id_sql = """ prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success FROM device_lists_outbound_last_success
@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not user_ids_to_check: if not user_ids_to_check:
return set() return set()
def _get_users_whose_devices_changed_txn(txn): def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
changes = set() changes = set()
stream_id_where_clause = "stream_id > ?" stream_id_where_clause = "stream_id > ?"
@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user.""" """Mark that we no longer track device lists for remote user."""
def _mark_remote_user_device_list_as_unsubscribed_txn(txn): def _mark_remote_user_device_list_as_unsubscribed_txn(
txn: LoggingTransaction,
) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
def _store_dehydrated_device_txn( def _store_dehydrated_device_txn(
self, txn, user_id: str, device_id: str, device_data: str self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
) -> Optional[str]: ) -> Optional[str]:
old_device_id = self.db_pool.simple_select_one_onecol_txn( old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, txn,
@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
yesterday = self._clock.time_msec() - prune_age yesterday = self._clock.time_msec() - prune_age
def _prune_txn(txn): def _prune_txn(txn: LoggingTransaction) -> None:
# look for (user, destination) pairs which have an update older than # look for (user, destination) pairs which have an update older than
# the cutoff. # the cutoff.
# #
@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"drop_device_lists_outbound_last_success_non_unique_idx", "drop_device_lists_outbound_last_success_non_unique_idx",
) )
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size): async def _drop_device_list_streams_non_unique_indexes(
def f(conn): self, progress: JsonDict, batch_size: int
) -> int:
def f(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor() txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
) )
return 1 return 1
async def _remove_duplicate_outbound_pokes(self, progress, batch_size): async def _remove_duplicate_outbound_pokes(
self, progress: JsonDict, batch_size: int
) -> int:
# for some reason, we have accumulated duplicate entries in # for some reason, we have accumulated duplicate entries in
# device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
# efficient. # efficient.
@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
{"stream_id": 0, "destination": "", "user_id": "", "device_id": ""}, {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
) )
def _txn(txn): def _txn(txn: LoggingTransaction) -> int:
clause, args = make_tuple_comparison_clause( clause, args = make_tuple_comparison_clause(
[(x, last_row[x]) for x in KEY_COLS] [(x, last_row[x]) for x in KEY_COLS]
) )
@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map() context = get_active_span_text_map()
def add_device_changes_txn(txn, stream_ids): def add_device_changes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
self._add_device_change_to_stream_txn( self._add_device_change_to_stream_txn(
txn, txn,
user_id, user_id,
@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn: LoggingTransaction, txn: LoggingTransaction,
user_id: str, user_id: str,
device_ids: Collection[str], device_ids: Collection[str],
stream_ids: List[str], stream_ids: List[int],
): ) -> None:
txn.call_after( txn.call_after(
self._device_list_stream_cache.entity_has_changed, self._device_list_stream_cache.entity_has_changed,
user_id, user_id,
@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str, user_id: str,
device_ids: Iterable[str], device_ids: Iterable[str],
room_ids: Collection[str], room_ids: Collection[str],
stream_ids: List[str], stream_ids: List[int],
context: Dict[str, str], context: Dict[str, str],
) -> None: ) -> None:
"""Record the user in the room has updated their device.""" """Record the user in the room has updated their device."""
@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
LIMIT ? LIMIT ?
""" """
def get_uncoverted_outbound_room_pokes_txn(txn): def get_uncoverted_outbound_room_pokes_txn(
txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
txn.execute(sql, (limit,)) txn.execute(sql, (limit,))
return [ return [
@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Marks the associated row in `device_lists_changes_in_room` as handled. Marks the associated row in `device_lists_changes_in_room` as handled.
""" """
def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]): def add_device_list_outbound_pokes_txn(
txn: LoggingTransaction, stream_ids: List[int]
) -> None:
if hosts: if hosts:
self._add_device_outbound_poke_to_stream_txn( self._add_device_outbound_poke_to_stream_txn(
txn, txn,

View File

@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups", desc="get_joined_groups",
) )
async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]: async def get_all_groups_for_user(
self, user_id: str, now_token: int
) -> List[JsonDict]:
def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]: def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
sql = """ sql = """
SELECT group_id, type, membership, u.content SELECT group_id, type, membership, u.content

View File

@ -15,11 +15,12 @@
import itertools import itertools
import logging import logging
from typing import Dict, Iterable, List, Optional, Tuple from typing import Any, Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys""" """Persistence for signature verification keys"""
@cached() @cached()
def _get_server_verify_key(self, server_name_and_key_id): def _get_server_verify_key(
self, server_name_and_key_id: Tuple[str, str]
) -> FetchKeyResult:
raise NotImplementedError() raise NotImplementedError()
@cachedList( @cachedList(
@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore):
async def get_server_keys_json( async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]: ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids. """Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list. that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response. used in an HTTP response.
Args: Args:
server_keys (list): List of (server_name, key_id, source) triplets. server_keys: List of (server_name, key_id, source) triplets.
Returns: Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts A mapping from (server_name, key_id, source) triplets to a list of dicts
""" """
def _get_server_keys_json_txn(txn): def _get_server_keys_json_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
results = {} results = {}
for server_name, key_id, from_server in server_keys: for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name} keyvalues = {"server_name": server_name}

View File

@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache( async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts self,
url: str,
response_code: int,
etag: Optional[str],
expires_ts: int,
og: Optional[str],
media_id: str,
download_ts: int,
) -> None: ) -> None:
await self.db_pool.simple_insert( await self.db_pool.simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
async def get_cached_remote_media( async def get_cached_remote_media(
self, origin, media_id: str self, origin: str, media_id: str
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one( return await self.db_pool.simple_select_one(
"remote_media_cache", "remote_media_cache",
@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
) )
async def delete_remote_media(self, media_origin: str, media_id: str) -> None: async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn): def delete_remote_media_txn(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
"remote_media_cache", "remote_media_cache",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
from synapse.api.presence import PresenceState, UserPresenceState from synapse.api.presence import PresenceState, UserPresenceState
from synapse.replication.tcp.streams import PresenceStream from synapse.replication.tcp.streams import PresenceStream
@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
prefilled_cache=presence_cache_prefill, prefilled_cache=presence_cache_prefill,
) )
async def update_presence(self, presence_states) -> Tuple[int, int]: async def update_presence(
self, presence_states: List[UserPresenceState]
) -> Tuple[int, int]:
assert self._can_persist_presence assert self._can_persist_presence
stream_ordering_manager = self._presence_id_gen.get_next_mult( stream_ordering_manager = self._presence_id_gen.get_next_mult(
@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
return stream_orderings[-1], self._presence_id_gen.get_current_token() return stream_orderings[-1], self._presence_id_gen.get_current_token()
def _update_presence_txn( def _update_presence_txn(
self, txn: LoggingTransaction, stream_orderings, presence_states self,
txn: LoggingTransaction,
stream_orderings: List[int],
presence_states: List[UserPresenceState],
) -> None: ) -> None:
for stream_id, state in zip(stream_orderings, presence_states): for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after( txn.call_after(
@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore):
self._presence_on_startup = [] self._presence_on_startup = []
return active_on_startup return active_on_startup
def process_replication_rows(self, stream_name, instance_name, token, rows) -> None: def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == PresenceStream.NAME: if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(instance_name, token) self._presence_id_gen.advance(instance_name, token)
for row in rows: for row in rows:

View File

@ -14,11 +14,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Tuple,
cast,
)
from synapse.push import PusherConfig, ThrottleParams from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(ret) return self._decode_pushers_rows(ret)
async def get_all_pushers(self) -> Iterator[PusherConfig]: async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn): def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers") txn.execute("SELECT * FROM pushers")
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
if last_id == current_id: if last_id == current_id:
return [], current_id, False return [], current_id, False
def get_all_updated_pushers_rows_txn(txn): def get_all_updated_pushers_rows_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
sql = """ sql = """
SELECT id, user_name, app_id, pushkey SELECT id, user_name, app_id, pushkey
FROM pushers FROM pushers
@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
ORDER BY id ASC LIMIT ? ORDER BY id ASC LIMIT ?
""" """
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = [ updates = cast(
(stream_id, (user_name, app_id, pushkey, False)) List[Tuple[int, tuple]],
for stream_id, user_name, app_id, pushkey in txn [
] (stream_id, (user_name, app_id, pushkey, False))
for stream_id, user_name, app_id, pushkey in txn
],
)
sql = """ sql = """
SELECT stream_id, user_id, app_id, pushkey SELECT stream_id, user_id, app_id, pushkey
@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
) )
@cached(num_args=1, max_entries=15000) @cached(num_args=1, max_entries=15000)
async def get_if_user_has_pusher(self, user_id: str): async def get_if_user_has_pusher(self, user_id: str) -> None:
# This only exists for the cachedList decorator # This only exists for the cachedList decorator
raise NotImplementedError() raise NotImplementedError()
async def update_pusher_last_stream_ordering( async def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
) -> None: ) -> None:
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
"pushers", "pushers",
@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
last_user = progress.get("last_user", "") last_user = progress.get("last_user", "")
def _delete_pushers(txn) -> int: def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """ sql = """
SELECT name FROM users SELECT name FROM users
@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0) last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int: def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """ sql = """
SELECT p.id, access_token FROM pushers AS p SELECT p.id, access_token FROM pushers AS p
@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
last_pusher = progress.get("last_pusher", 0) last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int: def _delete_pushers(txn: LoggingTransaction) -> int:
sql = """ sql = """
SELECT p.id, p.user_name, p.app_id, p.pushkey SELECT p.id, p.user_name, p.app_id, p.pushkey
@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
async def delete_pusher_by_app_id_pushkey_user_id( async def delete_pusher_by_app_id_pushkey_user_id(
self, app_id: str, pushkey: str, user_id: str self, app_id: str, pushkey: str, user_id: str
) -> None: ) -> None:
def delete_pusher_txn(txn, stream_id): def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined] self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,) txn, self.get_if_user_has_pusher, (user_id,)
) )
@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
# account. # account.
pushers = list(await self.get_pushers_by_user_id(user_id)) pushers = list(await self.get_pushers_by_user_id(user_id))
def delete_pushers_txn(txn, stream_ids): def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
self._invalidate_cache_and_stream( # type: ignore[attr-defined] self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_if_user_has_pusher, (user_id,) txn, self.get_if_user_has_pusher, (user_id,)
) )

View File

@ -370,10 +370,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _update_state_for_partial_state_event_txn( def _update_state_for_partial_state_event_txn(
self, self,
txn, txn: LoggingTransaction,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
): ) -> None:
# we shouldn't have any outliers here # we shouldn't have any outliers here
assert not event.internal_metadata.is_outlier() assert not event.internal_metadata.is_outlier()

View File

@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore):
session_id: str, session_id: str,
stage_type: str, stage_type: str,
result: Union[str, bool, JsonDict], result: Union[str, bool, JsonDict],
): ) -> None:
""" """
Mark a session stage as completed. Mark a session stage as completed.
@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore):
desc="set_ui_auth_client_dict", desc="set_ui_auth_client_dict",
) )
async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): async def set_ui_auth_session_data(
self, session_id: str, key: str, value: Any
) -> None:
""" """
Store a key-value pair into the sessions data associated with this Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by request. This data is stored server-side and cannot be modified by
@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore):
def _set_ui_auth_session_data_txn( def _set_ui_auth_session_data_txn(
self, txn: LoggingTransaction, session_id: str, key: str, value: Any self, txn: LoggingTransaction, session_id: str, key: str, value: Any
): ) -> None:
# Get the current value. # Get the current value.
result = cast( result = cast(
Dict[str, Any], Dict[str, Any],
@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore):
session_id: str, session_id: str,
user_agent: str, user_agent: str,
ip: str, ip: str,
): ) -> None:
"""Add the given user agent / IP to the tracking table""" """Add the given user agent / IP to the tracking table"""
await self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="ui_auth_sessions_ips", table="ui_auth_sessions_ips",
@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore):
def _delete_old_ui_auth_sessions_txn( def _delete_old_ui_auth_sessions_txn(
self, txn: LoggingTransaction, expiration_time: int self, txn: LoggingTransaction, expiration_time: int
): ) -> None:
# Get the expired sessions. # Get the expired sessions.
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
txn.execute(sql, [expiration_time]) txn.execute(sql, [expiration_time])