Convert calls of async database methods to async (#8166)

This commit is contained in:
Patrick Cloke 2020-08-27 13:38:41 -04:00 committed by GitHub
parent c9fa696ea2
commit 9b7ac03af3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 114 additions and 84 deletions

1
changelog.d/8166.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -21,7 +21,9 @@ These actions are mostly only used by the :py:mod:`.replication` module.
import logging import logging
from synapse.federation.units import Transaction
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,15 +51,15 @@ class TransactionActions(object):
return self.store.get_received_txn_response(transaction.transaction_id, origin) return self.store.get_received_txn_response(transaction.transaction_id, origin)
@log_function @log_function
def set_response(self, origin, transaction, code, response): async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
""" Persist how we responded to a transaction. """ Persist how we responded to a transaction.
Returns:
Deferred
""" """
if not transaction.transaction_id: transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id") raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.set_received_txn_response( await self.store.set_received_txn_response(
transaction.transaction_id, origin, code, response transaction_id, origin, code, response
) )

View File

@ -107,9 +107,7 @@ class Transaction(JsonEncodedObject):
if "edus" in kwargs and not kwargs["edus"]: if "edus" in kwargs and not kwargs["edus"]:
del kwargs["edus"] del kwargs["edus"]
super(Transaction, self).__init__( super().__init__(transaction_id=transaction_id, pdus=pdus, **kwargs)
transaction_id=transaction_id, pdus=pdus, **kwargs
)
@staticmethod @staticmethod
def create_new(pdus, **kwargs): def create_new(pdus, **kwargs):

View File

@ -161,16 +161,14 @@ class ApplicationServiceTransactionWorkerStore(
return result.get("state") return result.get("state")
return None return None
def set_appservice_state(self, service, state): async def set_appservice_state(self, service, state) -> None:
"""Set the application service state. """Set the application service state.
Args: Args:
service(ApplicationService): The service whose state to set. service(ApplicationService): The service whose state to set.
state(ApplicationServiceState): The connectivity state to apply. state(ApplicationServiceState): The connectivity state to apply.
Returns:
An Awaitable which resolves when the state was set successfully.
""" """
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
"application_services_state", {"as_id": service.id}, {"state": state} "application_services_state", {"as_id": service.id}, {"state": state}
) )

View File

@ -716,11 +716,11 @@ class DeviceWorkerStore(SQLBaseStore):
return {row["user_id"] for row in rows} return {row["user_id"] for row in rows}
def mark_remote_user_device_cache_as_stale(self, user_id: str): async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
"""Records that the server has reason to believe the cache of the devices """Records that the server has reason to believe the cache of the devices
for the remote users is out of date. for the remote users is out of date.
""" """
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="device_lists_remote_resync", table="device_lists_remote_resync",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
values={}, values={},

View File

@ -742,7 +742,13 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_room_from_summary", desc="remove_room_from_summary",
) )
def upsert_group_category(self, group_id, category_id, profile, is_public): async def upsert_group_category(
self,
group_id: str,
category_id: str,
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/update room category for group """Add/update room category for group
""" """
insertion_values = {} insertion_values = {}
@ -758,7 +764,7 @@ class GroupServerStore(GroupServerWorkerStore):
else: else:
update_values["is_public"] = is_public update_values["is_public"] = is_public
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values, values=update_values,
@ -773,7 +779,13 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_group_category", desc="remove_group_category",
) )
def upsert_group_role(self, group_id, role_id, profile, is_public): async def upsert_group_role(
self,
group_id: str,
role_id: str,
profile: Optional[JsonDict],
is_public: Optional[bool],
) -> None:
"""Add/remove user role """Add/remove user role
""" """
insertion_values = {} insertion_values = {}
@ -789,7 +801,7 @@ class GroupServerStore(GroupServerWorkerStore):
else: else:
update_values["is_public"] = is_public update_values["is_public"] = is_public
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values, values=update_values,
@ -938,10 +950,10 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_user_from_summary", desc="remove_user_from_summary",
) )
def add_group_invite(self, group_id, user_id): async def add_group_invite(self, group_id: str, user_id: str) -> None:
"""Record that the group server has invited a user """Record that the group server has invited a user
""" """
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="group_invites", table="group_invites",
values={"group_id": group_id, "user_id": user_id}, values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite", desc="add_group_invite",
@ -1044,8 +1056,10 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_user_from_group", _remove_user_from_group_txn "remove_user_from_group", _remove_user_from_group_txn
) )
def add_room_to_group(self, group_id, room_id, is_public): async def add_room_to_group(
return self.db_pool.simple_insert( self, group_id: str, room_id: str, is_public: bool
) -> None:
await self.db_pool.simple_insert(
table="group_rooms", table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public}, values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group", desc="add_room_to_group",

View File

@ -140,22 +140,28 @@ class KeyStore(SQLBaseStore):
for i in invalidations: for i in invalidations:
invalidate((i,)) invalidate((i,))
def store_server_keys_json( async def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes self,
): server_name: str,
key_id: str,
from_server: str,
ts_now_ms: int,
ts_expires_ms: int,
key_json_bytes: bytes,
) -> None:
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the server, and by this server. Updates the value for the
(server_name, key_id, from_server) triplet if one already existed. (server_name, key_id, from_server) triplet if one already existed.
Args: Args:
server_name (str): The name of the server. server_name: The name of the server.
key_id (str): The identifer of the key this JSON is for. key_id: The identifer of the key this JSON is for.
from_server (str): The server this JSON was fetched from. from_server: The server this JSON was fetched from.
ts_now_ms (int): The time now in milliseconds. ts_now_ms: The time now in milliseconds.
ts_valid_until_ms (int): The time when this json stops being valid. ts_valid_until_ms: The time when this json stops being valid.
key_json (bytes): The encoded JSON. key_json_bytes: The encoded JSON.
""" """
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="server_keys_json", table="server_keys_json",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,

View File

@ -60,7 +60,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media", desc="get_local_media",
) )
def store_local_media( async def store_local_media(
self, self,
media_id, media_id,
media_type, media_type,
@ -69,8 +69,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
media_length, media_length,
user_id, user_id,
url_cache=None, url_cache=None,
): ) -> None:
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
"local_media_repository", "local_media_repository",
{ {
"media_id": media_id, "media_id": media_id,
@ -141,10 +141,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn) return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache( async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts self, url, response_code, etag, expires_ts, og, media_id, download_ts
): ):
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
"local_media_repository_url_cache", "local_media_repository_url_cache",
{ {
"url": url, "url": url,
@ -172,7 +172,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_local_media_thumbnails", desc="get_local_media_thumbnails",
) )
def store_local_thumbnail( async def store_local_thumbnail(
self, self,
media_id, media_id,
thumbnail_width, thumbnail_width,
@ -181,7 +181,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method, thumbnail_method,
thumbnail_length, thumbnail_length,
): ):
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{ {
"media_id": media_id, "media_id": media_id,
@ -212,7 +212,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_cached_remote_media", desc="get_cached_remote_media",
) )
def store_cached_remote_media( async def store_cached_remote_media(
self, self,
origin, origin,
media_id, media_id,
@ -222,7 +222,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name, upload_name,
filesystem_id, filesystem_id,
): ):
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
"remote_media_cache", "remote_media_cache",
{ {
"media_origin": origin, "media_origin": origin,
@ -288,7 +288,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="get_remote_media_thumbnails", desc="get_remote_media_thumbnails",
) )
def store_remote_media_thumbnail( async def store_remote_media_thumbnail(
self, self,
origin, origin,
media_id, media_id,
@ -299,7 +299,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method, thumbnail_method,
thumbnail_length, thumbnail_length,
): ):
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{ {
"media_origin": origin, "media_origin": origin,

View File

@ -2,8 +2,10 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore): class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id): async def insert_open_id_token(
return self.db_pool.simple_insert( self, token: str, ts_valid_until_ms: int, user_id: str
) -> None:
await self.db_pool.simple_insert(
table="open_id_tokens", table="open_id_tokens",
values={ values={
"token": token, "token": token,

View File

@ -66,8 +66,8 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache", desc="get_from_remote_profile_cache",
) )
def create_profile(self, user_localpart): async def create_profile(self, user_localpart: str) -> None:
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile" table="profiles", values={"user_id": user_localpart}, desc="create_profile"
) )
@ -93,13 +93,15 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore): class ProfileStore(ProfileWorkerStore):
def add_remote_profile_cache(self, user_id, displayname, avatar_url): async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
"""Ensure we are caching the remote user's profiles. """Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user` This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user. would return true for the user.
""" """
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
values={ values={

View File

@ -17,7 +17,7 @@
import logging import logging
import re import re
from typing import Any, Awaitable, Dict, List, Optional from typing import Any, Dict, List, Optional
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@ -549,23 +549,22 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="user_delete_threepids", desc="user_delete_threepids",
) )
def add_user_bound_threepid(self, user_id, medium, address, id_server): async def add_user_bound_threepid(
self, user_id: str, medium: str, address: str, id_server: str
):
"""The server proxied a bind request to the given identity server on """The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid. asks us to unbind the threepid.
Args: Args:
user_id (str) user_id
medium (str) medium
address (str) address
id_server (str) id_server
Returns:
Awaitable
""" """
# We need to use an upsert, in case they user had already bound the # We need to use an upsert, in case they user had already bound the
# threepid # threepid
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -1083,9 +1082,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
def record_user_external_id( async def record_user_external_id(
self, auth_provider: str, external_id: str, user_id: str self, auth_provider: str, external_id: str, user_id: str
) -> Awaitable: ) -> None:
"""Record a mapping from an external user id to a mxid """Record a mapping from an external user id to a mxid
Args: Args:
@ -1093,7 +1092,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system external_id: id on that system
user_id: complete mxid that it is mapped to user_id: complete mxid that it is mapped to
""" """
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="user_external_ids", table="user_external_ids",
values={ values={
"auth_provider": auth_provider, "auth_provider": auth_provider,
@ -1237,12 +1236,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
return res if res else False return res if res else False
def add_user_pending_deactivation(self, user_id): async def add_user_pending_deactivation(self, user_id: str) -> None:
""" """
Adds a user to the table of users who need to be parted from all the rooms they're Adds a user to the table of users who need to be parted from all the rooms they're
in in
""" """
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
"users_pending_deactivation", "users_pending_deactivation",
values={"user_id": user_id}, values={"user_id": user_id},
desc="add_user_pending_deactivation", desc="add_user_pending_deactivation",

View File

@ -27,7 +27,7 @@ from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchStore from synapse.storage.databases.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@ -1296,11 +1296,17 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return self.db_pool.runInteraction("get_rooms", f) return self.db_pool.runInteraction("get_rooms", f)
def add_event_report( async def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts self,
): room_id: str,
event_id: str,
user_id: str,
reason: str,
content: JsonDict,
received_ts: int,
) -> None:
next_id = self._event_reports_id_gen.get_next() next_id = self._event_reports_id_gen.get_next()
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="event_reports", table="event_reports",
values={ values={
"id": next_id, "id": next_id,

View File

@ -16,7 +16,7 @@
import logging import logging
from itertools import chain from itertools import chain
from typing import Tuple from typing import Any, Dict, Tuple
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
@ -222,11 +222,11 @@ class StatsStore(StateDeltasStore):
desc="stats_incremental_position", desc="stats_incremental_position",
) )
def update_room_state(self, room_id, fields): async def update_room_state(self, room_id: str, fields: Dict[str, Any]) -> None:
""" """
Args: Args:
room_id (str) room_id
fields (dict[str:Any]) fields
""" """
# For whatever reason some of the fields may contain null bytes, which # For whatever reason some of the fields may contain null bytes, which
@ -244,7 +244,7 @@ class StatsStore(StateDeltasStore):
if field and "\0" in field: if field and "\0" in field:
fields[col] = None fields[col] = None
return self.db_pool.simple_upsert( await self.db_pool.simple_upsert(
table="room_stats_state", table="room_stats_state",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
values=fields, values=fields,

View File

@ -21,6 +21,7 @@ from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
db_binary_type = memoryview db_binary_type = memoryview
@ -98,20 +99,21 @@ class TransactionStore(SQLBaseStore):
else: else:
return None return None
def set_received_txn_response(self, transaction_id, origin, code, response_dict): async def set_received_txn_response(
"""Persist the response we returened for an incoming transaction, and self, transaction_id: str, origin: str, code: int, response_dict: JsonDict
) -> None:
"""Persist the response we returned for an incoming transaction, and
should return for subsequent transactions with the same transaction_id should return for subsequent transactions with the same transaction_id
and origin. and origin.
Args: Args:
txn transaction_id: The incoming transaction ID.
transaction_id (str) origin: The origin server.
origin (str) code: The response code.
code (int) response_dict: The response, to be encoded into JSON.
response_json (str)
""" """
return self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="received_transactions", table="received_transactions",
values={ values={
"transaction_id": transaction_id, "transaction_id": transaction_id,