Merge pull request #6482 from matrix-org/erikj/port_rest_v1

Port rest/v1 to async/await
This commit is contained in:
Erik Johnston 2019-12-05 16:40:06 +00:00 committed by GitHub
commit af5d0ebc72
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 134 additions and 169 deletions

1
changelog.d/6482.misc Normal file
View File

@ -0,0 +1 @@
Port synapse.rest.client.v1 to async/await.

View File

@ -151,7 +151,7 @@ class SynchrotronPresence(object):
def set_state(self, user, state, ignore_status_msg=False): def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work? # TODO Hows this supposed to work?
pass return defer.succeed(None)
get_states = __func__(PresenceHandler.get_states) get_states = __func__(PresenceHandler.get_states)
get_state = __func__(PresenceHandler.get_state) get_state = __func__(PresenceHandler.get_state)

View File

@ -16,8 +16,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_alias):
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias) res = await dir_handler.get_association(room_alias)
return 200, res return 200, res
@defer.inlineCallbacks async def on_PUT(self, request, room_alias):
def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet):
# TODO(erikj): Check types. # TODO(erikj): Check types.
room = yield self.store.get_room(room_id) room = await self.store.get_room(room_id)
if room is None: if room is None:
raise SynapseError(400, "Room does not exist") raise SynapseError(400, "Room does not exist")
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
yield self.handlers.directory_handler.create_association( await self.handlers.directory_handler.create_association(
requester, room_alias, room_id, servers requester, room_alias, room_id, servers
) )
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, room_alias):
def on_DELETE(self, request, room_alias):
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
try: try:
service = yield self.auth.get_appservice_by_req(request) service = await self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_appservice_association(service, room_alias) await dir_handler.delete_appservice_association(service, room_alias)
logger.info( logger.info(
"Application service at %s deleted alias %s", "Application service at %s deleted alias %s",
service.url, service.url,
@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet):
# fallback to default user behaviour if they aren't an AS # fallback to default user behaviour if they aren't an AS
pass pass
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = requester.user user = requester.user
room_alias = RoomAlias.from_string(room_alias) room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association(requester, room_alias) await dir_handler.delete_association(requester, room_alias)
logger.info( logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string() "User %s deleted alias %s", user.to_string(), room_alias.to_string()
@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet):
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, room_id):
def on_GET(self, request, room_id): room = await self.store.get_room(room_id)
room = yield self.store.get_room(room_id)
if room is None: if room is None:
raise NotFoundError("Unknown room") raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"} return 200, {"visibility": "public" if room["is_public"] else "private"}
@defer.inlineCallbacks async def on_PUT(self, request, room_id):
def on_PUT(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public") visibility = content.get("visibility", "public")
yield self.handlers.directory_handler.edit_published_room_list( await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, visibility requester, room_id, visibility
) )
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, room_id):
def on_DELETE(self, request, room_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
yield self.handlers.directory_handler.edit_published_room_list( await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, "private" requester, room_id, "private"
) )
@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def on_DELETE(self, request, network_id, room_id): def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private") return self._edit(request, network_id, room_id, "private")
@defer.inlineCallbacks async def _edit(self, request, network_id, room_id, visibility):
def _edit(self, request, network_id, room_id, visibility): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if not requester.app_service: if not requester.app_service:
raise AuthError( raise AuthError(
403, "Only appservices can edit the appservice published room list" 403, "Only appservices can edit the appservice published room list"
) )
yield self.handlers.directory_handler.edit_published_appservice_room_list( await self.handlers.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility requester.app_service.id, network_id, room_id, visibility
) )

View File

@ -16,8 +16,6 @@
"""This module contains REST servlets to do with event streaming, /events.""" """This module contains REST servlets to do with event streaming, /events."""
import logging import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet):
self.event_stream_handler = hs.get_event_stream_handler() self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest is_guest = requester.is_guest
room_id = None room_id = None
if is_guest: if is_guest:
@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet):
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
chunk = yield self.event_stream_handler.get_stream( chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(), requester.user.to_string(),
pagin_config, pagin_config,
timeout=timeout, timeout=timeout,
@ -83,14 +80,13 @@ class EventRestServlet(RestServlet):
self.event_handler = hs.get_event_handler() self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks async def on_GET(self, request, event_id):
def on_GET(self, request, event_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request) event = await self.event_handler.get_event(requester.user, None, event_id)
event = yield self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
if event: if event:
event = yield self._event_serializer.serialize_event(event, time_now) event = await self._event_serializer.serialize_event(event, time_now)
return 200, event return 200, event
else: else:
return 404, "Event not found." return 404, "Event not found."

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_boolean from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler() self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False) include_archived = parse_boolean(request, "archived", default=False)
content = yield self.initial_sync_handler.snapshot_all_rooms( content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
pagin_config=pagination_config, pagin_config=pagination_config,
as_client_event=as_client_event, as_client_event=as_client_event,

View File

@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET
from six.moves import urllib from six.moves import urllib
from twisted.internet import defer
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.errors import Codes, LoginError, SynapseError
@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet):
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request):
self._address_ratelimiter.ratelimit( self._address_ratelimiter.ratelimit(
request.getClientIP(), request.getClientIP(),
time_now_s=self.hs.clock.time(), time_now_s=self.hs.clock.time(),
@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet):
if self.jwt_enabled and ( if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE login_submission["type"] == LoginRestServlet.JWT_TYPE
): ):
result = yield self.do_jwt_login(login_submission) result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE: elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission) result = await self.do_token_login(login_submission)
else: else:
result = yield self._do_other_login(login_submission) result = await self._do_other_login(login_submission)
except KeyError: except KeyError:
raise SynapseError(400, "Missing JSON keys.") raise SynapseError(400, "Missing JSON keys.")
@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data result["well_known"] = well_known_data
return 200, result return 200, result
@defer.inlineCallbacks async def _do_other_login(self, login_submission):
def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins """Handle non-token/saml/jwt logins
Args: Args:
@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet):
( (
canonical_user_id, canonical_user_id,
callback_3pid, callback_3pid,
) = yield self.auth_handler.check_password_provider_3pid( ) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"] medium, address, login_submission["password"]
) )
if canonical_user_id: if canonical_user_id:
# Authentication through password provider and 3pid succeeded # Authentication through password provider and 3pid succeeded
result = yield self._complete_login( result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid canonical_user_id, login_submission, callback_3pid
) )
return result return result
# No password providers were able to handle this 3pid # No password providers were able to handle this 3pid
# Check local store # Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid( user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address medium, address
) )
if not user_id: if not user_id:
@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet):
) )
try: try:
canonical_user_id, callback = yield self.auth_handler.validate_login( canonical_user_id, callback = await self.auth_handler.validate_login(
identifier["user"], login_submission identifier["user"], login_submission
) )
except LoginError: except LoginError:
@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet):
) )
raise raise
result = yield self._complete_login( result = await self._complete_login(
canonical_user_id, login_submission, callback canonical_user_id, login_submission, callback
) )
return result return result
@defer.inlineCallbacks async def _complete_login(
def _complete_login(
self, user_id, login_submission, callback=None, create_non_existant_users=False self, user_id, login_submission, callback=None, create_non_existant_users=False
): ):
"""Called when we've successfully authed the user and now need to """Called when we've successfully authed the user and now need to
@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet):
) )
if create_non_existant_users: if create_non_existant_users:
user_id = yield self.auth_handler.check_user_exists(user_id) user_id = await self.auth_handler.check_user_exists(user_id)
if not user_id: if not user_id:
user_id = yield self.registration_handler.register_user( user_id = await self.registration_handler.register_user(
localpart=UserID.from_string(user_id).localpart localpart=UserID.from_string(user_id).localpart
) )
device_id = login_submission.get("device_id") device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name") initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device( device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name
) )
@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet):
} }
if callback is not None: if callback is not None:
yield callback(result) await callback(result)
return result return result
@defer.inlineCallbacks async def do_token_login(self, login_submission):
def do_token_login(self, login_submission):
token = login_submission["token"] token = login_submission["token"]
auth_handler = self.auth_handler auth_handler = self.auth_handler
user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id( user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token token
) )
result = yield self._complete_login(user_id, login_submission) result = await self._complete_login(user_id, login_submission)
return result return result
@defer.inlineCallbacks async def do_jwt_login(self, login_submission):
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None) token = login_submission.get("token", None)
if token is None: if token is None:
raise LoginError( raise LoginError(
@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string() user_id = UserID(user, self.hs.hostname).to_string()
result = yield self._complete_login( result = await self._complete_login(
user_id, login_submission, create_non_existant_users=True user_id, login_submission, create_non_existant_users=True
) )
return result return result
@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet):
self._sso_auth_handler = SSOAuthHandler(hs) self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client() self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True) client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate" uri = self.cas_server_url + "/proxyValidate"
args = { args = {
@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet):
"service": self.cas_service_url, "service": self.cas_service_url,
} }
try: try:
body = yield self._http_client.get_raw(uri, args) body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde: except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed, # Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data # even if that's being used old-http style to signal end-of-data
body = pde.response body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url) result = await self.handle_cas_response(request, body, client_redirect_url)
return result return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url): def handle_cas_response(self, request, cas_response_body, client_redirect_url):
@ -555,8 +548,7 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler() self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator() self._macaroon_gen = hs.get_macaroon_generator()
@defer.inlineCallbacks async def on_successful_auth(
def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None self, username, request, client_redirect_url, user_display_name=None
): ):
"""Called once the user has successfully authenticated with the SSO. """Called once the user has successfully authenticated with the SSO.
@ -582,9 +574,9 @@ class SSOAuthHandler(object):
""" """
localpart = map_username_to_mxid_localpart(username) localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string() user_id = UserID(localpart, self._hostname).to_string()
registered_user_id = yield self._auth_handler.check_user_exists(user_id) registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id: if not registered_user_id:
registered_user_id = yield self._registration_handler.register_user( registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name localpart=localpart, default_display_name=user_display_name
) )

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet):
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
if requester.device_id is None: if requester.device_id is None:
# the acccess token wasn't associated with a device. # the acccess token wasn't associated with a device.
# Just delete the access token # Just delete the access token
access_token = self.auth.get_access_token_from_request(request) access_token = self.auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token) await self._auth_handler.delete_access_token(access_token)
else: else:
yield self._device_handler.delete_device( await self._device_handler.delete_device(
requester.user.to_string(), requester.device_id requester.user.to_string(), requester.device_id
) )
@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet):
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# first delete all of the user's devices # first delete all of the user's devices
yield self._device_handler.delete_all_devices_for_user(user_id) await self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with # .. and then delete any access tokens which weren't associated with
# devices. # devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id) await self._auth_handler.delete_access_tokens_for_user(user_id)
return 200, {} return 200, {}

View File

@ -19,8 +19,6 @@ import logging
from six import string_types from six import string_types
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if requester.user != user: if requester.user != user:
allowed = yield self.presence_handler.is_visible( allowed = await self.presence_handler.is_visible(
observed_user=user, observer_user=requester.user observed_user=user, observer_user=requester.user
) )
if not allowed: if not allowed:
raise AuthError(403, "You are not allowed to see their presence.") raise AuthError(403, "You are not allowed to see their presence.")
state = yield self.presence_handler.get_state(target_user=user) state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec()) state = format_user_presence_state(state, self.clock.time_msec())
return 200, state return 200, state
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if requester.user != user: if requester.user != user:
@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet):
raise SynapseError(400, "Unable to parse state") raise SynapseError(400, "Unable to parse state")
if self.hs.config.use_presence: if self.hs.config.use_presence:
yield self.presence_handler.set_state(user, state) await self.presence_handler.set_state(user, state)
return 200, {} return 200, {}

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
yield self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
displayname = yield self.profile_handler.get_displayname(user) displayname = await self.profile_handler.get_displayname(user)
ret = {} ret = {}
if displayname is not None: if displayname is not None:
@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
except Exception: except Exception:
return 400, "Unable to parse name" return 400, "Unable to parse name"
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin) await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
return 200, {} return 200, {}
@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
yield self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
avatar_url = yield self.profile_handler.get_avatar_url(user) avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {} ret = {}
if avatar_url is not None: if avatar_url is not None:
@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret return 200, ret
@defer.inlineCallbacks async def on_PUT(self, request, user_id):
def on_PUT(self, request, user_id): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = yield self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:
@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
except Exception: except Exception:
return 400, "Unable to parse name" return 400, "Unable to parse name"
yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin) await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
return 200, {} return 200, {}
@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request, user_id):
def on_GET(self, request, user_id):
requester_user = None requester_user = None
if self.hs.config.require_auth_for_profile_requests: if self.hs.config.require_auth_for_profile_requests:
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
requester_user = requester.user requester_user = requester.user
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
yield self.profile_handler.check_profile_query_allowed(user, requester_user) await self.profile_handler.check_profile_query_allowed(user, requester_user)
displayname = yield self.profile_handler.get_displayname(user) displayname = await self.profile_handler.get_displayname(user)
avatar_url = yield self.profile_handler.get_avatar_url(user) avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {} ret = {}
if displayname is not None: if displayname is not None:

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
NotFoundError, NotFoundError,
@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None self._is_worker = hs.config.worker_app is not None
@defer.inlineCallbacks async def on_PUT(self, request, path):
def on_PUT(self, request, path):
if self._is_worker: if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker") raise Exception("Cannot handle PUT /push_rules on worker")
@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet):
except InvalidRuleException as e: except InvalidRuleException as e:
raise SynapseError(400, str(e)) raise SynapseError(400, str(e))
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]: if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
raise SynapseError(400, "rule_id may not contain slashes") raise SynapseError(400, "rule_id may not contain slashes")
@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string() user_id = requester.user.to_string()
if "attr" in spec: if "attr" in spec:
yield self.set_rule_attr(user_id, spec, content) await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id) self.notify_user(user_id)
return 200, {} return 200, {}
@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet):
after = _namespaced_rule_id(spec, after) after = _namespaced_rule_id(spec, after)
try: try:
yield self.store.add_push_rule( await self.store.add_push_rule(
user_id=user_id, user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet):
return 200, {} return 200, {}
@defer.inlineCallbacks async def on_DELETE(self, request, path):
def on_DELETE(self, request, path):
if self._is_worker: if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker") raise Exception("Cannot handle DELETE /push_rules on worker")
spec = _rule_spec_from_path([x for x in path.split("/")]) spec = _rule_spec_from_path([x for x in path.split("/")])
requester = yield self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.store.delete_push_rule(user_id, namespaced_rule_id) await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id) self.notify_user(user_id)
return 200, {} return 200, {}
except StoreError as e: except StoreError as e:
@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet):
else: else:
raise raise
@defer.inlineCallbacks async def on_GET(self, request, path):
def on_GET(self, request, path): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string() user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rules = yield self.store.get_push_rules_for_user(user_id) rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rules) rules = format_push_rules_for_user(requester.user, rules)

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request from synapse.http.server import finish_request
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
allowed_keys = [ allowed_keys = [
"app_display_name", "app_display_name",
@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet):
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks async def on_POST(self, request):
def on_POST(self, request): requester = await self.auth.get_user_by_req(request)
requester = yield self.auth.get_user_by_req(request)
user = requester.user user = requester.user
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet):
and "kind" in content and "kind" in content
and content["kind"] is None and content["kind"] is None
): ):
yield self.pusher_pool.remove_pusher( await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string() content["app_id"], content["pushkey"], user_id=user.to_string()
) )
return 200, {} return 200, {}
@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet):
append = content["append"] append = content["append"]
if not append: if not append:
yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"], app_id=content["app_id"],
pushkey=content["pushkey"], pushkey=content["pushkey"],
not_user_id=user.to_string(), not_user_id=user.to_string(),
) )
try: try:
yield self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
user_id=user.to_string(), user_id=user.to_string(),
access_token=requester.access_token_id, access_token=requester.access_token_id,
kind=content["kind"], kind=content["kind"],
@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool() self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user user = requester.user
app_id = parse_string(request, "app_id", required=True) app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True) pushkey = parse_string(request, "pushkey", required=True)
try: try:
yield self.pusher_pool.remove_pusher( await self.pusher_pool.remove_pusher(
app_id=app_id, pushkey=pushkey, user_id=user.to_string() app_id=app_id, pushkey=pushkey, user_id=user.to_string()
) )
except StoreError as se: except StoreError as se:

View File

@ -17,8 +17,6 @@ import base64
import hashlib import hashlib
import hmac import hmac
from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha._base import client_patterns
@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
@defer.inlineCallbacks async def on_GET(self, request):
def on_GET(self, request): requester = await self.auth.get_user_by_req(
requester = yield self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests request, self.hs.config.turn_allow_guests
) )

View File

@ -15,6 +15,8 @@
from mock import Mock from mock import Mock
from twisted.internet import defer
from synapse.rest.client.v1 import presence from synapse.rest.client.v1 import presence
from synapse.types import UserID from synapse.types import UserID
@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
) )
hs.presence_handler = Mock() hs.presence_handler = Mock()
hs.presence_handler.set_state.return_value = defer.succeed(None)
return hs return hs

View File

@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase):
] ]
) )
self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
Mock()
)
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
self.addCleanup, self.addCleanup,
"test", "test",
@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
) )
def _get_user_by_req(request=None, allow_guest=False): def _get_user_by_req(request=None, allow_guest=False):
return synapse.types.create_requester(myid) return defer.succeed(synapse.types.create_requester(myid))
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View File

@ -461,7 +461,9 @@ class MockHttpResource(HttpServer):
try: try:
args = [urlparse.unquote(u) for u in matcher.groups()] args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield func(mock_request, *args) (code, response) = yield defer.ensureDeferred(
func(mock_request, *args)
)
return code, response return code, response
except CodeMessageException as e: except CodeMessageException as e:
return (e.code, cs_error(e.msg, code=e.errcode)) return (e.code, cs_error(e.msg, code=e.errcode))