mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-07-26 23:15:15 -04:00
Add an API for listing threads in a room. (#13394)
Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes.
This commit is contained in:
parent
b6baa46db0
commit
3bbe532abb
10 changed files with 522 additions and 6 deletions
|
@ -14,6 +14,7 @@
|
|||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
|
@ -29,17 +30,46 @@ from typing import (
|
|||
import attr
|
||||
|
||||
from synapse.api.constants import MAIN_TIMELINE, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.stream import generate_pagination_where_clause
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadsNextBatch:
|
||||
topological_ordering: int
|
||||
stream_ordering: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.topological_ordering}_{self.stream_ordering}"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, string: str) -> "ThreadsNextBatch":
|
||||
"""
|
||||
Creates a ThreadsNextBatch from its textual representation.
|
||||
"""
|
||||
try:
|
||||
keys = (int(s) for s in string.split("_"))
|
||||
return cls(*keys)
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid threads token")
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _RelatedEvent:
|
||||
"""
|
||||
|
@ -56,6 +86,76 @@ class _RelatedEvent:
|
|||
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"threads_backfill", self._backfill_threads
|
||||
)
|
||||
|
||||
async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
|
||||
"""Backfill the threads table."""
|
||||
|
||||
def threads_backfill_txn(txn: LoggingTransaction) -> int:
|
||||
last_thread_id = progress.get("last_thread_id", "")
|
||||
|
||||
# Get the latest event in each thread by topo ordering / stream ordering.
|
||||
#
|
||||
# Note that the MAX(event_id) is needed to abide by the rules of group by,
|
||||
# but doesn't actually do anything since there should only be a single event
|
||||
# ID per topo/stream ordering pair.
|
||||
sql = f"""
|
||||
SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
relates_to_id > ? AND
|
||||
relation_type = '{RelationTypes.THREAD}'
|
||||
GROUP BY room_id, relates_to_id
|
||||
ORDER BY relates_to_id
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_thread_id, batch_size))
|
||||
|
||||
# No more rows to process.
|
||||
rows = txn.fetchall()
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
# Insert the rows into the threads table. If a matching thread already exists,
|
||||
# assume it is from a newer event.
|
||||
sql = """
|
||||
INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
|
||||
VALUES %s
|
||||
ON CONFLICT (room_id, thread_id)
|
||||
DO NOTHING
|
||||
"""
|
||||
if isinstance(txn.database_engine, PostgresEngine):
|
||||
txn.execute_values(sql % ("?",), rows, fetch=False)
|
||||
else:
|
||||
txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows)
|
||||
|
||||
# Mark the progress.
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
|
||||
)
|
||||
|
||||
return txn.rowcount
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"threads_backfill", threads_backfill_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
await self.db_pool.updates._end_background_update("threads_backfill")
|
||||
|
||||
return result
|
||||
|
||||
@cached(uncached_args=("event",), tree=True)
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
|
@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||
)
|
||||
|
||||
@cached(tree=True)
|
||||
async def get_threads(
|
||||
self,
|
||||
room_id: str,
|
||||
limit: int = 5,
|
||||
from_token: Optional[ThreadsNextBatch] = None,
|
||||
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
|
||||
"""Get a list of thread IDs, ordered by topological ordering of their
|
||||
latest reply.
|
||||
|
||||
Args:
|
||||
room_id: The room the event belongs to.
|
||||
limit: Only fetch the most recent `limit` threads.
|
||||
from_token: Fetch rows from a previous next_batch, or from the start if None.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
A list of thread root event IDs.
|
||||
|
||||
The next_batch, if one exists.
|
||||
"""
|
||||
# Generate the pagination clause, if necessary.
|
||||
#
|
||||
# Find any threads where the latest reply is equal / before the last
|
||||
# thread's topo ordering and earlier in stream ordering.
|
||||
pagination_clause = ""
|
||||
pagination_args: tuple = ()
|
||||
if from_token:
|
||||
pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
|
||||
pagination_args = (
|
||||
from_token.topological_ordering,
|
||||
from_token.stream_ordering,
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT thread_id, topological_ordering, stream_ordering
|
||||
FROM threads
|
||||
WHERE
|
||||
room_id = ?
|
||||
{pagination_clause}
|
||||
ORDER BY topological_ordering DESC, stream_ordering DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
def _get_threads_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
|
||||
txn.execute(sql, (room_id, *pagination_args, limit + 1))
|
||||
|
||||
rows = cast(List[Tuple[str, int, int]], txn.fetchall())
|
||||
thread_ids = [r[0] for r in rows]
|
||||
|
||||
# If there are more events, generate the next pagination key from the
|
||||
# last thread which will be returned.
|
||||
next_token = None
|
||||
if len(thread_ids) > limit:
|
||||
last_topo_id = rows[-2][1]
|
||||
last_stream_id = rows[-2][2]
|
||||
next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
|
||||
|
||||
return thread_ids[:limit], next_token
|
||||
|
||||
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
|
||||
|
||||
@cached()
|
||||
async def get_thread_id(self, event_id: str) -> str:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue