mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-10-01 11:49:51 -04:00
Convert misc database code to async (#8087)
This commit is contained in:
parent
7bdf9828d5
commit
894dae74fe
1
changelog.d/8087.misc
Normal file
1
changelog.d/8087.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -18,8 +18,6 @@ from typing import Optional
|
|||||||
|
|
||||||
from canonicaljson import json
|
from canonicaljson import json
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
|
||||||
from . import engines
|
from . import engines
|
||||||
@ -308,9 +306,8 @@ class BackgroundUpdater(object):
|
|||||||
update_name (str): Name of update
|
update_name (str): Name of update
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def noop_update(progress, batch_size):
|
||||||
def noop_update(progress, batch_size):
|
await self._end_background_update(update_name)
|
||||||
yield self._end_background_update(update_name)
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
self.register_background_update_handler(update_name, noop_update)
|
self.register_background_update_handler(update_name, noop_update)
|
||||||
@ -409,12 +406,11 @@ class BackgroundUpdater(object):
|
|||||||
else:
|
else:
|
||||||
runner = create_index_sqlite
|
runner = create_index_sqlite
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def updater(progress, batch_size):
|
||||||
def updater(progress, batch_size):
|
|
||||||
if runner is not None:
|
if runner is not None:
|
||||||
logger.info("Adding index %s to %s", index_name, table)
|
logger.info("Adding index %s to %s", index_name, table)
|
||||||
yield self.db_pool.runWithConnection(runner)
|
await self.db_pool.runWithConnection(runner)
|
||||||
yield self._end_background_update(update_name)
|
await self._end_background_update(update_name)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
self.register_background_update_handler(update_name, updater)
|
self.register_background_update_handler(update_name, updater)
|
||||||
|
@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="device_lists_remote_extremeties",
|
table="device_lists_remote_extremeties",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
|
@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
|||||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
|||||||
self._rotate_delay = 3
|
self._rotate_delay = 3
|
||||||
self._rotate_count = 10000
|
self._rotate_count = 10000
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
@cached(num_args=3, tree=True, max_entries=5000)
|
||||||
def get_unread_event_push_actions_by_room_for_user(
|
async def get_unread_event_push_actions_by_room_for_user(
|
||||||
self, room_id, user_id, last_read_event_id
|
self, room_id, user_id, last_read_event_id
|
||||||
):
|
):
|
||||||
ret = yield self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_unread_event_push_actions_by_room",
|
"get_unread_event_push_actions_by_room",
|
||||||
self._get_unread_counts_by_receipt_txn,
|
self._get_unread_counts_by_receipt_txn,
|
||||||
room_id,
|
room_id,
|
||||||
user_id,
|
user_id,
|
||||||
last_read_event_id,
|
last_read_event_id,
|
||||||
)
|
)
|
||||||
return ret
|
|
||||||
|
|
||||||
def _get_unread_counts_by_receipt_txn(
|
def _get_unread_counts_by_receipt_txn(
|
||||||
self, txn, room_id, user_id, last_read_event_id
|
self, txn, room_id, user_id, last_read_event_id
|
||||||
|
@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="_get_presence_for_user",
|
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
|
||||||
list_name="user_ids",
|
|
||||||
num_args=1,
|
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def get_presence_for_users(self, user_ids):
|
async def get_presence_for_users(self, user_ids):
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="presence_stream",
|
table="presence_stream",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
|
@ -170,18 +170,15 @@ class PushRulesWorkerStore(
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="get_push_rules_for_user",
|
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
|
||||||
list_name="user_ids",
|
|
||||||
num_args=1,
|
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def bulk_get_push_rules(self, user_ids):
|
async def bulk_get_push_rules(self, user_ids):
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = {user_id: [] for user_id in user_ids}
|
results = {user_id: [] for user_id in user_ids}
|
||||||
|
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="push_rules",
|
table="push_rules",
|
||||||
column="user_name",
|
column="user_name",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
@ -194,7 +191,7 @@ class PushRulesWorkerStore(
|
|||||||
for row in rows:
|
for row in rows:
|
||||||
results.setdefault(row["user_name"], []).append(row)
|
results.setdefault(row["user_name"], []).append(row)
|
||||||
|
|
||||||
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
|
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
||||||
|
|
||||||
for user_id, rules in results.items():
|
for user_id, rules in results.items():
|
||||||
use_new_defaults = user_id in self._users_new_default_push_rules
|
use_new_defaults = user_id in self._users_new_default_push_rules
|
||||||
@ -260,15 +257,14 @@ class PushRulesWorkerStore(
|
|||||||
cached_method_name="get_push_rules_enabled_for_user",
|
cached_method_name="get_push_rules_enabled_for_user",
|
||||||
list_name="user_ids",
|
list_name="user_ids",
|
||||||
num_args=1,
|
num_args=1,
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def bulk_get_push_rules_enabled(self, user_ids):
|
async def bulk_get_push_rules_enabled(self, user_ids):
|
||||||
if not user_ids:
|
if not user_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = {user_id: {} for user_id in user_ids}
|
results = {user_id: {} for user_id in user_ids}
|
||||||
|
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="push_rules_enable",
|
table="push_rules_enable",
|
||||||
column="user_name",
|
column="user_name",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
|
@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="get_if_user_has_pusher",
|
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
||||||
list_name="user_ids",
|
|
||||||
num_args=1,
|
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def get_if_users_have_pushers(self, user_ids):
|
async def get_if_users_have_pushers(self, user_ids):
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="pushers",
|
table="pushers",
|
||||||
column="user_name",
|
column="user_name",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
|
@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
cached_method_name="_get_linearized_receipts_for_room",
|
cached_method_name="_get_linearized_receipts_for_room",
|
||||||
list_name="room_ids",
|
list_name="room_ids",
|
||||||
num_args=3,
|
num_args=3,
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||||
if not room_ids:
|
if not room_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return self.db_pool.cursor_to_dict(txn)
|
return self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
txn_results = yield self.db_pool.runInteraction(
|
txn_results = await self.db_pool.runInteraction(
|
||||||
"_get_linearized_receipts_for_rooms", f
|
"_get_linearized_receipts_for_rooms", f
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,8 +17,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
lambda: self._known_servers_count,
|
lambda: self._known_servers_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _count_known_servers(self):
|
||||||
def _count_known_servers(self):
|
|
||||||
"""
|
"""
|
||||||
Count the servers that this server knows about.
|
Count the servers that this server knows about.
|
||||||
|
|
||||||
@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
txn.execute(query)
|
txn.execute(query)
|
||||||
return list(txn)[0][0]
|
return list(txn)[0][0]
|
||||||
|
|
||||||
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
|
count = await self.db_pool.runInteraction("get_known_servers", _transact)
|
||||||
|
|
||||||
# We always know about ourselves, even if we have nothing in
|
# We always know about ourselves, even if we have nothing in
|
||||||
# room_memberships (for example, the server is new).
|
# room_memberships (for example, the server is new).
|
||||||
@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(
|
||||||
cached_method_name="_get_joined_profile_from_event_id",
|
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
|
||||||
list_name="event_ids",
|
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||||
"""For given set of member event_ids check if they point to a join
|
"""For given set of member event_ids check if they point to a join
|
||||||
event and if so return the associated user and profile info.
|
event and if so return the associated user and profile info.
|
||||||
|
|
||||||
@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
|||||||
event_ids: The member event IDs to lookup
|
event_ids: The member event IDs to lookup
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
||||||
to `user_id` and ProfileInfo (or None if not join event).
|
to `user_id` and ProfileInfo (or None if not join event).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="room_memberships",
|
table="room_memberships",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
|
@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
cached_method_name="_get_state_group_for_event",
|
cached_method_name="_get_state_group_for_event",
|
||||||
list_name="event_ids",
|
list_name="event_ids",
|
||||||
num_args=1,
|
num_args=1,
|
||||||
inlineCallbacks=True,
|
|
||||||
)
|
)
|
||||||
def _get_state_group_for_events(self, event_ids):
|
async def _get_state_group_for_events(self, event_ids):
|
||||||
"""Returns mapping event_id -> state_group
|
"""Returns mapping event_id -> state_group
|
||||||
"""
|
"""
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
column="event_id",
|
column="event_id",
|
||||||
iterable=event_ids,
|
iterable=event_ids,
|
||||||
|
@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||||||
desc="is_user_erased",
|
desc="is_user_erased",
|
||||||
).addCallback(operator.truth)
|
).addCallback(operator.truth)
|
||||||
|
|
||||||
@cachedList(
|
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||||
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
|
async def are_users_erased(self, user_ids):
|
||||||
)
|
|
||||||
def are_users_erased(self, user_ids):
|
|
||||||
"""
|
"""
|
||||||
Checks which users in a list have requested erasure
|
Checks which users in a list have requested erasure
|
||||||
|
|
||||||
@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||||||
user_ids (iterable[str]): full user id to check
|
user_ids (iterable[str]): full user id to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, bool]]:
|
dict[str, bool]:
|
||||||
for each user, whether the user has requested erasure.
|
for each user, whether the user has requested erasure.
|
||||||
"""
|
"""
|
||||||
# this serves the dual purpose of (a) making sure we can do len and
|
# this serves the dual purpose of (a) making sure we can do len and
|
||||||
# iterate it multiple times, and (b) avoiding duplicates.
|
# iterate it multiple times, and (b) avoiding duplicates.
|
||||||
user_ids = tuple(set(user_ids))
|
user_ids = tuple(set(user_ids))
|
||||||
|
|
||||||
rows = yield self.db_pool.simple_select_many_batch(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
table="erased_users",
|
table="erased_users",
|
||||||
column="user_id",
|
column="user_id",
|
||||||
iterable=user_ids,
|
iterable=user_ids,
|
||||||
@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
erased_users = {row["user_id"] for row in rows}
|
erased_users = {row["user_id"] for row in rows}
|
||||||
|
|
||||||
res = {u: u in erased_users for u in user_ids}
|
return {u: u in erased_users for u in user_ids}
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class UserErasureStore(UserErasureWorkerStore):
|
class UserErasureStore(UserErasureWorkerStore):
|
||||||
|
Loading…
Reference in New Issue
Block a user