Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api

This commit is contained in:
David Baker 2016-08-11 14:09:13 +01:00
commit b4ecf0b886
208 changed files with 9809 additions and 3566 deletions

View file

@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import (
report_event,
openid,
notifications,
devices,
)
from synapse.http.server import JsonResource
@ -92,3 +93,4 @@ class ClientRestResource(JsonResource):
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)
notifications.register_servlets(hs, client_resource)
devices.register_servlets(hs, client_resource)

View file

@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
class PurgeMediaCacheRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
super(PurgeMediaCacheRestServlet, self).__init__(hs)
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
before_ts = request.args.get("before_ts", None)
if not before_ts:
raise SynapseError(400, "Missing 'before_ts' arg")
logger.info("before_ts: %r", before_ts[0])
try:
before_ts = int(before_ts[0])
except Exception:
raise SynapseError(400, "Invalid 'before_ts' arg")
ret = yield self.media_repository.delete_old_remote_media(before_ts)
defer.returnValue((200, ret))
class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
yield self.handlers.message_handler.purge_history(room_id, event_id)
defer.returnValue((200, {}))
class DeactivateAccountRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
super(DeactivateAccountRestServlet, self).__init__(hs)
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(target_user_id)
yield self.store.user_delete_threepids(target_user_id)
yield self.store.user_set_password_hash(target_user_id, None)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)

View file

@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()

View file

@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Guest users must specify room_id param")
if "room_id" in request.args:
room_id = request.args["room_id"][0]
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
as_client_event = "raw" not in request.args
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
try:
timeout = int(request.args["timeout"][0])
except ValueError:
raise SynapseError(400, "timeout must be in milliseconds.")
chunk = yield handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
as_client_event=as_client_event,
affect_presence=(not is_guest),
room_id=room_id,
is_guest=is_guest,
)
except:
logger.exception("Event stream failed")
raise
as_client_event = "raw" not in request.args
chunk = yield handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
as_client_event=as_client_event,
affect_presence=(not is_guest),
room_id=room_id,
is_guest=is_guest,
)
defer.returnValue((200, chunk))

View file

@ -54,10 +54,8 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
def on_GET(self, request):
flows = []
@ -108,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
uri = "%s/proxyValidate" % (self.cas_server_url,)
args = {
"ticket": login_submission["ticket"],
"service": login_submission["service"]
}
body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body)
defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
@ -143,16 +130,24 @@ class LoginRestServlet(ClientV1RestServlet):
user_id, self.hs.hostname
).to_string()
auth_handler = self.handlers.auth_handler
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
auth_handler = self.auth_handler
user_id = yield auth_handler.validate_password_login(
user_id=user_id,
password=login_submission["password"])
password=login_submission["password"],
)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
defer.returnValue((200, result))
@ -160,65 +155,27 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
device_id = yield self._register_device(user_id, login_submission)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
"device_id": device_id,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks
def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
@ -243,19 +200,28 @@ class LoginRestServlet(ClientV1RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
device_id = yield self._register_device(
registered_user_id, login_submission
)
access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(
registered_user_id, device_id,
login_submission.get("initial_device_display_name")
)
)
result = {
"user_id": user_id, # may have changed
"user_id": registered_user_id,
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
@ -267,32 +233,25 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
return (user, attributes)
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class SAML2RestServlet(ClientV1RestServlet):
@ -338,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs):
super(CasRestServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
def on_GET(self, request):
return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@ -381,6 +328,7 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_GET(self, request):
@ -412,14 +360,14 @@ class CasTicketServlet(ClientV1RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if not registered_user_id:
registered_user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(user_id)
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
@ -433,30 +381,39 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
user = None
attributes = None
try:
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
success = (root[0].tag.endswith("authenticationSuccess"))
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in
# attribute tags to the full URL of the namespace.
# We don't care about namespace here and it will always
# be encased in curly braces, so we remove them.
tag = attribute.tag
if "}" in tag:
tag = tag.split("}")[1]
attributes[tag] = attribute.text
if user is None:
raise Exception("CAS response does not contain user")
if attributes is None:
raise Exception("CAS response does not contain attributes")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
raise LoginError(401, "Invalid CAS response",
errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(401, "Unsuccessful CAS response",
errcode=Codes.UNAUTHORIZED)
return user, attributes
def register_servlets(hs, http_server):
@ -466,5 +423,3 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)

View file

@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.store.get_push_rules_for_user(user_id)
rules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
rules = format_push_rules_for_user(requester.user, rules)
path = request.postpath[1:]

View file

@ -17,7 +17,11 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.servlet import (
parse_json_object_from_request, parse_string, RestServlet
)
from synapse.http.server import finish_request
from synapse.api.errors import StoreError
from .base import ClientV1RestServlet, client_path_patterns
@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
return 200, {}
class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
self.auth = hs.get_v1auth()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
pusher_pool = self.hs.get_pusherpool()
try:
yield pusher_pool.remove_pusher(
app_id=app_id,
pushkey=pushkey,
user_id=user.to_string(),
)
except StoreError as se:
if se.code != 404:
# This is fine: they're already unsubscribed
raise
self.notifier.on_new_replication_data()
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)
PushersRemoveRestServlet(hs).register(http_server)

View file

@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__(hs)
# sessions are stored as:
# self.sessions = {
@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
user_localpart = register_json["user"].encode("utf-8")
handler = self.handlers.registration_handler
(user_id, token) = yield handler.appservice_register(
user_id = yield handler.appservice_register(
user_localpart, as_token
)
token = yield self.auth_handler.issue_access_token(user_id)
self._remove_session(session)
defer.returnValue({
"user_id": user_id,
@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Shared secret registration is not enabled")
user = register_json["user"].encode("utf-8")
password = register_json["password"].encode("utf-8")
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
if "\x00" in user:
raise SynapseError(400, "Invalid user")
if "\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
password = register_json["password"].encode("utf-8")
)
want_mac.update(user)
want_mac.update("\x00")
want_mac.update(password)
want_mac.update("\x00")
want_mac.update("admin" if admin else "notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
handler = self.handlers.registration_handler
user_id, token = yield handler.register(
localpart=user,
password=password,
admin=bool(admin),
)
self._remove_session(session)
defer.returnValue({
@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Failed to parse 'duration_seconds'")
if duration_seconds > self.direct_user_creation_max_duration:
duration_seconds = self.direct_user_creation_max_duration
password_hash = user_json["password_hash"].encode("utf-8") \
if user_json.get("password_hash") else None
handler = self.handlers.registration_handler
user_id, token = yield handler.get_or_create_user(
localpart=localpart,
displayname=displayname,
duration_seconds=duration_seconds
duration_in_ms=(duration_seconds * 1000),
password_hash=password_hash
)
defer.returnValue({

View file

@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import Filter
from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
from synapse.http.servlet import parse_json_object_from_request
import logging
import urllib
import ujson as json
logger = logging.getLogger(__name__)
@ -72,8 +74,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def get_room_config(self, request):
user_supplied_config = parse_json_object_from_request(request)
# default visibility
user_supplied_config.setdefault("visibility", "public")
return user_supplied_config
def on_OPTIONS(self, request):
@ -279,8 +279,16 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
handler = self.handlers.room_list_handler
data = yield handler.get_public_room_list()
try:
yield self.auth.get_user_by_req(request)
except AuthError:
# This endpoint isn't authed, but its useful to know who's hitting
# it if they *do* supply an access token
pass
handler = self.hs.get_room_list_handler()
data = yield handler.get_aggregated_public_room_list()
defer.returnValue((200, data))
@ -321,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
request, default_limit=10,
)
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
handler = self.handlers.message_handler
msgs = yield handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
as_client_event=as_client_event
as_client_event=as_client_event,
event_filter=event_filter,
)
defer.returnValue((200, msgs))

View file

@ -25,7 +25,9 @@ import logging
logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,)):
def client_v2_patterns(path_regex, releases=(0,),
v2_alpha=True,
unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)):
Returns:
SRE_Pattern
"""
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
patterns = []
if v2_alpha:
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))

View file

@ -28,14 +28,46 @@ import logging
logger = logging.getLogger(__name__)
class PasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(PasswordRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class PasswordRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/password")
PATTERNS = client_v2_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
@ -89,15 +121,90 @@ class PasswordRestServlet(RestServlet):
return 200, {}
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/deactivate$")
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
super(DeactivateAccountRestServlet, self).__init__()
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
authed, result, params, _ = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
user_id = None
requester = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
if user_id != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
# FIXME: Theoretically there is a race here wherein user resets password
# using threepid.
yield self.store.user_delete_access_tokens(user_id)
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
defer.returnValue((200, {}))
class ThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
super(ThreepidRequestTokenRestServlet, self).__init__()
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class ThreepidRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid")
PATTERNS = client_v2_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_GET(self, request):
@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View file

@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks

View file

@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http import servlet
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
class DevicesRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DevicesRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
defer.returnValue((200, {"devices": devices}))
class DeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
releases=[], v2_alpha=False)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(DeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
device = yield self.device_handler.get_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, device))
@defer.inlineCallbacks
def on_DELETE(self, request, device_id):
# XXX: it's not completely obvious we want to expose this endpoint.
# It allows the client to delete access tokens, which feels like a
# thing which merits extra auth. But if we want to do the interactive-
# auth dance, we should really make it possible to delete more than one
# device at a time.
requester = yield self.auth.get_user_by_req(request)
yield self.device_handler.delete_device(
requester.user.to_string(),
device_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
body = servlet.parse_json_object_from_request(request)
yield self.device_handler.update_device(
requester.user.to_string(),
device_id,
body
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
DevicesRestServlet(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)

View file

@ -13,24 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import simplejson as json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
import synapse.api.errors
import synapse.server
import synapse.types
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
from canonicaljson import encode_canonical_json
from ._base import client_v2_patterns
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet):
"""
POST /keys/upload/<device_id> HTTP/1.1
POST /keys/upload HTTP/1.1
Content-Type: application/json
{
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
},
}
"""
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
releases=())
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
body = parse_json_object_from_request(request)
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if (requester.device_id is not None and
device_id != requester.device_id):
logger.warning("Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
requester.device_id, device_id)
else:
device_id = requester.device_id
if device_id is None:
raise synapse.api.errors.SynapseError(
400,
"To upload keys, you must pass device_id when authenticating"
)
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet):
user_id, device_id, time_now, key_list
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@defer.inlineCallbacks
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet):
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
result = yield self.handle_request(body)
result = yield self.e2e_keys_handler.query_devices(body)
defer.returnValue(result)
@defer.inlineCallbacks
@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet):
auth_user_id = requester.user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.handle_request(
result = yield self.e2e_keys_handler.query_devices(
{"device_keys": {user_id: device_ids}}
)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet):
"""

View file

@ -41,17 +41,59 @@ else:
logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register")
class RegisterRequestTokenRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
class RegisterRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/register$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
super(RegisterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.auth_handler = hs.get_handlers().auth_handler
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_POST(self, request):
@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
"Do not understand membership kind: %s" % (kind,)
)
if '/register/email/requestToken' in request.path:
ret = yield self.onEmailTokenRequest(request)
defer.returnValue(ret)
body = parse_json_object_from_request(request)
# we do basic sanity checks here because the auth layer will store these
@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet):
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
if isinstance(body.get("user"), basestring):
desired_username = body["user"]
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
desired_username = body.get("user", desired_username)
if isinstance(desired_username, basestring):
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0], body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
desired_username, desired_password, body
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
authed, result, params, session_id = yield self.auth_handler.check_auth(
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
defer.returnValue((401, auth_result))
return
if registered_user_id is not None:
@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
"Already registered user ID %r for this session",
registered_user_id
)
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
refresh_token = yield self.auth_handler.issue_refresh_token(
registered_user_id
# don't re-register the email address
add_email = False
else:
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.",
Codes.MISSING_PARAM)
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
generate_token=False,
)
defer.returnValue((200, {
"user_id": registered_user_id,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
}))
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
)
desired_username = params.get("username", None)
new_password = params.get("password", None)
guest_access_token = params.get("guest_access_token", None)
add_email = True
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
return_dict = yield self._create_registration_details(
registered_user_id, params
)
# remember that we've now registered that user account, and with what
# user ID (since the user may not have specified)
self.auth_handler.set_session_data(
session_id, "registered_user_id", user_id
)
if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
yield self._register_email_threepid(
registered_user_id, threepid, return_dict["access_token"],
params.get("bind_email")
)
if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (
self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users
):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=user_tuple["token_id"],
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = yield self._create_registration_details(user_id, token)
defer.returnValue((200, result))
defer.returnValue((200, return_dict))
def on_OPTIONS(self, _):
return 200, {}
@defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
def _do_appservice_registration(self, username, as_token, body):
user_id = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue((yield self._create_registration_details(user_id, token)))
defer.returnValue((yield self._create_registration_details(user_id, body)))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(mac)
got_mac = str(body["mac"])
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet):
403, "HMAC incorrect",
)
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
(user_id, _) = yield self.registration_handler.register(
localpart=username, password=password, generate_token=False,
)
defer.returnValue((yield self._create_registration_details(user_id, token)))
result = yield self._create_registration_details(user_id, body)
defer.returnValue(result)
@defer.inlineCallbacks
def _create_registration_details(self, user_id, token):
refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
def _register_email_threepid(self, user_id, threepid, token, bind_email):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
HS config
Also optionally binds emails to the given user_id on the identity server
Args:
user_id (str): id of user
threepid (object): m.login.email.identity auth response
token (str): access_token for the user
bind_email (bool): true if the client requested the email to be
bound at the identity server
Returns:
defer.Deferred:
"""
reqd = ('medium', 'address', 'validated_at')
if any(x not in threepid for x in reqd):
logger.info("Can't add incomplete 3pid")
defer.returnValue()
yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
if (self.hs.config.email_enable_notifs and
self.hs.config.email_notif_for_new_users):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = yield self.store.get_user_by_access_token(
token
)
token_id = user_tuple["token_id"]
yield self.hs.get_pusherpool().add_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
app_id="m.email",
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
data={},
)
if bind_email:
logger.info("bind_email specified: binding")
logger.debug("Binding emails %s to %s" % (
threepid, user_id
))
yield self.identity_handler.bind_threepid(
threepid['threepid_creds'], user_id
)
else:
logger.info("bind_email not specified: not binding email")
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token
and refresh_token.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
device_id = yield self._register_device(user_id, params)
access_token, refresh_token = (
yield self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id,
initial_display_name=params.get("initial_device_display_name")
)
)
defer.returnValue({
"user_id": user_id,
"access_token": token,
"access_token": access_token,
"home_server": self.hs.hostname,
"refresh_token": refresh_token,
"device_id": device_id,
})
@defer.inlineCallbacks
def onEmailTokenRequest(self, request):
body = parse_json_object_from_request(request)
def _register_device(self, user_id, params):
"""Register a device for a user.
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
This is called after the user's credentials have been validated, but
before the access token has been issued.
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id = self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
return device_id
@defer.inlineCallbacks
def _do_guest_registration(self):
@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet):
generate_token=False,
make_guest=True
)
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
access_token = self.auth_handler.generate_access_token(
user_id, ["guest = true"]
)
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
# so long as we don't return a refresh_token here.
defer.returnValue((200, {
"user_id": user_id,
"access_token": access_token,
@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet):
def register_servlets(hs, http_server):
RegisterRequestTokenRestServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)

View file

@ -38,10 +38,14 @@ class TokenRefreshRestServlet(RestServlet):
body = parse_json_object_from_request(request)
try:
old_refresh_token = body["refresh_token"]
auth_handler = self.hs.get_handlers().auth_handler
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token)
new_access_token = yield auth_handler.issue_access_token(user_id)
auth_handler = self.hs.get_auth_handler()
refresh_result = yield self.store.exchange_refresh_token(
old_refresh_token, auth_handler.generate_refresh_token
)
(user_id, new_refresh_token, device_id) = refresh_result
new_access_token = yield auth_handler.issue_access_token(
user_id, device_id
)
defer.returnValue((200, {
"access_token": new_access_token,
"refresh_token": new_refresh_token,

View file

@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
def on_GET(self, request):
return (200, {
"versions": ["r0.0.1"]
"versions": [
"r0.0.1",
"r0.1.0",
"r0.2.0",
]
})

View file

@ -15,6 +15,7 @@
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@ -210,9 +211,10 @@ class RemoteKey(Resource):
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:
logger.info("Failed to fetch key: %s", e)
except:
logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)

View file

@ -15,14 +15,12 @@
from synapse.http.server import respond_with_json_bytes, finish_request
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, Codes, cs_error
Codes, cs_error
)
from twisted.protocols.basic import FileSender
from twisted.web import server, resource
from twisted.internet import defer
import base64
import simplejson as json
@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
"""
isLeaf = True
def __init__(self, hs, directory, auth, external_addr):
def __init__(self, hs, directory):
resource.Resource.__init__(self)
self.hs = hs
self.directory = directory
self.auth = auth
self.external_addr = external_addr.rstrip('/')
self.max_upload_size = hs.config.max_upload_size
if not os.path.isdir(self.directory):
os.mkdir(self.directory)
logger.info("ContentRepoResource : Created %s directory.",
self.directory)
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
requester = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
requester.user.to_string()
).replace('=', '')
# use a random string for the main portion
main_part = random_string(24)
# suffix with a file extension if we can make one. This is nice to
# provide a hint to clients on the file information. We will also reuse
# this info to spit back the content type to the client.
suffix = ""
if request.requestHeaders.hasHeader("Content-Type"):
content_type = request.requestHeaders.getRawHeaders(
"Content-Type")[0]
suffix = "." + base64.urlsafe_b64encode(content_type)
if (content_type.split("/")[0].lower() in
["image", "video", "audio"]):
file_ext = content_type.split("/")[-1]
# be a little paranoid and only allow a-z
file_ext = re.sub("[^a-z]", "", file_ext)
suffix += "." + file_ext
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
logger.info("User %s is uploading a file to path %s",
request.user.user_id.to_string(),
file_path)
# keep trying to make a non-clashing file, with a sensible max attempts
attempts = 0
while os.path.exists(file_path):
main_part = random_string(24)
file_name = prefix + main_part + suffix
file_path = os.path.join(self.directory, file_name)
attempts += 1
if attempts > 25: # really? Really?
raise SynapseError(500, "Unable to create file.")
defer.returnValue(file_path)
def render_GET(self, request):
# no auth here on purpose, to allow anyone to view, even across home
@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
return server.NOT_DONE_YET
def render_POST(self, request):
self._async_render(request)
return server.NOT_DONE_YET
def render_OPTIONS(self, request):
respond_with_json_bytes(request, 200, {}, send_cors=True)
return server.NOT_DONE_YET
@defer.inlineCallbacks
def _async_render(self, request):
try:
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(
msg="Request must specify a Content-Length", code=400
)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
)
fname = yield self.map_request_to_name(request)
# TODO I have a suspicious feeling this is just going to block
with open(fname, "wb") as f:
f.write(request.content.read())
# FIXME (erikj): These should use constants.
file_name = os.path.basename(fname)
# FIXME: we can't assume what the repo's public mounted path is
# ...plus self-signed SSL won't work to remote clients anyway
# ...and we can't assume that it's SSL anyway, as we might want to
# serve it via the non-SSL listener...
url = "%s/_matrix/content/%s" % (
self.external_addr, file_name
)
respond_with_json_bytes(request, 200,
json.dumps({"content_token": url}),
send_cors=True)
except CodeMessageException as e:
logger.exception(e)
respond_with_json_bytes(request, e.code,
json.dumps(cs_exception(e)))
except Exception as e:
logger.error("Failed to store file: %s" % e)
respond_with_json_bytes(
request,
500,
json.dumps({"error": "Internal server error"}),
send_cors=True)

View file

@ -65,3 +65,9 @@ class MediaFilePaths(object):
file_id[0:2], file_id[2:4], file_id[4:],
file_name
)
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path, "remote_thumbnail", server_name,
file_id[0:2], file_id[2:4], file_id[4:],
)

View file

@ -26,14 +26,17 @@ from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError
from twisted.internet import defer, threads
from synapse.util.async import ObservableDeferred
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
import errno
import shutil
import cgi
import logging
@ -42,8 +45,11 @@ import urlparse
logger = logging.getLogger(__name__)
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
class MediaRepository(object):
def __init__(self, hs, filepaths):
def __init__(self, hs):
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
@ -51,11 +57,28 @@ class MediaRepository(object):
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
self.downloads = {}
self.filepaths = MediaFilePaths(hs.config.media_store_path)
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer()
self.recently_accessed_remotes = set()
self.clock.looping_call(
self._update_recently_accessed_remotes,
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
)
@defer.inlineCallbacks
def _update_recently_accessed_remotes(self):
media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
yield self.store.update_cached_last_access_time(
media, self.clock.time_msec()
)
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
@ -92,22 +115,12 @@ class MediaRepository(object):
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
def get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
with (yield self.remote_media_linearizer.queue(key)):
media_info = yield self._get_remote_media_impl(server_name, media_id)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@ -118,6 +131,11 @@ class MediaRepository(object):
media_info = yield self._download_remote_file(
server_name, media_id
)
else:
self.recently_accessed_remotes.add((server_name, media_id))
yield self.store.update_cached_last_access_time(
[(server_name, media_id)], self.clock.time_msec()
)
defer.returnValue(media_info)
@defer.inlineCallbacks
@ -134,10 +152,15 @@ class MediaRepository(object):
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
try:
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
except Exception as e:
logger.warn("Failed to fetch remoted media %r", e)
raise SynapseError(502, "Failed to fetch remoted media")
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
@ -410,6 +433,41 @@ class MediaRepository(object):
"height": m_height,
})
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
old_media = yield self.store.get_remote_media_before(before_ts)
deleted = 0
for media in old_media:
origin = media["media_origin"]
media_id = media["media_id"]
file_id = media["filesystem_id"]
key = (origin, media_id)
logger.info("Deleting: %r", key)
with (yield self.remote_media_linearizer.queue(key)):
full_path = self.filepaths.remote_media_filepath(origin, file_id)
try:
os.remove(full_path)
except OSError as e:
logger.warn("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT:
pass
else:
continue
thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
origin, file_id
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
yield self.store.delete_remote_media(origin, media_id)
deleted += 1
defer.returnValue({"deleted": deleted})
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
@ -458,9 +516,8 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
filepaths = MediaFilePaths(hs.config.media_store_path)
media_repo = MediaRepository(hs, filepaths)
media_repo = hs.get_media_repository()
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))

View file

@ -29,6 +29,8 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from copy import deepcopy
import os
import re
import fnmatch
@ -252,7 +254,8 @@ class PreviewUrlResource(Resource):
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
og[tag.attrib['property']] = tag.attrib['content']
if 'content' in tag.attrib:
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
@ -279,7 +282,7 @@ class PreviewUrlResource(Resource):
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * int(i.attrib['width']) * int(i.attrib['height'])
-1 * float(i.attrib['width']) * float(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
@ -287,9 +290,9 @@ class PreviewUrlResource(Resource):
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url request
# itself and benefit from the same caching etc. But for now we just rely on the
# caching on the master request to speed things up.
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
@ -328,20 +331,24 @@ class PreviewUrlResource(Resource):
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
"ancestor::aside | ancestor::footer | "
"ancestor::script | ancestor::style)]" +
"[ancestor::body]")
text = ''
for text_node in text_nodes:
if len(text) < 500:
text += text_node + ' '
else:
break
text = re.sub(r'[\t ]+', ' ', text)
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
text = text.strip()[:500]
og['og:description'] = text if text else None
# We don't just use XPATH here as that is slow on some machines.
# We clone `tree` as we modify it.
cloned_tree = deepcopy(tree.find("body"))
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
for el in cloned_tree.iter(TAGS_TO_REMOVE):
el.getparent().remove(el)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r'\s+', '\n', el.text).strip()
for el in cloned_tree.iter()
if el.text and isinstance(el.tag, basestring) # Removes comments
)
og['og:description'] = summarize_paragraphs(text_nodes)
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
@ -449,3 +456,56 @@ class PreviewUrlResource(Resource):
content_type.startswith("application/xhtml")
):
return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?
description = ''
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
description += text_node + '\n\n'
else:
break
description = description.strip()
description = re.sub(r'[\t ]+', ' ', description)
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
if len(description) > max_size:
new_desc = ""
# This splits the paragraph into words, but keeping the
# (preceeding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer("\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
# MAX_SIZE.
if len(word) + len(new_desc) < max_size:
new_desc += word
else:
# At this point the next word *will* take us over
# MAX_SIZE, but we also want to ensure that its not
# a huge word. If it is add it anyway and we'll
# truncate later.
if len(new_desc) < min_size:
new_desc += word
break
# Double check that we're not over the limit
if len(new_desc) > max_size:
new_desc = new_desc[:max_size]
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
description = new_desc.strip() + ""
return description if description else None