Do not consider events by ignored users for bundled aggregations (#12235)

Consider the requester's ignored users when calculating the
bundled aggregations.

See #12285 / 4df10d3214
for corresponding changes for the `/relations` endpoint.
This commit is contained in:
Patrick Cloke 2022-04-11 10:09:57 -04:00 committed by GitHub
parent 3cdf5a1386
commit 772bad2562
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 427 additions and 45 deletions

1
changelog.d/12235.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where events from ignored users were still considered for bundled aggregations.

1
changelog.d/12338.bugfix Normal file
View File

@ -0,0 +1 @@
Fix a long-standing bug where events from ignored users were still considered for bundled aggregations.

View File

@ -1 +0,0 @@
Refactor relations code to remove an unnecessary class.

View File

@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
import attr
from frozendict import frozendict
@ -20,7 +29,8 @@ from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StreamToken
from synapse.storage.databases.main.relations import _RelatedEvent
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@ -115,6 +125,9 @@ class RelationsHandler:
if event is None:
raise SynapseError(404, "Unknown parent event.")
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
event_id=event_id,
event=event,
@ -128,7 +141,9 @@ class RelationsHandler:
to_token=to_token,
)
events = await self._main_store.get_events_as_list(related_events)
events = await self._main_store.get_events_as_list(
[e.event_id for e in related_events]
)
events = await filter_events_for_client(
self._storage, user_id, events, is_peeking=(member_event_id is None)
@ -162,8 +177,87 @@ class RelationsHandler:
return return_value
async def get_relations_for_event(
self,
event_id: str,
event: EventBase,
room_id: str,
relation_type: str,
ignored_users: FrozenSet[str] = frozenset(),
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of events which relate to an event, ordered by topological ordering.
Args:
event_id: Fetch events that relate to this event ID.
event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: The type of relation.
ignored_users: The users ignored by the requesting user.
Returns:
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
# Call the underlying storage method, which is cached.
related_events, next_token = await self._main_store.get_relations_for_event(
event_id, event, room_id, relation_type, direction="f"
)
# Filter out ignored users and convert to the expected format.
related_events = [
event for event in related_events if event.sender not in ignored_users
]
return related_events, next_token
async def get_annotations_for_event(
self,
event_id: str,
room_id: str,
limit: int = 5,
ignored_users: FrozenSet[str] = frozenset(),
) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
This is used e.g. to get the what and how many reactions have happend
on an event.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
ignored_users: The users ignored by the requesting user.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
# Get the base results for all users.
full_results = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id, limit
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
event_id, room_id, limit, ignored_users
)
filtered_results = []
for result in full_results:
key = (result["type"], result["key"])
if key in ignored_results:
result = result.copy()
result["count"] -= ignored_results[key]
if result["count"] <= 0:
continue
filtered_results.append(result)
return filtered_results
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
self, event: EventBase, ignored_users: FrozenSet[str]
) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
@ -171,7 +265,7 @@ class RelationsHandler:
Args:
event: The event to calculate bundled aggregations for.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Returns:
The bundled aggregations for an event, if bundled aggregations are
@ -194,18 +288,22 @@ class RelationsHandler:
# while others need more processing during serialization.
aggregations = BundledAggregations()
annotations = await self._main_store.get_aggregation_groups_for_event(
event_id, room_id
annotations = await self.get_annotations_for_event(
event_id, room_id, ignored_users=ignored_users
)
if annotations:
aggregations.annotations = {"chunk": annotations}
references, next_token = await self._main_store.get_relations_for_event(
event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
references, next_token = await self.get_relations_for_event(
event_id,
event,
room_id,
RelationTypes.REFERENCE,
ignored_users=ignored_users,
)
if references:
aggregations.references = {
"chunk": [{"event_id": event_id} for event_id in references]
"chunk": [{"event_id": event.event_id} for event in references]
}
if next_token:
@ -216,6 +314,99 @@ class RelationsHandler:
# Store the bundled aggregations in the event metadata for later use.
return aggregations
async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
event_ids: Events to get aggregations for threads.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
Returns:
A dictionary mapping event ID to the thread information.
May not contain a value for all requested event IDs.
"""
user = UserID.from_string(user_id)
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
participated = await self._main_store.get_threads_participated(
thread_event_ids, user_id
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_threaded_messages_per_user(
thread_event_ids, ignored_users
)
# A map of event ID to the thread aggregation.
results = {}
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
# Subtract off the count of any ignored users.
for ignored_user in ignored_users:
thread_count -= ignored_results.get((event_id, ignored_user), 0)
# This is gnarly, but if the latest event is from an ignored user,
# attempt to find one that isn't from an ignored user.
if latest_thread_event.sender in ignored_users:
room_id = latest_thread_event.room_id
# If the root event is not found, something went wrong, do
# not include a summary of the thread.
event = await self._event_handler.get_event(user, room_id, event_id)
if event is None:
continue
potential_events, _ = await self.get_relations_for_event(
event_id,
event,
room_id,
RelationTypes.THREAD,
ignored_users,
)
# If all found events are from ignored users, do not include
# a summary of the thread.
if not potential_events:
continue
# The *last* event returned is the one that is cared about.
event = await self._event_handler.get_event(
user, room_id, potential_events[-1].event_id
)
# It is unexpected that the event will not exist.
if event is None:
logger.warning(
"Unable to fetch latest event in a thread with event ID: %s",
potential_events[-1].event_id,
)
continue
latest_thread_event = event
results[event_id] = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
return results
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
@ -239,13 +430,21 @@ class RelationsHandler:
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch any ignored users of the requesting user.
ignored_users = await self._main_store.ignored_users(user_id)
# Fetch other relations per event.
for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
event_result = await self._get_bundled_aggregation_for_event(
event, ignored_users
)
if event_result:
results[event.event_id] = event_result
# Fetch any edits (but not for redacted events).
#
# Note that there is no use in limiting edits by ignored users since the
# parent event should be ignored in the first place if the user is ignored.
edits = await self._main_store.get_applicable_edits(
[
event_id
@ -256,25 +455,10 @@ class RelationsHandler:
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._main_store.get_threads_participated(
[event_id for event_id, summary in summaries.items() if summary], user_id
threads = await self.get_threads_for_events(
events_by_id.keys(), user_id, ignored_users
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
for event_id, thread in threads.items():
results.setdefault(event_id, BundledAggregations()).thread = thread
return results

View File

@ -17,6 +17,7 @@ from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
@ -26,6 +27,8 @@ from typing import (
cast,
)
import attr
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@ -46,6 +49,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
Contains enough information about a related event in order to properly filter
events from ignored users.
"""
# The event ID of the related event.
event_id: str
# The sender of the related event.
sender: str
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@ -70,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[str], Optional[StreamToken]]:
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@ -88,7 +104,7 @@ class RelationsWorkerStore(SQLBaseStore):
Returns:
A tuple of:
A list of related event IDs
A list of related event IDs & their senders.
The next stream token, if one exists.
"""
@ -131,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
SELECT event_id, relation_type, topological_ordering, stream_ordering
SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@ -145,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], Optional[StreamToken]]:
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@ -155,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
events.append(row[0])
last_topo_id = row[2]
last_stream_id = row[3]
events.append(_RelatedEvent(row[0], row[2]))
last_topo_id = row[3]
last_stream_id = row[4]
# If there are more events, generate the next pagination key.
next_token = None
@ -267,7 +283,7 @@ class RelationsWorkerStore(SQLBaseStore):
`type`, `key` and `count` fields.
"""
where_args = [
args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
@ -287,7 +303,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
) -> List[JsonDict]:
txn.execute(sql, where_args)
txn.execute(sql, args)
return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
@ -295,6 +311,63 @@ class RelationsWorkerStore(SQLBaseStore):
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
)
async def get_aggregation_groups_for_users(
self,
event_id: str,
room_id: str,
limit: int,
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
"""Fetch the partial aggregations for an event for specific users.
This is used, in conjunction with get_aggregation_groups_for_event, to
remove information from the results for ignored users.
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
limit: Only fetch the `limit` groups.
users: The users to fetch information for.
Returns:
A map of (event type, aggregation key) to a count of users.
"""
if not users:
return {}
args: List[Union[str, int]] = [
event_id,
room_id,
RelationTypes.ANNOTATION,
]
users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "sender", users
)
args.extend(users_args)
sql = f"""
SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
GROUP BY relation_type, type, aggregation_key
ORDER BY COUNT(*) DESC
LIMIT ?
"""
def _get_aggregation_groups_for_users_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, str], int]:
txn.execute(sql, args + [limit])
return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
)
@cached()
def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
raise NotImplementedError()
@ -521,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
return summaries
async def get_threaded_messages_per_user(
self,
event_ids: Collection[str],
users: FrozenSet[str] = frozenset(),
) -> Dict[Tuple[str, str], int]:
"""Get the number of threaded replies for a set of users.
This is used, in conjunction with get_thread_summaries, to calculate an
accurate count of the replies to a thread by subtracting ignored users.
Args:
event_ids: The events to check for threaded replies.
users: The user to calculate the count of their replies.
Returns:
A map of the (event_id, sender) to the count of their replies.
"""
if not users:
return {}
# Fetch the number of threaded replies.
sql = """
SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
INNER JOIN event_relations USING (event_id)
INNER JOIN events AS parent ON
parent.event_id = relates_to_id
AND parent.room_id = child.room_id
WHERE
%s
AND %s
AND %s
GROUP BY parent.event_id, child.sender
"""
def _get_threaded_messages_per_user_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, str], int]:
users_sql, users_args = make_in_list_sql_clause(
self.database_engine, "child.sender", users
)
events_clause, events_args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
if self._msc3440_enabled:
relations_clause = "(relation_type = ? OR relation_type = ?)"
relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
else:
relations_clause = "relation_type = ?"
relations_args = [RelationTypes.THREAD]
txn.execute(
sql % (users_sql, events_clause, relations_clause),
users_args + events_args + relations_args,
)
return {(row[0], row[1]): row[2] for row in txn}
return await self.db_pool.runInteraction(
"get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
)
@cached()
def get_thread_participated(self, event_id: str, user_id: str) -> bool:
raise NotImplementedError()

View File

@ -1137,16 +1137,27 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
"""Relations sent from an ignored user should be ignored."""
def _test_ignored_user(
self, allowed_event_ids: List[str], ignored_event_ids: List[str]
) -> None:
self,
relation_type: str,
allowed_event_ids: List[str],
ignored_event_ids: List[str],
) -> Tuple[JsonDict, JsonDict]:
"""
Fetch the relations and ensure they're all there, then ignore user2, and
repeat.
Returns:
A tuple of two JSON dictionaries, each are bundled aggregations, the
first is from before the user is ignored, and the second is after.
"""
# Get the relations.
event_ids = self._get_related_events()
self.assertCountEqual(event_ids, allowed_event_ids + ignored_event_ids)
# And the bundled aggregations.
before_aggregations = self._get_bundled_aggregations()
self.assertIn(relation_type, before_aggregations)
# Ignore user2 and re-do the requests.
self.get_success(
self.store.add_account_data_for_user(
@ -1160,6 +1171,12 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
event_ids = self._get_related_events()
self.assertCountEqual(event_ids, allowed_event_ids)
# And the bundled aggregations.
after_aggregations = self._get_bundled_aggregations()
self.assertIn(relation_type, after_aggregations)
return before_aggregations[relation_type], after_aggregations[relation_type]
def test_annotation(self) -> None:
"""Annotations should ignore"""
# Send 2 from us, 2 from the to be ignored user.
@ -1184,7 +1201,26 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
)
ignored_event_ids.append(channel.json_body["event_id"])
self._test_ignored_user(allowed_event_ids, ignored_event_ids)
before_aggregations, after_aggregations = self._test_ignored_user(
RelationTypes.ANNOTATION, allowed_event_ids, ignored_event_ids
)
self.assertCountEqual(
before_aggregations["chunk"],
[
{"type": "m.reaction", "key": "a", "count": 2},
{"type": "m.reaction", "key": "b", "count": 1},
{"type": "m.reaction", "key": "c", "count": 1},
],
)
self.assertCountEqual(
after_aggregations["chunk"],
[
{"type": "m.reaction", "key": "a", "count": 1},
{"type": "m.reaction", "key": "b", "count": 1},
],
)
def test_reference(self) -> None:
"""Annotations should ignore"""
@ -1196,7 +1232,18 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
)
ignored_event_ids = [channel.json_body["event_id"]]
self._test_ignored_user(allowed_event_ids, ignored_event_ids)
before_aggregations, after_aggregations = self._test_ignored_user(
RelationTypes.REFERENCE, allowed_event_ids, ignored_event_ids
)
self.assertCountEqual(
[e["event_id"] for e in before_aggregations["chunk"]],
allowed_event_ids + ignored_event_ids,
)
self.assertCountEqual(
[e["event_id"] for e in after_aggregations["chunk"]], allowed_event_ids
)
def test_thread(self) -> None:
"""Annotations should ignore"""
@ -1208,7 +1255,23 @@ class RelationIgnoredUserTestCase(BaseRelationsTestCase):
)
ignored_event_ids = [channel.json_body["event_id"]]
self._test_ignored_user(allowed_event_ids, ignored_event_ids)
before_aggregations, after_aggregations = self._test_ignored_user(
RelationTypes.THREAD, allowed_event_ids, ignored_event_ids
)
self.assertEqual(before_aggregations["count"], 2)
self.assertTrue(before_aggregations["current_user_participated"])
# The latest thread event has some fields that don't matter.
self.assertEqual(
before_aggregations["latest_event"]["event_id"], ignored_event_ids[0]
)
self.assertEqual(after_aggregations["count"], 1)
self.assertTrue(after_aggregations["current_user_participated"])
# The latest thread event has some fields that don't matter.
self.assertEqual(
after_aggregations["latest_event"]["event_id"], allowed_event_ids[0]
)
class RelationRedactionTestCase(BaseRelationsTestCase):