From 724e11d62057a77aa9b43fdd803b6fcd1cbc183b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 27 May 2022 07:44:10 -0400 Subject: [PATCH] Clean-up some receipts code (#12888) * Properly marks private methods as private. * Adds missing docstrings. * Rework inline methods. --- changelog.d/12888.misc | 1 + synapse/storage/databases/main/receipts.py | 89 ++++++++++++---------- 2 files changed, 48 insertions(+), 42 deletions(-) create mode 100644 changelog.d/12888.misc diff --git a/changelog.d/12888.misc b/changelog.d/12888.misc new file mode 100644 index 000000000..8ed2ea65b --- /dev/null +++ b/changelog.d/12888.misc @@ -0,0 +1 @@ +Refactor receipt linearization code. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index f74aa1e3f..21e954ccc 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return super().process_replication_rows(stream_name, instance_name, token, rows) - def insert_linearized_receipt_txn( + def _insert_linearized_receipt_txn( self, txn: LoggingTransaction, room_id: str, @@ -686,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore): return rx_ts + def _graph_to_linear( + self, txn: LoggingTransaction, room_id: str, event_ids: List[str] + ) -> str: + """ + Generate a linearized event from a list of events (i.e. a list of forward + extremities in the room). + + This should allow for calculation of the correct read receipt even if + servers have different event ordering. + + Args: + txn: The transaction + room_id: The room ID the events are in. + event_ids: The list of event IDs to linearize. + + Returns: + The linearized event ID. + """ + # TODO: Make this better. + clause, args = make_in_list_sql_clause( + self.database_engine, "event_id", event_ids + ) + + sql = """ + SELECT event_id WHERE room_id = ? AND stream_ordering IN ( + SELECT max(stream_ordering) WHERE %s + ) + """ % ( + clause, + ) + + txn.execute(sql, [room_id] + list(args)) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + async def insert_receipt( self, room_id: str, @@ -712,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore): linearized_event_id = event_ids[0] else: # we need to points in graph -> linearized form. - # TODO: Make this better. - def graph_to_linear(txn: LoggingTransaction) -> str: - clause, args = make_in_list_sql_clause( - self.database_engine, "event_id", event_ids - ) - - sql = """ - SELECT event_id WHERE room_id = ? AND stream_ordering IN ( - SELECT max(stream_ordering) WHERE %s - ) - """ % ( - clause, - ) - - txn.execute(sql, [room_id] + list(args)) - rows = txn.fetchall() - if rows: - return rows[0][0] - else: - raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - linearized_event_id = await self.db_pool.runInteraction( - "insert_receipt_conv", graph_to_linear + "insert_receipt_conv", self._graph_to_linear, room_id, event_ids ) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( "insert_linearized_receipt", - self.insert_linearized_receipt_txn, + self._insert_linearized_receipt_txn, room_id, receipt_type, user_id, @@ -761,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore): now - event_ts, ) - await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) - - max_persisted_id = self._receipts_id_gen.get_current_token() - - return stream_id, max_persisted_id - - async def insert_graph_receipt( - self, - room_id: str, - receipt_type: str, - user_id: str, - event_ids: List[str], - data: JsonDict, - ) -> None: - assert self._can_write_to_receipts - await self.db_pool.runInteraction( "insert_graph_receipt", - self.insert_graph_receipt_txn, + self._insert_graph_receipt_txn, room_id, receipt_type, user_id, @@ -787,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore): data, ) - def insert_graph_receipt_txn( + max_persisted_id = self._receipts_id_gen.get_current_token() + + return stream_id, max_persisted_id + + def _insert_graph_receipt_txn( self, txn: LoggingTransaction, room_id: str,