Convert additional databases to async/await (#8199)

This commit is contained in:
Patrick Cloke 2020-09-01 09:21:48 -04:00 committed by GitHub
parent 5bf8e5f55b
commit 54f8d73c00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 137 deletions

1
changelog.d/8199.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -18,7 +18,7 @@
import calendar import calendar
import logging import logging
import time import time
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional, Tuple
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -294,16 +294,16 @@ class DataStore(
return [UserPresenceState(**row) for row in rows] return [UserPresenceState(**row) for row in rows]
def count_daily_users(self): async def count_daily_users(self) -> int:
""" """
Counts the number of users who used this homeserver in the last 24 hours. Counts the number of users who used this homeserver in the last 24 hours.
""" """
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_daily_users", self._count_users, yesterday "count_daily_users", self._count_users, yesterday
) )
def count_monthly_users(self): async def count_monthly_users(self) -> int:
""" """
Counts the number of users who used this homeserver in the last 30 days. Counts the number of users who used this homeserver in the last 30 days.
Note this method is intended for phonehome metrics only and is different Note this method is intended for phonehome metrics only and is different
@ -311,7 +311,7 @@ class DataStore(
amongst other things, includes a 3 day grace period before a user counts. amongst other things, includes a 3 day grace period before a user counts.
""" """
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"count_monthly_users", self._count_users, thirty_days_ago "count_monthly_users", self._count_users, thirty_days_ago
) )
@ -330,15 +330,15 @@ class DataStore(
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count return count
def count_r30_users(self): async def count_r30_users(self) -> Dict[str, int]:
""" """
Counts the number of 30 day retained users, defined as:- Counts the number of 30 day retained users, defined as:-
* Users who have created their accounts more than 30 days ago * Users who have created their accounts more than 30 days ago
* Where last seen at most 30 days ago * Where last seen at most 30 days ago
* Where account creation and last_seen are > 30 days apart * Where account creation and last_seen are > 30 days apart
Returns counts globaly for a given user as well as breaking Returns:
by platform A mapping of counts globally as well as broken out by platform.
""" """
def _count_r30_users(txn): def _count_r30_users(txn):
@ -411,7 +411,7 @@ class DataStore(
return results return results
return self.db_pool.runInteraction("count_r30_users", _count_r30_users) return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
def _get_start_of_day(self): def _get_start_of_day(self):
""" """
@ -421,7 +421,7 @@ class DataStore(
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000 return today_start * 1000
def generate_user_daily_visits(self): async def generate_user_daily_visits(self) -> None:
""" """
Generates daily visit data for use in cohort/ retention analysis Generates daily visit data for use in cohort/ retention analysis
""" """
@ -476,7 +476,7 @@ class DataStore(
# frequently # frequently
self._last_user_visit_update = now self._last_user_visit_update = now
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits "generate_user_daily_visits", _generate_user_daily_visits
) )
@ -500,22 +500,28 @@ class DataStore(
desc="get_users", desc="get_users",
) )
def get_users_paginate( async def get_users_paginate(
self, start, limit, user_id=None, name=None, guests=True, deactivated=False self,
): start: int,
limit: int,
user_id: Optional[str] = None,
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
users list. This will return a json list of users and the users list. This will return a json list of users and the
total number of users matching the filter criteria. total number of users matching the filter criteria.
Args: Args:
start (int): start number to begin the query from start: start number to begin the query from
limit (int): number of rows to retrieve limit: number of rows to retrieve
user_id (string): search for user_id. ignored if name is not None user_id: search for user_id. ignored if name is not None
name (string): search for local part of user_id or display name name: search for local part of user_id or display name
guests (bool): whether to in include guest users guests: whether to in include guest users
deactivated (bool): whether to include deactivated users deactivated: whether to include deactivated users
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]], int A tuple of a list of mappings from user to information and a count of total users.
""" """
def get_users_paginate_txn(txn): def get_users_paginate_txn(txn):
@ -558,7 +564,7 @@ class DataStore(
users = self.db_pool.cursor_to_dict(txn) users = self.db_pool.cursor_to_dict(txn)
return users, count return users, count
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_paginate_txn", get_users_paginate_txn "get_users_paginate_txn", get_users_paginate_txn
) )

View File

@ -313,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results return results
def _get_last_device_update_for_remote_user( async def _get_last_device_update_for_remote_user(
self, destination: str, user_id: str, from_stream_id: int self, destination: str, user_id: str, from_stream_id: int
): ) -> int:
def f(txn): def f(txn):
prev_sent_id_sql = """ prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id SELECT coalesce(max(stream_id), 0) as stream_id
@ -326,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall() rows = txn.fetchall()
return rows[0][0] return rows[0][0]
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) return await self.db_pool.runInteraction(
"get_last_device_update_for_remote_user", f
)
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): async def mark_as_sent_devices_by_remote(
self, destination: str, stream_id: int
) -> None:
"""Mark that updates have successfully been sent to the destination. """Mark that updates have successfully been sent to the destination.
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"mark_as_sent_devices_by_remote", "mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn, self._mark_as_sent_devices_by_remote_txn,
destination, destination,
@ -684,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale", desc="make_remote_user_device_cache_as_stale",
) )
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
"""Mark that we no longer track device lists for remote user. """Mark that we no longer track device lists for remote user.
""" """
@ -698,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore):
txn, self.get_device_list_last_stream_id_for_remote, (user_id,) txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
) )
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"mark_remote_user_device_list_as_unsubscribed", "mark_remote_user_device_list_as_unsubscribed",
_mark_remote_user_device_list_as_unsubscribed_txn, _mark_remote_user_device_list_as_unsubscribed_txn,
) )
@ -959,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device", desc="update_device",
) )
def update_remote_device_list_cache_entry( async def update_remote_device_list_cache_entry(
self, user_id: str, device_id: str, content: JsonDict, stream_id: int self, user_id: str, device_id: str, content: JsonDict, stream_id: int
): ) -> None:
"""Updates a single device in the cache of a remote user's devicelist. """Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's Note: assumes that we are the only thread that can be updating this user's
@ -972,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device_id: ID of decivice being updated device_id: ID of decivice being updated
content: new data on this device content: new data on this device
stream_id: the version of the device list stream_id: the version of the device list
Returns:
Deferred[None]
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_remote_device_list_cache_entry", "update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn, self._update_remote_device_list_cache_entry_txn,
user_id, user_id,
@ -1028,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False, lock=False,
) )
def update_remote_device_list_cache( async def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int self, user_id: str, devices: List[dict], stream_id: int
): ) -> None:
"""Replace the entire cache of the remote user's devices. """Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's Note: assumes that we are the only thread that can be updating this user's
@ -1040,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: User to update device list for user_id: User to update device list for
devices: list of device objects supplied over federation devices: list of device objects supplied over federation
stream_id: the version of the device list stream_id: the version of the device list
Returns:
Deferred[None]
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_remote_device_list_cache", "update_remote_device_list_cache",
self._update_remote_device_list_cache_txn, self._update_remote_device_list_cache_txn,
user_id, user_id,
@ -1054,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def _update_remote_device_list_cache_txn( def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
): ) -> None:
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
) )

View File

@ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore):
return event_dict return event_dict
def _maybe_redact_event_row(self, original_ev, redactions, event_map): def _maybe_redact_event_row(
self,
original_ev: EventBase,
redactions: Iterable[str],
event_map: Dict[str, EventBase],
) -> Optional[EventBase]:
"""Given an event object and a list of possible redacting event ids, """Given an event object and a list of possible redacting event ids,
determine whether to honour any of those redactions and if so return a redacted determine whether to honour any of those redactions and if so return a redacted
event. event.
Args: Args:
original_ev (EventBase): original_ev: The original event.
redactions (iterable[str]): list of event ids of potential redaction events redactions: list of event ids of potential redaction events
event_map (dict[str, EventBase]): other events which have been fetched, in event_map: other events which have been fetched, in which we can
which we can look up the redaaction events. Map from event id to event. look up the redaaction events. Map from event id to event.
Returns: Returns:
Deferred[EventBase|None]: if the event should be redacted, a pruned If the event should be redacted, a pruned event object. Otherwise, None.
event object. Otherwise, None.
""" """
if original_ev.type == "m.room.create": if original_ev.type == "m.room.create":
# we choose to ignore redactions of m.room.create events. # we choose to ignore redactions of m.room.create events.
@ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore):
row = txn.fetchone() row = txn.fetchone()
return row[0] if row else 0 return row[0] if row else 0
def get_current_state_event_counts(self, room_id): async def get_current_state_event_counts(self, room_id: str) -> int:
""" """
Gets the current number of state events in a room. Gets the current number of state events in a room.
Args: Args:
room_id (str) room_id: The room ID to query.
Returns: Returns:
Deferred[int] The current number of state events.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_state_event_counts", "get_current_state_event_counts",
self._get_current_state_event_counts_txn, self._get_current_state_event_counts_txn,
room_id, room_id,
@ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore):
"""The current maximum token that events have reached""" """The current maximum token that events have reached"""
return self._stream_id_gen.get_current_token() return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit): async def get_all_new_forward_event_rows(
self, last_id: int, current_id: int, limit: int
) -> List[Tuple]:
"""Returns new events, for the Events replication stream """Returns new events, for the Events replication stream
Args: Args:
@ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore):
current_id: the maximum stream_id to return up to current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return limit: the maximum number of rows to return
Returns: Deferred[List[Tuple]] Returns:
a list of events stream rows. Each tuple consists of a stream id as a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an the first element, followed by fields suitable for casting into an
EventsStreamRow. EventsStreamRow.
@ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall() return txn.fetchall()
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows "get_all_new_forward_event_rows", get_all_new_forward_event_rows
) )
def get_ex_outlier_stream_rows(self, last_id, current_id): async def get_ex_outlier_stream_rows(
self, last_id: int, current_id: int
) -> List[Tuple]:
"""Returns de-outliered events, for the Events replication stream """Returns de-outliered events, for the Events replication stream
Args: Args:
last_id: the last stream_id from the previous batch. last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to current_id: the maximum stream_id to return up to
Returns: Deferred[List[Tuple]] Returns:
a list of events stream rows. Each tuple consists of a stream id as a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an the first element, followed by fields suitable for casting into an
EventsStreamRow. EventsStreamRow.
@ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id)) txn.execute(sql, (last_id, current_id))
return txn.fetchall() return txn.fetchall()
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
) )
@ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore):
return (int(res["topological_ordering"]), int(res["stream_ordering"])) return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_next_event_to_expire(self): async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry """Retrieve the entry with the lowest expiry timestamp in the event_expiry
table, or None if there's no more event to expire. table, or None if there's no more event to expire.
Returns: Deferred[Optional[Tuple[str, int]]] Returns:
A tuple containing the event ID as its first element and an expiry timestamp A tuple containing the event ID as its first element and an expiry timestamp
as its second one, if there's at least one row in the event_expiry table. as its second one, if there's at least one row in the event_expiry table.
None otherwise. None otherwise.
@ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore):
return txn.fetchone() return txn.fetchone()
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
) )

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Tuple from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
def purge_history(self, room_id, token, delete_local_events): async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]:
"""Deletes room history before a certain point """Deletes room history before a certain point
Args: Args:
room_id (str): room_id:
token: A topological token to delete events before
token (str): A topological token to delete events before delete_local_events:
delete_local_events (bool):
if True, we will delete local events as well as remote ones if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their (instead of just marking them as outliers and deleting their
state groups). state groups).
Returns: Returns:
Deferred[set[int]]: The set of state groups that are referenced by The set of state groups that are referenced by deleted events.
deleted events.
""" """
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"purge_history", "purge_history",
self._purge_history_txn, self._purge_history_txn,
room_id, room_id,
@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
return referenced_state_groups return referenced_state_groups
def purge_room(self, room_id): async def purge_room(self, room_id: str) -> List[int]:
"""Deletes all record of a room """Deletes all record of a room
Args: Args:
room_id (str) room_id
Returns: Returns:
Deferred[List[int]]: The list of state groups to delete. The list of state groups to delete.
""" """
return await self.db_pool.runInteraction(
return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) "purge_room", self._purge_room_txn, room_id
)
def _purge_room_txn(self, txn, room_id): def _purge_room_txn(self, txn, room_id):
# First we fetch all the state groups that should be deleted, before # First we fetch all the state groups that should be deleted, before

View File

@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
} }
return results return results
def get_users_sent_receipts_between(self, last_id: int, current_id: int): async def get_users_sent_receipts_between(
self, last_id: int, current_id: int
) -> List[str]:
"""Get all users who sent receipts between `last_id` exclusive and """Get all users who sent receipts between `last_id` exclusive and
`current_id` inclusive. `current_id` inclusive.
Returns: Returns:
Deferred[List[str]] The list of users.
""" """
if last_id == current_id: if last_id == current_id:
@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return [r[0] for r in txn] return [r[0] for r in txn]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
) )
@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
return stream_id, max_persisted_id return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): async def insert_graph_receipt(
return self.db_pool.runInteraction( self, room_id, receipt_type, user_id, event_ids, data
):
return await self.db_pool.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self.insert_graph_receipt_txn,
room_id, room_id,

View File

@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
class RelationsWorkerStore(SQLBaseStore): class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True) @cached(tree=True)
def get_relations_for_event( async def get_relations_for_event(
self, self,
event_id, event_id: str,
relation_type=None, relation_type: Optional[str] = None,
event_type=None, event_type: Optional[str] = None,
aggregation_key=None, aggregation_key: Optional[str] = None,
limit=5, limit: int = 5,
direction="b", direction: str = "b",
from_token=None, from_token: Optional[RelationPaginationToken] = None,
to_token=None, to_token: Optional[RelationPaginationToken] = None,
): ) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering. """Get a list of relations for an event, ordered by topological ordering.
Args: Args:
event_id (str): Fetch events that relate to this event ID. event_id: Fetch events that relate to this event ID.
relation_type (str|None): Only fetch events with this relation relation_type: Only fetch events with this relation type, if given.
type, if given. event_type: Only fetch events with this event type, if given.
event_type (str|None): Only fetch events with this event type, if aggregation_key: Only fetch events with this aggregation key, if given.
given. limit: Only fetch the most recent `limit` events.
aggregation_key (str|None): Only fetch events with this aggregation direction: Whether to fetch the most recent first (`"b"`) or the
key, if given. oldest first (`"f"`).
limit (int): Only fetch the most recent `limit` events. from_token: Fetch rows from the given token, or from the start if None.
direction (str): Whether to fetch the most recent first (`"b"`) or to_token: Fetch rows up to the given token, or up to the end if None.
the oldest first (`"f"`).
from_token (RelationPaginationToken|None): Fetch rows from the given
token, or from the start if None.
to_token (RelationPaginationToken|None): Fetch rows up to the given
token, or up to the end if None.
Returns: Returns:
Deferred[PaginationChunk]: List of event IDs that match relations List of event IDs that match relations requested. The rows are of
requested. The rows are of the form `{"event_id": "..."}`. the form `{"event_id": "..."}`.
""" """
where_clause = ["relates_to_id = ?"] where_clause = ["relates_to_id = ?"]
@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
) )
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn "get_recent_references_for_event", _get_recent_references_for_event_txn
) )
@cached(tree=True) @cached(tree=True)
def get_aggregation_groups_for_event( async def get_aggregation_groups_for_event(
self, self,
event_id, event_id: str,
event_type=None, event_type: Optional[str] = None,
limit=5, limit: int = 5,
direction="b", direction: str = "b",
from_token=None, from_token: Optional[AggregationPaginationToken] = None,
to_token=None, to_token: Optional[AggregationPaginationToken] = None,
): ) -> PaginationChunk:
"""Get a list of annotations on the event, grouped by event type and """Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count. aggregation key, sorted by count.
@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
on an event. on an event.
Args: Args:
event_id (str): Fetch events that relate to this event ID. event_id: Fetch events that relate to this event ID.
event_type (str|None): Only fetch events with this event type, if event_type: Only fetch events with this event type, if given.
given. limit: Only fetch the `limit` groups.
limit (int): Only fetch the `limit` groups. direction: Whether to fetch the highest count first (`"b"`) or
direction (str): Whether to fetch the highest count first (`"b"`) or
the lowest count first (`"f"`). the lowest count first (`"f"`).
from_token (AggregationPaginationToken|None): Fetch rows from the from_token: Fetch rows from the given token, or from the start if None.
given token, or from the start if None. to_token: Fetch rows up to the given token, or up to the end if None.
to_token (AggregationPaginationToken|None): Fetch rows up to the
given token, or up to the end if None.
Returns: Returns:
Deferred[PaginationChunk]: List of groups of annotations that List of groups of annotations that match. Each row is a dict with
match. Each row is a dict with `type`, `key` and `count` fields. `type`, `key` and `count` fields.
""" """
where_clause = ["relates_to_id = ?", "relation_type = ?"] where_clause = ["relates_to_id = ?", "relation_type = ?"]
@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
) )
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
) )
@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
return await self.get_event(edit_id, allow_none=True) return await self.get_event(edit_id, allow_none=True)
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): async def has_user_annotated_event(
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
) -> bool:
"""Check if a user has already annotated an event with the same key """Check if a user has already annotated an event with the same key
(e.g. already liked an event). (e.g. already liked an event).
Args: Args:
parent_id (str): The event being annotated parent_id: The event being annotated
event_type (str): The event type of the annotation event_type: The event type of the annotation
aggregation_key (str): The aggregation key of the annotation aggregation_key: The aggregation key of the annotation
sender (str): The sender of the annotation sender: The sender of the annotation
Returns: Returns:
Deferred[bool] True if the event is already annotated.
""" """
sql = """ sql = """
@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
return bool(txn.fetchone()) return bool(txn.fetchone())
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
) )