mirror of
				https://git.anonymousland.org/anonymousland/synapse.git
				synced 2025-10-31 08:48:51 -04:00 
			
		
		
		
	Convert synapse.api to async/await (#8031)
This commit is contained in:
		
							parent
							
								
									c36228c403
								
							
						
					
					
						commit
						d4a7829b12
					
				
					 22 changed files with 171 additions and 159 deletions
				
			
		|  | @ -13,12 +13,11 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| import logging | ||||
| from typing import Optional | ||||
| from typing import List, Optional, Tuple | ||||
| 
 | ||||
| import pymacaroons | ||||
| from netaddr import IPAddress | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| from twisted.web.server import Request | ||||
| 
 | ||||
| import synapse.types | ||||
|  | @ -80,13 +79,14 @@ class Auth(object): | |||
|         self._track_appservice_user_ips = hs.config.track_appservice_user_ips | ||||
|         self._macaroon_secret_key = hs.config.macaroon_secret_key | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def check_from_context(self, room_version: str, event, context, do_sig_check=True): | ||||
|         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids()) | ||||
|         auth_events_ids = yield self.compute_auth_events( | ||||
|     async def check_from_context( | ||||
|         self, room_version: str, event, context, do_sig_check=True | ||||
|     ): | ||||
|         prev_state_ids = await context.get_prev_state_ids() | ||||
|         auth_events_ids = self.compute_auth_events( | ||||
|             event, prev_state_ids, for_verification=True | ||||
|         ) | ||||
|         auth_events = yield self.store.get_events(auth_events_ids) | ||||
|         auth_events = await self.store.get_events(auth_events_ids) | ||||
|         auth_events = {(e.type, e.state_key): e for e in auth_events.values()} | ||||
| 
 | ||||
|         room_version_obj = KNOWN_ROOM_VERSIONS[room_version] | ||||
|  | @ -94,14 +94,13 @@ class Auth(object): | |||
|             room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def check_user_in_room( | ||||
|     async def check_user_in_room( | ||||
|         self, | ||||
|         room_id: str, | ||||
|         user_id: str, | ||||
|         current_state: Optional[StateMap[EventBase]] = None, | ||||
|         allow_departed_users: bool = False, | ||||
|     ): | ||||
|     ) -> EventBase: | ||||
|         """Check if the user is in the room, or was at some point. | ||||
|         Args: | ||||
|             room_id: The room to check. | ||||
|  | @ -119,37 +118,35 @@ class Auth(object): | |||
|         Raises: | ||||
|             AuthError if the user is/was not in the room. | ||||
|         Returns: | ||||
|             Deferred[Optional[EventBase]]: | ||||
|                 Membership event for the user if the user was in the | ||||
|                 room. This will be the join event if they are currently joined to | ||||
|                 the room. This will be the leave event if they have left the room. | ||||
|             Membership event for the user if the user was in the | ||||
|             room. This will be the join event if they are currently joined to | ||||
|             the room. This will be the leave event if they have left the room. | ||||
|         """ | ||||
|         if current_state: | ||||
|             member = current_state.get((EventTypes.Member, user_id), None) | ||||
|         else: | ||||
|             member = yield defer.ensureDeferred( | ||||
|                 self.state.get_current_state( | ||||
|                     room_id=room_id, event_type=EventTypes.Member, state_key=user_id | ||||
|                 ) | ||||
|             member = await self.state.get_current_state( | ||||
|                 room_id=room_id, event_type=EventTypes.Member, state_key=user_id | ||||
|             ) | ||||
|         membership = member.membership if member else None | ||||
| 
 | ||||
|         if membership == Membership.JOIN: | ||||
|             return member | ||||
|         if member: | ||||
|             membership = member.membership | ||||
| 
 | ||||
|         # XXX this looks totally bogus. Why do we not allow users who have been banned, | ||||
|         # or those who were members previously and have been re-invited? | ||||
|         if allow_departed_users and membership == Membership.LEAVE: | ||||
|             forgot = yield self.store.did_forget(user_id, room_id) | ||||
|             if not forgot: | ||||
|             if membership == Membership.JOIN: | ||||
|                 return member | ||||
| 
 | ||||
|             # XXX this looks totally bogus. Why do we not allow users who have been banned, | ||||
|             # or those who were members previously and have been re-invited? | ||||
|             if allow_departed_users and membership == Membership.LEAVE: | ||||
|                 forgot = await self.store.did_forget(user_id, room_id) | ||||
|                 if not forgot: | ||||
|                     return member | ||||
| 
 | ||||
|         raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def check_host_in_room(self, room_id, host): | ||||
|     async def check_host_in_room(self, room_id, host): | ||||
|         with Measure(self.clock, "check_host_in_room"): | ||||
|             latest_event_ids = yield self.store.is_host_joined(room_id, host) | ||||
|             latest_event_ids = await self.store.is_host_joined(room_id, host) | ||||
|             return latest_event_ids | ||||
| 
 | ||||
|     def can_federate(self, event, auth_events): | ||||
|  | @ -160,14 +157,13 @@ class Auth(object): | |||
|     def get_public_keys(self, invite_event): | ||||
|         return event_auth.get_public_keys(invite_event) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_user_by_req( | ||||
|     async def get_user_by_req( | ||||
|         self, | ||||
|         request: Request, | ||||
|         allow_guest: bool = False, | ||||
|         rights: str = "access", | ||||
|         allow_expired: bool = False, | ||||
|     ): | ||||
|     ) -> synapse.types.Requester: | ||||
|         """ Get a registered user's ID. | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -180,7 +176,7 @@ class Auth(object): | |||
|                 /login will deliver access tokens regardless of expiration. | ||||
| 
 | ||||
|         Returns: | ||||
|             defer.Deferred: resolves to a `synapse.types.Requester` object | ||||
|             Resolves to the requester | ||||
|         Raises: | ||||
|             InvalidClientCredentialsError if no user by that token exists or the token | ||||
|                 is invalid. | ||||
|  | @ -194,14 +190,14 @@ class Auth(object): | |||
| 
 | ||||
|             access_token = self.get_access_token_from_request(request) | ||||
| 
 | ||||
|             user_id, app_service = yield self._get_appservice_user_id(request) | ||||
|             user_id, app_service = await self._get_appservice_user_id(request) | ||||
|             if user_id: | ||||
|                 request.authenticated_entity = user_id | ||||
|                 opentracing.set_tag("authenticated_entity", user_id) | ||||
|                 opentracing.set_tag("appservice_id", app_service.id) | ||||
| 
 | ||||
|                 if ip_addr and self._track_appservice_user_ips: | ||||
|                     yield self.store.insert_client_ip( | ||||
|                     await self.store.insert_client_ip( | ||||
|                         user_id=user_id, | ||||
|                         access_token=access_token, | ||||
|                         ip=ip_addr, | ||||
|  | @ -211,7 +207,7 @@ class Auth(object): | |||
| 
 | ||||
|                 return synapse.types.create_requester(user_id, app_service=app_service) | ||||
| 
 | ||||
|             user_info = yield self.get_user_by_access_token( | ||||
|             user_info = await self.get_user_by_access_token( | ||||
|                 access_token, rights, allow_expired=allow_expired | ||||
|             ) | ||||
|             user = user_info["user"] | ||||
|  | @ -221,7 +217,7 @@ class Auth(object): | |||
|             # Deny the request if the user account has expired. | ||||
|             if self._account_validity.enabled and not allow_expired: | ||||
|                 user_id = user.to_string() | ||||
|                 expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) | ||||
|                 expiration_ts = await self.store.get_expiration_ts_for_user(user_id) | ||||
|                 if ( | ||||
|                     expiration_ts is not None | ||||
|                     and self.clock.time_msec() >= expiration_ts | ||||
|  | @ -235,7 +231,7 @@ class Auth(object): | |||
|             device_id = user_info.get("device_id") | ||||
| 
 | ||||
|             if user and access_token and ip_addr: | ||||
|                 yield self.store.insert_client_ip( | ||||
|                 await self.store.insert_client_ip( | ||||
|                     user_id=user.to_string(), | ||||
|                     access_token=access_token, | ||||
|                     ip=ip_addr, | ||||
|  | @ -261,8 +257,7 @@ class Auth(object): | |||
|         except KeyError: | ||||
|             raise MissingClientTokenError() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_appservice_user_id(self, request): | ||||
|     async def _get_appservice_user_id(self, request): | ||||
|         app_service = self.store.get_app_service_by_token( | ||||
|             self.get_access_token_from_request(request) | ||||
|         ) | ||||
|  | @ -283,14 +278,13 @@ class Auth(object): | |||
| 
 | ||||
|         if not app_service.is_interested_in_user(user_id): | ||||
|             raise AuthError(403, "Application service cannot masquerade as this user.") | ||||
|         if not (yield self.store.get_user_by_id(user_id)): | ||||
|         if not (await self.store.get_user_by_id(user_id)): | ||||
|             raise AuthError(403, "Application service has not registered this user") | ||||
|         return user_id, app_service | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_user_by_access_token( | ||||
|     async def get_user_by_access_token( | ||||
|         self, token: str, rights: str = "access", allow_expired: bool = False, | ||||
|     ): | ||||
|     ) -> dict: | ||||
|         """ Validate access token and get user_id from it | ||||
| 
 | ||||
|         Args: | ||||
|  | @ -300,7 +294,7 @@ class Auth(object): | |||
|             allow_expired: If False, raises an InvalidClientTokenError | ||||
|                 if the token is expired | ||||
|         Returns: | ||||
|             Deferred[dict]: dict that includes: | ||||
|             dict that includes: | ||||
|                `user` (UserID) | ||||
|                `is_guest` (bool) | ||||
|                `token_id` (int|None): access token id. May be None if guest | ||||
|  | @ -314,7 +308,7 @@ class Auth(object): | |||
| 
 | ||||
|         if rights == "access": | ||||
|             # first look in the database | ||||
|             r = yield self._look_up_user_by_access_token(token) | ||||
|             r = await self._look_up_user_by_access_token(token) | ||||
|             if r: | ||||
|                 valid_until_ms = r["valid_until_ms"] | ||||
|                 if ( | ||||
|  | @ -352,7 +346,7 @@ class Auth(object): | |||
|                 # It would of course be much easier to store guest access | ||||
|                 # tokens in the database as well, but that would break existing | ||||
|                 # guest tokens. | ||||
|                 stored_user = yield self.store.get_user_by_id(user_id) | ||||
|                 stored_user = await self.store.get_user_by_id(user_id) | ||||
|                 if not stored_user: | ||||
|                     raise InvalidClientTokenError("Unknown user_id %s" % user_id) | ||||
|                 if not stored_user["is_guest"]: | ||||
|  | @ -482,9 +476,8 @@ class Auth(object): | |||
|         now = self.hs.get_clock().time_msec() | ||||
|         return now < expiry | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _look_up_user_by_access_token(self, token): | ||||
|         ret = yield self.store.get_user_by_access_token(token) | ||||
|     async def _look_up_user_by_access_token(self, token): | ||||
|         ret = await self.store.get_user_by_access_token(token) | ||||
|         if not ret: | ||||
|             return None | ||||
| 
 | ||||
|  | @ -507,7 +500,7 @@ class Auth(object): | |||
|             logger.warning("Unrecognised appservice access token.") | ||||
|             raise InvalidClientTokenError() | ||||
|         request.authenticated_entity = service.sender | ||||
|         return defer.succeed(service) | ||||
|         return service | ||||
| 
 | ||||
|     async def is_server_admin(self, user: UserID) -> bool: | ||||
|         """ Check if the given user is a local server admin. | ||||
|  | @ -522,7 +515,7 @@ class Auth(object): | |||
| 
 | ||||
|     def compute_auth_events( | ||||
|         self, event, current_state_ids: StateMap[str], for_verification: bool = False, | ||||
|     ): | ||||
|     ) -> List[str]: | ||||
|         """Given an event and current state return the list of event IDs used | ||||
|         to auth an event. | ||||
| 
 | ||||
|  | @ -530,11 +523,11 @@ class Auth(object): | |||
|         should be added to the event's `auth_events`. | ||||
| 
 | ||||
|         Returns: | ||||
|             defer.Deferred(list[str]): List of event IDs. | ||||
|             List of event IDs. | ||||
|         """ | ||||
| 
 | ||||
|         if event.type == EventTypes.Create: | ||||
|             return defer.succeed([]) | ||||
|             return [] | ||||
| 
 | ||||
|         # Currently we ignore the `for_verification` flag even though there are | ||||
|         # some situations where we can drop particular auth events when adding | ||||
|  | @ -553,7 +546,7 @@ class Auth(object): | |||
|             if auth_ev_id: | ||||
|                 auth_ids.append(auth_ev_id) | ||||
| 
 | ||||
|         return defer.succeed(auth_ids) | ||||
|         return auth_ids | ||||
| 
 | ||||
|     async def check_can_change_room_list(self, room_id: str, user: UserID): | ||||
|         """Determine whether the user is allowed to edit the room's entry in the | ||||
|  | @ -636,10 +629,9 @@ class Auth(object): | |||
| 
 | ||||
|             return query_params[0].decode("ascii") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def check_user_in_room_or_world_readable( | ||||
|     async def check_user_in_room_or_world_readable( | ||||
|         self, room_id: str, user_id: str, allow_departed_users: bool = False | ||||
|     ): | ||||
|     ) -> Tuple[str, Optional[str]]: | ||||
|         """Checks that the user is or was in the room or the room is world | ||||
|         readable. If it isn't then an exception is raised. | ||||
| 
 | ||||
|  | @ -650,10 +642,9 @@ class Auth(object): | |||
|                 members but have now departed | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[tuple[str, str|None]]: Resolves to the current membership of | ||||
|                 the user in the room and the membership event ID of the user. If | ||||
|                 the user is not in the room and never has been, then | ||||
|                 `(Membership.JOIN, None)` is returned. | ||||
|             Resolves to the current membership of the user in the room and the | ||||
|             membership event ID of the user. If the user is not in the room and | ||||
|             never has been, then `(Membership.JOIN, None)` is returned. | ||||
|         """ | ||||
| 
 | ||||
|         try: | ||||
|  | @ -662,15 +653,13 @@ class Auth(object): | |||
|             #  * The user is a non-guest user, and was ever in the room | ||||
|             #  * The user is a guest user, and has joined the room | ||||
|             # else it will throw. | ||||
|             member_event = yield self.check_user_in_room( | ||||
|             member_event = await self.check_user_in_room( | ||||
|                 room_id, user_id, allow_departed_users=allow_departed_users | ||||
|             ) | ||||
|             return member_event.membership, member_event.event_id | ||||
|         except AuthError: | ||||
|             visibility = yield defer.ensureDeferred( | ||||
|                 self.state.get_current_state( | ||||
|                     room_id, EventTypes.RoomHistoryVisibility, "" | ||||
|                 ) | ||||
|             visibility = await self.state.get_current_state( | ||||
|                 room_id, EventTypes.RoomHistoryVisibility, "" | ||||
|             ) | ||||
|             if ( | ||||
|                 visibility | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Patrick Cloke
						Patrick Cloke