Check for space membership during a remote join of a restricted room. (#9763)

When receiving a /send_join request for a room with join rules set to 'restricted',
check if the user is a member of the spaces defined in the 'allow' key of the join
rules.
    
This only applies to an experimental room version, as defined in MSC3083.
This commit is contained in:
Patrick Cloke 2021-04-14 12:32:20 -04:00 committed by GitHub
parent 05e8c70c05
commit cc51aaaa7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 238 additions and 131 deletions

1
changelog.d/9763.feature Normal file
View File

@ -0,0 +1 @@
Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.

1
changelog.d/9800.feature Normal file
View File

@ -0,0 +1 @@
Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.

View File

@ -0,0 +1,82 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersion
from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.server import HomeServer
class EventAuthHandler:
def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastore()
async def can_join_without_invite(
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
) -> bool:
"""
Check whether a user can join a room without an invite.
When joining a room with restricted joined rules (as defined in MSC3083),
the membership of spaces must be checked during join.
Args:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
Returns:
True if the user can join the room, false otherwise.
"""
# This only applies to room versions which support the new join rule.
if not room_version.msc3083_join_rules:
return True
# If there's no join rule, then it defaults to public (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return True
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self._store.get_event(join_rules_event_id)
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
return True
# If allowed is of the wrong form, then only allow invited users.
allowed_spaces = join_rules_event.content.get("allow", [])
if not isinstance(allowed_spaces, list):
return False
# Get the list of joined rooms and see if there's an overlap.
joined_rooms = await self._store.get_rooms_for_user(user_id)
# Pull out the other room IDs, invalid data gets filtered.
for space in allowed_spaces:
if not isinstance(space, dict):
continue
space_id = space.get("space")
if not isinstance(space_id, str):
continue
# The user was joined to one of the spaces specified, they can join
# this room!
if space_id in joined_rooms:
return True
# The user was not in any of the required spaces.
return False

View File

@ -103,7 +103,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True) @attr.s(slots=True)
class _NewEventInfo: class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events """Holds information about a received event, ready for passing to _auth_and_persist_events
Attributes: Attributes:
event: the received event event: the received event
@ -146,6 +146,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.event_auth_handler = hs.get_event_auth_handler()
self._message_handler = hs.get_message_handler() self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config self.config = hs.config
@ -807,7 +808,10 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event) logger.debug("Processing event: %s", event)
try: try:
await self._handle_new_event(origin, event, state=state) context = await self.state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(origin, event, context, state=state)
except AuthError as e: except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
@ -1010,7 +1014,9 @@ class FederationHandler(BaseHandler):
) )
if ev_infos: if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)
# Step 2: Persist the rest of the events in the chunk one by one # Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -1023,10 +1029,12 @@ class FederationHandler(BaseHandler):
# non-outliers # non-outliers
assert not event.internal_metadata.is_outlier() assert not event.internal_metadata.is_outlier()
context = await self.state_handler.compute_event_context(event)
# We store these one at a time since each event depends on the # We store these one at a time since each event depends on the
# previous to work out the state. # previous to work out the state.
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
await self._handle_new_event(dest, event, backfilled=True) await self._auth_and_persist_event(dest, event, context, backfilled=True)
return events return events
@ -1360,7 +1368,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth)) event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events( await self._auth_and_persist_events(
destination, destination,
room_id, room_id,
event_infos, event_infos,
@ -1666,16 +1674,47 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems. # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin event.internal_metadata.send_on_behalf_of = origin
context = await self._handle_new_event(origin, event) # Calculate the event context.
context = await self.state_handler.compute_event_context(event)
# Get the state before the new event.
prev_state_ids = await context.get_prev_state_ids()
# Check if the user is already in the room or invited to the room.
user_id = event.state_key
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
newly_joined = True
is_invite = False
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
is_invite = prev_member_event.membership == Membership.INVITE
# If the member is not already in the room, and not invited, check if
# they should be allowed access via membership in a space.
if (
newly_joined
and not is_invite
and not await self.event_auth_handler.can_join_without_invite(
prev_state_ids,
event.room_version,
user_id,
)
):
raise SynapseError(
400,
"You do not belong to any of the required spaces to join this room.",
)
# Persist the event.
await self._auth_and_persist_event(origin, event, context)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id, event.event_id,
event.signatures, event.signatures,
) )
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values()) state_ids = list(prev_state_ids.values())
auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) auth_chain = await self.store.get_auth_chain(event.room_id, state_ids)
@ -1878,10 +1917,11 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
await self._handle_new_event(origin, event) context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)
logger.debug( logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s", "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id, event.event_id,
event.signatures, event.signatures,
) )
@ -1989,16 +2029,44 @@ class FederationHandler(BaseHandler):
async def get_min_depth_for_context(self, context: str) -> int: async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context) return await self.store.get_min_depth(context)
async def _handle_new_event( async def _auth_and_persist_event(
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None, state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False, backfilled: bool = False,
) -> EventContext: ) -> None:
context = await self._prep_event( """
origin, event, state=state, auth_events=auth_events, backfilled=backfilled Process an event by performing auth checks and then persisting to the database.
Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to auth the event. If this is not provided
the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.
backfilled: True if the event was backfilled.
"""
context = await self._check_event_auth(
origin,
event,
context,
state=state,
auth_events=auth_events,
backfilled=backfilled,
) )
try: try:
@ -2020,9 +2088,7 @@ class FederationHandler(BaseHandler):
) )
raise raise
return context async def _auth_and_persist_events(
async def _handle_new_events(
self, self,
origin: str, origin: str,
room_id: str, room_id: str,
@ -2040,9 +2106,13 @@ class FederationHandler(BaseHandler):
async def prep(ev_info: _NewEventInfo): async def prep(ev_info: _NewEventInfo):
event = ev_info.event event = ev_info.event
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
res = await self._prep_event( res = await self.state_handler.compute_event_context(
event, old_state=ev_info.state
)
res = await self._check_event_auth(
origin, origin,
event, event,
res,
state=ev_info.state, state=ev_info.state,
auth_events=ev_info.auth_events, auth_events=ev_info.auth_events,
backfilled=backfilled, backfilled=backfilled,
@ -2177,49 +2247,6 @@ class FederationHandler(BaseHandler):
room_id, [(event, new_event_context)] room_id, [(event, new_event_context)]
) )
async def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
context = await self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context
async def _check_for_soft_fail( async def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
) -> None: ) -> None:
@ -2330,19 +2357,28 @@ class FederationHandler(BaseHandler):
return missing_events return missing_events
async def do_auth( async def _check_event_auth(
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
auth_events: MutableStateMap[EventBase], state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext: ) -> EventContext:
""" """
Checks whether an event should be rejected (for failing auth checks).
Args: Args:
origin: origin: The host the event originates from.
event: event: The event itself.
context: context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to auth the event. If this is not provided
the current state events will be used.
auth_events: auth_events:
Map from (event_type, state_key) to event Map from (event_type, state_key) to event
@ -2352,12 +2388,32 @@ class FederationHandler(BaseHandler):
server. server.
Also NB that this function adds entries to it. Also NB that this function adds entries to it.
backfilled: True if the event was backfilled.
Returns: Returns:
updated context object The updated context object.
""" """
room_version = await self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try: try:
context = await self._update_auth_events_and_context_for_auth( context = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events origin, event, context, auth_events
@ -2379,6 +2435,17 @@ class FederationHandler(BaseHandler):
logger.warning("Failed auth resolution for %r because %s", event, e) logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context return context
async def _update_auth_events_and_context_for_auth( async def _update_auth_events_and_context_for_auth(
@ -2388,7 +2455,7 @@ class FederationHandler(BaseHandler):
context: EventContext, context: EventContext,
auth_events: MutableStateMap[EventBase], auth_events: MutableStateMap[EventBase],
) -> EventContext: ) -> EventContext:
"""Helper for do_auth. See there for docs. """Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if doesn't then we talk to the remote server to compare state to see if
@ -2468,9 +2535,14 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
logger.debug( logger.debug(
"do_auth %s missing_auth: %s", event.event_id, e.event_id "_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
)
context = await self.state_handler.compute_event_context(e)
await self._auth_and_persist_event(
origin, e, context, auth_events=auth
) )
await self._handle_new_event(origin, e, auth_events=auth)
if e.event_id in event_auth_events: if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e auth_events[(e.type, e.state_key)] = e

View File

@ -19,7 +19,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from synapse import types from synapse import types
from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -28,7 +28,6 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
@ -64,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.account_data_handler = hs.get_account_data_handler() self.account_data_handler = hs.get_account_data_handler()
self.event_auth_handler = hs.get_event_auth_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
@ -178,62 +178,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id) await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
async def _can_join_without_invite(
self, state_ids: StateMap[str], room_version: RoomVersion, user_id: str
) -> bool:
"""
Check whether a user can join a room without an invite.
When joining a room with restricted joined rules (as defined in MSC3083),
the membership of spaces must be checked during join.
Args:
state_ids: The state of the room as it currently is.
room_version: The room version of the room being joined.
user_id: The user joining the room.
Returns:
True if the user can join the room, false otherwise.
"""
# This only applies to room versions which support the new join rule.
if not room_version.msc3083_join_rules:
return True
# If there's no join rule, then it defaults to public (so this doesn't apply).
join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
if not join_rules_event_id:
return True
# If the join rule is not restricted, this doesn't apply.
join_rules_event = await self.store.get_event(join_rules_event_id)
if join_rules_event.content.get("join_rule") != JoinRules.MSC3083_RESTRICTED:
return True
# If allowed is of the wrong form, then only allow invited users.
allowed_spaces = join_rules_event.content.get("allow", [])
if not isinstance(allowed_spaces, list):
return False
# Get the list of joined rooms and see if there's an overlap.
joined_rooms = await self.store.get_rooms_for_user(user_id)
# Pull out the other room IDs, invalid data gets filtered.
for space in allowed_spaces:
if not isinstance(space, dict):
continue
space_id = space.get("space")
if not isinstance(space_id, str):
continue
# The user was joined to one of the spaces specified, they can join
# this room!
if space_id in joined_rooms:
return True
# The user was not in any of the required spaces.
return False
async def _local_membership_update( async def _local_membership_update(
self, self,
requester: Requester, requester: Requester,
@ -302,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
if ( if (
newly_joined newly_joined
and not user_is_invited and not user_is_invited
and not await self._can_join_without_invite( and not await self.event_auth_handler.can_join_without_invite(
prev_state_ids, event.room_version, user_id prev_state_ids, event.room_version, user_id
) )
): ):

View File

@ -77,6 +77,7 @@ from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.directory import DirectoryHandler from synapse.handlers.directory import DirectoryHandler
from synapse.handlers.e2e_keys import E2eKeysHandler from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.event_auth import EventAuthHandler
from synapse.handlers.events import EventHandler, EventStreamHandler from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.handlers.federation import FederationHandler from synapse.handlers.federation import FederationHandler
from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
@ -749,6 +750,10 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_space_summary_handler(self) -> SpaceSummaryHandler: def get_space_summary_handler(self) -> SpaceSummaryHandler:
return SpaceSummaryHandler(self) return SpaceSummaryHandler(self)
@cache_in_self
def get_event_auth_handler(self) -> EventAuthHandler:
return EventAuthHandler(self)
@cache_in_self @cache_in_self
def get_external_cache(self) -> ExternalCache: def get_external_cache(self) -> ExternalCache:
return ExternalCache(self) return ExternalCache(self)

View File

@ -75,8 +75,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
self.handler = self.homeserver.get_federation_handler() self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed( self.handler._check_event_auth = (
context lambda origin, event, context, state, auth_events, backfilled: succeed(
context
)
) )
self.client = self.homeserver.get_federation_client() self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(