Convert account data, device inbox, and censor events databases to async/await (#8063)

This commit is contained in:
Patrick Cloke 2020-08-12 09:29:06 -04:00 committed by GitHub
parent a3a59bab7b
commit d68e10f308
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 87 deletions

View file

@ -16,8 +16,6 @@
import logging
from typing import List, Tuple
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
async def get_new_messages_for_device(
self,
user_id: str,
device_id: str,
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
user_id: The recipient user_id.
device_id: The recipient device_id.
last_stream_id: The last stream ID checked.
current_stream_id: The current position of the to device
message stream.
limit: The maximum number of messages to retrieve.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
A list of messages for the device and where in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
return ([], current_stream_id)
def get_new_messages_for_device_txn(txn):
sql = (
@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@trace
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
async def delete_messages_for_device(
self, user_id: str, device_id: str, up_to_stream_id: int
) -> int:
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
user_id: The recipient user_id.
device_id: The recipient device_id.
up_to_stream_id: Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
The number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.db_pool.runInteraction(
count = await self.db_pool.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return count
@trace
def get_new_device_msgs_for_remote(
async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
):
) -> Tuple[List[dict], int]:
"""
Args:
destination(str): The name of the remote server.
@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
A list of messages for the device and where in the stream the messages got to.
"""
set_tag("destination", destination)
@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
return defer.succeed(([], current_stream_id))
return ([], current_stream_id)
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
return defer.succeed(([], last_stream_id))
return ([], last_stream_id)
@trace
def get_new_messages_for_remote_destination_txn(txn):
@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.db_pool.runWithConnection(reindex_txn)
await self.db_pool.runWithConnection(reindex_txn)
yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
@trace
@defer.inlineCallbacks
def add_messages_to_device_inbox(
self, local_messages_by_user_then_device, remote_messages_by_destination
):
async def add_messages_to_device_inbox(
self,
local_messages_by_user_then_device: dict,
remote_messages_by_destination: dict,
) -> int:
"""Used to send messages from this server.
Args:
sender_user_id(str): The ID of the user sending these messages.
local_messages_by_user_and_device(dict):
local_messages_by_user_and_device:
Dictionary of user_id to device_id to message.
remote_messages_by_destination(dict):
remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.
Returns:
A deferred stream_id that resolves when the messages have been
inserted.
The new stream_id.
"""
def add_messages_txn(txn, now_ms, stream_id):
@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
self, origin, message_id, local_messages_by_user_then_device
):
async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
) -> int:
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,