Convert additional databases to async/await part 3 (#8201)

This commit is contained in:
Patrick Cloke 2020-09-01 11:04:17 -04:00 committed by GitHub
parent 7d103a594e
commit 37db6252b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 121 additions and 87 deletions

1
changelog.d/8201.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -433,7 +433,7 @@ class BackgroundUpdater(object):
"background_updates", keyvalues={"update_name": update_name}
)
def _background_update_progress(self, update_name: str, progress: dict):
async def _background_update_progress(self, update_name: str, progress: dict):
"""Update the progress of a background update
Args:
@ -441,7 +441,7 @@ class BackgroundUpdater(object):
progress: The progress of the update.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"background_update_progress",
self._background_update_progress_txn,
update_name,

View File

@ -16,9 +16,7 @@
import abc
import logging
from typing import List, Optional, Tuple
from twisted.internet import defer
from typing import Dict, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
@ -58,14 +56,16 @@ class AccountDataWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cached()
def get_account_data_for_user(self, user_id):
async def get_account_data_for_user(
self, user_id: str
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a user.
Args:
user_id(str): The user to get the account_data for.
user_id: The user to get the account_data for.
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
A 2-tuple of a dict of global account_data and a dict mapping from
room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
@ -94,7 +94,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return global_account_data, by_room
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@ -120,14 +120,16 @@ class AccountDataWorkerStore(SQLBaseStore):
return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
async def get_account_data_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the client account_data for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
Returns:
A deferred dict of the room account_data
A dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
@ -142,21 +144,22 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@cached(num_args=3, max_entries=5000)
def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type):
async def get_account_data_for_room_and_type(
self, user_id: str, room_id: str, account_data_type: str
) -> Optional[JsonDict]:
"""Get the client account_data of given type for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
account_data_type (str): The account data type to get.
user_id: The user to get the account_data for.
room_id: The room to get the account_data for.
account_data_type: The account data type to get.
Returns:
A deferred of the room account_data for that type, or None if
there isn't any set.
The room account_data for that type, or None if there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
@ -174,7 +177,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return db_to_json(content_json) if content_json else None
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@ -238,12 +241,14 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
async def get_updated_account_data_for_user(
self, user_id: str, stream_id: int
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
"""Get all the client account_data for a that's changed for a user
Args:
user_id(str): The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates
user_id: The user to get the account_data for.
stream_id: The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
@ -277,9 +282,9 @@ class AccountDataWorkerStore(SQLBaseStore):
user_id, int(stream_id)
)
if not changed:
return defer.succeed(({}, {}))
return ({}, {})
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@ -416,7 +421,7 @@ class AccountDataStore(AccountDataWorkerStore):
return self._account_data_id_gen.get_current_token()
def _update_max_stream_id(self, next_id: int):
async def _update_max_stream_id(self, next_id: int) -> None:
"""Update the max stream_id
Args:
@ -435,4 +440,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.db_pool.runInteraction("update_account_data_max_stream_id", _update)
await self.db_pool.runInteraction("update_account_data_max_stream_id", _update)

View File

@ -34,13 +34,15 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str):
async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
"""Get all devices (with any device keys) for a user
Returns:
Deferred which resolves to (stream_id, devices)
(stream_id, devices)
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
@ -292,10 +294,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
def count_e2e_one_time_keys(self, user_id, device_id):
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Dict[str, int]:
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
A mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
@ -310,7 +314,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result[algorithm] = key_count
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
@ -348,7 +352,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(
async def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
@ -356,16 +360,15 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the signatures for the calling user need to be fetched.
Args:
user_ids (list[str]): the users whose keys are being requested
user_ids: the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
A mapping from user ID to key type to key data. If a user's cross-signing
keys were not found, either their user ID will not be in the dict, or
their user ID will map to None.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
@ -588,7 +591,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
async def set_e2e_device_keys(
self, user_id: str, device_id: str, time_now: int, device_keys: JsonDict
) -> bool:
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
@ -624,12 +629,21 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
) -> Dict[str, Dict[str, Dict[str, bytes]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
"""
@trace
def _claim_e2e_one_time_keys(txn):
@ -665,11 +679,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
def delete_e2e_keys_by_device(self, user_id, device_id):
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
@ -692,7 +706,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@ -93,7 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="mark_local_media_as_safe",
)
def get_url_cache(self, url, ts):
async def get_url_cache(self, url: str, ts: int) -> Optional[Dict[str, Any]]:
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
)
return self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
async def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
@ -237,12 +237,17 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_cached_remote_media",
)
def update_cached_last_access_time(self, local_media, remote_media, time_ms):
async def update_cached_last_access_time(
self,
local_media: Iterable[str],
remote_media: Iterable[Tuple[str, str]],
time_ms: int,
):
"""Updates the last access time of the given media
Args:
local_media (iterable[str]): Set of media_ids
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
local_media: Set of media_ids
remote_media: Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
@ -267,7 +272,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
)
@ -325,7 +330,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_remote_media_before", self.db_pool.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
def delete_remote_media_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@ -338,11 +343,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_remote_media", delete_remote_media_txn
)
def get_expired_url_cache(self, now_ts):
async def get_expired_url_cache(self, now_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository_url_cache"
" WHERE expires_ts < ?"
@ -354,7 +359,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_expired_url_cache", _get_expired_url_cache_txn
)
@ -371,7 +376,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"delete_url_cache", _delete_url_cache_txn
)
def get_url_cache_media_before(self, before_ts):
async def get_url_cache_media_before(self, before_ts: int) -> List[str]:
sql = (
"SELECT media_id FROM local_media_repository"
" WHERE created_ts < ? AND url_cache IS NOT NULL"
@ -383,7 +388,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)

View File

@ -16,9 +16,10 @@
import logging
import re
from collections import namedtuple
from typing import List, Optional
from typing import List, Optional, Set
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@ -620,7 +621,9 @@ class SearchStore(SearchBackgroundUpdateStore):
"count": count,
}
def _find_highlights_in_postgres(self, search_query, events):
async def _find_highlights_in_postgres(
self, search_query: str, events: List[EventBase]
) -> Set[str]:
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
@ -628,11 +631,11 @@ class SearchStore(SearchBackgroundUpdateStore):
highlight the matching parts.
Args:
search_query (str)
events (list): A list of events
search_query
events: A list of events
Returns:
deferred : A set of strings.
A set of strings.
"""
def f(txn):
@ -685,7 +688,7 @@ class SearchStore(SearchBackgroundUpdateStore):
return highlight_words
return self.db_pool.runInteraction("_find_highlights", f)
return await self.db_pool.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):

View File

@ -15,7 +15,7 @@
import logging
import re
from typing import Any, Dict, Optional
from typing import Any, Dict, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@ -365,7 +365,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
async def update_profile_in_user_dir(
self, user_id: str, display_name: str, avatar_url: str
) -> None:
"""
Update or add a user's profile in the user directory.
"""
@ -458,17 +460,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
def add_users_who_share_private_room(self, room_id, user_id_tuples):
async def add_users_who_share_private_room(
self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
room_id (str)
user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
room_id
user_id_tuples: iterable of 2-tuple of user IDs.
"""
def _add_users_who_share_room_txn(txn):
@ -484,17 +488,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
def add_users_in_public_rooms(self, room_id, user_ids):
async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str]
) -> None:
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
Args:
room_id (str)
user_ids (list[str])
room_id
user_ids
"""
def _add_users_in_public_rooms_txn(txn):
@ -508,11 +514,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
value_values=None,
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
def delete_all_from_user_dir(self):
async def delete_all_from_user_dir(self) -> None:
"""Delete the entire user directory
"""
@ -523,7 +529,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
txn.execute("DELETE FROM users_who_share_private_rooms")
txn.call_after(self.get_user_in_directory.invalidate_all)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_all_from_user_dir", _delete_all_from_user_dir_txn
)
@ -555,7 +561,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(UserDirectoryStore, self).__init__(database, db_conn, hs)
def remove_from_user_dir(self, user_id):
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
@ -578,7 +584,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_from_user_dir", _remove_from_user_dir_txn
)
@ -605,14 +611,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
return user_ids
def remove_user_who_share_room(self, user_id, room_id):
async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
"""
Deletes entries in the users_who_share_*_rooms table. The first
user should be a local user.
Args:
user_id (str)
room_id (str)
user_id
room_id
"""
def _remove_user_who_share_room_txn(txn):
@ -632,7 +638,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
keyvalues={"user_id": user_id, "room_id": room_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)