Clean-up some receipts code (#12888)

* Properly marks private methods as private.
* Adds missing docstrings.
* Rework inline methods.
This commit is contained in:
Patrick Cloke 2022-05-27 07:44:10 -04:00 committed by GitHub
parent c52abc1cfd
commit 724e11d620
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 42 deletions

1
changelog.d/12888.misc Normal file
View File

@ -0,0 +1 @@
Refactor receipt linearization code.

View File

@ -597,7 +597,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def insert_linearized_receipt_txn( def _insert_linearized_receipt_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
room_id: str, room_id: str,
@ -686,6 +686,44 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rx_ts return rx_ts
def _graph_to_linear(
self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
) -> str:
"""
Generate a linearized event from a list of events (i.e. a list of forward
extremities in the room).
This should allow for calculation of the correct read receipt even if
servers have different event ordering.
Args:
txn: The transaction
room_id: The room ID the events are in.
event_ids: The list of event IDs to linearize.
Returns:
The linearized event ID.
"""
# TODO: Make this better.
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
async def insert_receipt( async def insert_receipt(
self, self,
room_id: str, room_id: str,
@ -712,35 +750,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
linearized_event_id = event_ids[0] linearized_event_id = event_ids[0]
else: else:
# we need to points in graph -> linearized form. # we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn: LoggingTransaction) -> str:
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = await self.db_pool.runInteraction( linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
) )
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction( event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self._insert_linearized_receipt_txn,
room_id, room_id,
receipt_type, receipt_type,
user_id, user_id,
@ -761,25 +778,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
now - event_ts, now - event_ts,
) )
await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
async def insert_graph_receipt(
self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"insert_graph_receipt", "insert_graph_receipt",
self.insert_graph_receipt_txn, self._insert_graph_receipt_txn,
room_id, room_id,
receipt_type, receipt_type,
user_id, user_id,
@ -787,7 +788,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
data, data,
) )
def insert_graph_receipt_txn( max_persisted_id = self._receipts_id_gen.get_current_token()
return stream_id, max_persisted_id
def _insert_graph_receipt_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
room_id: str, room_id: str,