Add some type hints to datastore (#12485)

This commit is contained in:
Dirk Klimpel 2022-04-27 14:05:00 +02:00 committed by GitHub
parent 63ba9ba38b
commit b76f1a4d5f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 188 additions and 84 deletions

View file

@ -14,7 +14,17 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
prefilled_cache=device_outbox_prefill,
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
) -> None:
if stream_name == ToDeviceStream.NAME:
# If replication is happening than postgres must be being used.
assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
def get_to_device_stream_token(self):
def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()
async def get_messages_for_user_devices(
@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if not user_ids_to_query:
return {}, to_stream_id
def get_device_messages_txn(txn: LoggingTransaction):
def get_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
# Build a query to select messages from any of the given devices that
# are between the given stream id bounds.
@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": "No changes in cache since last check"})
return 0
def delete_messages_for_device_txn(txn):
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
) -> Tuple[List[dict], int]:
self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
) -> Tuple[List[JsonDict], int]:
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
destination: The name of the remote server.
last_stream_id: The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
current_stream_id: The current position of the device message stream.
Returns:
A list of messages for the device and where in the stream the messages got to.
"""
@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return [], last_stream_id
@trace
def get_new_messages_for_remote_destination_txn(txn):
def get_new_messages_for_remote_destination_txn(
txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]:
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
up_to_stream_id: Where to delete messages up to.
"""
def delete_messages_for_remote_destination_txn(txn):
def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
if last_id == current_id:
return [], current_id, False
def get_all_new_device_messages_txn(txn):
def get_all_new_device_messages_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def add_messages_to_device_inbox(
self,
local_messages_by_user_then_device: dict,
remote_messages_by_destination: dict,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
remote_messages_by_destination: Dict[str, JsonDict],
) -> int:
"""Used to send messages from this server.
@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
assert self._can_write_to_device
def add_messages_txn(txn, now_ms, stream_id):
def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Add the local messages directly to the local inbox.
self._add_messages_to_local_device_inbox_txn(
txn, stream_id, local_messages_by_user_then_device
@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return self._device_inbox_id_gen.get_current_token()
async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
self,
origin: str,
message_id: str,
local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> int:
assert self._can_write_to_device
def add_messages_txn(txn, now_ms, stream_id):
def add_messages_txn(
txn: LoggingTransaction, now_ms: int, stream_id: int
) -> None:
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return stream_id
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
self,
txn: LoggingTransaction,
stream_id: int,
messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
) -> None:
assert self._can_write_to_device
local_by_user_then_device = {}
@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self._remove_dead_devices_from_device_inbox,
)
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
async def _background_drop_index_device_inbox(
self, progress: JsonDict, batch_size: int
) -> int:
def reindex_txn(conn: LoggingDatabaseConnection) -> None:
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()