Add an API for listing threads in a room. (#13394)

Implement the /threads endpoint from MSC3856.

This is currently unstable and behind an experimental configuration
flag.

It includes a background update to backfill data, results from
the /threads endpoint will be partial until that finishes.
This commit is contained in:
Patrick Cloke 2022-10-13 08:02:11 -04:00 committed by GitHub
parent b6baa46db0
commit 3bbe532abb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 522 additions and 6 deletions

View file

@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
@ -29,17 +30,46 @@ from typing import (
import attr
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadsNextBatch:
topological_ordering: int
stream_ordering: int
def __str__(self) -> str:
return f"{self.topological_ordering}_{self.stream_ordering}"
@classmethod
def from_string(cls, string: str) -> "ThreadsNextBatch":
"""
Creates a ThreadsNextBatch from its textual representation.
"""
try:
keys = (int(s) for s in string.split("_"))
return cls(*keys)
except Exception:
raise SynapseError(400, "Invalid threads token")
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
@ -56,6 +86,76 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
"threads_backfill", self._backfill_threads
)
async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
"""Backfill the threads table."""
def threads_backfill_txn(txn: LoggingTransaction) -> int:
last_thread_id = progress.get("last_thread_id", "")
# Get the latest event in each thread by topo ordering / stream ordering.
#
# Note that the MAX(event_id) is needed to abide by the rules of group by,
# but doesn't actually do anything since there should only be a single event
# ID per topo/stream ordering pair.
sql = f"""
SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
relates_to_id > ? AND
relation_type = '{RelationTypes.THREAD}'
GROUP BY room_id, relates_to_id
ORDER BY relates_to_id
LIMIT ?
"""
txn.execute(sql, (last_thread_id, batch_size))
# No more rows to process.
rows = txn.fetchall()
if not rows:
return 0
# Insert the rows into the threads table. If a matching thread already exists,
# assume it is from a newer event.
sql = """
INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
VALUES %s
ON CONFLICT (room_id, thread_id)
DO NOTHING
"""
if isinstance(txn.database_engine, PostgresEngine):
txn.execute_values(sql % ("?",), rows, fetch=False)
else:
txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows)
# Mark the progress.
self.db_pool.updates._background_update_progress_txn(
txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
)
return txn.rowcount
result = await self.db_pool.runInteraction(
"threads_backfill", threads_backfill_txn
)
if not result:
await self.db_pool.updates._end_background_update("threads_backfill")
return result
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
@cached(tree=True)
async def get_threads(
self,
room_id: str,
limit: int = 5,
from_token: Optional[ThreadsNextBatch] = None,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.
Args:
room_id: The room the event belongs to.
limit: Only fetch the most recent `limit` threads.
from_token: Fetch rows from a previous next_batch, or from the start if None.
Returns:
A tuple of:
A list of thread root event IDs.
The next_batch, if one exists.
"""
# Generate the pagination clause, if necessary.
#
# Find any threads where the latest reply is equal / before the last
# thread's topo ordering and earlier in stream ordering.
pagination_clause = ""
pagination_args: tuple = ()
if from_token:
pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
pagination_args = (
from_token.topological_ordering,
from_token.stream_ordering,
)
sql = f"""
SELECT thread_id, topological_ordering, stream_ordering
FROM threads
WHERE
room_id = ?
{pagination_clause}
ORDER BY topological_ordering DESC, stream_ordering DESC
LIMIT ?
"""
def _get_threads_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
txn.execute(sql, (room_id, *pagination_args, limit + 1))
rows = cast(List[Tuple[str, int, int]], txn.fetchall())
thread_ids = [r[0] for r in rows]
# If there are more events, generate the next pagination key from the
# last thread which will be returned.
next_token = None
if len(thread_ids) > limit:
last_topo_id = rows[-2][1]
last_stream_id = rows[-2][2]
next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
return thread_ids[:limit], next_token
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
@cached()
async def get_thread_id(self, event_id: str) -> str:
"""