Improve tests for get_unread_push_actions_for_user_in_range_*. (#13893)

* Adds a docstring.
* Reduces a small amount of duplicated code.
* Improves tests.
This commit is contained in:
Patrick Cloke 2022-09-26 14:28:12 -04:00 committed by GitHub
parent 58ab96747c
commit 2fae1a3f78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 97 additions and 30 deletions

View file

@ -559,7 +559,18 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
) -> List[Tuple[str, int]]:
) -> Dict[str, int]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.
Args:
txn:
user_id: The user to fetch receipts for.
Returns:
A map of room ID to stream ordering for all rooms the user has a receipt in.
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
@ -580,7 +591,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
args.extend((user_id,))
txn.execute(sql, args)
return cast(List[Tuple[str, int]], txn.fetchall())
return {
room_id: latest_stream_ordering
for room_id, latest_stream_ordering in txn.fetchall()
}
async def get_unread_push_actions_for_user_in_range_for_http(
self,
@ -605,12 +619,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
receipts_by_room = dict(
await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
),
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)
def get_push_actions_txn(
@ -679,12 +691,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries.
"""
receipts_by_room = dict(
await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
),
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)
def get_push_actions_txn(