Allow guests to register and call /events?room_id=

This follows the same flows-based flow as regular registration, but as
the only implemented flow has no requirements, it auto-succeeds. In the
future, other flows (e.g. captcha) may be required, so clients should
treat this like the regular registration flow choices.
This commit is contained in:
Daniel Wagner-Hall 2015-11-04 17:29:07 +00:00
parent f74f48e9e6
commit f522f50a08
33 changed files with 271 additions and 166 deletions

View File

@ -49,6 +49,7 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ",
"type = ", "type = ",
"time < ", "time < ",
"user_id = ", "user_id = ",
@ -183,15 +184,11 @@ class Auth(object):
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None): def check_user_was_in_room(self, room_id, user_id):
"""Check if the user was in the room at some point. """Check if the user was in the room at some point.
Args: Args:
room_id(str): The room to check. room_id(str): The room to check.
user_id(str): The user to check. user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises: Raises:
AuthError if the user was never in the room. AuthError if the user was never in the room.
Returns: Returns:
@ -199,17 +196,11 @@ class Auth(object):
room. This will be the join event if they are currently joined to room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room. the room. This will be the leave event if they have left the room.
""" """
if current_state: member = yield self.state.get_current_state(
member = current_state.get( room_id=room_id,
(EventTypes.Member, user_id), event_type=EventTypes.Member,
None state_key=user_id
) )
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE): if membership not in (Membership.JOIN, Membership.LEAVE):
@ -497,7 +488,7 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request): def get_user_by_req(self, request, allow_guest=False):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -535,7 +526,7 @@ class Auth(object):
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), "")) defer.returnValue((UserID.from_string(user_id), "", False))
return return
except KeyError: except KeyError:
pass # normal users won't have the user_id query parameter set. pass # normal users won't have the user_id query parameter set.
@ -543,6 +534,7 @@ class Auth(object):
user_info = yield self._get_user_by_access_token(access_token) user_info = yield self._get_user_by_access_token(access_token)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"]
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
@ -557,9 +549,14 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
if is_guest and not allow_guest:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id,)) defer.returnValue((user, token_id, is_guest,))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -592,31 +589,45 @@ class Auth(object):
self._validate_macaroon(macaroon) self._validate_macaroon(macaroon)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None
guest = False
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix): if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):]) user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
# This codepath exists so that we can actually return a elif caveat.caveat_id == "guest = true":
# token ID, because we use token IDs in place of device guest = True
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are if user is None:
# properly implemented. raise AuthError(
ret = yield self._look_up_user_by_access_token(macaroon_str) self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
if ret["user"] != user: errcode=Codes.UNKNOWN_TOKEN
logger.error( )
"Macaroon user (%s) != DB user (%s)",
user, if guest:
ret["user"] ret = {
) "user": user,
raise AuthError( "is_guest": True,
self.TOKEN_NOT_FOUND_HTTP_STATUS, "token_id": None,
"User mismatch in macaroon", }
errcode=Codes.UNKNOWN_TOKEN else:
) # This codepath exists so that we can actually return a
defer.returnValue(ret) # token ID, because we use token IDs in place of device
raise AuthError( # identifiers throughout the codebase.
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon", # TODO(daniel): Remove this fallback when device IDs are
errcode=Codes.UNKNOWN_TOKEN # properly implemented.
) ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
@ -629,6 +640,7 @@ class Auth(object):
v.satisfy_exact("type = access") v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
v.satisfy_exact("guest = true")
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -666,6 +678,7 @@ class Auth(object):
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False,
} }
defer.returnValue(user_info) defer.returnValue(user_info)

View File

@ -33,6 +33,7 @@ class Codes(object):
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN" MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"

View File

@ -34,6 +34,7 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret") self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key") self.macaroon_secret_key = config.get("macaroon_secret_key")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12) self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.allow_guest_access = config.get("allow_guest_access", False)
def default_config(self, **kwargs): def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50) registration_shared_secret = random_string_with_symbols(50)
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
# Larger numbers increase the work factor needed to generate the hash. # Larger numbers increase the work factor needed to generate the hash.
# The default number of rounds is 12. # The default number of rounds is 12.
bcrypt_rounds: 12 bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
# participate in rooms hosted on this server which have been made
# accessible to anonymous users.
allow_guest_access: False
""" % locals() """ % locals()
def add_arguments(self, parser): def add_arguments(self, parser):

View File

@ -47,37 +47,23 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events): def _filter_events_for_client(self, user_id, events, is_guest=False):
event_id_to_state = yield self.store.get_state_for_events( # Assumes that user has at some point joined the room if not is_guest.
frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state): def allowed(event, membership, visibility):
if event.type == EventTypes.RoomHistoryVisibility: if visibility == "world_readable":
return True return True
membership_ev = state.get((EventTypes.Member, user_id), None) if is_guest:
if membership_ev: return False
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN: if membership == Membership.JOIN:
return True return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None) if event.type == EventTypes.RoomHistoryVisibility:
if history: return not is_guest
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public": if visibility == "shared":
return True
elif visibility == "shared":
return True return True
elif visibility == "joined": elif visibility == "joined":
return membership == Membership.JOIN return membership == Membership.JOIN
@ -86,11 +72,44 @@ class BaseHandler(object):
return True return True
defer.returnValue([ event_id_to_state = yield self.store.get_state_for_events(
event frozenset(e.event_id for e in events),
for event in events types=(
if allowed(event, event_id_to_state[event.event_id]) (EventTypes.RoomHistoryVisibility, ""),
]) (EventTypes.Member, user_id),
)
)
events_to_return = []
for event in events:
state = event_id_to_state[event.event_id]
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
membership = membership_event.membership
else:
membership = None
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if visibility_event:
visibility = visibility_event.content.get("history_visibility", "shared")
else:
visibility = "shared"
should_include = allowed(event, membership, visibility)
if should_include:
events_to_return.append(event)
if is_guest and len(events_to_return) < len(events):
# This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility
# boundaries.
raise AuthError(403, "User %s does not have permission" % (
user_id
))
defer.returnValue(events_to_return)
def ratelimit(self, user_id): def ratelimit(self, user_id):
time_now = self.clock.time() time_now = self.clock.time()

View File

@ -372,12 +372,15 @@ class AuthHandler(BaseHandler):
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id): def generate_access_token(self, user_id, extra_caveats=None):
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000) expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize() return macaroon.serialize()
def generate_refresh_token(self, user_id): def generate_refresh_token(self, user_id):

View File

@ -71,20 +71,20 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None, def get_messages(self, user_id=None, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True, is_guest=False):
"""Get messages in a room. """Get messages in a room.
Args: Args:
user_id (str): The user requesting messages. user_id (str): The user requesting messages.
room_id (str): The room they want messages from. room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
is_guest (bool): Whether the requesting user is a guest (as opposed
to a fully registered user).
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"] data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
@ -107,23 +107,27 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
if member_event.membership == Membership.LEAVE: if not is_guest:
# If they have left the room then clamp the token to be before member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
# they left the room if member_event.membership == Membership.LEAVE:
leave_token = yield self.store.get_topological_token_for_event( # If they have left the room then clamp the token to be before
member_event.event_id # they left the room.
) # If they're a guest, we'll just 403 them if they're asking for
leave_token = RoomStreamToken.parse(leave_token) # events they can't see.
if leave_token.topological < room_token.topological: leave_token = yield self.store.get_topological_token_for_event(
source_config.from_key = str(leave_token) member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f": if source_config.direction == "f":
if source_config.to_key is None: if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token) source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill( yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological room_id, room_token.topological
@ -146,7 +150,7 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
events = yield self._filter_events_for_client(user_id, events) events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None): def register(self, localpart=None, password=None, generate_token=True):
"""Registers a new client on the server. """Registers a new client on the server.
Args: Args:
@ -89,7 +89,9 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self.auth_handler().generate_access_token(user_id) token = None
if generate_token:
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -102,14 +104,14 @@ class RegistrationHandler(BaseHandler):
attempts = 0 attempts = 0
user_id = None user_id = None
token = None token = None
while not user_id and not token: while not user_id:
try: try:
localpart = self._generate_user_id() localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
if generate_token:
token = self.auth_handler().generate_access_token(user_id) token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,

View File

@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:

View File

@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
try: try:
# try to auth as a user # try to auth as a user
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
try: try:
user_id = user.to_string() user_id = user.to_string()
yield dir_handler.create_association( yield dir_handler.create_association(
@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:

View File

@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)

View File

@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler

View File

@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} state = {}
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):

View File

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:

View File

@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']: if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_DELETE(self, request): def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is

View File

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)

View File

@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request) room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None) info = yield self.make_room(room_config, auth_user, None)
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid # other if it fails to parse, without swallowing other valid
@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
@ -325,7 +325,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
@ -334,6 +334,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
msgs = yield handler.get_messages( msgs = yield handler.get_messages(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event as_client_event=as_client_event
) )
@ -347,7 +348,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
@ -363,7 +364,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
@ -443,7 +444,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -524,7 +525,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -564,7 +565,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))
@ -597,7 +598,7 @@ class SearchRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)

View File

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret

View File

@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet):
if LoginType.PASSWORD in result: if LoginType.PASSWORD in result:
# if using password, they should also be logged in # if using password, they should also be logged in
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]: if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN) raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string() user_id = auth_user.to_string()
@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
yield run_on_reactor() yield run_on_reactor()
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids( threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string() auth_user.to_string()
@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet):
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM) raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds'] threePidCreds = body['threePidCreds']
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds) threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)

View File

@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id): def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users") raise AuthError(403, "Cannot get filters for other users")
@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if target_user != auth_user: if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users") raise AuthError(403, "Cannot create filters for other users")

View File

@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, device_id): def on_POST(self, request, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
# TODO: Check that the device_id matches that in the authentication # TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead. # or derive the device_id from the authentication instead.
@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id): def on_GET(self, request, user_id, device_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string() auth_user_id = auth_user.to_string()
user_id = user_id if user_id else auth_user_id user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else [] device_ids = [device_id] if device_id else []

View File

@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id): def on_POST(self, request, room_id, receipt_type, event_id):
user, _ = yield self.auth.get_user_by_req(request) user, _, _ = yield self.auth.get_user_by_req(request)
if receipt_type != "m.read": if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'") raise SynapseError(400, "Receipt type must be 'm.read'")

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_pattern, parse_json_dict_from_request
@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet):
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
kind = "user"
if "kind" in request.args:
kind = request.args["kind"][0]
if kind == "guest":
ret = yield self._do_guest_registration()
defer.returnValue(ret)
return
elif kind != "user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path: if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request) ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret) defer.returnValue(ret)
@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet):
ret = yield self.identity_handler.requestEmailToken(**body) ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret)) defer.returnValue((200, ret))
@defer.inlineCallbacks
def _do_guest_registration(self):
if not self.hs.config.allow_guest_access:
defer.returnValue((403, "Guest access is disabled"))
user_id, _ = yield self.registration_handler.register(generate_token=False)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
}))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user, token_id = yield self.auth.get_user_by_req(request) user, token_id, _ = yield self.auth.get_user_by_req(request)
timeout = parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since") since = parse_string(request, "since")

View File

@ -42,7 +42,7 @@ class TagListServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, room_id): def on_GET(self, request, user_id, room_id):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != auth_user.to_string():
raise AuthError(403, "Cannot get tags for other users.") raise AuthError(403, "Cannot get tags for other users.")
@ -68,7 +68,7 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag): def on_PUT(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")
@ -88,7 +88,7 @@ class TagServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag): def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string(): if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.") raise AuthError(403, "Cannot add tags for other users.")

View File

@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def map_request_to_name(self, request): def map_request_to_name(self, request):
# auth the user # auth the user
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user # namespace all file uploads on the user
prefix = base64.urlsafe_b64encode( prefix = base64.urlsafe_b64encode(

View File

@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request) auth_user, _, _ = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")

View File

@ -102,13 +102,14 @@ class RegistrationStore(SQLBaseStore):
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
# it's possible for this to get a conflict, but only for a single user if token:
# since tokens are namespaced based on their user ID # it's possible for this to get a conflict, but only for a single user
txn.execute( # since tokens are namespaced based on their user ID
"INSERT INTO access_tokens(id, user_id, token)" txn.execute(
" VALUES (?,?,?)", "INSERT INTO access_tokens(id, user_id, token)"
(next_id, user_id, token,) " VALUES (?,?,?)",
) (next_id, user_id, token,)
)
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self._simple_select_one( return self._simple_select_one(

View File

@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self): def test_get_user_by_req_user_bad_token(self):
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={}) request = Mock(args={})
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), self.test_user) self.assertEquals(user.to_string(), self.test_user)
def test_get_user_by_req_appservice_bad_token(self): def test_get_user_by_req_appservice_bad_token(self):
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
request.args["access_token"] = [self.test_token] request.args["access_token"] = [self.test_token]
request.args["user_id"] = [masquerading_user_id] request.args["user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = Mock(return_value=[""]) request.requestHeaders.getRawHeaders = Mock(return_value=[""])
(user, _) = yield self.auth.get_user_by_req(request) (user, _, _) = yield self.auth.get_user_by_req(request)
self.assertEquals(user.to_string(), masquerading_user_id) self.assertEquals(user.to_string(), masquerading_user_id)
def test_get_user_by_req_appservice_valid_token_bad_user_id(self): def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase):
user = user_info["user"] user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user) self.assertEqual(UserID.from_string(user_id), user)
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
user_info = yield self.auth._get_user_from_macaroon(serialized)
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_from_macaroon_user_db_mismatch(self): def test_get_user_from_macaroon_user_db_mismatch(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(

View File

@ -86,10 +86,11 @@ class PresenceStateTestCase(unittest.TestCase):
return defer.succeed([]) return defer.succeed([])
self.datastore.get_presence_list = get_presence_list self.datastore.get_presence_list = get_presence_list
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -173,10 +174,11 @@ class PresenceListTestCase(unittest.TestCase):
) )
self.datastore.has_presence_state = has_presence_state self.datastore.has_presence_state = has_presence_state
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.handlers.room_member_handler = Mock( hs.handlers.room_member_handler = Mock(
@ -291,8 +293,8 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000 hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None): def _get_user_by_req(req=None, allow_guest=False):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View File

@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase):
replication_layer=Mock(), replication_layer=Mock(),
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None, allow_guest=False):
return (UserID.from_string(myid), "") return (UserID.from_string(myid), "", False)
hs.get_v1auth().get_user_by_req = _get_user_by_req hs.get_v1auth().get_user_by_req = _get_user_by_req

View File

@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase):
self.auth_user_id = self.user_id self.auth_user_id = self.user_id
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token

View File

@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock() hs.get_handlers().federation_handler = Mock()
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token

View File

@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase):
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
) )
def _get_user_by_access_token(token=None): def _get_user_by_access_token(token=None, allow_guest=False):
return { return {
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"token_id": 1, "token_id": 1,
"is_guest": False,
} }
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token hs.get_auth()._get_user_by_access_token = _get_user_by_access_token