Merge implementation of /join by alias or ID

This code is kind of rough (passing the remote servers down a long
chain), but is a step towards improvement.
This commit is contained in:
Daniel Wagner-Hall 2016-02-15 15:39:16 +00:00
parent dbeed36dec
commit e71095801f
5 changed files with 72 additions and 71 deletions

View File

@ -188,9 +188,12 @@ class BaseHandler(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_users=[]): def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
self.ratelimit(event.sender)
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
yield self.maybe_kick_guest_users(event, context.current_state.values()) yield self.maybe_kick_guest_users(event, context.current_state.values())

View File

@ -216,7 +216,7 @@ class MessageHandler(BaseHandler):
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks
def send_event(self, event, context, ratelimit=True, is_guest=False): def send_event(self, event, context, ratelimit=True, is_guest=False, room_hosts=None):
""" """
Persists and notifies local clients and federation of an event. Persists and notifies local clients and federation of an event.
@ -230,9 +230,6 @@ class MessageHandler(BaseHandler):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if ratelimit:
self.ratelimit(event.sender)
if event.is_state(): if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key)) prev_state = context.current_state.get((event.type, event.state_key))
if prev_state and event.user_id == prev_state.user_id: if prev_state and event.user_id == prev_state.user_id:
@ -245,11 +242,18 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(event, context, is_guest=is_guest) yield member_handler.send_membership_event(
event,
context,
is_guest=is_guest,
ratelimit=ratelimit,
room_hosts=room_hosts
)
else: else:
yield self.handle_new_client_event( yield self.handle_new_client_event(
event=event, event=event,
context=context, context=context,
ratelimit=ratelimit,
) )
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
@ -259,7 +263,8 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None, is_guest=False): token_id=None, txn_id=None, is_guest=False,
room_hosts=None):
""" """
Creates an event, then sends it. Creates an event, then sends it.
@ -274,7 +279,8 @@ class MessageHandler(BaseHandler):
event, event,
context, context,
ratelimit=ratelimit, ratelimit=ratelimit,
is_guest=is_guest is_guest=is_guest,
room_hosts=room_hosts,
) )
defer.returnValue(event) defer.returnValue(event)

View File

@ -455,7 +455,9 @@ class RoomMemberHandler(BaseHandler):
yield self.forget(requester.user, room_id) yield self.forget(requester.user, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def send_membership_event(self, event, context, is_guest=False, room_hosts=None): def send_membership_event(
self, event, context, is_guest=False, room_hosts=None, ratelimit=True
):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
Args: Args:
@ -527,8 +529,17 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def join_room_alias(self, requester, room_alias, content={}): def lookup_room_alias(self, room_alias):
joinee = requester.user """
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
The room ID as a RoomID object.
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias) mapping = yield directory_handler.get_association(room_alias)
@ -540,28 +551,7 @@ class RoomMemberHandler(BaseHandler):
if not hosts: if not hosts:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
# If event doesn't include a display name, add one. defer.returnValue((RoomID.from_string(room_id), hosts))
yield collect_presencelike_data(self.distributor, joinee, content)
content.update({"membership": Membership.JOIN})
builder = self.event_builder_factory.new({
"type": EventTypes.Member,
"state_key": joinee.to_string(),
"room_id": room_id,
"sender": joinee.to_string(),
"membership": Membership.JOIN,
"content": content,
})
event, context = yield self._create_new_client_event(builder)
yield self.send_membership_event(
event,
context,
is_guest=requester.is_guest,
room_hosts=hosts
)
defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_hosts=None): def _do_join(self, event, context, room_hosts=None):

View File

@ -229,28 +229,19 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
allow_guest=True, allow_guest=True,
) )
# the identifier could be a room alias or a room id. Try one then the if RoomID.is_valid(room_identifier):
# other if it fails to parse, without swallowing other valid room_id = room_identifier
# SynapseErrors. room_hosts = None
elif RoomAlias.is_valid(room_identifier):
identifier = None
is_room_alias = False
try:
identifier = RoomAlias.from_string(room_identifier)
is_room_alias = True
except SynapseError:
identifier = RoomID.from_string(room_identifier)
# TODO: Support for specifying the home server to join with?
if is_room_alias:
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias( room_alias = RoomAlias.from_string(room_identifier)
requester, room_id, room_hosts = yield handler.lookup_room_alias(room_alias)
identifier, room_id = room_id.to_string()
) else:
defer.returnValue((200, ret_dict)) raise SynapseError(400, "%s was not legal room ID or room alias" % (
else: # room id room_identifier,
))
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": Membership.JOIN} content = {"membership": Membership.JOIN}
if requester.is_guest: if requester.is_guest:
@ -259,16 +250,19 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
"content": content, "content": content,
"room_id": identifier.to_string(), "room_id": room_id,
"sender": requester.user.to_string(), "sender": requester.user.to_string(),
"state_key": requester.user.to_string(), "state_key": requester.user.to_string(),
"membership": Membership.JOIN, # For backwards compatibility
}, },
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
is_guest=requester.is_guest, is_guest=requester.is_guest,
room_hosts=room_hosts,
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": room_id}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id): def on_PUT(self, request, room_identifier, txn_id):

View File

@ -73,6 +73,14 @@ class DomainSpecificString(
"""Return a string encoding the fields of the structure object.""" """Return a string encoding the fields of the structure object."""
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain) return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
@classmethod
def is_valid(cls, s):
try:
cls.from_string(s)
return True
except:
return False
__str__ = to_string __str__ = to_string
@classmethod @classmethod