Add a filter_event_for_clients_with_state function (#13222)

This commit is contained in:
Erik Johnston 2022-07-11 14:14:09 +01:00 committed by GitHub
parent a113011794
commit e610128c50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 411 additions and 149 deletions

1
changelog.d/13222.misc Normal file
View File

@ -0,0 +1 @@
Improve memory usage of calculating push actions for events in large rooms.

View File

@ -39,6 +39,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.room import RoomWorkerStore from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
@ -60,7 +61,17 @@ class AdminCmdSlavedStore(
BaseSlavedStore, BaseSlavedStore,
RoomWorkerStore, RoomWorkerStore,
): ):
pass def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# Annoyingly `filter_events_for_client` assumes that this exists. We
# should refactor it to take a `Clock` directly.
self.clock = hs.get_clock()
class AdminCmdServer(HomeServer): class AdminCmdServer(HomeServer):

View File

@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from enum import Enum, auto
from typing import Collection, Dict, FrozenSet, List, Optional, Tuple from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
import attr
from typing_extensions import Final from typing_extensions import Final
from synapse.api.constants import EventTypes, HistoryVisibility, Membership from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.storage.controllers import StorageControllers from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RetentionPolicy, StateMap, get_domain_from_id from synapse.types import RetentionPolicy, StateMap, get_domain_from_id
from synapse.util import Clock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,9 +107,176 @@ async def filter_events_for_client(
] = await storage.main.get_retention_policy_for_room(room_id) ] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event: EventBase) -> Optional[EventBase]: def allowed(event: EventBase) -> Optional[EventBase]:
return _check_client_allowed_to_see_event(
user_id=user_id,
event=event,
clock=storage.main.clock,
filter_send_to_client=filter_send_to_client,
sender_ignored=event.sender in ignore_list,
always_include_ids=always_include_ids,
retention_policy=retention_policies[room_id],
state=event_id_to_state.get(event.event_id),
is_peeking=is_peeking,
sender_erased=erased_senders.get(event.sender, False),
)
# Check each event: gives an iterable of None or (a potentially modified)
# EventBase.
filtered_events = map(allowed, events)
# Turn it into a list and remove None entries before returning.
return [ev for ev in filtered_events if ev]
async def filter_event_for_clients_with_state(
store: DataStore,
user_ids: Collection[str],
event: EventBase,
context: EventContext,
is_peeking: bool = False,
filter_send_to_client: bool = True,
) -> Collection[str]:
""" """
Checks to see if an event is visible to the users in the list at the time of
the event.
Note: This does *not* check if the sender of the event was erased.
Args: Args:
event: event to check store: databases
user_ids: user_ids to be checked
event: the event to be checked
context: EventContext for the event to be checked
is_peeking: Whether the users are peeking into the room, ie not
currently joined
filter_send_to_client: Whether we're checking an event that's going to be
sent to a client. This might not always be the case since this function can
also be called to check whether a user can see the state at a given point.
Returns:
Collection of user IDs for whom the event is visible
"""
# None of the users should see the event if it is soft_failed
if event.internal_metadata.is_soft_failed():
return []
# Make a set for all user IDs that haven't been filtered out by a check.
allowed_user_ids = set(user_ids)
# Only run some checks if these events aren't about to be sent to clients. This is
# because, if this is not the case, we're probably only checking if the users can
# see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks.
if filter_send_to_client:
ignored_by = await store.ignored_by(event.sender)
retention_policy = await store.get_retention_policy_for_room(event.room_id)
for user_id in user_ids:
if (
_check_filter_send_to_client(
event,
store.clock,
retention_policy,
sender_ignored=user_id in ignored_by,
)
== _CheckFilter.DENIED
):
allowed_user_ids.discard(user_id)
if event.internal_metadata.outlier:
# Normally these can't be seen by clients, but we make an exception for
# for out-of-band membership events (eg, incoming invites, or rejections of
# said invite) for the user themselves.
if event.type == EventTypes.Member and event.state_key in allowed_user_ids:
logger.debug("Returning out-of-band-membership event %s", event)
return {event.state_key}
return set()
# First we get just the history visibility in case its shared/world-readable
# room.
visibility_state_map = await _get_state_map(
store, event, context, StateFilter.from_types([_HISTORY_VIS_KEY])
)
visibility = get_effective_room_visibility_from_state(visibility_state_map)
if (
_check_history_visibility(event, visibility, is_peeking=is_peeking)
== _CheckVisibility.ALLOWED
):
return allowed_user_ids
# The history visibility isn't lax, so we now need to fetch the membership
# events of all the users.
filter_list = []
for user_id in allowed_user_ids:
filter_list.append((EventTypes.Member, user_id))
filter_list.append((EventTypes.RoomHistoryVisibility, ""))
state_filter = StateFilter.from_types(filter_list)
state_map = await _get_state_map(store, event, context, state_filter)
# Now we check whether the membership allows each user to see the event.
return {
user_id
for user_id in allowed_user_ids
if _check_membership(user_id, event, visibility, state_map, is_peeking).allowed
}
async def _get_state_map(
store: DataStore, event: EventBase, context: EventContext, state_filter: StateFilter
) -> StateMap[EventBase]:
"""Helper function for getting a `StateMap[EventBase]` from an `EventContext`"""
state_map = await context.get_prev_state_ids(state_filter)
# Use events rather than event ids as content from the events are needed in
# _check_visibility
event_map = await store.get_events(state_map.values(), get_prev_content=False)
updated_state_map = {}
for state_key, event_id in state_map.items():
state_event = event_map.get(event_id)
if state_event:
updated_state_map[state_key] = state_event
if event.is_state():
current_state_key = (event.type, event.state_key)
# Add current event to updated_state_map, we need to do this here as it
# may not have been persisted to the db yet
updated_state_map[current_state_key] = event
return updated_state_map
def _check_client_allowed_to_see_event(
user_id: str,
event: EventBase,
clock: Clock,
filter_send_to_client: bool,
is_peeking: bool,
always_include_ids: FrozenSet[str],
sender_ignored: bool,
retention_policy: RetentionPolicy,
state: Optional[StateMap[EventBase]],
sender_erased: bool,
) -> Optional[EventBase]:
"""Check with the given user is allowed to see the given event
See `filter_events_for_client` for details about args
Args:
user_id
event
clock
filter_send_to_client
is_peeking
always_include_ids
sender_ignored: Whether the user is ignoring the event sender
retention_policy: The retention policy of the room
state: The state at the event, unless its an outlier
sender_erased: Whether the event sender has been marked as "erased"
Returns: Returns:
None if the user cannot see this event at all None if the user cannot see this event at all
@ -119,30 +291,10 @@ async def filter_events_for_client(
# see events in the room at that point in the DAG, and that shouldn't be decided # see events in the room at that point in the DAG, and that shouldn't be decided
# on those checks. # on those checks.
if filter_send_to_client: if filter_send_to_client:
if event.type == EventTypes.Dummy: if (
return None _check_filter_send_to_client(event, clock, retention_policy, sender_ignored)
== _CheckFilter.DENIED
if not event.is_state() and event.sender in ignore_list: ):
return None
# Until MSC2261 has landed we can't redact malicious alias events, so for
# now we temporarily filter out m.room.aliases entirely to mitigate
# abuse, while we spec a better solution to advertising aliases
# on rooms.
if event.type == EventTypes.Aliases:
return None
# Don't try to apply the room's retention policy if the event is a state
# event, as MSC1763 states that retention is only considered for non-state
# events.
if not event.is_state():
retention_policy = retention_policies[event.room_id]
max_lifetime = retention_policy.max_lifetime
if max_lifetime is not None:
oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
if event.origin_server_ts < oldest_allowed_ts:
return None return None
if event.event_id in always_include_ids: if event.event_id in always_include_ids:
@ -159,28 +311,58 @@ async def filter_events_for_client(
return None return None
state = event_id_to_state[event.event_id] if state is None:
raise Exception("Missing state for non-outlier event")
# get the room_visibility at the time of the event. # get the room_visibility at the time of the event.
visibility = get_effective_room_visibility_from_state(state) visibility = get_effective_room_visibility_from_state(state)
# Always allow history visibility events on boundaries. This is done # Check if the room has lax history visibility, allowing us to skip
# by setting the effective visibility to the least restrictive # membership checks.
# of the old vs new. #
if event.type == EventTypes.RoomHistoryVisibility: # We can only do this check if the sender has *not* been erased, as if they
prev_content = event.unsigned.get("prev_content", {}) # have we need to check the user's membership.
prev_visibility = prev_content.get("history_visibility", None) if (
not sender_erased
and _check_history_visibility(event, visibility, is_peeking)
== _CheckVisibility.ALLOWED
):
return event
if prev_visibility not in VISIBILITY_PRIORITY: membership_result = _check_membership(user_id, event, visibility, state, is_peeking)
prev_visibility = HistoryVisibility.SHARED if not membership_result.allowed:
return None
new_priority = VISIBILITY_PRIORITY.index(visibility) # If the sender has been erased and the user was not joined at the time, we
old_priority = VISIBILITY_PRIORITY.index(prev_visibility) # must only return the redacted form.
if old_priority < new_priority: if sender_erased and not membership_result.joined:
visibility = prev_visibility event = prune_event(event)
# likewise, if the event is the user's own membership event, use return event
# the 'most joined' membership
@attr.s(frozen=True, slots=True, auto_attribs=True)
class _CheckMembershipReturn:
"Return value of _check_membership"
allowed: bool
joined: bool
def _check_membership(
user_id: str,
event: EventBase,
visibility: str,
state: StateMap[EventBase],
is_peeking: bool,
) -> _CheckMembershipReturn:
"""Check whether the user can see the event due to their membership
Returns:
True if they can, False if they can't, plus the membership of the user
at the event.
"""
# If the event is the user's own membership event, use the 'most joined'
# membership
membership = None membership = None
if event.type == EventTypes.Member and event.state_key == user_id: if event.type == EventTypes.Member and event.state_key == user_id:
membership = event.content.get("membership", None) membership = event.content.get("membership", None)
@ -200,7 +382,7 @@ async def filter_events_for_client(
if membership == "leave" and ( if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite" prev_membership == "join" or prev_membership == "invite"
): ):
return event return _CheckMembershipReturn(True, membership == Membership.JOIN)
new_priority = MEMBERSHIP_PRIORITY.index(membership) new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
@ -216,19 +398,19 @@ async def filter_events_for_client(
# if the user was a member of the room at the time of the event, # if the user was a member of the room at the time of the event,
# they can see it. # they can see it.
if membership == Membership.JOIN: if membership == Membership.JOIN:
return event return _CheckMembershipReturn(True, True)
# otherwise, it depends on the room visibility. # otherwise, it depends on the room visibility.
if visibility == HistoryVisibility.JOINED: if visibility == HistoryVisibility.JOINED:
# we weren't a member at the time of the event, so we can't # we weren't a member at the time of the event, so we can't
# see this event. # see this event.
return None return _CheckMembershipReturn(False, False)
elif visibility == HistoryVisibility.INVITED: elif visibility == HistoryVisibility.INVITED:
# user can also see the event if they were *invited* at the time # user can also see the event if they were *invited* at the time
# of the event. # of the event.
return event if membership == Membership.INVITE else None return _CheckMembershipReturn(membership == Membership.INVITE, False)
elif visibility == HistoryVisibility.SHARED and is_peeking: elif visibility == HistoryVisibility.SHARED and is_peeking:
# if the visibility is shared, users cannot see the event unless # if the visibility is shared, users cannot see the event unless
@ -239,28 +421,96 @@ async def filter_events_for_client(
# ideally we would share history up to the point they left. But # ideally we would share history up to the point they left. But
# we don't know when they left. We just treat it as though they # we don't know when they left. We just treat it as though they
# never joined, and restrict access. # never joined, and restrict access.
return None return _CheckMembershipReturn(False, False)
# the visibility is either shared or world_readable, and the user was # The visibility is either shared or world_readable, and the user was
# not a member at the time. We allow it, provided the original sender # not a member at the time. We allow it.
# has not requested their data to be erased, in which case, we return return _CheckMembershipReturn(True, False)
# a redacted version.
if erased_senders[event.sender]:
return prune_event(event)
return event
# Check each event: gives an iterable of None or (a potentially modified) class _CheckFilter(Enum):
# EventBase. MAYBE_ALLOWED = auto()
filtered_events = map(allowed, events) DENIED = auto()
# Turn it into a list and remove None entries before returning.
return [ev for ev in filtered_events if ev] def _check_filter_send_to_client(
event: EventBase,
clock: Clock,
retention_policy: RetentionPolicy,
sender_ignored: bool,
) -> _CheckFilter:
"""Apply checks for sending events to client
Returns:
True if might be allowed to be sent to clients, False if definitely not.
"""
if event.type == EventTypes.Dummy:
return _CheckFilter.DENIED
if not event.is_state() and sender_ignored:
return _CheckFilter.DENIED
# Until MSC2261 has landed we can't redact malicious alias events, so for
# now we temporarily filter out m.room.aliases entirely to mitigate
# abuse, while we spec a better solution to advertising aliases
# on rooms.
if event.type == EventTypes.Aliases:
return _CheckFilter.DENIED
# Don't try to apply the room's retention policy if the event is a state
# event, as MSC1763 states that retention is only considered for non-state
# events.
if not event.is_state():
max_lifetime = retention_policy.max_lifetime
if max_lifetime is not None:
oldest_allowed_ts = clock.time_msec() - max_lifetime
if event.origin_server_ts < oldest_allowed_ts:
return _CheckFilter.DENIED
return _CheckFilter.MAYBE_ALLOWED
class _CheckVisibility(Enum):
ALLOWED = auto()
MAYBE_DENIED = auto()
def _check_history_visibility(
event: EventBase, visibility: str, is_peeking: bool
) -> _CheckVisibility:
"""Check if event is allowed to be seen due to lax history visibility.
Returns:
True if user can definitely see the event, False if maybe not.
"""
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
# of the old vs new.
if event.type == EventTypes.RoomHistoryVisibility:
prev_content = event.unsigned.get("prev_content", {})
prev_visibility = prev_content.get("history_visibility", None)
if prev_visibility not in VISIBILITY_PRIORITY:
prev_visibility = HistoryVisibility.SHARED
new_priority = VISIBILITY_PRIORITY.index(visibility)
old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
if old_priority < new_priority:
visibility = prev_visibility
if visibility == HistoryVisibility.SHARED and not is_peeking:
return _CheckVisibility.ALLOWED
elif visibility == HistoryVisibility.WORLD_READABLE:
return _CheckVisibility.ALLOWED
return _CheckVisibility.MAYBE_DENIED
def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:
"""Get the actual history vis, from a state map including the history_visibility event """Get the actual history vis, from a state map including the history_visibility event
Handles missing and invalid history visibility events. Handles missing and invalid history visibility events.
""" """
visibility_event = state.get(_HISTORY_VIS_KEY, None) visibility_event = state.get(_HISTORY_VIS_KEY, None)