mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-11-13 16:40:56 -05:00
Replace context.current_state with context.current_state_ids
This commit is contained in:
parent
17f4f14df7
commit
a3dc1e9cbe
15 changed files with 435 additions and 270 deletions
|
|
@ -65,33 +65,21 @@ class BaseHandler(object):
|
|||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
)
|
||||
|
||||
def is_host_in_room(self, current_state):
|
||||
room_members = [
|
||||
(state_key, event.membership)
|
||||
for ((event_type, state_key), event) in current_state.items()
|
||||
if event_type == EventTypes.Member
|
||||
]
|
||||
if len(room_members) == 0:
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event = current_state.get(("m.room.create", ""))
|
||||
if create_event:
|
||||
return True
|
||||
for (state_key, membership) in room_members:
|
||||
if (
|
||||
self.hs.is_mine_id(state_key)
|
||||
and membership == Membership.JOIN
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_kick_guest_users(self, event, current_state):
|
||||
def maybe_kick_guest_users(self, event, context=None):
|
||||
# Technically this function invalidates current_state by changing it.
|
||||
# Hopefully this isn't that important to the caller.
|
||||
if event.type == EventTypes.GuestAccess:
|
||||
guest_access = event.content.get("guest_access", "forbidden")
|
||||
if guest_access != "can_join":
|
||||
if context:
|
||||
current_state = yield self.store.get_events(
|
||||
context.current_state_ids.values()
|
||||
)
|
||||
current_state = current_state.values()
|
||||
else:
|
||||
current_state = yield self.store.get_current_state(event.room_id)
|
||||
logger.info("maybe_kick_guest_users %r", current_state)
|
||||
yield self.kick_guest_users(current_state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
|
|||
|
|
@ -217,11 +217,21 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
prev_state = context.current_state.get((event.type, event.state_key))
|
||||
if not prev_state or prev_state.membership != Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally
|
||||
# joined the room. Don't bother if the user is just
|
||||
# changing their profile info.
|
||||
# Only fire user_joined_room if the user has acutally
|
||||
# joined the room. Don't bother if the user is just
|
||||
# changing their profile info.
|
||||
newly_joined = True
|
||||
prev_state_id = context.current_state_ids.get(
|
||||
(event.type, event.state_key)
|
||||
)
|
||||
if prev_state_id:
|
||||
prev_state = yield self.store.get_event(
|
||||
prev_state_id, allow_none=True,
|
||||
)
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
newly_joined = False
|
||||
|
||||
if newly_joined:
|
||||
user = UserID.from_string(event.state_key)
|
||||
yield user_joined_room(self.distributor, user, event.room_id)
|
||||
|
||||
|
|
@ -734,7 +744,7 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_join_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
|
|
@ -782,18 +792,11 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
new_pdu = event
|
||||
|
||||
destinations = set()
|
||||
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||
context
|
||||
)
|
||||
destinations = set(destinations)
|
||||
destinations.discard(origin)
|
||||
|
||||
logger.debug(
|
||||
|
|
@ -804,13 +807,15 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
self.replication_layer.send_pdu(new_pdu, destinations)
|
||||
|
||||
state_ids = [e.event_id for e in context.current_state.values()]
|
||||
state_ids = context.current_state_ids.values()
|
||||
auth_chain = yield self.store.get_auth_chain(set(
|
||||
[event.event_id] + state_ids
|
||||
))
|
||||
|
||||
state = yield self.store.get_events(context.current_state_ids.values())
|
||||
|
||||
defer.returnValue({
|
||||
"state": context.current_state.values(),
|
||||
"state": state.values(),
|
||||
"auth_chain": auth_chain,
|
||||
})
|
||||
|
||||
|
|
@ -966,7 +971,7 @@ class FederationHandler(BaseHandler):
|
|||
try:
|
||||
# The remote hasn't signed it yet, obviously. We'll do the full checks
|
||||
# when we get the event back in `on_send_leave_request`
|
||||
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
|
||||
yield self.auth.check_from_context(event, context, do_sig_check=False)
|
||||
except AuthError as e:
|
||||
logger.warn("Failed to create new leave %r because %s", event, e)
|
||||
raise e
|
||||
|
|
@ -1010,18 +1015,11 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
new_pdu = event
|
||||
|
||||
destinations = set()
|
||||
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.LEAVE:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
destinations = yield message_handler.get_joined_hosts_for_room_from_state(
|
||||
context
|
||||
)
|
||||
destinations = set(destinations)
|
||||
destinations.discard(origin)
|
||||
|
||||
logger.debug(
|
||||
|
|
@ -1306,7 +1304,13 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if not auth_events:
|
||||
auth_events = context.current_state
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, context.current_state_ids, for_verification=True,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
|
||||
# This is a hack to fix some old rooms where the initial join event
|
||||
# didn't reference the create event in its auth events.
|
||||
|
|
@ -1332,8 +1336,7 @@ class FederationHandler(BaseHandler):
|
|||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
if event.type == EventTypes.GuestAccess:
|
||||
full_context = yield self.store.get_current_state(room_id=event.room_id)
|
||||
yield self.maybe_kick_guest_users(event, full_context)
|
||||
yield self.maybe_kick_guest_users(event)
|
||||
|
||||
defer.returnValue(context)
|
||||
|
||||
|
|
@ -1504,7 +1507,9 @@ class FederationHandler(BaseHandler):
|
|||
current_state = set(e.event_id for e in auth_events.values())
|
||||
different_auth = event_auth_events - current_state
|
||||
|
||||
context.current_state.update(auth_events)
|
||||
context.current_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
})
|
||||
context.state_group = None
|
||||
|
||||
if different_auth and not event.internal_metadata.is_outlier():
|
||||
|
|
@ -1526,8 +1531,8 @@ class FederationHandler(BaseHandler):
|
|||
|
||||
if do_resolution:
|
||||
# 1. Get what we think is the auth chain.
|
||||
auth_ids = self.auth.compute_auth_events(
|
||||
event, context.current_state
|
||||
auth_ids = yield self.auth.compute_auth_events(
|
||||
event, context.current_state_ids
|
||||
)
|
||||
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
|
||||
|
||||
|
|
@ -1583,7 +1588,9 @@ class FederationHandler(BaseHandler):
|
|||
# 4. Look at rejects and their proofs.
|
||||
# TODO.
|
||||
|
||||
context.current_state.update(auth_events)
|
||||
context.current_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.items()
|
||||
})
|
||||
context.state_group = None
|
||||
|
||||
try:
|
||||
|
|
@ -1770,12 +1777,12 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, context.current_state)
|
||||
yield self.auth.check_from_context(event, context)
|
||||
except AuthError as e:
|
||||
logger.warn("Denying new third party invite %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
yield self._check_signature(event, auth_events=context.current_state)
|
||||
yield self._check_signature(event, context)
|
||||
member_handler = self.hs.get_handlers().room_member_handler
|
||||
yield member_handler.send_membership_event(None, event, context)
|
||||
else:
|
||||
|
|
@ -1801,11 +1808,11 @@ class FederationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
self.auth.check_from_context(event, context)
|
||||
except AuthError as e:
|
||||
logger.warn("Denying third party invite %r because %s", event, e)
|
||||
raise e
|
||||
yield self._check_signature(event, auth_events=context.current_state)
|
||||
yield self._check_signature(event, context)
|
||||
|
||||
returned_invite = yield self.send_invite(origin, event)
|
||||
# TODO: Make sure the signatures actually are correct.
|
||||
|
|
@ -1819,7 +1826,12 @@ class FederationHandler(BaseHandler):
|
|||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"]
|
||||
)
|
||||
original_invite = context.current_state.get(key)
|
||||
original_invite = None
|
||||
original_invite_id = context.current_state_ids.get(key)
|
||||
if original_invite_id:
|
||||
original_invite = yield self.store.get_event(
|
||||
original_invite_id, allow_none=True
|
||||
)
|
||||
if not original_invite:
|
||||
logger.info(
|
||||
"Could not find invite event for third_party_invite - "
|
||||
|
|
@ -1836,13 +1848,13 @@ class FederationHandler(BaseHandler):
|
|||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_signature(self, event, auth_events):
|
||||
def _check_signature(self, event, context):
|
||||
"""
|
||||
Checks that the signature in the event is consistent with its invite.
|
||||
|
||||
Args:
|
||||
event (Event): The m.room.member event to check
|
||||
auth_events (dict<(event type, state_key), event>):
|
||||
context (EventContext):
|
||||
|
||||
Raises:
|
||||
AuthError: if signature didn't match any keys, or key has been
|
||||
|
|
@ -1853,10 +1865,14 @@ class FederationHandler(BaseHandler):
|
|||
signed = event.content["third_party_invite"]["signed"]
|
||||
token = signed["token"]
|
||||
|
||||
invite_event = auth_events.get(
|
||||
invite_event_id = context.current_state_ids.get(
|
||||
(EventTypes.ThirdPartyInvite, token,)
|
||||
)
|
||||
|
||||
invite_event = None
|
||||
if invite_event_id:
|
||||
invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
|
||||
|
||||
if not invite_event:
|
||||
raise AuthError(403, "Could not find invite")
|
||||
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLo
|
|||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
|
@ -248,7 +249,7 @@ class MessageHandler(BaseHandler):
|
|||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = self.deduplicate_state_event(event, context)
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
|
|
@ -263,6 +264,7 @@ class MessageHandler(BaseHandler):
|
|||
presence = self.hs.get_presence_handler()
|
||||
yield presence.bump_presence_active_time(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deduplicate_state_event(self, event, context):
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
|
|
@ -270,13 +272,17 @@ class MessageHandler(BaseHandler):
|
|||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
prev_event = context.current_state.get((event.type, event.state_key))
|
||||
prev_event_id = context.current_state_ids.get((event.type, event.state_key))
|
||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
return prev_event
|
||||
return None
|
||||
defer.returnValue(prev_event)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
|
|
@ -803,7 +809,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
logger.debug(
|
||||
"Created event %s with current state: %s",
|
||||
event.event_id, context.current_state,
|
||||
event.event_id, context.current_state_ids,
|
||||
)
|
||||
|
||||
defer.returnValue(
|
||||
|
|
@ -826,12 +832,12 @@ class MessageHandler(BaseHandler):
|
|||
self.ratelimit(requester)
|
||||
|
||||
try:
|
||||
self.auth.check(event, auth_events=context.current_state)
|
||||
yield self.auth.check_from_context(event, context)
|
||||
except AuthError as err:
|
||||
logger.warn("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
|
||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
||||
yield self.maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
|
|
@ -859,6 +865,15 @@ class MessageHandler(BaseHandler):
|
|||
e.sender == event.sender
|
||||
)
|
||||
|
||||
state_to_include_ids = [
|
||||
e_id
|
||||
for k, e_id in context.current_state_ids.items()
|
||||
if k[0] in self.hs.config.room_invite_state_types
|
||||
or k[0] == EventTypes.Member and k[1] == event.sender
|
||||
]
|
||||
|
||||
state_to_include = yield self.store.get_events(state_to_include_ids)
|
||||
|
||||
event.unsigned["invite_room_state"] = [
|
||||
{
|
||||
"type": e.type,
|
||||
|
|
@ -866,9 +881,7 @@ class MessageHandler(BaseHandler):
|
|||
"content": e.content,
|
||||
"sender": e.sender,
|
||||
}
|
||||
for k, e in context.current_state.items()
|
||||
if e.type in self.hs.config.room_invite_state_types
|
||||
or is_inviter_member_event(e)
|
||||
for e in state_to_include.values()
|
||||
]
|
||||
|
||||
invitee = UserID.from_string(event.state_key)
|
||||
|
|
@ -890,7 +903,14 @@ class MessageHandler(BaseHandler):
|
|||
)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
if self.auth.check_redaction(event, auth_events=context.current_state):
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, context.current_state_ids, for_verification=True,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.values()
|
||||
}
|
||||
if self.auth.check_redaction(event, auth_events=auth_events):
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
check_redacted=False,
|
||||
|
|
@ -904,7 +924,7 @@ class MessageHandler(BaseHandler):
|
|||
"You don't have permission to redact events"
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Create and context.current_state:
|
||||
if event.type == EventTypes.Create and context.current_state_ids:
|
||||
raise AuthError(
|
||||
403,
|
||||
"Changing the room create event is forbidden",
|
||||
|
|
@ -925,16 +945,7 @@ class MessageHandler(BaseHandler):
|
|||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
destinations = set()
|
||||
for k, s in context.current_state.items():
|
||||
try:
|
||||
if k[0] == EventTypes.Member:
|
||||
if s.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(s.state_key))
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", s.event_id
|
||||
)
|
||||
destinations = yield self.get_joined_hosts_for_room_from_state(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _notify():
|
||||
|
|
@ -952,3 +963,39 @@ class MessageHandler(BaseHandler):
|
|||
preserve_fn(federation_handler.handle_new_event)(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
||||
def get_joined_hosts_for_room_from_state(self, context):
|
||||
state_group = context.state_group
|
||||
if not state_group:
|
||||
# If state_group is None it means it has yet to be assigned a
|
||||
# state group, i.e. we need to make sure that calls with a state_group
|
||||
# of None don't hit previous cached calls with a None state_group.
|
||||
# To do this we set the state_group to a new object as object() != object()
|
||||
state_group = object()
|
||||
|
||||
return self._get_joined_hosts_for_room_from_state(
|
||||
state_group, context.current_state_ids
|
||||
)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _get_joined_hosts_for_room_from_state(self, state_group, current_state_ids,
|
||||
cache_context):
|
||||
|
||||
# Don't bother getting state for people on the same HS
|
||||
current_state = yield self.store.get_events([
|
||||
e_id for key, e_id in current_state_ids.items()
|
||||
if key[0] == EventTypes.Member and not self.hs.is_mine_id(key[1])
|
||||
])
|
||||
|
||||
destinations = set()
|
||||
for e in current_state.itervalues():
|
||||
try:
|
||||
if e.type == EventTypes.Member:
|
||||
if e.content["membership"] == Membership.JOIN:
|
||||
destinations.add(get_domain_from_id(e.state_key))
|
||||
except SynapseError:
|
||||
logger.warn(
|
||||
"Failed to get destination from event %s", e.event_id
|
||||
)
|
||||
|
||||
defer.returnValue(destinations)
|
||||
|
|
|
|||
|
|
@ -93,20 +93,26 @@ class RoomMemberHandler(BaseHandler):
|
|||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
prev_member_event = context.current_state.get(
|
||||
prev_member_event_id = context.current_state_ids.get(
|
||||
(EventTypes.Member, target.to_string()),
|
||||
None
|
||||
)
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
yield user_joined_room(self.distributor, target, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target, room_id)
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remote_join(self, remote_room_hosts, room_id, user, content):
|
||||
|
|
@ -195,29 +201,32 @@ class RoomMemberHandler(BaseHandler):
|
|||
remote_room_hosts = []
|
||||
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
current_state = yield self.state_handler.get_current_state(
|
||||
current_state_ids = yield self.state_handler.get_current_state_ids(
|
||||
room_id, latest_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
old_state = current_state.get((EventTypes.Member, target.to_string()))
|
||||
old_membership = old_state.content.get("membership") if old_state else None
|
||||
if action == "unban" and old_membership != "ban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot unban user who was not banned (membership=%s)" % old_membership,
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
if old_membership == "ban" and action != "unban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot %s user who was banned" % (action,),
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||
if old_state_id:
|
||||
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
||||
old_membership = old_state.content.get("membership") if old_state else None
|
||||
if action == "unban" and old_membership != "ban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot unban user who was not banned"
|
||||
" (membership=%s)" % old_membership,
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
if old_membership == "ban" and action != "unban":
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Cannot %s user who was banned" % (action,),
|
||||
errcode=Codes.BAD_STATE
|
||||
)
|
||||
|
||||
is_host_in_room = self.is_host_in_room(current_state)
|
||||
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
||||
|
||||
if effective_membership_state == Membership.JOIN:
|
||||
if requester.is_guest and not self._can_guest_join(current_state):
|
||||
if requester.is_guest and not self._can_guest_join(current_state_ids):
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
|
@ -326,15 +335,17 @@ class RoomMemberHandler(BaseHandler):
|
|||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = message_handler.deduplicate_state_event(event, context)
|
||||
prev_event = yield message_handler.deduplicate_state_event(event, context)
|
||||
if prev_event is not None:
|
||||
return
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
if requester.is_guest and not self._can_guest_join(context.current_state):
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
if requester.is_guest:
|
||||
guest_can_join = yield self._can_guest_join(context.current_state_ids)
|
||||
if not guest_can_join:
|
||||
# This should be an auth check, but guests are a local concept,
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
yield message_handler.handle_new_client_event(
|
||||
requester,
|
||||
|
|
@ -344,27 +355,39 @@ class RoomMemberHandler(BaseHandler):
|
|||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
prev_member_event = context.current_state.get(
|
||||
(EventTypes.Member, target_user.to_string()),
|
||||
prev_member_event_id = context.current_state_ids.get(
|
||||
(EventTypes.Member, event.state_key),
|
||||
None
|
||||
)
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
# Only fire user_joined_room if the user has acutally joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
if newly_joined:
|
||||
yield user_joined_room(self.distributor, target_user, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event and prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target_user, room_id)
|
||||
if prev_member_event_id:
|
||||
prev_member_event = yield self.store.get_event(prev_member_event_id)
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
user_left_room(self.distributor, target_user, room_id)
|
||||
|
||||
def _can_guest_join(self, current_state):
|
||||
@defer.inlineCallbacks
|
||||
def _can_guest_join(self, current_state_ids):
|
||||
"""
|
||||
Returns whether a guest can join a room based on its current state.
|
||||
"""
|
||||
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
|
||||
return (
|
||||
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
|
||||
if not guest_access_id:
|
||||
defer.returnValue(False)
|
||||
|
||||
guest_access = yield self.store.get_event(guest_access_id)
|
||||
|
||||
defer.returnValue(
|
||||
guest_access
|
||||
and guest_access.content
|
||||
and "guest_access" in guest_access.content
|
||||
|
|
@ -683,3 +706,24 @@ class RoomMemberHandler(BaseHandler):
|
|||
|
||||
if membership:
|
||||
yield self.store.forget(user_id, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_host_in_room(self, current_state_ids):
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||
if len(current_state_ids) == 1 and create_event_id:
|
||||
defer.returnValue(self.hs.is_mine_id(create_event_id))
|
||||
|
||||
for (etype, state_key), event_id in current_state_ids.items():
|
||||
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
|
||||
continue
|
||||
|
||||
event = yield self.store.get_event(event_id, allow_none=True)
|
||||
if not event:
|
||||
continue
|
||||
|
||||
if event.membership == Membership.JOIN:
|
||||
defer.returnValue(True)
|
||||
|
||||
defer.returnValue(False)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue