From 04faa0bfa960d9f0dc60e9cf4ec270221249b7ca Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 11 Aug 2020 17:21:20 -0400 Subject: [PATCH] Convert tags and metrics databases to async/await (#8062) --- changelog.d/8062.misc | 1 + synapse/storage/databases/main/metrics.py | 20 +--- synapse/storage/databases/main/tags.py | 103 +++++++++--------- .../test_resource_limits_server_notices.py | 5 +- 4 files changed, 64 insertions(+), 65 deletions(-) create mode 100644 changelog.d/8062.misc diff --git a/changelog.d/8062.misc b/changelog.d/8062.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/8062.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index baa7a5092..686052bd8 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -15,8 +15,6 @@ import typing from collections import Counter -from twisted.internet import defer - from synapse.metrics import BucketCollector from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import SQLBaseStore @@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): res = await self.db_pool.runInteraction("read_forward_extremities", fetch) self._current_forward_extremities_amount = Counter([x[0] for x in res]) - @defer.inlineCallbacks - def count_daily_messages(self): + async def count_daily_messages(self): """ Returns an estimate of the number of messages sent in the last day. @@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_messages", _count_messages) - return ret + return await self.db_pool.runInteraction("count_messages", _count_messages) - @defer.inlineCallbacks - def count_daily_sent_messages(self): + async def count_daily_sent_messages(self): def _count_messages(txn): # This is good enough as if you have silly characters in your own # hostname then thats your own fault. @@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_sent_messages", _count_messages ) - return ret - @defer.inlineCallbacks - def count_daily_active_rooms(self): + async def count_daily_active_rooms(self): def _count(txn): sql = """ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events @@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): (count,) = txn.fetchone() return count - ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count) - return ret + return await self.db_pool.runInteraction("count_daily_active_rooms", _count) diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index eedd2d96c..e4e0a0c43 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -15,14 +15,13 @@ # limitations under the License. import logging -from typing import List, Tuple +from typing import Dict, List, Tuple from canonicaljson import json -from twisted.internet import defer - from synapse.storage._base import db_to_json from synapse.storage.databases.main.account_data import AccountDataWorkerStore +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached logger = logging.getLogger(__name__) @@ -30,30 +29,26 @@ logger = logging.getLogger(__name__) class TagsWorkerStore(AccountDataWorkerStore): @cached() - def get_tags_for_user(self, user_id): + async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]: """Get all the tags for a user. Args: - user_id(str): The user to get the tags for. + user_id: The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to dicts mapping from - tag strings to tag content. + A mapping from room_id strings to dicts mapping from tag strings to + tag content. """ - deferred = self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] ) - @deferred.addCallback - def tags_by_room(rows): - tags_by_room = {} - for row in rows: - room_tags = tags_by_room.setdefault(row["room_id"], {}) - room_tags[row["tag"]] = db_to_json(row["content"]) - return tags_by_room - - return deferred + tags_by_room = {} + for row in rows: + room_tags = tags_by_room.setdefault(row["room_id"], {}) + room_tags[row["tag"]] = db_to_json(row["content"]) + return tags_by_room async def get_all_updated_tags( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore): return results, upto_token, limited - @defer.inlineCallbacks - def get_updated_tags(self, user_id, stream_id): + async def get_updated_tags( + self, user_id: str, stream_id: int + ) -> Dict[str, List[str]]: """Get all the tags for the rooms where the tags have changed since the given version Args: user_id(str): The user to get the tags for. stream_id(int): The earliest update to get for the user. + Returns: - A deferred dict mapping from room_id strings to lists of tag - strings for all the rooms that changed since the stream_id token. + A mapping from room_id strings to lists of tag strings for all the + rooms that changed since the stream_id token. """ def get_updated_tags_txn(txn): @@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore): if not changed: return {} - room_ids = yield self.db_pool.runInteraction( + room_ids = await self.db_pool.runInteraction( "get_updated_tags", get_updated_tags_txn ) results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(user_id) + tags_by_room = await self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room.get(room_id, {}) return results - def get_tags_for_room(self, user_id, room_id): + async def get_tags_for_room( + self, user_id: str, room_id: str + ) -> Dict[str, JsonDict]: """Get all the tags for the given room + Args: - user_id(str): The user to get tags for - room_id(str): The room to get tags for + user_id: The user to get tags for + room_id: The room to get tags for + Returns: - A deferred list of string tags. + A mapping of tags to tag content. """ - return self.db_pool.simple_select_list( + rows = await self.db_pool.simple_select_list( table="room_tags", keyvalues={"user_id": user_id, "room_id": room_id}, retcols=("tag", "content"), desc="get_tags_for_room", - ).addCallback( - lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows} ) + return {row["tag"]: db_to_json(row["content"]) for row in rows} class TagsStore(TagsWorkerStore): - @defer.inlineCallbacks - def add_tag_to_room(self, user_id, room_id, tag, content): + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: """Add a tag to a room for a user. + Args: - user_id(str): The user to add a tag for. - room_id(str): The room to add a tag for. - tag(str): The tag name to add. - content(dict): A json object to associate with the tag. + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + Returns: - A deferred that completes once the tag has been added. + The next account data ID. """ content_json = json.dumps(content) @@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) + await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - @defer.inlineCallbacks - def remove_tag_from_room(self, user_id, room_id, tag): + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: """Remove a tag from a room for a user. + Returns: - A deferred that completes once the tag has been removed + The next account data ID. """ def remove_tag_txn(txn, next_id): @@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore): self._update_revision_txn(txn, user_id, room_id, next_id) with self._account_data_id_gen.get_next() as next_id: - yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) + await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) self.get_tags_for_user.invalidate((user_id,)) - result = self._account_data_id_gen.get_current_token() - return result + return self._account_data_id_gen.get_current_token() - def _update_revision_txn(self, txn, user_id, room_id, next_id): + def _update_revision_txn( + self, txn, user_id: str, room_id: str, next_id: int + ) -> None: """Update the latest revision of the tags for the given user and room. Args: txn: The database cursor - user_id(str): The ID of the user. - room_id(str): The ID of the room. - next_id(int): The the revision to advance to. + user_id: The ID of the user. + room_id: The ID of the room. + next_id: The the revision to advance to. """ txn.call_after( diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 3f88abe3d..2858d1355 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import ( ) from tests import unittest +from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=defer.succeed("!something:localhost") ) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) - self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) + self._rlsn._store.get_tags_for_room = Mock( + side_effect=lambda user_id, room_id: make_awaitable({}) + ) @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self):