Shuffle things around to make unit tests work

This commit is contained in:
Erik Johnston 2016-09-22 11:08:12 +01:00
parent 1168cbd54d
commit a61e4522b5
2 changed files with 16 additions and 15 deletions

View File

@ -91,23 +91,24 @@ class Auth(object):
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender) if do_sig_check:
event_id_domain = get_domain_from_id(event.event_id) sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = ( is_invite_via_3pid = (
event.type == EventTypes.Member event.type == EventTypes.Member
and event.membership == Membership.INVITE and event.membership == Membership.INVITE
and "third_party_invite" in event.content and "third_party_invite" in event.content
) )
# Check the sender's domain has signed the event # Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain): if not event.signatures.get(sender_domain):
if not is_invite_via_3pid: if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server") raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event # Check the event_id's domain has signed the event
if do_sig_check and not event.signatures.get(event_id_domain): if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server") raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we

View File

@ -56,7 +56,7 @@ def get_domain_from_id(string):
try: try:
return string.split(":", 1)[1] return string.split(":", 1)[1]
except IndexError: except IndexError:
raise SynapseError(400, "Invalid ID: %r", string) raise SynapseError(400, "Invalid ID: %r" % (string,))
class DomainSpecificString( class DomainSpecificString(