mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Merge pull request #6482 from matrix-org/erikj/port_rest_v1
Port rest/v1 to async/await
This commit is contained in:
commit
af5d0ebc72
1
changelog.d/6482.misc
Normal file
1
changelog.d/6482.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Port synapse.rest.client.v1 to async/await.
|
@ -151,7 +151,7 @@ class SynchrotronPresence(object):
|
|||||||
|
|
||||||
def set_state(self, user, state, ignore_status_msg=False):
|
def set_state(self, user, state, ignore_status_msg=False):
|
||||||
# TODO Hows this supposed to work?
|
# TODO Hows this supposed to work?
|
||||||
pass
|
return defer.succeed(None)
|
||||||
|
|
||||||
get_states = __func__(PresenceHandler.get_states)
|
get_states = __func__(PresenceHandler.get_states)
|
||||||
get_state = __func__(PresenceHandler.get_state)
|
get_state = __func__(PresenceHandler.get_state)
|
||||||
|
@ -16,8 +16,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
Codes,
|
Codes,
|
||||||
@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet):
|
|||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, room_alias):
|
||||||
def on_GET(self, request, room_alias):
|
|
||||||
room_alias = RoomAlias.from_string(room_alias)
|
room_alias = RoomAlias.from_string(room_alias)
|
||||||
|
|
||||||
dir_handler = self.handlers.directory_handler
|
dir_handler = self.handlers.directory_handler
|
||||||
res = yield dir_handler.get_association(room_alias)
|
res = await dir_handler.get_association(room_alias)
|
||||||
|
|
||||||
return 200, res
|
return 200, res
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_PUT(self, request, room_alias):
|
||||||
def on_PUT(self, request, room_alias):
|
|
||||||
room_alias = RoomAlias.from_string(room_alias)
|
room_alias = RoomAlias.from_string(room_alias)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet):
|
|||||||
|
|
||||||
# TODO(erikj): Check types.
|
# TODO(erikj): Check types.
|
||||||
|
|
||||||
room = yield self.store.get_room(room_id)
|
room = await self.store.get_room(room_id)
|
||||||
if room is None:
|
if room is None:
|
||||||
raise SynapseError(400, "Room does not exist")
|
raise SynapseError(400, "Room does not exist")
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
yield self.handlers.directory_handler.create_association(
|
await self.handlers.directory_handler.create_association(
|
||||||
requester, room_alias, room_id, servers
|
requester, room_alias, room_id, servers
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_DELETE(self, request, room_alias):
|
||||||
def on_DELETE(self, request, room_alias):
|
|
||||||
dir_handler = self.handlers.directory_handler
|
dir_handler = self.handlers.directory_handler
|
||||||
|
|
||||||
try:
|
try:
|
||||||
service = yield self.auth.get_appservice_by_req(request)
|
service = await self.auth.get_appservice_by_req(request)
|
||||||
room_alias = RoomAlias.from_string(room_alias)
|
room_alias = RoomAlias.from_string(room_alias)
|
||||||
yield dir_handler.delete_appservice_association(service, room_alias)
|
await dir_handler.delete_appservice_association(service, room_alias)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Application service at %s deleted alias %s",
|
"Application service at %s deleted alias %s",
|
||||||
service.url,
|
service.url,
|
||||||
@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet):
|
|||||||
# fallback to default user behaviour if they aren't an AS
|
# fallback to default user behaviour if they aren't an AS
|
||||||
pass
|
pass
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
user = requester.user
|
user = requester.user
|
||||||
|
|
||||||
room_alias = RoomAlias.from_string(room_alias)
|
room_alias = RoomAlias.from_string(room_alias)
|
||||||
|
|
||||||
yield dir_handler.delete_association(requester, room_alias)
|
await dir_handler.delete_association(requester, room_alias)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
|
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
|
||||||
@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet):
|
|||||||
self.handlers = hs.get_handlers()
|
self.handlers = hs.get_handlers()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, room_id):
|
||||||
def on_GET(self, request, room_id):
|
room = await self.store.get_room(room_id)
|
||||||
room = yield self.store.get_room(room_id)
|
|
||||||
if room is None:
|
if room is None:
|
||||||
raise NotFoundError("Unknown room")
|
raise NotFoundError("Unknown room")
|
||||||
|
|
||||||
return 200, {"visibility": "public" if room["is_public"] else "private"}
|
return 200, {"visibility": "public" if room["is_public"] else "private"}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_PUT(self, request, room_id):
|
||||||
def on_PUT(self, request, room_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
visibility = content.get("visibility", "public")
|
visibility = content.get("visibility", "public")
|
||||||
|
|
||||||
yield self.handlers.directory_handler.edit_published_room_list(
|
await self.handlers.directory_handler.edit_published_room_list(
|
||||||
requester, room_id, visibility
|
requester, room_id, visibility
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_DELETE(self, request, room_id):
|
||||||
def on_DELETE(self, request, room_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
yield self.handlers.directory_handler.edit_published_room_list(
|
await self.handlers.directory_handler.edit_published_room_list(
|
||||||
requester, room_id, "private"
|
requester, room_id, "private"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
|
|||||||
def on_DELETE(self, request, network_id, room_id):
|
def on_DELETE(self, request, network_id, room_id):
|
||||||
return self._edit(request, network_id, room_id, "private")
|
return self._edit(request, network_id, room_id, "private")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _edit(self, request, network_id, room_id, visibility):
|
||||||
def _edit(self, request, network_id, room_id, visibility):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
if not requester.app_service:
|
if not requester.app_service:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "Only appservices can edit the appservice published room list"
|
403, "Only appservices can edit the appservice published room list"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.handlers.directory_handler.edit_published_appservice_room_list(
|
await self.handlers.directory_handler.edit_published_appservice_room_list(
|
||||||
requester.app_service.id, network_id, room_id, visibility
|
requester.app_service.id, network_id, room_id, visibility
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,8 +16,6 @@
|
|||||||
"""This module contains REST servlets to do with event streaming, /events."""
|
"""This module contains REST servlets to do with event streaming, /events."""
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet):
|
|||||||
self.event_stream_handler = hs.get_event_stream_handler()
|
self.event_stream_handler = hs.get_event_stream_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request):
|
||||||
def on_GET(self, request):
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
|
||||||
is_guest = requester.is_guest
|
is_guest = requester.is_guest
|
||||||
room_id = None
|
room_id = None
|
||||||
if is_guest:
|
if is_guest:
|
||||||
@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet):
|
|||||||
|
|
||||||
as_client_event = b"raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
|
|
||||||
chunk = yield self.event_stream_handler.get_stream(
|
chunk = await self.event_stream_handler.get_stream(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
pagin_config,
|
pagin_config,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
@ -83,14 +80,13 @@ class EventRestServlet(RestServlet):
|
|||||||
self.event_handler = hs.get_event_handler()
|
self.event_handler = hs.get_event_handler()
|
||||||
self._event_serializer = hs.get_event_client_serializer()
|
self._event_serializer = hs.get_event_client_serializer()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, event_id):
|
||||||
def on_GET(self, request, event_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
event = await self.event_handler.get_event(requester.user, None, event_id)
|
||||||
event = yield self.event_handler.get_event(requester.user, None, event_id)
|
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
if event:
|
if event:
|
||||||
event = yield self._event_serializer.serialize_event(event, time_now)
|
event = await self._event_serializer.serialize_event(event, time_now)
|
||||||
return 200, event
|
return 200, event
|
||||||
else:
|
else:
|
||||||
return 404, "Event not found."
|
return 404, "Event not found."
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet, parse_boolean
|
from synapse.http.servlet import RestServlet, parse_boolean
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet):
|
|||||||
self.initial_sync_handler = hs.get_initial_sync_handler()
|
self.initial_sync_handler = hs.get_initial_sync_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request):
|
||||||
def on_GET(self, request):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
as_client_event = b"raw" not in request.args
|
as_client_event = b"raw" not in request.args
|
||||||
pagination_config = PaginationConfig.from_request(request)
|
pagination_config = PaginationConfig.from_request(request)
|
||||||
include_archived = parse_boolean(request, "archived", default=False)
|
include_archived = parse_boolean(request, "archived", default=False)
|
||||||
content = yield self.initial_sync_handler.snapshot_all_rooms(
|
content = await self.initial_sync_handler.snapshot_all_rooms(
|
||||||
user_id=requester.user.to_string(),
|
user_id=requester.user.to_string(),
|
||||||
pagin_config=pagination_config,
|
pagin_config=pagination_config,
|
||||||
as_client_event=as_client_event,
|
as_client_event=as_client_event,
|
||||||
|
@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET
|
|||||||
|
|
||||||
from six.moves import urllib
|
from six.moves import urllib
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from twisted.web.client import PartialDownloadError
|
from twisted.web.client import PartialDownloadError
|
||||||
|
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||||
@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request):
|
||||||
def on_POST(self, request):
|
|
||||||
self._address_ratelimiter.ratelimit(
|
self._address_ratelimiter.ratelimit(
|
||||||
request.getClientIP(),
|
request.getClientIP(),
|
||||||
time_now_s=self.hs.clock.time(),
|
time_now_s=self.hs.clock.time(),
|
||||||
@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet):
|
|||||||
if self.jwt_enabled and (
|
if self.jwt_enabled and (
|
||||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||||
):
|
):
|
||||||
result = yield self.do_jwt_login(login_submission)
|
result = await self.do_jwt_login(login_submission)
|
||||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||||
result = yield self.do_token_login(login_submission)
|
result = await self.do_token_login(login_submission)
|
||||||
else:
|
else:
|
||||||
result = yield self._do_other_login(login_submission)
|
result = await self._do_other_login(login_submission)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise SynapseError(400, "Missing JSON keys.")
|
raise SynapseError(400, "Missing JSON keys.")
|
||||||
|
|
||||||
@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
result["well_known"] = well_known_data
|
result["well_known"] = well_known_data
|
||||||
return 200, result
|
return 200, result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _do_other_login(self, login_submission):
|
||||||
def _do_other_login(self, login_submission):
|
|
||||||
"""Handle non-token/saml/jwt logins
|
"""Handle non-token/saml/jwt logins
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet):
|
|||||||
(
|
(
|
||||||
canonical_user_id,
|
canonical_user_id,
|
||||||
callback_3pid,
|
callback_3pid,
|
||||||
) = yield self.auth_handler.check_password_provider_3pid(
|
) = await self.auth_handler.check_password_provider_3pid(
|
||||||
medium, address, login_submission["password"]
|
medium, address, login_submission["password"]
|
||||||
)
|
)
|
||||||
if canonical_user_id:
|
if canonical_user_id:
|
||||||
# Authentication through password provider and 3pid succeeded
|
# Authentication through password provider and 3pid succeeded
|
||||||
|
|
||||||
result = yield self._complete_login(
|
result = await self._complete_login(
|
||||||
canonical_user_id, login_submission, callback_3pid
|
canonical_user_id, login_submission, callback_3pid
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# No password providers were able to handle this 3pid
|
# No password providers were able to handle this 3pid
|
||||||
# Check local store
|
# Check local store
|
||||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
medium, address
|
medium, address
|
||||||
)
|
)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
canonical_user_id, callback = yield self.auth_handler.validate_login(
|
canonical_user_id, callback = await self.auth_handler.validate_login(
|
||||||
identifier["user"], login_submission
|
identifier["user"], login_submission
|
||||||
)
|
)
|
||||||
except LoginError:
|
except LoginError:
|
||||||
@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
result = yield self._complete_login(
|
result = await self._complete_login(
|
||||||
canonical_user_id, login_submission, callback
|
canonical_user_id, login_submission, callback
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _complete_login(
|
||||||
def _complete_login(
|
|
||||||
self, user_id, login_submission, callback=None, create_non_existant_users=False
|
self, user_id, login_submission, callback=None, create_non_existant_users=False
|
||||||
):
|
):
|
||||||
"""Called when we've successfully authed the user and now need to
|
"""Called when we've successfully authed the user and now need to
|
||||||
@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if create_non_existant_users:
|
if create_non_existant_users:
|
||||||
user_id = yield self.auth_handler.check_user_exists(user_id)
|
user_id = await self.auth_handler.check_user_exists(user_id)
|
||||||
if not user_id:
|
if not user_id:
|
||||||
user_id = yield self.registration_handler.register_user(
|
user_id = await self.registration_handler.register_user(
|
||||||
localpart=UserID.from_string(user_id).localpart
|
localpart=UserID.from_string(user_id).localpart
|
||||||
)
|
)
|
||||||
|
|
||||||
device_id = login_submission.get("device_id")
|
device_id = login_submission.get("device_id")
|
||||||
initial_display_name = login_submission.get("initial_device_display_name")
|
initial_display_name = login_submission.get("initial_device_display_name")
|
||||||
device_id, access_token = yield self.registration_handler.register_device(
|
device_id, access_token = await self.registration_handler.register_device(
|
||||||
user_id, device_id, initial_display_name
|
user_id, device_id, initial_display_name
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
yield callback(result)
|
await callback(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def do_token_login(self, login_submission):
|
||||||
def do_token_login(self, login_submission):
|
|
||||||
token = login_submission["token"]
|
token = login_submission["token"]
|
||||||
auth_handler = self.auth_handler
|
auth_handler = self.auth_handler
|
||||||
user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
|
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||||
token
|
token
|
||||||
)
|
)
|
||||||
|
|
||||||
result = yield self._complete_login(user_id, login_submission)
|
result = await self._complete_login(user_id, login_submission)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def do_jwt_login(self, login_submission):
|
||||||
def do_jwt_login(self, login_submission):
|
|
||||||
token = login_submission.get("token", None)
|
token = login_submission.get("token", None)
|
||||||
if token is None:
|
if token is None:
|
||||||
raise LoginError(
|
raise LoginError(
|
||||||
@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet):
|
|||||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
user_id = UserID(user, self.hs.hostname).to_string()
|
user_id = UserID(user, self.hs.hostname).to_string()
|
||||||
result = yield self._complete_login(
|
result = await self._complete_login(
|
||||||
user_id, login_submission, create_non_existant_users=True
|
user_id, login_submission, create_non_existant_users=True
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet):
|
|||||||
self._sso_auth_handler = SSOAuthHandler(hs)
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
||||||
self._http_client = hs.get_proxied_http_client()
|
self._http_client = hs.get_proxied_http_client()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request):
|
||||||
def on_GET(self, request):
|
|
||||||
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
||||||
uri = self.cas_server_url + "/proxyValidate"
|
uri = self.cas_server_url + "/proxyValidate"
|
||||||
args = {
|
args = {
|
||||||
@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet):
|
|||||||
"service": self.cas_service_url,
|
"service": self.cas_service_url,
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
body = yield self._http_client.get_raw(uri, args)
|
body = await self._http_client.get_raw(uri, args)
|
||||||
except PartialDownloadError as pde:
|
except PartialDownloadError as pde:
|
||||||
# Twisted raises this error if the connection is closed,
|
# Twisted raises this error if the connection is closed,
|
||||||
# even if that's being used old-http style to signal end-of-data
|
# even if that's being used old-http style to signal end-of-data
|
||||||
body = pde.response
|
body = pde.response
|
||||||
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
result = await self.handle_cas_response(request, body, client_redirect_url)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
||||||
@ -555,8 +548,7 @@ class SSOAuthHandler(object):
|
|||||||
self._registration_handler = hs.get_registration_handler()
|
self._registration_handler = hs.get_registration_handler()
|
||||||
self._macaroon_gen = hs.get_macaroon_generator()
|
self._macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_successful_auth(
|
||||||
def on_successful_auth(
|
|
||||||
self, username, request, client_redirect_url, user_display_name=None
|
self, username, request, client_redirect_url, user_display_name=None
|
||||||
):
|
):
|
||||||
"""Called once the user has successfully authenticated with the SSO.
|
"""Called once the user has successfully authenticated with the SSO.
|
||||||
@ -582,9 +574,9 @@ class SSOAuthHandler(object):
|
|||||||
"""
|
"""
|
||||||
localpart = map_username_to_mxid_localpart(username)
|
localpart = map_username_to_mxid_localpart(username)
|
||||||
user_id = UserID(localpart, self._hostname).to_string()
|
user_id = UserID(localpart, self._hostname).to_string()
|
||||||
registered_user_id = yield self._auth_handler.check_user_exists(user_id)
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
||||||
if not registered_user_id:
|
if not registered_user_id:
|
||||||
registered_user_id = yield self._registration_handler.register_user(
|
registered_user_id = await self._registration_handler.register_user(
|
||||||
localpart=localpart, default_display_name=user_display_name
|
localpart=localpart, default_display_name=user_display_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,8 +15,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
|
|
||||||
@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet):
|
|||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request):
|
||||||
def on_POST(self, request):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
|
|
||||||
if requester.device_id is None:
|
if requester.device_id is None:
|
||||||
# the acccess token wasn't associated with a device.
|
# the acccess token wasn't associated with a device.
|
||||||
# Just delete the access token
|
# Just delete the access token
|
||||||
access_token = self.auth.get_access_token_from_request(request)
|
access_token = self.auth.get_access_token_from_request(request)
|
||||||
yield self._auth_handler.delete_access_token(access_token)
|
await self._auth_handler.delete_access_token(access_token)
|
||||||
else:
|
else:
|
||||||
yield self._device_handler.delete_device(
|
await self._device_handler.delete_device(
|
||||||
requester.user.to_string(), requester.device_id
|
requester.user.to_string(), requester.device_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet):
|
|||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request):
|
||||||
def on_POST(self, request):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
# first delete all of the user's devices
|
# first delete all of the user's devices
|
||||||
yield self._device_handler.delete_all_devices_for_user(user_id)
|
await self._device_handler.delete_all_devices_for_user(user_id)
|
||||||
|
|
||||||
# .. and then delete any access tokens which weren't associated with
|
# .. and then delete any access tokens which weren't associated with
|
||||||
# devices.
|
# devices.
|
||||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
await self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +19,6 @@ import logging
|
|||||||
|
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import AuthError, SynapseError
|
from synapse.api.errors import AuthError, SynapseError
|
||||||
from synapse.handlers.presence import format_user_presence_state
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet):
|
|||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, user_id):
|
||||||
def on_GET(self, request, user_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if requester.user != user:
|
if requester.user != user:
|
||||||
allowed = yield self.presence_handler.is_visible(
|
allowed = await self.presence_handler.is_visible(
|
||||||
observed_user=user, observer_user=requester.user
|
observed_user=user, observer_user=requester.user
|
||||||
)
|
)
|
||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
raise AuthError(403, "You are not allowed to see their presence.")
|
raise AuthError(403, "You are not allowed to see their presence.")
|
||||||
|
|
||||||
state = yield self.presence_handler.get_state(target_user=user)
|
state = await self.presence_handler.get_state(target_user=user)
|
||||||
state = format_user_presence_state(state, self.clock.time_msec())
|
state = format_user_presence_state(state, self.clock.time_msec())
|
||||||
|
|
||||||
return 200, state
|
return 200, state
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_PUT(self, request, user_id):
|
||||||
def on_PUT(self, request, user_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if requester.user != user:
|
if requester.user != user:
|
||||||
@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet):
|
|||||||
raise SynapseError(400, "Unable to parse state")
|
raise SynapseError(400, "Unable to parse state")
|
||||||
|
|
||||||
if self.hs.config.use_presence:
|
if self.hs.config.use_presence:
|
||||||
yield self.presence_handler.set_state(user, state)
|
await self.presence_handler.set_state(user, state)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
|
@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
""" This module contains REST servlets to do with profile: /profile/<paths> """
|
""" This module contains REST servlets to do with profile: /profile/<paths> """
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, user_id):
|
||||||
def on_GET(self, request, user_id):
|
|
||||||
requester_user = None
|
requester_user = None
|
||||||
|
|
||||||
if self.hs.config.require_auth_for_profile_requests:
|
if self.hs.config.require_auth_for_profile_requests:
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user = requester.user
|
requester_user = requester.user
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
yield self.profile_handler.check_profile_query_allowed(user, requester_user)
|
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||||
|
|
||||||
displayname = yield self.profile_handler.get_displayname(user)
|
displayname = await self.profile_handler.get_displayname(user)
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||||||
|
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_PUT(self, request, user_id):
|
||||||
def on_PUT(self, request, user_id):
|
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
is_admin = await self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
|
|
||||||
@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return 400, "Unable to parse name"
|
return 400, "Unable to parse name"
|
||||||
|
|
||||||
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
|
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, user_id):
|
||||||
def on_GET(self, request, user_id):
|
|
||||||
requester_user = None
|
requester_user = None
|
||||||
|
|
||||||
if self.hs.config.require_auth_for_profile_requests:
|
if self.hs.config.require_auth_for_profile_requests:
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user = requester.user
|
requester_user = requester.user
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
yield self.profile_handler.check_profile_query_allowed(user, requester_user)
|
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||||
|
|
||||||
avatar_url = yield self.profile_handler.get_avatar_url(user)
|
avatar_url = await self.profile_handler.get_avatar_url(user)
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
if avatar_url is not None:
|
if avatar_url is not None:
|
||||||
@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||||||
|
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_PUT(self, request, user_id):
|
||||||
def on_PUT(self, request, user_id):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
is_admin = await self.auth.is_server_admin(requester.user)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
try:
|
try:
|
||||||
@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return 400, "Unable to parse name"
|
return 400, "Unable to parse name"
|
||||||
|
|
||||||
yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
|
await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
|
||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet):
|
|||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, user_id):
|
||||||
def on_GET(self, request, user_id):
|
|
||||||
requester_user = None
|
requester_user = None
|
||||||
|
|
||||||
if self.hs.config.require_auth_for_profile_requests:
|
if self.hs.config.require_auth_for_profile_requests:
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester_user = requester.user
|
requester_user = requester.user
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
yield self.profile_handler.check_profile_query_allowed(user, requester_user)
|
await self.profile_handler.check_profile_query_allowed(user, requester_user)
|
||||||
|
|
||||||
displayname = yield self.profile_handler.get_displayname(user)
|
displayname = await self.profile_handler.get_displayname(user)
|
||||||
avatar_url = yield self.profile_handler.get_avatar_url(user)
|
avatar_url = await self.profile_handler.get_avatar_url(user)
|
||||||
|
|
||||||
ret = {}
|
ret = {}
|
||||||
if displayname is not None:
|
if displayname is not None:
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
NotFoundError,
|
NotFoundError,
|
||||||
@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self._is_worker = hs.config.worker_app is not None
|
self._is_worker = hs.config.worker_app is not None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_PUT(self, request, path):
|
||||||
def on_PUT(self, request, path):
|
|
||||||
if self._is_worker:
|
if self._is_worker:
|
||||||
raise Exception("Cannot handle PUT /push_rules on worker")
|
raise Exception("Cannot handle PUT /push_rules on worker")
|
||||||
|
|
||||||
@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
raise SynapseError(400, str(e))
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
|
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
|
||||||
raise SynapseError(400, "rule_id may not contain slashes")
|
raise SynapseError(400, "rule_id may not contain slashes")
|
||||||
@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
if "attr" in spec:
|
if "attr" in spec:
|
||||||
yield self.set_rule_attr(user_id, spec, content)
|
await self.set_rule_attr(user_id, spec, content)
|
||||||
self.notify_user(user_id)
|
self.notify_user(user_id)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
after = _namespaced_rule_id(spec, after)
|
after = _namespaced_rule_id(spec, after)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.store.add_push_rule(
|
await self.store.add_push_rule(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
rule_id=_namespaced_rule_id_from_spec(spec),
|
rule_id=_namespaced_rule_id_from_spec(spec),
|
||||||
priority_class=priority_class,
|
priority_class=priority_class,
|
||||||
@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
|
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_DELETE(self, request, path):
|
||||||
def on_DELETE(self, request, path):
|
|
||||||
if self._is_worker:
|
if self._is_worker:
|
||||||
raise Exception("Cannot handle DELETE /push_rules on worker")
|
raise Exception("Cannot handle DELETE /push_rules on worker")
|
||||||
|
|
||||||
spec = _rule_spec_from_path([x for x in path.split("/")])
|
spec = _rule_spec_from_path([x for x in path.split("/")])
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.store.delete_push_rule(user_id, namespaced_rule_id)
|
await self.store.delete_push_rule(user_id, namespaced_rule_id)
|
||||||
self.notify_user(user_id)
|
self.notify_user(user_id)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
except StoreError as e:
|
except StoreError as e:
|
||||||
@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet):
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request, path):
|
||||||
def on_GET(self, request, path):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
# we build up the full structure and then decide which bits of it
|
# we build up the full structure and then decide which bits of it
|
||||||
# to send which means doing unnecessary work sometimes but is
|
# to send which means doing unnecessary work sometimes but is
|
||||||
# is probably not going to make a whole lot of difference
|
# is probably not going to make a whole lot of difference
|
||||||
rules = yield self.store.get_push_rules_for_user(user_id)
|
rules = await self.store.get_push_rules_for_user(user_id)
|
||||||
|
|
||||||
rules = format_push_rules_for_user(requester.user, rules)
|
rules = format_push_rules_for_user(requester.user, rules)
|
||||||
|
|
||||||
|
@ -15,8 +15,6 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||||
from synapse.http.server import finish_request
|
from synapse.http.server import finish_request
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet):
|
|||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request):
|
||||||
def on_GET(self, request):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user = requester.user
|
user = requester.user
|
||||||
|
|
||||||
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
||||||
|
|
||||||
allowed_keys = [
|
allowed_keys = [
|
||||||
"app_display_name",
|
"app_display_name",
|
||||||
@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet):
|
|||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.pusher_pool = self.hs.get_pusherpool()
|
self.pusher_pool = self.hs.get_pusherpool()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_POST(self, request):
|
||||||
def on_POST(self, request):
|
requester = await self.auth.get_user_by_req(request)
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
user = requester.user
|
user = requester.user
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet):
|
|||||||
and "kind" in content
|
and "kind" in content
|
||||||
and content["kind"] is None
|
and content["kind"] is None
|
||||||
):
|
):
|
||||||
yield self.pusher_pool.remove_pusher(
|
await self.pusher_pool.remove_pusher(
|
||||||
content["app_id"], content["pushkey"], user_id=user.to_string()
|
content["app_id"], content["pushkey"], user_id=user.to_string()
|
||||||
)
|
)
|
||||||
return 200, {}
|
return 200, {}
|
||||||
@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet):
|
|||||||
append = content["append"]
|
append = content["append"]
|
||||||
|
|
||||||
if not append:
|
if not append:
|
||||||
yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||||
app_id=content["app_id"],
|
app_id=content["app_id"],
|
||||||
pushkey=content["pushkey"],
|
pushkey=content["pushkey"],
|
||||||
not_user_id=user.to_string(),
|
not_user_id=user.to_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_pusher(
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
access_token=requester.access_token_id,
|
access_token=requester.access_token_id,
|
||||||
kind=content["kind"],
|
kind=content["kind"],
|
||||||
@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet):
|
|||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.pusher_pool = self.hs.get_pusherpool()
|
self.pusher_pool = self.hs.get_pusherpool()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request):
|
||||||
def on_GET(self, request):
|
requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
|
||||||
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
|
|
||||||
user = requester.user
|
user = requester.user
|
||||||
|
|
||||||
app_id = parse_string(request, "app_id", required=True)
|
app_id = parse_string(request, "app_id", required=True)
|
||||||
pushkey = parse_string(request, "pushkey", required=True)
|
pushkey = parse_string(request, "pushkey", required=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield self.pusher_pool.remove_pusher(
|
await self.pusher_pool.remove_pusher(
|
||||||
app_id=app_id, pushkey=pushkey, user_id=user.to_string()
|
app_id=app_id, pushkey=pushkey, user_id=user.to_string()
|
||||||
)
|
)
|
||||||
except StoreError as se:
|
except StoreError as se:
|
||||||
|
@ -17,8 +17,6 @@ import base64
|
|||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||||
|
|
||||||
@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet):
|
|||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def on_GET(self, request):
|
||||||
def on_GET(self, request):
|
requester = await self.auth.get_user_by_req(
|
||||||
requester = yield self.auth.get_user_by_req(
|
|
||||||
request, self.hs.config.turn_allow_guests
|
request, self.hs.config.turn_allow_guests
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.rest.client.v1 import presence
|
from synapse.rest.client.v1 import presence
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
@ -36,6 +38,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hs.presence_handler = Mock()
|
hs.presence_handler = Mock()
|
||||||
|
hs.presence_handler.set_state.return_value = defer.succeed(None)
|
||||||
|
|
||||||
return hs
|
return hs
|
||||||
|
|
||||||
|
@ -52,6 +52,14 @@ class MockHandlerProfileTestCase(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.mock_handler.get_displayname.return_value = defer.succeed(Mock())
|
||||||
|
self.mock_handler.set_displayname.return_value = defer.succeed(Mock())
|
||||||
|
self.mock_handler.get_avatar_url.return_value = defer.succeed(Mock())
|
||||||
|
self.mock_handler.set_avatar_url.return_value = defer.succeed(Mock())
|
||||||
|
self.mock_handler.check_profile_query_allowed.return_value = defer.succeed(
|
||||||
|
Mock()
|
||||||
|
)
|
||||||
|
|
||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
self.addCleanup,
|
self.addCleanup,
|
||||||
"test",
|
"test",
|
||||||
@ -63,7 +71,7 @@ class MockHandlerProfileTestCase(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None, allow_guest=False):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
return synapse.types.create_requester(myid)
|
return defer.succeed(synapse.types.create_requester(myid))
|
||||||
|
|
||||||
hs.get_auth().get_user_by_req = _get_user_by_req
|
hs.get_auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
@ -461,7 +461,9 @@ class MockHttpResource(HttpServer):
|
|||||||
try:
|
try:
|
||||||
args = [urlparse.unquote(u) for u in matcher.groups()]
|
args = [urlparse.unquote(u) for u in matcher.groups()]
|
||||||
|
|
||||||
(code, response) = yield func(mock_request, *args)
|
(code, response) = yield defer.ensureDeferred(
|
||||||
|
func(mock_request, *args)
|
||||||
|
)
|
||||||
return code, response
|
return code, response
|
||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
return (e.code, cs_error(e.msg, code=e.errcode))
|
return (e.code, cs_error(e.msg, code=e.errcode))
|
||||||
|
Loading…
Reference in New Issue
Block a user