Accept & store thread IDs for receipts (implement MSC3771). (#13782)

Updates the `/receipts` endpoint and receipt EDU handler to parse a
`thread_id` from the body and insert it in the database.
This commit is contained in:
Patrick Cloke 2022-09-23 10:33:28 -04:00 committed by GitHub
parent 03c2bfb7f8
commit efd108b45d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 173 additions and 41 deletions

View file

@ -540,7 +540,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_all_updated_receipts(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool
]:
"""Get updates for receipts replication stream.
Args:
@ -567,9 +569,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_all_updated_receipts_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
int,
bool,
]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
@ -578,8 +584,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
[(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
)
limited = False
@ -631,6 +637,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
data: JsonDict,
stream_id: int,
) -> Optional[int]:
@ -657,12 +664,27 @@ class ReceiptsWorkerStore(SQLBaseStore):
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
if stream_ordering is not None:
sql = (
"SELECT stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
if thread_id is None:
thread_clause = "r.thread_id IS NULL"
thread_args: Tuple[str, ...] = ()
else:
thread_clause = "r.thread_id = ?"
thread_args = (thread_id,)
sql = f"""
SELECT stream_ordering, event_id FROM events
INNER JOIN receipts_linearized AS r USING (event_id, room_id)
WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause}
"""
txn.execute(
sql,
(
room_id,
receipt_type,
user_id,
)
+ thread_args,
)
txn.execute(sql, (room_id, receipt_type, user_id))
for so, eid in txn:
if int(so) >= stream_ordering:
@ -682,21 +704,28 @@ class ReceiptsWorkerStore(SQLBaseStore):
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
keyvalues = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
where_clause = ""
if thread_id is None:
where_clause = "thread_id IS NULL"
else:
keyvalues["thread_id"] = thread_id
self.db_pool.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
keyvalues=keyvalues,
values={
"stream_id": stream_id,
"event_id": event_id,
"event_stream_ordering": stream_ordering,
"data": json_encoder.encode(data),
"thread_id": None,
},
where_clause=where_clause,
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,
@ -748,6 +777,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: dict,
) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
@ -780,6 +810,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type,
user_id,
linearized_event_id,
thread_id,
data,
stream_id=stream_id,
# Read committed is actually beneficial here because we check for a receipt with
@ -794,7 +825,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
"Receipt %s for event %s in %s (%i ms old)",
receipt_type,
linearized_event_id,
room_id,
now - event_ts,
@ -807,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type,
user_id,
event_ids,
thread_id,
data,
)
@ -821,6 +854,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts
@ -832,19 +866,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
keyvalues = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
where_clause = ""
if thread_id is None:
where_clause = "thread_id IS NULL"
else:
keyvalues["thread_id"] = thread_id
self.db_pool.simple_upsert_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
},
keyvalues=keyvalues,
values={
"event_ids": json_encoder.encode(event_ids),
"data": json_encoder.encode(data),
"thread_id": None,
},
where_clause=where_clause,
# receipts_graph has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
lock=False,