mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-06-12 13:12:52 -04:00
Do not consider events by ignored users for bundled aggregations (#12235)
Consider the requester's ignored users when calculating the
bundled aggregations.
See #12285 / 4df10d3214
for corresponding changes for the `/relations` endpoint.
This commit is contained in:
parent
3cdf5a1386
commit
772bad2562
6 changed files with 427 additions and 45 deletions
|
@ -17,6 +17,7 @@ from typing import (
|
|||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
|
@ -26,6 +27,8 @@ from typing import (
|
|||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
@ -46,6 +49,19 @@ if TYPE_CHECKING:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _RelatedEvent:
|
||||
"""
|
||||
Contains enough information about a related event in order to properly filter
|
||||
events from ignored users.
|
||||
"""
|
||||
|
||||
# The event ID of the related event.
|
||||
event_id: str
|
||||
# The sender of the related event.
|
||||
sender: str
|
||||
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -70,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
direction: str = "b",
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> Tuple[List[str], Optional[StreamToken]]:
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
"""Get a list of relations for an event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
|
@ -88,7 +104,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
Returns:
|
||||
A tuple of:
|
||||
A list of related event IDs
|
||||
A list of related event IDs & their senders.
|
||||
|
||||
The next stream token, if one exists.
|
||||
"""
|
||||
|
@ -131,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
order = "ASC"
|
||||
|
||||
sql = """
|
||||
SELECT event_id, relation_type, topological_ordering, stream_ordering
|
||||
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE %s
|
||||
|
@ -145,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
def _get_recent_references_for_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[str], Optional[StreamToken]]:
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
txn.execute(sql, where_args + [limit + 1])
|
||||
|
||||
last_topo_id = None
|
||||
|
@ -155,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
# Do not include edits for redacted events as they leak event
|
||||
# content.
|
||||
if not is_redacted or row[1] != RelationTypes.REPLACE:
|
||||
events.append(row[0])
|
||||
last_topo_id = row[2]
|
||||
last_stream_id = row[3]
|
||||
events.append(_RelatedEvent(row[0], row[2]))
|
||||
last_topo_id = row[3]
|
||||
last_stream_id = row[4]
|
||||
|
||||
# If there are more events, generate the next pagination key.
|
||||
next_token = None
|
||||
|
@ -267,7 +283,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
`type`, `key` and `count` fields.
|
||||
"""
|
||||
|
||||
where_args = [
|
||||
args = [
|
||||
event_id,
|
||||
room_id,
|
||||
RelationTypes.ANNOTATION,
|
||||
|
@ -287,7 +303,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
def _get_aggregation_groups_for_event_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[JsonDict]:
|
||||
txn.execute(sql, where_args)
|
||||
txn.execute(sql, args)
|
||||
|
||||
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
|
||||
|
||||
|
@ -295,6 +311,63 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||
)
|
||||
|
||||
async def get_aggregation_groups_for_users(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
limit: int,
|
||||
users: FrozenSet[str] = frozenset(),
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
"""Fetch the partial aggregations for an event for specific users.
|
||||
|
||||
This is used, in conjunction with get_aggregation_groups_for_event, to
|
||||
remove information from the results for ignored users.
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
limit: Only fetch the `limit` groups.
|
||||
users: The users to fetch information for.
|
||||
|
||||
Returns:
|
||||
A map of (event type, aggregation key) to a count of users.
|
||||
"""
|
||||
|
||||
if not users:
|
||||
return {}
|
||||
|
||||
args: List[Union[str, int]] = [
|
||||
event_id,
|
||||
room_id,
|
||||
RelationTypes.ANNOTATION,
|
||||
]
|
||||
|
||||
users_sql, users_args = make_in_list_sql_clause(
|
||||
self.database_engine, "sender", users
|
||||
)
|
||||
args.extend(users_args)
|
||||
|
||||
sql = f"""
|
||||
SELECT type, aggregation_key, COUNT(DISTINCT sender)
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
|
||||
GROUP BY relation_type, type, aggregation_key
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
def _get_aggregation_groups_for_users_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
txn.execute(sql, args + [limit])
|
||||
|
||||
return {(row[0], row[1]): row[2] for row in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
|
||||
raise NotImplementedError()
|
||||
|
@ -521,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
|
|||
|
||||
return summaries
|
||||
|
||||
async def get_threaded_messages_per_user(
|
||||
self,
|
||||
event_ids: Collection[str],
|
||||
users: FrozenSet[str] = frozenset(),
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
"""Get the number of threaded replies for a set of users.
|
||||
|
||||
This is used, in conjunction with get_thread_summaries, to calculate an
|
||||
accurate count of the replies to a thread by subtracting ignored users.
|
||||
|
||||
Args:
|
||||
event_ids: The events to check for threaded replies.
|
||||
users: The user to calculate the count of their replies.
|
||||
|
||||
Returns:
|
||||
A map of the (event_id, sender) to the count of their replies.
|
||||
"""
|
||||
if not users:
|
||||
return {}
|
||||
|
||||
# Fetch the number of threaded replies.
|
||||
sql = """
|
||||
SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
|
||||
INNER JOIN event_relations USING (event_id)
|
||||
INNER JOIN events AS parent ON
|
||||
parent.event_id = relates_to_id
|
||||
AND parent.room_id = child.room_id
|
||||
WHERE
|
||||
%s
|
||||
AND %s
|
||||
AND %s
|
||||
GROUP BY parent.event_id, child.sender
|
||||
"""
|
||||
|
||||
def _get_threaded_messages_per_user_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[Tuple[str, str], int]:
|
||||
users_sql, users_args = make_in_list_sql_clause(
|
||||
self.database_engine, "child.sender", users
|
||||
)
|
||||
events_clause, events_args = make_in_list_sql_clause(
|
||||
txn.database_engine, "relates_to_id", event_ids
|
||||
)
|
||||
|
||||
if self._msc3440_enabled:
|
||||
relations_clause = "(relation_type = ? OR relation_type = ?)"
|
||||
relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
|
||||
else:
|
||||
relations_clause = "relation_type = ?"
|
||||
relations_args = [RelationTypes.THREAD]
|
||||
|
||||
txn.execute(
|
||||
sql % (users_sql, events_clause, relations_clause),
|
||||
users_args + events_args + relations_args,
|
||||
)
|
||||
return {(row[0], row[1]): row[2] for row in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
|
||||
raise NotImplementedError()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue