Move rest APIs back under the rest directory

This commit is contained in:
Mark Haines 2015-01-22 16:10:07 +00:00
parent 1d2016b4a8
commit 97c68c508d
31 changed files with 33 additions and 19 deletions

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin,
)
class RestServletFactory(object):
""" A factory for creating REST servlets.
These REST servlets represent the entire client-server REST API. Generally
speaking, they serve as wrappers around events and the handlers that
process them.
See synapse.events for information on synapse events.
"""
def __init__(self, hs):
client_resource = hs.get_resource_for_client()
# TODO(erikj): There *must* be a better way of doing this.
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)

View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from base import RestServlet, client_path_pattern
import logging
logger = logging.getLogger(__name__)
class WhoisRestServlet(RestServlet):
PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = self.hs.parse_userid(user_id)
auth_user = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
ret = yield self.handlers.admin_handler.get_whois(target_user)
defer.returnValue((200, ret))
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)

View file

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_PREFIX + path_regex)
class RestServlet(object):
""" A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
To use this latter, the implementing class instead provides a `PATTERN`
class attribute containing a pre-compiled regular expression. The automatic
register method will then use this method to register any of the following
instance methods associated with the corresponding HTTP method:
on_GET
on_PUT
on_POST
on_DELETE
on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
"""
def __init__(self, hs):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")

View file

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes
from base import RestServlet, client_path_pattern
import json
import logging
logger = logging.getLogger(__name__)
def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server)
class ClientDirectoryServer(RestServlet):
PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$")
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = self.hs.parse_roomalias(room_alias)
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias)
defer.returnValue((200, res))
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
if not "room_id" in content:
raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
room_alias = self.hs.parse_roomalias(room_alias)
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
logger.debug("Got room_id: %s", room_id)
logger.debug("Got servers: %s", servers)
# TODO(erikj): Check types.
# TODO(erikj): Check that room exists
dir_handler = self.handlers.directory_handler
try:
user_id = user.to_string()
yield dir_handler.create_association(
user_id, room_alias, room_id, servers
)
yield dir_handler.send_room_alias_update_event(user_id, room_id)
except SynapseError as e:
raise e
except:
logger.exception("Failed to create association")
raise
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
user = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")
dir_handler = self.handlers.directory_handler
room_alias = self.hs.parse_roomalias(room_alias)
yield dir_handler.delete_association(
user.to_string(), room_alias
)
defer.returnValue((200, {}))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View file

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with event streaming, /events."""
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from .base import RestServlet, client_path_pattern
import logging
logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
DEFAULT_LONGPOLL_TIME_MS = 30000
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event
)
except:
logger.exception("Event stream failed")
raise
defer.returnValue((200, chunk))
def on_OPTIONS(self, request):
return (200, {})
# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$")
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)
if event:
defer.returnValue((200, self.hs.serialize_event(event)))
else:
defer.returnValue((404, "Event not found."))
def register_servlets(hs, http_server):
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)

View file

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from base import RestServlet, client_path_pattern
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERN = client_path_pattern("/initialSync$")
@defer.inlineCallbacks
def on_GET(self, request):
user = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
content = yield handler.snapshot_all_rooms(
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event
)
defer.returnValue((200, content))
def register_servlets(hs, http_server):
InitialSyncRestServlet(hs).register(http_server)

View file

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID
from base import RestServlet, client_path_pattern
import json
class LoginRestServlet(RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
def on_GET(self, request):
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
login_submission = _parse_json(request)
try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if not login_submission["user"].startswith('@'):
login_submission["user"] = UserID.create(
login_submission["user"], self.hs.hostname).to_string()
handler = self.handlers.login_handler
token = yield handler.login(
user=login_submission["user"],
password=login_submission["password"])
result = {
"user_id": login_submission["user"], # may have changed
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
class LoginFallbackRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request):
# TODO(kegan): This should be returning some HTML which is capable of
# hitting LoginRestServlet
return (200, {})
class PasswordResetRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks
def on_POST(self, request):
reset_info = _parse_json(request)
try:
email = reset_info["email"]
user_id = reset_info["user_id"]
handler = self.handlers.login_handler
yield handler.reset_password(user_id, email)
# purposefully give no feedback to avoid people hammering different
# combinations.
defer.returnValue((200, {}))
except KeyError:
raise SynapseError(
400,
"Missing keys. Requires 'email' and 'user_id'."
)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View file

@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains REST servlets to do with presence: /presence/<paths>
"""
from twisted.internet import defer
from synapse.api.errors import SynapseError
from base import RestServlet, client_path_pattern
import json
import logging
logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=auth_user)
defer.returnValue((200, state))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
state = {}
try:
content = json.loads(request.content.read())
state["presence"] = content.pop("presence")
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
if not isinstance(state["status_msg"], basestring):
raise SynapseError(400, "status_msg must be a string.")
if content:
raise KeyError()
except SynapseError as e:
raise e
except:
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(
target_user=user, auth_user=auth_user, state=state)
defer.returnValue((200, {}))
def on_OPTIONS(self, request):
return (200, {})
class PresenceListRestServlet(RestServlet):
PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
defer.returnValue((200, presence))
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
raise SynapseError(
400, "Cannot modify another user's presence list")
try:
content = json.loads(request.content.read())
except:
logger.exception("JSON parse error")
raise SynapseError(400, "Unable to parse content")
if "invite" in content:
for u in content["invite"]:
if not isinstance(u, basestring):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
invited_user = self.hs.parse_userid(u)
yield self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user
)
if "drop" in content:
for u in content["drop"]:
if not isinstance(u, basestring):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
dropped_user = self.hs.parse_userid(u)
yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)
defer.returnValue((200, {}))
def on_OPTIONS(self, request):
return (200, {})
def register_servlets(hs, http_server):
PresenceStatusRestServlet(hs).register(http_server)
PresenceListRestServlet(hs).register(http_server)

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
from base import RestServlet, client_path_pattern
import json
class ProfileDisplaynameRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
user,
)
defer.returnValue((200, {"displayname": displayname}))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
try:
content = json.loads(request.content.read())
new_name = content["displayname"]
except:
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname(
user, auth_user, new_name)
defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id):
return (200, {})
class ProfileAvatarURLRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
user,
)
defer.returnValue((200, {"avatar_url": avatar_url}))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
try:
content = json.loads(request.content.read())
new_name = content["avatar_url"]
except:
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url(
user, auth_user, new_name)
defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id):
return (200, {})
class ProfileRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
user,
)
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
user,
)
defer.returnValue((200, {
"displayname": displayname,
"avatar_url": avatar_url
}))
def register_servlets(hs, http_server):
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)

View file

@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
from base import RestServlet, client_path_pattern
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from hashlib import sha1
import hmac
import json
import logging
import urllib
logger = logging.getLogger(__name__)
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
# because the timing attack is so obscured by all the other code here it's
# unlikely to make much difference
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
class RegisterRestServlet(RestServlet):
"""Handles registration with the home server.
This servlet is in control of the registration flow; the registration
handler doesn't have a concept of multi-stages or sessions.
"""
PATTERN = client_path_pattern("/register$")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
# "session_id" : { __session_dict__ }
# }
# TODO: persistent storage
self.sessions = {}
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
return (
200,
{"flows": [
{
"type": LoginType.RECAPTCHA,
"stages": [
LoginType.RECAPTCHA,
LoginType.EMAIL_IDENTITY,
LoginType.PASSWORD
]
},
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
]}
)
else:
return (
200,
{"flows": [
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
},
{
"type": LoginType.PASSWORD
}
]}
)
@defer.inlineCallbacks
def on_POST(self, request):
register_json = _parse_json(request)
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
if "type" not in register_json:
raise SynapseError(400, "Missing 'type' key.")
try:
login_type = register_json["type"]
stages = {
LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity
}
session_info = self._get_session_info(request, session)
logger.debug("%s : session info %s request info %s",
login_type, session_info, register_json)
response = yield stages[login_type](
request,
register_json,
session_info
)
if "access_token" not in response:
# isn't a final response
response["session"] = session_info["id"]
defer.returnValue((200, response))
except KeyError as e:
logger.exception(e)
raise SynapseError(400, "Missing JSON keys for login type %s." % (
login_type,
))
def on_OPTIONS(self, request):
return (200, {})
def _get_session_info(self, request, session_id):
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
LoginType.EMAIL_IDENTITY: False,
LoginType.RECAPTCHA: False
}
return self.sessions[session_id]
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
self.sessions.pop(session["id"])
@defer.inlineCallbacks
def _do_recaptcha(self, request, register_json, session):
if not self.hs.config.enable_registration_captcha:
raise SynapseError(400, "Captcha not required.")
yield self._check_recaptcha(request, register_json, session)
session[LoginType.RECAPTCHA] = True # mark captcha as done
self._save_session(session)
defer.returnValue({
"next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY]
})
@defer.inlineCallbacks
def _check_recaptcha(self, request, register_json, session):
if ("captcha_bypass_hmac" in register_json and
self.hs.config.captcha_bypass_secret):
if "user" not in register_json:
raise SynapseError(400, "Captcha bypass needs 'user'")
want = hmac.new(
key=self.hs.config.captcha_bypass_secret,
msg=register_json["user"],
digestmod=sha1,
).hexdigest()
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got = str(register_json["captcha_bypass_hmac"])
if compare_digest(want, got):
session["user"] = register_json["user"]
defer.returnValue(None)
else:
raise SynapseError(
400, "Captcha bypass HMAC incorrect",
errcode=Codes.CAPTCHA_NEEDED
)
challenge = None
user_response = None
try:
challenge = register_json["challenge"]
user_response = register_json["response"]
except KeyError:
raise SynapseError(400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED)
ip_addr = self.hs.get_ip_from_request(request)
handler = self.handlers.registration_handler
yield handler.check_recaptcha(
ip_addr,
self.hs.config.recaptcha_private_key,
challenge,
user_response
)
@defer.inlineCallbacks
def _do_email_identity(self, request, register_json, session):
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
raise SynapseError(400, "Captcha is required.")
threepidCreds = register_json['threepidCreds']
handler = self.handlers.registration_handler
logger.debug("Registering email. threepidcreds: %s" % (threepidCreds))
yield handler.register_email(threepidCreds)
session["threepidCreds"] = threepidCreds # store creds for next stage
session[LoginType.EMAIL_IDENTITY] = True # mark email as done
self._save_session(session)
defer.returnValue({
"next": LoginType.PASSWORD
})
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
raise SynapseError(400, "Captcha is required.")
if ("user" in session and "user" in register_json and
session["user"] != register_json["user"]):
raise SynapseError(
400, "Cannot change user ID during registration"
)
password = register_json["password"].encode("utf-8")
desired_user_id = (register_json["user"].encode("utf-8")
if "user" in register_json else None)
if (desired_user_id
and urllib.quote(desired_user_id) != desired_user_id):
raise SynapseError(
400,
"User ID must only contain characters which do not " +
"require URL encoding.")
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
password=password
)
if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (
session["threepidCreds"], user_id)
)
yield handler.bind_emails(user_id, session["threepidCreds"])
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
self._remove_session(session)
defer.returnValue(result)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View file

@ -0,0 +1,559 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer
from base import RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
import json
import logging
import urllib
logger = logging.getLogger(__name__)
class RoomCreateRestServlet(RestServlet):
# No PATTERN; we have custom dispatch rules here
def register(self, http_server):
PATTERN = "/createRoom"
register_txn_path(self, PATTERN, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_path("OPTIONS",
client_path_pattern("/rooms(?:/.*)?$"),
self.on_OPTIONS)
# define CORS for /createRoom[/txnid]
http_server.register_path("OPTIONS",
client_path_pattern("/createRoom(?:/.*)?$"),
self.on_OPTIONS)
@defer.inlineCallbacks
def on_PUT(self, request, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
room_config.update(info)
defer.returnValue((200, info))
@defer.inlineCallbacks
def make_room(self, room_config, auth_user, room_id):
handler = self.handlers.room_creation_handler
info = yield handler.create_room(
user_id=auth_user.to_string(),
room_id=room_id,
config=room_config
)
defer.returnValue(info)
def get_room_config(self, request):
try:
user_supplied_config = json.loads(request.content.read())
if "visibility" not in user_supplied_config:
# default visibility
user_supplied_config["visibility"] = "public"
return user_supplied_config
except (ValueError, TypeError):
raise SynapseError(400, "Body must be JSON.",
errcode=Codes.BAD_JSON)
def on_OPTIONS(self, request):
return (200, {})
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(RestServlet):
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
# /room/$roomid/state/$eventtype/$statekey
state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_path("GET",
client_path_pattern(state_key),
self.on_GET)
http_server.register_path("PUT",
client_path_pattern(state_key),
self.on_PUT)
http_server.register_path("GET",
client_path_pattern(no_state_key),
self.on_GET_no_state_key)
http_server.register_path("PUT",
client_path_pattern(no_state_key),
self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type):
return self.on_GET(request, room_id, event_type, "")
def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "")
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
user_id=user.to_string(),
room_id=room_id,
event_type=event_type,
state_key=state_key,
)
if not data:
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
event_dict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
if state_key is not None:
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(event_dict)
defer.returnValue((200, {}))
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event(
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
)
defer.returnValue((200, {"event_id": event.event_id}))
def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented")
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_type)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(RestServlet):
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
PATTERN = ("/join/(?P<room_identifier>[^/]*)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_identifier):
user = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
# SynapseErrors.
identifier = None
is_room_alias = False
try:
identifier = self.hs.parse_roomalias(room_identifier)
is_room_alias = True
except SynapseError:
identifier = self.hs.parse_roomid(room_identifier)
# TODO: Support for specifying the home server to join with?
if is_room_alias:
handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias(user, identifier)
defer.returnValue((200, ret_dict))
else: # room id
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {"membership": Membership.JOIN},
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
}
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_identifier)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing
class PublicRoomListRestServlet(RestServlet):
PATTERN = client_path_pattern("/publicRooms$")
@defer.inlineCallbacks
def on_GET(self, request):
handler = self.handlers.room_list_handler
data = yield handler.get_public_room_list()
defer.returnValue((200, data))
# TODO: Needs unit testing
class RoomMemberListRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id,
user_id=user.to_string())
for event in members["chunk"]:
# FIXME: should probably be state_key here, not user_id
target_user = self.hs.parse_userid(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user, auth_user=user
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, members))
# TODO: Needs unit testing
class RoomMessageListRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event
)
defer.returnValue((200, msgs))
# TODO: Needs unit testing
class RoomStateRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
room_id=room_id,
user_id=user.to_string(),
)
defer.returnValue((200, events))
# TODO: Needs unit testing
class RoomInitialSyncRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
)
defer.returnValue((200, content))
class RoomTriggerBackfill(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
remote_server = urllib.unquote(
request.args["remote"][0]
).decode("UTF-8")
limit = int(request.args["limit"][0])
handler = self.handlers.federation_handler
events = yield handler.backfill(remote_server, room_id, limit)
res = [self.hs.serialize_event(event) for event in events]
defer.returnValue((200, res))
# TODO: Needs unit testing
class RoomMembershipRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|kick)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
# target user is you unless it is an invite
state_key = user.to_string()
if membership_action in ["invite", "ban", "kick"]:
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"]
if membership_action == "kick":
membership_action = "leave"
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {"membership": unicode(membership_action)},
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
}
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_id, membership_action)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomRedactEventRestServlet(RestServlet):
def register(self, http_server):
PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event(
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
}
)
defer.returnValue((200, {"event_id": event.event_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomTypingRestServlet(RestServlet):
PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = self.hs.parse_userid(urllib.unquote(user_id))
content = _parse_json(request)
typing_handler = self.handlers.typing_notification_handler
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,
auth_user=auth_user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield typing_handler.stopped_typing(
target_user=target_user,
auth_user=auth_user,
room_id=room_id,
)
defer.returnValue((200, {}))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path.
This registers two paths:
PUT regex_string/$txnid
POST regex_string
Args:
regex_string (str): The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server : The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
http_server.register_path(
"POST",
client_path_pattern(regex_string + "$"),
servlet.on_POST
)
http_server.register_path(
"PUT",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_PUT
)
if with_get:
http_server.register_path(
"GET",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_GET
)
def register_servlets(hs, http_server):
RoomStateEventRestServlet(hs).register(http_server)
RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomTriggerBackfill(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)

View file

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
logger = logging.getLogger(__name__)
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
class HttpTransactionStore(object):
def __init__(self):
# { key : (txn_id, response) }
self.transactions = {}
def get_response(self, key, txn_id):
"""Retrieve a response for this request.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request
Returns:
A tuple of (HTTP response code, response content) or None.
"""
try:
logger.debug("get_response Key: %s TxnId: %s", key, txn_id)
(last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", key)
return response
except KeyError:
pass
return None
def store_response(self, key, txn_id, response):
"""Stores an HTTP response tuple.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content)
"""
logger.debug("store_response Key: %s TxnId: %s", key, txn_id)
self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response):
"""Stores the request/response pair of an HTTP transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
response (tuple): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
"""
self.store_response(self._get_key(request), txn_id, response)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
txn_id (str): The transaction ID for this request.
Returns:
The response tuple.
Raises:
KeyError if the transaction was not found.
"""
response = self.get_response(self._get_key(request), txn_id)
if response is None:
raise KeyError("Transaction not found.")
return response
def _get_key(self, request):
token = request.args["access_token"][0]
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token

View file

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from base import RestServlet, client_path_pattern
import hmac
import hashlib
import base64
class VoipRestServlet(RestServlet):
PATTERN = client_path_pattern("/voip/turnServer$")
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, auth_user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard base64 encoding here, *not* syutil's
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
defer.returnValue((200, {
'username': username,
'password': password,
'ttl': userLifetime / 1000,
'uris': turnUris,
}))
def on_OPTIONS(self, request):
return (200, {})
def register_servlets(hs, http_server):
VoipRestServlet(hs).register(http_server)