diff --git a/changelog.d/7649.misc b/changelog.d/7649.misc new file mode 100644 index 000000000..8a26c8b3b --- /dev/null +++ b/changelog.d/7649.misc @@ -0,0 +1 @@ +Convert registration handler to async/await. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index af812dbda..51979ea43 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -16,8 +16,6 @@ """Contains functions for registering clients.""" import logging -from twisted.internet import defer - from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError @@ -75,8 +73,9 @@ class RegistrationHandler(BaseHandler): self.session_lifetime = hs.config.session_lifetime - @defer.inlineCallbacks - def check_username(self, localpart, guest_access_token=None, assigned_user_id=None): + async def check_username( + self, localpart, guest_access_token=None, assigned_user_id=None + ): if types.contains_invalid_mxid_characters(localpart): raise SynapseError( 400, @@ -113,13 +112,13 @@ class RegistrationHandler(BaseHandler): Codes.INVALID_USERNAME, ) - users = yield self.store.get_users_by_id_case_insensitive(user_id) + users = await self.store.get_users_by_id_case_insensitive(user_id) if users: if not guest_access_token: raise SynapseError( 400, "User ID already taken.", errcode=Codes.USER_IN_USE ) - user_data = yield self.auth.get_user_by_access_token(guest_access_token) + user_data = await self.auth.get_user_by_access_token(guest_access_token) if not user_data["is_guest"] or user_data["user"].localpart != localpart: raise AuthError( 403, @@ -137,8 +136,7 @@ class RegistrationHandler(BaseHandler): except ValueError: pass - @defer.inlineCallbacks - def register_user( + async def register_user( self, localpart=None, password_hash=None, @@ -169,18 +167,18 @@ class RegistrationHandler(BaseHandler): by_admin (bool): True if this registration is being made via the admin api, otherwise False. Returns: - Deferred[str]: user_id + str: user_id Raises: SynapseError if there was a problem registering. """ - yield self.check_registration_ratelimit(address) + self.check_registration_ratelimit(address) # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: - yield self.auth.check_auth_blocking(threepid=threepid) + await self.auth.check_auth_blocking(threepid=threepid) if localpart is not None: - yield self.check_username(localpart, guest_access_token=guest_access_token) + await self.check_username(localpart, guest_access_token=guest_access_token) was_guest = guest_access_token is not None @@ -194,7 +192,7 @@ class RegistrationHandler(BaseHandler): elif default_display_name is None: default_display_name = localpart - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, was_guest=was_guest, @@ -206,11 +204,9 @@ class RegistrationHandler(BaseHandler): ) if self.hs.config.user_directory_search_all_users: - profile = yield self.store.get_profileinfo(localpart) - yield defer.ensureDeferred( - self.user_directory_handler.handle_local_profile_change( - user_id, profile - ) + profile = await self.store.get_profileinfo(localpart) + await self.user_directory_handler.handle_local_profile_change( + user_id, profile ) else: @@ -222,14 +218,14 @@ class RegistrationHandler(BaseHandler): if fail_count > 10: raise SynapseError(500, "Unable to find a suitable guest user ID") - localpart = yield self._generate_user_id() + localpart = await self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() - yield self.check_user_id_not_appservice_exclusive(user_id) + self.check_user_id_not_appservice_exclusive(user_id) if default_display_name is None: default_display_name = localpart try: - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash=password_hash, make_guest=make_guest, @@ -252,7 +248,7 @@ class RegistrationHandler(BaseHandler): user_id, ) else: - yield defer.ensureDeferred(self._auto_join_rooms(user_id)) + await self._auto_join_rooms(user_id) else: logger.info( "Skipping auto-join for %s because consent is required at registration", @@ -270,7 +266,7 @@ class RegistrationHandler(BaseHandler): } # Bind email to new account - yield self._register_email_threepid(user_id, threepid_dict, None) + await self._register_email_threepid(user_id, threepid_dict, None) return user_id @@ -335,8 +331,7 @@ class RegistrationHandler(BaseHandler): """ await self._auto_join_rooms(user_id) - @defer.inlineCallbacks - def appservice_register(self, user_localpart, as_token): + async def appservice_register(self, user_localpart, as_token): user = UserID(user_localpart, self.hs.hostname) user_id = user.to_string() service = self.store.get_app_service_by_token(as_token) @@ -351,11 +346,9 @@ class RegistrationHandler(BaseHandler): service_id = service.id if service.is_exclusive_user(user_id) else None - yield self.check_user_id_not_appservice_exclusive( - user_id, allowed_appservice=service - ) + self.check_user_id_not_appservice_exclusive(user_id, allowed_appservice=service) - yield self.register_with_store( + await self.register_with_store( user_id=user_id, password_hash="", appservice_id=service_id, @@ -387,13 +380,12 @@ class RegistrationHandler(BaseHandler): errcode=Codes.EXCLUSIVE, ) - @defer.inlineCallbacks - def _generate_user_id(self): + async def _generate_user_id(self): if self._next_generated_user_id is None: - with (yield self._generate_user_id_linearizer.queue(())): + with await self._generate_user_id_linearizer.queue(()): if self._next_generated_user_id is None: self._next_generated_user_id = ( - yield self.store.find_next_generated_user_id_localpart() + await self.store.find_next_generated_user_id_localpart() ) id = self._next_generated_user_id @@ -496,8 +488,9 @@ class RegistrationHandler(BaseHandler): user_type=user_type, ) - @defer.inlineCallbacks - def register_device(self, user_id, device_id, initial_display_name, is_guest=False): + async def register_device( + self, user_id, device_id, initial_display_name, is_guest=False + ): """Register a device for a user and generate an access token. The access token will be limited by the homeserver's session_lifetime config. @@ -511,11 +504,11 @@ class RegistrationHandler(BaseHandler): is_guest (bool): Whether this is a guest account Returns: - defer.Deferred[tuple[str, str]]: Tuple of device ID and access token + tuple[str, str]: Tuple of device ID and access token """ if self.hs.config.worker_app: - r = yield self._register_device_client( + r = await self._register_device_client( user_id=user_id, device_id=device_id, initial_display_name=initial_display_name, @@ -531,7 +524,7 @@ class RegistrationHandler(BaseHandler): ) valid_until_ms = self.clock.time_msec() + self.session_lifetime - device_id = yield self.device_handler.check_device_registered( + device_id = await self.device_handler.check_device_registered( user_id, device_id, initial_display_name ) if is_guest: @@ -540,10 +533,8 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield defer.ensureDeferred( - self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms - ) + access_token = await self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms ) return (device_id, access_token) @@ -594,8 +585,7 @@ class RegistrationHandler(BaseHandler): await self.store.user_set_consent_version(user_id, consent_version) await self.post_consent_actions(user_id) - @defer.inlineCallbacks - def _register_email_threepid(self, user_id, threepid, token): + async def _register_email_threepid(self, user_id, threepid, token): """Add an email address as a 3pid identifier Also adds an email pusher for the email address, if configured in the @@ -608,8 +598,6 @@ class RegistrationHandler(BaseHandler): threepid (object): m.login.email.identity auth response token (str|None): access_token for the user, or None if not logged in. - Returns: - defer.Deferred: """ reqd = ("medium", "address", "validated_at") if any(x not in threepid for x in reqd): @@ -617,13 +605,8 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield defer.ensureDeferred( - self._auth_handler.add_threepid( - user_id, - threepid["medium"], - threepid["address"], - threepid["validated_at"], - ) + await self._auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], threepid["validated_at"], ) # And we add an email pusher for them by default, but only @@ -639,10 +622,10 @@ class RegistrationHandler(BaseHandler): # It would really make more sense for this to be passed # up when the access token is saved, but that's quite an # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token(token) + user_tuple = await self.store.get_user_by_access_token(token) token_id = user_tuple["token_id"] - yield self.pusher_pool.add_pusher( + await self.pusher_pool.add_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -654,8 +637,7 @@ class RegistrationHandler(BaseHandler): data={}, ) - @defer.inlineCallbacks - def _register_msisdn_threepid(self, user_id, threepid): + async def _register_msisdn_threepid(self, user_id, threepid): """Add a phone number as a 3pid identifier Must be called on master. @@ -663,8 +645,6 @@ class RegistrationHandler(BaseHandler): Args: user_id (str): id of user threepid (object): m.login.msisdn auth response - Returns: - defer.Deferred: """ try: assert_params_in_dict(threepid, ["medium", "address", "validated_at"]) @@ -675,11 +655,6 @@ class RegistrationHandler(BaseHandler): return None raise - yield defer.ensureDeferred( - self._auth_handler.add_threepid( - user_id, - threepid["medium"], - threepid["address"], - threepid["validated_at"], - ) + await self._auth_handler.add_threepid( + user_id, threepid["medium"], threepid["address"], threepid["validated_at"], ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index d678c0eb9..ecdf1ad69 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -128,8 +128,12 @@ class ModuleApi(object): Returns: Deferred[str]: user_id """ - return self._hs.get_registration_handler().register_user( - localpart=localpart, default_display_name=displayname, bind_emails=emails + return defer.ensureDeferred( + self._hs.get_registration_handler().register_user( + localpart=localpart, + default_display_name=displayname, + bind_emails=emails, + ) ) def register_device(self, user_id, device_id=None, initial_display_name=None):