Convert tags and metrics databases to async/await (#8062)

This commit is contained in:
Patrick Cloke 2020-08-11 17:21:20 -04:00 committed by GitHub
parent a0acdfa9e9
commit 04faa0bfa9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 64 additions and 65 deletions

1
changelog.d/8062.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -15,8 +15,6 @@
import typing import typing
from collections import Counter from collections import Counter
from twisted.internet import defer
from synapse.metrics import BucketCollector from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -69,8 +67,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
res = await self.db_pool.runInteraction("read_forward_extremities", fetch) res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
self._current_forward_extremities_amount = Counter([x[0] for x in res]) self._current_forward_extremities_amount = Counter([x[0] for x in res])
@defer.inlineCallbacks async def count_daily_messages(self):
def count_daily_messages(self):
""" """
Returns an estimate of the number of messages sent in the last day. Returns an estimate of the number of messages sent in the last day.
@ -88,11 +85,9 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db_pool.runInteraction("count_messages", _count_messages) return await self.db_pool.runInteraction("count_messages", _count_messages)
return ret
@defer.inlineCallbacks async def count_daily_sent_messages(self):
def count_daily_sent_messages(self):
def _count_messages(txn): def _count_messages(txn):
# This is good enough as if you have silly characters in your own # This is good enough as if you have silly characters in your own
# hostname then thats your own fault. # hostname then thats your own fault.
@ -109,13 +104,11 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_daily_sent_messages", _count_messages "count_daily_sent_messages", _count_messages
) )
return ret
@defer.inlineCallbacks async def count_daily_active_rooms(self):
def count_daily_active_rooms(self):
def _count(txn): def _count(txn):
sql = """ sql = """
SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
@ -126,5 +119,4 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
ret = yield self.db_pool.runInteraction("count_daily_active_rooms", _count) return await self.db_pool.runInteraction("count_daily_active_rooms", _count)
return ret

View File

@ -15,14 +15,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Tuple from typing import Dict, List, Tuple
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,30 +29,26 @@ logger = logging.getLogger(__name__)
class TagsWorkerStore(AccountDataWorkerStore): class TagsWorkerStore(AccountDataWorkerStore):
@cached() @cached()
def get_tags_for_user(self, user_id): async def get_tags_for_user(self, user_id: str) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for a user. """Get all the tags for a user.
Args: Args:
user_id(str): The user to get the tags for. user_id: The user to get the tags for.
Returns: Returns:
A deferred dict mapping from room_id strings to dicts mapping from A mapping from room_id strings to dicts mapping from tag strings to
tag strings to tag content. tag content.
""" """
deferred = self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"] "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
) )
@deferred.addCallback tags_by_room = {}
def tags_by_room(rows): for row in rows:
tags_by_room = {} room_tags = tags_by_room.setdefault(row["room_id"], {})
for row in rows: room_tags[row["tag"]] = db_to_json(row["content"])
room_tags = tags_by_room.setdefault(row["room_id"], {}) return tags_by_room
room_tags[row["tag"]] = db_to_json(row["content"])
return tags_by_room
return deferred
async def get_all_updated_tags( async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
@ -127,17 +122,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
return results, upto_token, limited return results, upto_token, limited
@defer.inlineCallbacks async def get_updated_tags(
def get_updated_tags(self, user_id, stream_id): self, user_id: str, stream_id: int
) -> Dict[str, List[str]]:
"""Get all the tags for the rooms where the tags have changed since the """Get all the tags for the rooms where the tags have changed since the
given version given version
Args: Args:
user_id(str): The user to get the tags for. user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user. stream_id(int): The earliest update to get for the user.
Returns: Returns:
A deferred dict mapping from room_id strings to lists of tag A mapping from room_id strings to lists of tag strings for all the
strings for all the rooms that changed since the stream_id token. rooms that changed since the stream_id token.
""" """
def get_updated_tags_txn(txn): def get_updated_tags_txn(txn):
@ -155,47 +152,53 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed: if not changed:
return {} return {}
room_ids = yield self.db_pool.runInteraction( room_ids = await self.db_pool.runInteraction(
"get_updated_tags", get_updated_tags_txn "get_updated_tags", get_updated_tags_txn
) )
results = {} results = {}
if room_ids: if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id) tags_by_room = await self.get_tags_for_user(user_id)
for room_id in room_ids: for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {}) results[room_id] = tags_by_room.get(room_id, {})
return results return results
def get_tags_for_room(self, user_id, room_id): async def get_tags_for_room(
self, user_id: str, room_id: str
) -> Dict[str, JsonDict]:
"""Get all the tags for the given room """Get all the tags for the given room
Args: Args:
user_id(str): The user to get tags for user_id: The user to get tags for
room_id(str): The room to get tags for room_id: The room to get tags for
Returns: Returns:
A deferred list of string tags. A mapping of tags to tag content.
""" """
return self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="room_tags", table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id}, keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"), retcols=("tag", "content"),
desc="get_tags_for_room", desc="get_tags_for_room",
).addCallback(
lambda rows: {row["tag"]: db_to_json(row["content"]) for row in rows}
) )
return {row["tag"]: db_to_json(row["content"]) for row in rows}
class TagsStore(TagsWorkerStore): class TagsStore(TagsWorkerStore):
@defer.inlineCallbacks async def add_tag_to_room(
def add_tag_to_room(self, user_id, room_id, tag, content): self, user_id: str, room_id: str, tag: str, content: JsonDict
) -> int:
"""Add a tag to a room for a user. """Add a tag to a room for a user.
Args: Args:
user_id(str): The user to add a tag for. user_id: The user to add a tag for.
room_id(str): The room to add a tag for. room_id: The room to add a tag for.
tag(str): The tag name to add. tag: The tag name to add.
content(dict): A json object to associate with the tag. content: A json object to associate with the tag.
Returns: Returns:
A deferred that completes once the tag has been added. The next account data ID.
""" """
content_json = json.dumps(content) content_json = json.dumps(content)
@ -209,18 +212,17 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.db_pool.runInteraction("add_tag", add_tag_txn, next_id) await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
return result
@defer.inlineCallbacks async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
def remove_tag_from_room(self, user_id, room_id, tag):
"""Remove a tag from a room for a user. """Remove a tag from a room for a user.
Returns: Returns:
A deferred that completes once the tag has been removed The next account data ID.
""" """
def remove_tag_txn(txn, next_id): def remove_tag_txn(txn, next_id):
@ -232,21 +234,22 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id) await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()
return result
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(
self, txn, user_id: str, room_id: str, next_id: int
) -> None:
"""Update the latest revision of the tags for the given user and room. """Update the latest revision of the tags for the given user and room.
Args: Args:
txn: The database cursor txn: The database cursor
user_id(str): The ID of the user. user_id: The ID of the user.
room_id(str): The ID of the room. room_id: The ID of the room.
next_id(int): The the revision to advance to. next_id: The the revision to advance to.
""" """
txn.call_after( txn.call_after(

View File

@ -27,6 +27,7 @@ from synapse.server_notices.resource_limits_server_notices import (
) )
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config from tests.unittest import override_config
from tests.utils import default_config from tests.utils import default_config
@ -79,7 +80,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=defer.succeed("!something:localhost") return_value=defer.succeed("!something:localhost")
) )
self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=defer.succeed({})) self._rlsn._store.get_tags_for_room = Mock(
side_effect=lambda user_id, room_id: make_awaitable({})
)
@override_config({"hs_disabled": True}) @override_config({"hs_disabled": True})
def test_maybe_send_server_notice_disabled_hs(self): def test_maybe_send_server_notice_disabled_hs(self):