Convert directory handler to async/await (#7727)

This commit is contained in:
Patrick Cloke 2020-06-22 07:18:00 -04:00 committed by GitHub
parent 91e886d615
commit e060bf4462
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 40 deletions

1
changelog.d/7727.misc Normal file
View File

@ -0,0 +1 @@
Convert directory handler to async/await.

View File

@ -17,8 +17,6 @@ import logging
import string import string
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
from twisted.internet import defer
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -55,8 +53,7 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks async def _create_association(
def _create_association(
self, self,
room_alias: RoomAlias, room_alias: RoomAlias,
room_id: str, room_id: str,
@ -76,13 +73,13 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
users = yield self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
servers = {get_domain_from_id(u) for u in users} servers = {get_domain_from_id(u) for u in users}
if not servers: if not servers:
raise SynapseError(400, "Failed to get server list") raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association( await self.store.create_room_alias_association(
room_alias, room_id, servers, creator=creator room_alias, room_id, servers, creator=creator
) )
@ -93,7 +90,7 @@ class DirectoryHandler(BaseHandler):
room_id: str, room_id: str,
servers: Optional[List[str]] = None, servers: Optional[List[str]] = None,
check_membership: bool = True, check_membership: bool = True,
): ) -> None:
"""Attempt to create a new alias """Attempt to create a new alias
Args: Args:
@ -103,9 +100,6 @@ class DirectoryHandler(BaseHandler):
servers: Iterable of servers that others servers should try and join via servers: Iterable of servers that others servers should try and join via
check_membership: Whether to check if the user is in the room check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it). before the alias can be set (if the server's config requires it).
Returns:
Deferred
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
@ -148,7 +142,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule? # per alias creation rule?
raise SynapseError(403, "Not allowed to create alias") raise SynapseError(403, "Not allowed to create alias")
can_create = await self.can_modify_alias(room_alias, user_id=user_id) can_create = self.can_modify_alias(room_alias, user_id=user_id)
if not can_create: if not can_create:
raise AuthError( raise AuthError(
400, 400,
@ -158,7 +152,9 @@ class DirectoryHandler(BaseHandler):
await self._create_association(room_alias, room_id, servers, creator=user_id) await self._create_association(room_alias, room_id, servers, creator=user_id)
async def delete_association(self, requester: Requester, room_alias: RoomAlias): async def delete_association(
self, requester: Requester, room_alias: RoomAlias
) -> str:
"""Remove an alias from the directory """Remove an alias from the directory
(this is only meant for human users; AS users should call (this is only meant for human users; AS users should call
@ -169,7 +165,7 @@ class DirectoryHandler(BaseHandler):
room_alias room_alias
Returns: Returns:
Deferred[unicode]: room id that the alias used to point to room id that the alias used to point to
Raises: Raises:
NotFoundError: if the alias doesn't exist NotFoundError: if the alias doesn't exist
@ -191,7 +187,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete: if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.") raise AuthError(403, "You don't have permission to delete the alias.")
can_delete = await self.can_modify_alias(room_alias, user_id=user_id) can_delete = self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete: if not can_delete:
raise SynapseError( raise SynapseError(
400, 400,
@ -208,8 +204,7 @@ class DirectoryHandler(BaseHandler):
return room_id return room_id
@defer.inlineCallbacks async def delete_appservice_association(
def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias self, service: ApplicationService, room_alias: RoomAlias
): ):
if not service.is_interested_in_alias(room_alias.to_string()): if not service.is_interested_in_alias(room_alias.to_string()):
@ -218,29 +213,27 @@ class DirectoryHandler(BaseHandler):
"This application service has not reserved this kind of alias", "This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE, errcode=Codes.EXCLUSIVE,
) )
yield self._delete_association(room_alias) await self._delete_association(room_alias)
@defer.inlineCallbacks async def _delete_association(self, room_alias: RoomAlias):
def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias) room_id = await self.store.delete_room_alias(room_alias)
return room_id return room_id
@defer.inlineCallbacks async def get_association(self, room_alias: RoomAlias):
def get_association(self, room_alias: RoomAlias):
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result = yield self.get_association_from_room_alias(room_alias) result = await self.get_association_from_room_alias(room_alias)
if result: if result:
room_id = result.room_id room_id = result.room_id
servers = result.servers servers = result.servers
else: else:
try: try:
result = yield self.federation.make_query( result = await self.federation.make_query(
destination=room_alias.domain, destination=room_alias.domain,
query_type="directory", query_type="directory",
args={"room_alias": room_alias.to_string()}, args={"room_alias": room_alias.to_string()},
@ -265,7 +258,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND, Codes.NOT_FOUND,
) )
users = yield self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
extra_servers = {get_domain_from_id(u) for u in users} extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers) servers = set(extra_servers) | set(servers)
@ -277,13 +270,12 @@ class DirectoryHandler(BaseHandler):
return {"room_id": room_id, "servers": servers} return {"room_id": room_id, "servers": servers}
@defer.inlineCallbacks async def on_directory_query(self, args):
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"]) room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room Alias is not hosted on this homeserver") raise SynapseError(400, "Room Alias is not hosted on this homeserver")
result = yield self.get_association_from_room_alias(room_alias) result = await self.get_association_from_room_alias(room_alias)
if result is not None: if result is not None:
return {"room_id": result.room_id, "servers": result.servers} return {"room_id": result.room_id, "servers": result.servers}
@ -344,16 +336,15 @@ class DirectoryHandler(BaseHandler):
ratelimit=False, ratelimit=False,
) )
@defer.inlineCallbacks async def get_association_from_room_alias(self, room_alias: RoomAlias):
def get_association_from_room_alias(self, room_alias: RoomAlias): result = await self.store.get_association_from_room_alias(room_alias)
result = yield self.store.get_association_from_room_alias(room_alias)
if not result: if not result:
# Query AS to see if it exists # Query AS to see if it exists
as_handler = self.appservice_handler as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias) result = await as_handler.query_room_alias_exists(room_alias)
return result return result
def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None): def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None) -> bool:
# Any application service "interested" in an alias they are regexing on # Any application service "interested" in an alias they are regexing on
# can modify the alias. # can modify the alias.
# Users can only modify the alias if ALL the interested services have # Users can only modify the alias if ALL the interested services have
@ -366,12 +357,12 @@ class DirectoryHandler(BaseHandler):
for service in interested_services: for service in interested_services:
if user_id == service.sender: if user_id == service.sender:
# this user IS the app service so they can do whatever they like # this user IS the app service so they can do whatever they like
return defer.succeed(True) return True
elif service.is_exclusive_alias(alias.to_string()): elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias. # another service has an exclusive lock on this alias.
return defer.succeed(False) return False
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
return defer.succeed(True) return True
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str): async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias. """Determine whether a user can delete an alias.
@ -459,8 +450,7 @@ class DirectoryHandler(BaseHandler):
await self.store.set_room_is_public(room_id, making_public) await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks async def edit_published_appservice_room_list(
def edit_published_appservice_room_list(
self, appservice_id: str, network_id: str, room_id: str, visibility: str self, appservice_id: str, network_id: str, room_id: str, visibility: str
): ):
"""Add or remove a room from the appservice/network specific public """Add or remove a room from the appservice/network specific public
@ -475,7 +465,7 @@ class DirectoryHandler(BaseHandler):
if visibility not in ["public", "private"]: if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting") raise SynapseError(400, "Invalid visibility setting")
yield self.store.set_room_is_public_appservice( await self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public" room_id, appservice_id, network_id, visibility == "public"
) )

View File

@ -879,7 +879,9 @@ class EventCreationHandler(object):
""" """
room_alias = RoomAlias.from_string(room_alias_str) room_alias = RoomAlias.from_string(room_alias_str)
try: try:
mapping = yield directory_handler.get_association(room_alias) mapping = yield defer.ensureDeferred(
directory_handler.get_association(room_alias)
)
except SynapseError as e: except SynapseError as e:
# Turn M_NOT_FOUND errors into M_BAD_ALIAS errors. # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
if e.errcode == Codes.NOT_FOUND: if e.errcode == Codes.NOT_FOUND: