Move rest APIs back under the rest directory

This commit is contained in:
Mark Haines 2015-01-22 16:10:07 +00:00
parent 1d2016b4a8
commit 97c68c508d
31 changed files with 33 additions and 19 deletions

14
synapse/rest/__init__.py Normal file
View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin,
)
class RestServletFactory(object):
""" A factory for creating REST servlets.
These REST servlets represent the entire client-server REST API. Generally
speaking, they serve as wrappers around events and the handlers that
process them.
See synapse.events for information on synapse events.
"""
def __init__(self, hs):
client_resource = hs.get_resource_for_client()
# TODO(erikj): There *must* be a better way of doing this.
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)

View file

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from base import RestServlet, client_path_pattern
import logging
logger = logging.getLogger(__name__)
class WhoisRestServlet(RestServlet):
PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = self.hs.parse_userid(user_id)
auth_user = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:
raise AuthError(403, "You are not a server admin")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
ret = yield self.handlers.admin_handler.get_whois(target_user)
defer.returnValue((200, ret))
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)

View file

@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_PREFIX + path_regex)
class RestServlet(object):
""" A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
To use this latter, the implementing class instead provides a `PATTERN`
class attribute containing a pre-compiled regular expression. The automatic
register method will then use this method to register any of the following
instance methods associated with the corresponding HTTP method:
on_GET
on_PUT
on_POST
on_DELETE
on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
"""
def __init__(self, hs):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")

View file

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes
from base import RestServlet, client_path_pattern
import json
import logging
logger = logging.getLogger(__name__)
def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server)
class ClientDirectoryServer(RestServlet):
PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$")
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = self.hs.parse_roomalias(room_alias)
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias)
defer.returnValue((200, res))
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
if not "room_id" in content:
raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content)
room_alias = self.hs.parse_roomalias(room_alias)
logger.debug("Got room name: %s", room_alias.to_string())
room_id = content["room_id"]
servers = content["servers"] if "servers" in content else None
logger.debug("Got room_id: %s", room_id)
logger.debug("Got servers: %s", servers)
# TODO(erikj): Check types.
# TODO(erikj): Check that room exists
dir_handler = self.handlers.directory_handler
try:
user_id = user.to_string()
yield dir_handler.create_association(
user_id, room_alias, room_id, servers
)
yield dir_handler.send_room_alias_update_event(user_id, room_id)
except SynapseError as e:
raise e
except:
logger.exception("Failed to create association")
raise
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
user = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
raise AuthError(403, "You need to be a server admin")
dir_handler = self.handlers.directory_handler
room_alias = self.hs.parse_roomalias(room_alias)
yield dir_handler.delete_association(
user.to_string(), room_alias
)
defer.returnValue((200, {}))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View file

@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with event streaming, /events."""
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig
from .base import RestServlet, client_path_pattern
import logging
logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet):
PATTERN = client_path_pattern("/events$")
DEFAULT_LONGPOLL_TIME_MS = 30000
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event
)
except:
logger.exception("Event stream failed")
raise
defer.returnValue((200, chunk))
def on_OPTIONS(self, request):
return (200, {})
# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$")
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)
if event:
defer.returnValue((200, self.hs.serialize_event(event)))
else:
defer.returnValue((404, "Event not found."))
def register_servlets(hs, http_server):
EventStreamRestServlet(hs).register(http_server)
EventRestServlet(hs).register(http_server)

View file

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from base import RestServlet, client_path_pattern
# TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet):
PATTERN = client_path_pattern("/initialSync$")
@defer.inlineCallbacks
def on_GET(self, request):
user = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
handler = self.handlers.message_handler
content = yield handler.snapshot_all_rooms(
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event
)
defer.returnValue((200, content))
def register_servlets(hs, http_server):
InitialSyncRestServlet(hs).register(http_server)

View file

@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import UserID
from base import RestServlet, client_path_pattern
import json
class LoginRestServlet(RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
def on_GET(self, request):
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
def on_OPTIONS(self, request):
return (200, {})
@defer.inlineCallbacks
def on_POST(self, request):
login_submission = _parse_json(request)
try:
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if not login_submission["user"].startswith('@'):
login_submission["user"] = UserID.create(
login_submission["user"], self.hs.hostname).to_string()
handler = self.handlers.login_handler
token = yield handler.login(
user=login_submission["user"],
password=login_submission["password"])
result = {
"user_id": login_submission["user"], # may have changed
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
class LoginFallbackRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request):
# TODO(kegan): This should be returning some HTML which is capable of
# hitting LoginRestServlet
return (200, {})
class PasswordResetRestServlet(RestServlet):
PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks
def on_POST(self, request):
reset_info = _parse_json(request)
try:
email = reset_info["email"]
user_id = reset_info["user_id"]
handler = self.handlers.login_handler
yield handler.reset_password(user_id, email)
# purposefully give no feedback to avoid people hammering different
# combinations.
defer.returnValue((200, {}))
except KeyError:
raise SynapseError(
400,
"Missing keys. Requires 'email' and 'user_id'."
)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View file

@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains REST servlets to do with presence: /presence/<paths>
"""
from twisted.internet import defer
from synapse.api.errors import SynapseError
from base import RestServlet, client_path_pattern
import json
import logging
logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet):
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=auth_user)
defer.returnValue((200, state))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
state = {}
try:
content = json.loads(request.content.read())
state["presence"] = content.pop("presence")
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
if not isinstance(state["status_msg"], basestring):
raise SynapseError(400, "status_msg must be a string.")
if content:
raise KeyError()
except SynapseError as e:
raise e
except:
raise SynapseError(400, "Unable to parse state")
yield self.handlers.presence_handler.set_state(
target_user=user, auth_user=auth_user, state=state)
defer.returnValue((200, {}))
def on_OPTIONS(self, request):
return (200, {})
class PresenceListRestServlet(RestServlet):
PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.handlers.presence_handler.get_presence_list(
observer_user=user, accepted=True)
for p in presence:
observed_user = p.pop("observed_user")
p["user_id"] = observed_user.to_string()
defer.returnValue((200, presence))
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if auth_user != user:
raise SynapseError(
400, "Cannot modify another user's presence list")
try:
content = json.loads(request.content.read())
except:
logger.exception("JSON parse error")
raise SynapseError(400, "Unable to parse content")
if "invite" in content:
for u in content["invite"]:
if not isinstance(u, basestring):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
invited_user = self.hs.parse_userid(u)
yield self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user
)
if "drop" in content:
for u in content["drop"]:
if not isinstance(u, basestring):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
dropped_user = self.hs.parse_userid(u)
yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)
defer.returnValue((200, {}))
def on_OPTIONS(self, request):
return (200, {})
def register_servlets(hs, http_server):
PresenceStatusRestServlet(hs).register(http_server)
PresenceListRestServlet(hs).register(http_server)

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
from base import RestServlet, client_path_pattern
import json
class ProfileDisplaynameRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
user,
)
defer.returnValue((200, {"displayname": displayname}))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
try:
content = json.loads(request.content.read())
new_name = content["displayname"]
except:
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname(
user, auth_user, new_name)
defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id):
return (200, {})
class ProfileAvatarURLRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
user,
)
defer.returnValue((200, {"avatar_url": avatar_url}))
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id)
try:
content = json.loads(request.content.read())
new_name = content["avatar_url"]
except:
defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url(
user, auth_user, new_name)
defer.returnValue((200, {}))
def on_OPTIONS(self, request, user_id):
return (200, {})
class ProfileRestServlet(RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)")
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id)
displayname = yield self.handlers.profile_handler.get_displayname(
user,
)
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
user,
)
defer.returnValue((200, {
"displayname": displayname,
"avatar_url": avatar_url
}))
def register_servlets(hs, http_server):
ProfileDisplaynameRestServlet(hs).register(http_server)
ProfileAvatarURLRestServlet(hs).register(http_server)
ProfileRestServlet(hs).register(http_server)

View file

@ -0,0 +1,291 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType
from base import RestServlet, client_path_pattern
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from hashlib import sha1
import hmac
import json
import logging
import urllib
logger = logging.getLogger(__name__)
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
# because the timing attack is so obscured by all the other code here it's
# unlikely to make much difference
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
class RegisterRestServlet(RestServlet):
"""Handles registration with the home server.
This servlet is in control of the registration flow; the registration
handler doesn't have a concept of multi-stages or sessions.
"""
PATTERN = client_path_pattern("/register$")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
# "session_id" : { __session_dict__ }
# }
# TODO: persistent storage
self.sessions = {}
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
return (
200,
{"flows": [
{
"type": LoginType.RECAPTCHA,
"stages": [
LoginType.RECAPTCHA,
LoginType.EMAIL_IDENTITY,
LoginType.PASSWORD
]
},
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
]}
)
else:
return (
200,
{"flows": [
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
},
{
"type": LoginType.PASSWORD
}
]}
)
@defer.inlineCallbacks
def on_POST(self, request):
register_json = _parse_json(request)
session = (register_json["session"]
if "session" in register_json else None)
login_type = None
if "type" not in register_json:
raise SynapseError(400, "Missing 'type' key.")
try:
login_type = register_json["type"]
stages = {
LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity
}
session_info = self._get_session_info(request, session)
logger.debug("%s : session info %s request info %s",
login_type, session_info, register_json)
response = yield stages[login_type](
request,
register_json,
session_info
)
if "access_token" not in response:
# isn't a final response
response["session"] = session_info["id"]
defer.returnValue((200, response))
except KeyError as e:
logger.exception(e)
raise SynapseError(400, "Missing JSON keys for login type %s." % (
login_type,
))
def on_OPTIONS(self, request):
return (200, {})
def _get_session_info(self, request, session_id):
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
LoginType.EMAIL_IDENTITY: False,
LoginType.RECAPTCHA: False
}
return self.sessions[session_id]
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
self.sessions.pop(session["id"])
@defer.inlineCallbacks
def _do_recaptcha(self, request, register_json, session):
if not self.hs.config.enable_registration_captcha:
raise SynapseError(400, "Captcha not required.")
yield self._check_recaptcha(request, register_json, session)
session[LoginType.RECAPTCHA] = True # mark captcha as done
self._save_session(session)
defer.returnValue({
"next": [LoginType.PASSWORD, LoginType.EMAIL_IDENTITY]
})
@defer.inlineCallbacks
def _check_recaptcha(self, request, register_json, session):
if ("captcha_bypass_hmac" in register_json and
self.hs.config.captcha_bypass_secret):
if "user" not in register_json:
raise SynapseError(400, "Captcha bypass needs 'user'")
want = hmac.new(
key=self.hs.config.captcha_bypass_secret,
msg=register_json["user"],
digestmod=sha1,
).hexdigest()
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got = str(register_json["captcha_bypass_hmac"])
if compare_digest(want, got):
session["user"] = register_json["user"]
defer.returnValue(None)
else:
raise SynapseError(
400, "Captcha bypass HMAC incorrect",
errcode=Codes.CAPTCHA_NEEDED
)
challenge = None
user_response = None
try:
challenge = register_json["challenge"]
user_response = register_json["response"]
except KeyError:
raise SynapseError(400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED)
ip_addr = self.hs.get_ip_from_request(request)
handler = self.handlers.registration_handler
yield handler.check_recaptcha(
ip_addr,
self.hs.config.recaptcha_private_key,
challenge,
user_response
)
@defer.inlineCallbacks
def _do_email_identity(self, request, register_json, session):
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
raise SynapseError(400, "Captcha is required.")
threepidCreds = register_json['threepidCreds']
handler = self.handlers.registration_handler
logger.debug("Registering email. threepidcreds: %s" % (threepidCreds))
yield handler.register_email(threepidCreds)
session["threepidCreds"] = threepidCreds # store creds for next stage
session[LoginType.EMAIL_IDENTITY] = True # mark email as done
self._save_session(session)
defer.returnValue({
"next": LoginType.PASSWORD
})
@defer.inlineCallbacks
def _do_password(self, request, register_json, session):
yield run_on_reactor()
if (self.hs.config.enable_registration_captcha and
not session[LoginType.RECAPTCHA]):
# captcha should've been done by this stage!
raise SynapseError(400, "Captcha is required.")
if ("user" in session and "user" in register_json and
session["user"] != register_json["user"]):
raise SynapseError(
400, "Cannot change user ID during registration"
)
password = register_json["password"].encode("utf-8")
desired_user_id = (register_json["user"].encode("utf-8")
if "user" in register_json else None)
if (desired_user_id
and urllib.quote(desired_user_id) != desired_user_id):
raise SynapseError(
400,
"User ID must only contain characters which do not " +
"require URL encoding.")
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
password=password
)
if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (
session["threepidCreds"], user_id)
)
yield handler.bind_emails(user_id, session["threepidCreds"])
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
self._remove_session(session)
defer.returnValue(result)
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View file

@ -0,0 +1,559 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer
from base import RestServlet, client_path_pattern
from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
import json
import logging
import urllib
logger = logging.getLogger(__name__)
class RoomCreateRestServlet(RestServlet):
# No PATTERN; we have custom dispatch rules here
def register(self, http_server):
PATTERN = "/createRoom"
register_txn_path(self, PATTERN, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_path("OPTIONS",
client_path_pattern("/rooms(?:/.*)?$"),
self.on_OPTIONS)
# define CORS for /createRoom[/txnid]
http_server.register_path("OPTIONS",
client_path_pattern("/createRoom(?:/.*)?$"),
self.on_OPTIONS)
@defer.inlineCallbacks
def on_PUT(self, request, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@defer.inlineCallbacks
def on_POST(self, request):
auth_user = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
room_config.update(info)
defer.returnValue((200, info))
@defer.inlineCallbacks
def make_room(self, room_config, auth_user, room_id):
handler = self.handlers.room_creation_handler
info = yield handler.create_room(
user_id=auth_user.to_string(),
room_id=room_id,
config=room_config
)
defer.returnValue(info)
def get_room_config(self, request):
try:
user_supplied_config = json.loads(request.content.read())
if "visibility" not in user_supplied_config:
# default visibility
user_supplied_config["visibility"] = "public"
return user_supplied_config
except (ValueError, TypeError):
raise SynapseError(400, "Body must be JSON.",
errcode=Codes.BAD_JSON)
def on_OPTIONS(self, request):
return (200, {})
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(RestServlet):
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
# /room/$roomid/state/$eventtype/$statekey
state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_path("GET",
client_path_pattern(state_key),
self.on_GET)
http_server.register_path("PUT",
client_path_pattern(state_key),
self.on_PUT)
http_server.register_path("GET",
client_path_pattern(no_state_key),
self.on_GET_no_state_key)
http_server.register_path("PUT",
client_path_pattern(no_state_key),
self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type):
return self.on_GET(request, room_id, event_type, "")
def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "")
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
user_id=user.to_string(),
room_id=room_id,
event_type=event_type,
state_key=state_key,
)
if not data:
raise SynapseError(
404, "Event not found.", errcode=Codes.NOT_FOUND
)
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
event_dict = {
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
if state_key is not None:
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(event_dict)
defer.returnValue((200, {}))
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event(
{
"type": event_type,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
)
defer.returnValue((200, {"event_id": event.event_id}))
def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented")
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_type)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(RestServlet):
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
PATTERN = ("/join/(?P<room_identifier>[^/]*)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_identifier):
user = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
# SynapseErrors.
identifier = None
is_room_alias = False
try:
identifier = self.hs.parse_roomalias(room_identifier)
is_room_alias = True
except SynapseError:
identifier = self.hs.parse_roomid(room_identifier)
# TODO: Support for specifying the home server to join with?
if is_room_alias:
handler = self.handlers.room_member_handler
ret_dict = yield handler.join_room_alias(user, identifier)
defer.returnValue((200, ret_dict))
else: # room id
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {"membership": Membership.JOIN},
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
}
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@defer.inlineCallbacks
def on_PUT(self, request, room_identifier, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_identifier)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing
class PublicRoomListRestServlet(RestServlet):
PATTERN = client_path_pattern("/publicRooms$")
@defer.inlineCallbacks
def on_GET(self, request):
handler = self.handlers.room_list_handler
data = yield handler.get_public_room_list()
defer.returnValue((200, data))
# TODO: Needs unit testing
class RoomMemberListRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id,
user_id=user.to_string())
for event in members["chunk"]:
# FIXME: should probably be state_key here, not user_id
target_user = self.hs.parse_userid(event["user_id"])
# Presence is an optional cache; don't fail if we can't fetch it
try:
presence_handler = self.handlers.presence_handler
presence_state = yield presence_handler.get_state(
target_user=target_user, auth_user=user
)
event["content"].update(presence_state)
except:
pass
defer.returnValue((200, members))
# TODO: Needs unit testing
class RoomMessageListRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
feedback=with_feedback,
as_client_event=as_client_event
)
defer.returnValue((200, msgs))
# TODO: Needs unit testing
class RoomStateRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
room_id=room_id,
user_id=user.to_string(),
)
defer.returnValue((200, events))
# TODO: Needs unit testing
class RoomInitialSyncRestServlet(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
user_id=user.to_string(),
pagin_config=pagination_config,
)
defer.returnValue((200, content))
class RoomTriggerBackfill(RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$")
@defer.inlineCallbacks
def on_GET(self, request, room_id):
remote_server = urllib.unquote(
request.args["remote"][0]
).decode("UTF-8")
limit = int(request.args["limit"][0])
handler = self.handlers.federation_handler
events = yield handler.backfill(remote_server, room_id, limit)
res = [self.hs.serialize_event(event) for event in events]
defer.returnValue((200, res))
# TODO: Needs unit testing
class RoomMembershipRestServlet(RestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|kick)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
# target user is you unless it is an invite
state_key = user.to_string()
if membership_action in ["invite", "ban", "kick"]:
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"]
if membership_action == "kick":
membership_action = "leave"
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(
{
"type": EventTypes.Member,
"content": {"membership": unicode(membership_action)},
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
}
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, membership_action, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_id, membership_action)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomRedactEventRestServlet(RestServlet):
def register(self, http_server):
PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
event = yield msg_handler.create_and_send_event(
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
}
)
defer.returnValue((200, {"event_id": event.event_id}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_id, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
class RoomTypingRestServlet(RestServlet):
PATTERN = client_path_pattern(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
)
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = self.hs.parse_userid(urllib.unquote(user_id))
content = _parse_json(request)
typing_handler = self.handlers.typing_notification_handler
if content["typing"]:
yield typing_handler.started_typing(
target_user=target_user,
auth_user=auth_user,
room_id=room_id,
timeout=content.get("timeout", 30000),
)
else:
yield typing_handler.stopped_typing(
target_user=target_user,
auth_user=auth_user,
room_id=room_id,
)
defer.returnValue((200, {}))
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""Registers a transaction-based path.
This registers two paths:
PUT regex_string/$txnid
POST regex_string
Args:
regex_string (str): The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server : The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs.
"""
http_server.register_path(
"POST",
client_path_pattern(regex_string + "$"),
servlet.on_POST
)
http_server.register_path(
"PUT",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_PUT
)
if with_get:
http_server.register_path(
"GET",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_GET
)
def register_servlets(hs, http_server):
RoomStateEventRestServlet(hs).register(http_server)
RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomTriggerBackfill(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomInitialSyncRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)

View file

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains logic for storing HTTP PUT transactions. This is used
to ensure idempotency when performing PUTs using the REST API."""
import logging
logger = logging.getLogger(__name__)
# FIXME: elsewhere we use FooStore to indicate something in the storage layer...
class HttpTransactionStore(object):
def __init__(self):
# { key : (txn_id, response) }
self.transactions = {}
def get_response(self, key, txn_id):
"""Retrieve a response for this request.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request
Returns:
A tuple of (HTTP response code, response content) or None.
"""
try:
logger.debug("get_response Key: %s TxnId: %s", key, txn_id)
(last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", key)
return response
except KeyError:
pass
return None
def store_response(self, key, txn_id, response):
"""Stores an HTTP response tuple.
Args:
key (str): A transaction-independent key for this request. Usually
this is a combination of the path (without the transaction id)
and the user's access token.
txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content)
"""
logger.debug("store_response Key: %s TxnId: %s", key, txn_id)
self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response):
"""Stores the request/response pair of an HTTP transaction.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
response (tuple): A tuple of (response code, response dict)
txn_id (str): The transaction ID for this request.
"""
self.store_response(self._get_key(request), txn_id, response)
def get_client_transaction(self, request, txn_id):
"""Retrieves a stored response if there was one.
Args:
request (twisted.web.http.Request): The twisted HTTP request. This
request must have the transaction ID as the last path segment.
txn_id (str): The transaction ID for this request.
Returns:
The response tuple.
Raises:
KeyError if the transaction was not found.
"""
response = self.get_response(self._get_key(request), txn_id)
if response is None:
raise KeyError("Transaction not found.")
return response
def _get_key(self, request):
token = request.args["access_token"][0]
path_without_txn_id = request.path.rsplit("/", 1)[0]
return path_without_txn_id + "/" + token

View file

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from base import RestServlet, client_path_pattern
import hmac
import hashlib
import base64
class VoipRestServlet(RestServlet):
PATTERN = client_path_pattern("/voip/turnServer$")
@defer.inlineCallbacks
def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, auth_user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard base64 encoding here, *not* syutil's
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
defer.returnValue((200, {
'username': username,
'password': password,
'ttl': userLifetime / 1000,
'uris': turnUris,
}))
def on_OPTIONS(self, request):
return (200, {})
def register_servlets(hs, http_server):
VoipRestServlet(hs).register(http_server)

View file

View file

View file

@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, Codes, cs_error
)
from twisted.protocols.basic import FileSender
from twisted.web import server, resource
from twisted.internet import defer
import base64
import json
import logging
import os
import re
logger = logging.getLogger(__name__)
class ContentRepoResource(resource.Resource):
"""Provides file uploading and downloading.
Uploads are POSTed to wherever this Resource is linked to. This resource
returns a "content token" which can be used to GET this content again. The
token is typically a path, but it may not be. Tokens can expire, be
one-time uses, etc.
In this case, the token is a path to the file and contains 3 interesting
sections:
- User ID base64d (for namespacing content to each user)
- random 24 char string
- Content type base64d (so we can return it when clients GET it)
"""
isLeaf = True
def __init__(self, hs, directory, auth, external_addr):
resource.Resource.__init__(self)
self.hs = hs
self.directory = directory
self.auth = auth
self.external_addr = external_addr.rstrip('/')
self.max_upload_size = hs.config.max_upload_size
if not os.path.isdir(self.directory):
os.mkdir(self.directory)
logger.info("ContentRepoResource : Created %s directory.",
self.directory)
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
auth_user = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
auth_user.to_string()
).replace('=', '')
# use a random string for the main portion
main_part = random_string(24)
# suffix with a file extension if we can make one. This is nice to
# provide a hint to clients on the file information. We will also reuse
# this info to spit back the content type to the client.
suffix = ""
if request.requestHeaders.hasHeader("Content-Type"):
content_type = request.requestHeaders.getRawHeaders(
"Content-Type")[0]
suffix = "." + base64.urlsafe_b64encode(content_type)
if (content_type.split("/")[0].lower() in
["image", "video", "audio"]):
file_ext = content_type.split("/")[-1]
# be a little paranoid and only allow a-z
file_ext = re.sub("[^a-z]", "", file_ext)
suffix += "." + file_ext
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s",
auth_user.to_string(),
file_path)
# keep trying to make a non-clashing file, with a sensible max attempts
attempts = 0
while os.path.exists(file_path):
main_part = random_string(24)
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
attempts += 1
if attempts > 25: # really? Really?
raise SynapseError(500, "Unable to create file.")
defer.returnValue(file_path)
def render_GET(self, request):
# no auth here on purpose, to allow anyone to view, even across home
# servers.
# TODO: A little crude here, we could do this better.
filename = request.path.split('/')[-1]
# be paranoid
filename = re.sub("[^0-9A-z.-_]", "", filename)
file_path = self.directory + "/" + filename
logger.debug("Searching for %s", file_path)
if os.path.isfile(file_path):
# filename has the content type
base64_contentype = filename.split(".")[1]
content_type = base64.urlsafe_b64decode(base64_contentype)
logger.info("Sending file %s", file_path)
f = open(file_path, 'rb')
request.setHeader('Content-Type', content_type)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our matrix
# clients are smart enough to be happy with Cache-Control (right?)
request.setHeader(
"Cache-Control", "public,max-age=86400,s-maxage=86400"
)
d = FileSender().beginFileTransfer(f, request)
# after the file has been sent, clean up and finish the request
def cbFinished(ignored):
f.close()
request.finish()
d.addCallback(cbFinished)
else:
respond_with_json_bytes(
request,
404,
json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)),
send_cors=True)
return server.NOT_DONE_YET
def render_POST(self, request):
self._async_render(request)
return server.NOT_DONE_YET
def render_OPTIONS(self, request):
respond_with_json_bytes(request, 200, {}, send_cors=True)
return server.NOT_DONE_YET
@defer.inlineCallbacks
def _async_render(self, request):
try:
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
)
fname = yield self.map_request_to_name(request)
# TODO I have a suspicious feeling this is just going to block
with open(fname, "wb") as f:
f.write(request.content.read())
# FIXME (erikj): These should use constants.
file_name = os.path.basename(fname)
# FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to
# serve it via the non-SSL listener...
url = "%s/_matrix/content/%s" % (
self.external_addr, file_name
)
respond_with_json_bytes(request, 200,
json.dumps({"content_token": url}),
send_cors=True)
except CodeMessageException as e:
logger.exception(e)
respond_with_json_bytes(request, e.code,
json.dumps(cs_exception(e)))
except Exception as e:
logger.error("Failed to store file: %s" % e)
respond_with_json_bytes(
request,
500,
json.dumps({"error": "Internal server error"}),
send_cors=True)

View file

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import PIL.Image
# check for JPEG support.
try:
PIL.Image._getdecoder("rgb", "jpeg", None)
except IOError as e:
if str(e).startswith("decoder jpeg not available"):
raise Exception(
"FATAL: jpeg codec not supported. Install pillow correctly! "
" 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
" pip install pillow --user'"
)
except Exception:
# any other exception is fine
pass
# check for PNG support.
try:
PIL.Image._getdecoder("rgb", "zip", None)
except IOError as e:
if str(e).startswith("decoder zip not available"):
raise Exception(
"FATAL: zip codec not supported. Install pillow correctly! "
" 'sudo apt-get install libjpeg-dev' then 'pip uninstall pillow &&"
" pip install pillow --user'"
)
except Exception:
# any other exception is fine
pass

View file

@ -0,0 +1,378 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .thumbnailer import Thumbnailer
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_exception, CodeMessageException, cs_error, Codes, SynapseError
)
from twisted.internet import defer
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
import os
import logging
logger = logging.getLogger(__name__)
class BaseMediaResource(Resource):
isLeaf = True
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
self.client = hs.get_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
self.downloads = {}
@staticmethod
def catch_errors(request_handler):
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
try:
yield request_handler(self, request)
except CodeMessageException as e:
logger.exception(e)
respond_with_json(
request, e.code, cs_exception(e), send_cors=True
)
except:
logger.exception(
"Failed handle request %s.%s on %r",
request_handler.__module__,
request_handler.__name__,
self,
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
return wrapped_request_handler
@staticmethod
def _parse_media_id(request):
try:
server_name, media_id = request.postpath
return (server_name, media_id)
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKKOWN,
)
@staticmethod
def _parse_integer(request, arg_name, default=None):
try:
if default is None:
return int(request.args[arg_name][0])
else:
return int(request.args.get(arg_name, [default])[0])
except:
raise SynapseError(
400,
"Missing integer argument %r" % (arg_name,),
Codes.UNKNOWN,
)
@staticmethod
def _parse_string(request, arg_name, default=None):
try:
if default is None:
return request.args[arg_name][0]
else:
return request.args.get(arg_name, [default])[0]
except:
raise SynapseError(
400,
"Missing string argument %r" % (arg_name,),
Codes.UNKNOWN,
)
def _respond_404(self, request):
respond_with_json(
request, 404,
cs_error(
"Not found %r" % (request.postpath,),
code=Codes.NOT_FOUND,
),
send_cors=True
)
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
def _get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
fname = self.filepaths.remote_media_filepath(
server_name, file_id
)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=None,
media_length=length,
filesystem_id=file_id,
)
except:
os.remove(fname)
raise
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": None,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
yield self._generate_remote_thumbnails(
server_name, media_id, media_info
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path,
file_size=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request)
request.finish()
else:
self._respond_404(request)
def _get_thumbnail_requirements(self, media_type):
if media_type == "image/jpeg":
return (
(32, 32, "crop", "image/jpeg"),
(96, 96, "crop", "image/jpeg"),
(320, 240, "scale", "image/jpeg"),
(640, 480, "scale", "image/jpeg"),
)
elif (media_type == "image/png") or (media_type == "image/gif"):
return (
(32, 32, "crop", "image/png"),
(96, 96, "crop", "image/png"),
(320, 240, "scale", "image/png"),
(640, 480, "scale", "image/png"),
)
else:
return ()
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({
"width": m_width,
"height": m_height,
})

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class DownloadResource(BaseMediaResource):
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@BaseMediaResource.catch_errors
@defer.inlineCallbacks
def _async_render_GET(self, request):
try:
server_name, media_id = request.postpath
except:
self._respond_404(request)
return
if server_name == self.server_name:
yield self._respond_local_file(request, media_id)
else:
yield self._respond_remote_file(request, server_name, media_id)
@defer.inlineCallbacks
def _respond_local_file(self, request, media_id):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
return
media_type = media_info["media_type"]
media_length = media_info["media_length"]
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(
request, media_type, file_path, media_length
)
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id):
media_info = yield self._get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield self._respond_with_file(
request, media_type, file_path, media_length
)

View file

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
class MediaFilePaths(object):
def __init__(self, base_path):
self.base_path = base_path
def default_thumbnail(self, default_top_level, default_sub_type, width,
height, content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
self.base_path, "default_thumbnails", default_top_level,
default_sub_type, file_name
)
def local_media_filepath(self, media_id):
return os.path.join(
self.base_path, "local_content",
media_id[0:2], media_id[2:4], media_id[4:]
)
def local_media_thumbnail(self, media_id, width, height, content_type,
method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
return os.path.join(
self.base_path, "local_thumbnails",
media_id[0:2], media_id[2:4], media_id[4:],
file_name
)
def remote_media_filepath(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_content", server_name,
file_id[0:2], file_id[2:4], file_id[4:]
)
def remote_media_thumbnail(self, server_name, file_id, width, height,
content_type, method):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)

View file

@ -0,0 +1,77 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
from .filepath import MediaFilePaths
from twisted.web.resource import Resource
import logging
logger = logging.getLogger(__name__)
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
Uploads are POSTed to a resource which returns a token which is used to GET
the download::
=> POST /_matrix/media/v1/upload HTTP/1.1
Content-Type: <media-type>
<media>
<= HTTP/1.1 200 OK
Content-Type: application/json
{ "content_uri": "mxc://<server-name>/<media-id>" }
=> GET /_matrix/media/v1/download/<server-name>/<media-id> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: <media-type>
Content-Disposition: attachment;filename=<upload-filename>
<media>
Clients can get thumbnails by supplying a desired width and height and
thumbnailing method::
=> GET /_matrix/media/v1/thumbnail/<server_name>
/<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: image/jpeg or image/png
<thumbnail>
The thumbnail methods are "crop" and "scale". "scale" trys to return an
image where either the width or the height is smaller than the requested
size. The client should then scale and letterbox the image if it needs to
fit within a given rectangle. "crop" trys to return an image where the
width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit
within a given rectangle.
"""
def __init__(self, hs):
Resource.__init__(self)
filepaths = MediaFilePaths(hs.config.media_store_path)
self.putChild("upload", UploadResource(hs, filepaths))
self.putChild("download", DownloadResource(hs, filepaths))
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))

View file

@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_resource import BaseMediaResource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class ThumbnailResource(BaseMediaResource):
isLeaf = True
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@BaseMediaResource.catch_errors
@defer.inlineCallbacks
def _async_render_GET(self, request):
server_name, media_id = self._parse_media_id(request)
width = self._parse_integer(request, "width")
height = self._parse_integer(request, "height")
method = self._parse_string(request, "method", "scale")
m_type = self._parse_string(request, "type", "image/png")
if server_name == self.server_name:
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
else:
yield self._respond_remote_thumbnail(
request, server_name, media_id,
width, height, method, m_type
)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
file_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method,
)
yield self._respond_with_file(request, t_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead.
media_info = yield self._get_remote_media(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
file_id = thumbnail_info["filesystem_id"]
t_length = thumbnail_info["thumbnail_length"]
file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method,
)
yield self._respond_with_file(request, t_type, file_path, t_length)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
)
@defer.inlineCallbacks
def _respond_default_thumbnail(self, request, media_info, width, height,
method, m_type):
media_type = media_info["media_type"]
top_level_type = media_type.split("/")[0]
sub_type = media_type.split("/")[-1].split(";")[0]
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, sub_type,
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
top_level_type, "_default",
)
if not thumbnail_infos:
thumbnail_infos = yield self.store.get_default_thumbnails(
"_default", "_default",
)
if not thumbnail_infos:
self._respond_404(request)
return
thumbnail_info = self._select_thumbnail(
width, height, "crop", m_type, thumbnail_infos
)
t_width = thumbnail_info["thumbnail_width"]
t_height = thumbnail_info["thumbnail_height"]
t_type = thumbnail_info["thumbnail_type"]
t_method = thumbnail_info["thumbnail_method"]
t_length = thumbnail_info["thumbnail_length"]
file_path = self.filepaths.default_thumbnail(
top_level_type, sub_type, t_width, t_height, t_type, t_method,
)
yield self.respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):
d_w = desired_width
d_h = desired_height
if desired_method.lower() == "crop":
info_list = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
info_list.append((
aspect_quality, size_quality, type_quality,
length_quality, info
))
if info_list:
return min(info_list)[-1]
else:
info_list = []
info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
info_list.append((
size_quality, type_quality, length_quality, info
))
elif t_method == "scale":
info_list2.append((
size_quality, type_quality, length_quality, info
))
if info_list:
return min(info_list)[-1]
else:
return min(info_list2)[-1]

View file

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import PIL.Image as Image
from io import BytesIO
class Thumbnailer(object):
FORMATS = {
"image/jpeg": "JPEG",
"image/png": "PNG",
}
def __init__(self, input_path):
self.image = Image.open(input_path)
self.width, self.height = self.image.size
def aspect(self, max_width, max_height):
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
(w_in / h_in) = (w_out / h_out)
w_out = min(w_max, h_max * (w_in / h_in))
h_out = min(h_max, w_max * (h_in / w_in))
Args:
max_width: The largest possible width.
max_height: The larget possible height.
"""
if max_width * self.height < max_height * self.width:
return (max_width, (max_width * self.height) // self.width)
else:
return ((max_height * self.width) // self.height, max_height)
def scale(self, output_path, width, height, output_type):
"""Rescales the image to the given dimensions"""
scaled = self.image.resize((width, height), Image.ANTIALIAS)
return self.save_image(scaled, output_type, output_path)
def crop(self, output_path, width, height, output_type):
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
w_scaled = max(w_out, h_out * (w_in / h_in))
h_scaled = max(h_out, w_out * (h_in / w_in))
Args:
max_width: The largest possible width.
max_height: The larget possible height.
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
scaled_image = self.image.resize(
(width, scaled_height), Image.ANTIALIAS
)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
scaled_image = self.image.resize(
(scaled_width, height), Image.ANTIALIAS
)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self.save_image(cropped, output_type, output_path)
def save_image(self, output_image, output_type, output_path):
output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70)
output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
return len(output_bytes)

View file

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException
)
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from .base_resource import BaseMediaResource
import logging
logger = logging.getLogger(__name__)
class UploadResource(BaseMediaResource):
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
def render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@defer.inlineCallbacks
def _async_render_POST(self, request):
try:
auth_user = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
)
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
media_type = headers.getRawHeaders("Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
#if headers.hasHeader("Content-Disposition"):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(request.content.read())
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=None,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
content_uri = "mxc://%s/%s" % (self.server_name, media_id)
respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True
)
except CodeMessageException as e:
logger.exception(e)
respond_with_json(request, e.code, cs_exception(e), send_cors=True)
except:
logger.exception("Failed to store file")
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)