mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-09-20 05:34:37 -04:00
Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api
This commit is contained in:
commit
b4ecf0b886
208 changed files with 9809 additions and 3566 deletions
|
@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import (
|
|||
report_event,
|
||||
openid,
|
||||
notifications,
|
||||
devices,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
|
@ -92,3 +93,4 @@ class ClientRestResource(JsonResource):
|
|||
report_event.register_servlets(hs, client_resource)
|
||||
openid.register_servlets(hs, client_resource)
|
||||
notifications.register_servlets(hs, client_resource)
|
||||
devices.register_servlets(hs, client_resource)
|
||||
|
|
|
@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PurgeMediaCacheRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/purge_media_cache")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.media_repository = hs.get_media_repository()
|
||||
super(PurgeMediaCacheRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
before_ts = request.args.get("before_ts", None)
|
||||
if not before_ts:
|
||||
raise SynapseError(400, "Missing 'before_ts' arg")
|
||||
|
||||
logger.info("before_ts: %r", before_ts[0])
|
||||
|
||||
try:
|
||||
before_ts = int(before_ts[0])
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid 'before_ts' arg")
|
||||
|
||||
ret = yield self.media_repository.delete_old_remote_media(before_ts)
|
||||
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
yield self.handlers.message_handler.purge_history(room_id, event_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, target_user_id):
|
||||
UserID.from_string(target_user_id)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(target_user_id)
|
||||
yield self.store.user_delete_threepids(target_user_id)
|
||||
yield self.store.user_set_password_hash(target_user_id, None)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
WhoisRestServlet(hs).register(http_server)
|
||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
PurgeHistoryRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
|
|||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.hs = hs
|
||||
self.handlers = hs.get_handlers()
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
|
|
|
@ -45,30 +45,27 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
if "room_id" in request.args:
|
||||
room_id = request.args["room_id"][0]
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
as_client_event = "raw" not in request.args
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
if "timeout" in request.args:
|
||||
try:
|
||||
timeout = int(request.args["timeout"][0])
|
||||
except ValueError:
|
||||
raise SynapseError(400, "timeout must be in milliseconds.")
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
raise
|
||||
as_client_event = "raw" not in request.args
|
||||
|
||||
chunk = yield handler.get_stream(
|
||||
requester.user.to_string(),
|
||||
pagin_config,
|
||||
timeout=timeout,
|
||||
as_client_event=as_client_event,
|
||||
affect_presence=(not is_guest),
|
||||
room_id=room_id,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
|
||||
defer.returnValue((200, chunk))
|
||||
|
||||
|
|
|
@ -54,10 +54,8 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
self.jwt_secret = hs.config.jwt_secret
|
||||
self.jwt_algorithm = hs.config.jwt_algorithm
|
||||
self.cas_enabled = hs.config.cas_enabled
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.servername = hs.config.server_name
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth_handler = self.hs.get_auth_handler()
|
||||
self.device_handler = self.hs.get_device_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
|
@ -108,17 +106,6 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
LoginRestServlet.JWT_TYPE):
|
||||
result = yield self.do_jwt_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
elif self.cas_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.CAS_TYPE):
|
||||
uri = "%s/proxyValidate" % (self.cas_server_url,)
|
||||
args = {
|
||||
"ticket": login_submission["ticket"],
|
||||
"service": login_submission["service"]
|
||||
}
|
||||
body = yield self.http_client.get_raw(uri, args)
|
||||
result = yield self.do_cas_login(body)
|
||||
defer.returnValue(result)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
result = yield self.do_token_login(login_submission)
|
||||
defer.returnValue(result)
|
||||
|
@ -143,16 +130,24 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
user_id, self.hs.hostname
|
||||
).to_string()
|
||||
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_id, access_token, refresh_token = yield auth_handler.login_with_password(
|
||||
auth_handler = self.auth_handler
|
||||
user_id = yield auth_handler.validate_password_login(
|
||||
user_id=user_id,
|
||||
password=login_submission["password"])
|
||||
|
||||
password=login_submission["password"],
|
||||
)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -160,65 +155,27 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
auth_handler = self.handlers.auth_handler
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
)
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
@defer.inlineCallbacks
|
||||
def do_cas_login(self, cas_response_body):
|
||||
user, attributes = self.parse_cas_response(cas_response_body)
|
||||
|
||||
for required_attribute, required_value in self.cas_required_attributes.items():
|
||||
# If required attribute was not in CAS Response - Forbidden
|
||||
if required_attribute not in attributes:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
# Also need to check value
|
||||
if required_value is not None:
|
||||
actual_value = attributes[required_attribute]
|
||||
# If required attribute value does not match expected - Forbidden
|
||||
if required_value != actual_value:
|
||||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
else:
|
||||
user_id, access_token = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_jwt_login(self, login_submission):
|
||||
token = login_submission.get("token", None)
|
||||
|
@ -243,19 +200,28 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if user_exists:
|
||||
user_id, access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(user_id)
|
||||
auth_handler = self.auth_handler
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if registered_user_id:
|
||||
device_id = yield self._register_device(
|
||||
registered_user_id, login_submission
|
||||
)
|
||||
access_token, refresh_token = (
|
||||
yield auth_handler.get_login_tuple_for_user_id(
|
||||
registered_user_id, device_id,
|
||||
login_submission.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
result = {
|
||||
"user_id": user_id, # may have changed
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
else:
|
||||
# TODO: we should probably check that the register isn't going
|
||||
# to fonx/change our user_id before registering the device
|
||||
device_id = yield self._register_device(user_id, login_submission)
|
||||
user_id, access_token = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
|
@ -267,32 +233,25 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not root[0].tag.endswith("authenticationSuccess"):
|
||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
attributes = {}
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in attribute tags
|
||||
# to the full URL of the namespace.
|
||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
||||
# We don't care about namespace here and it will always be encased in
|
||||
# curly braces, so we remove them.
|
||||
if "}" in attribute.tag:
|
||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
||||
else:
|
||||
attributes[attribute.tag] = attribute.text
|
||||
if user is None or attributes is None:
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
def _register_device(self, user_id, login_submission):
|
||||
"""Register a device for a user.
|
||||
|
||||
return (user, attributes)
|
||||
This is called after the user's credentials have been validated, but
|
||||
before the access token has been issued.
|
||||
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) login_submission: dictionary supplied to /login call, from
|
||||
which we pull device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (str) device_id
|
||||
"""
|
||||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get(
|
||||
"initial_device_display_name")
|
||||
return self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
|
||||
class SAML2RestServlet(ClientV1RestServlet):
|
||||
|
@ -338,18 +297,6 @@ class SAML2RestServlet(ClientV1RestServlet):
|
|||
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||
|
||||
|
||||
# TODO Delete this after all CAS clients switch to token login instead
|
||||
class CasRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login/cas", releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
super(CasRestServlet, self).__init__(hs)
|
||||
self.cas_server_url = hs.config.cas_server_url
|
||||
|
||||
def on_GET(self, request):
|
||||
return (200, {"serverUrl": self.cas_server_url})
|
||||
|
||||
|
||||
class CasRedirectServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
|
||||
|
||||
|
@ -381,6 +328,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
self.cas_server_url = hs.config.cas_server_url
|
||||
self.cas_service_url = hs.config.cas_service_url
|
||||
self.cas_required_attributes = hs.config.cas_required_attributes
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
@ -412,14 +360,14 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
user_id = UserID.create(user, self.hs.hostname).to_string()
|
||||
auth_handler = self.handlers.auth_handler
|
||||
user_exists = yield auth_handler.does_user_exist(user_id)
|
||||
if not user_exists:
|
||||
user_id, _ = (
|
||||
auth_handler = self.auth_handler
|
||||
registered_user_id = yield auth_handler.check_user_exists(user_id)
|
||||
if not registered_user_id:
|
||||
registered_user_id, _ = (
|
||||
yield self.handlers.registration_handler.register(localpart=user)
|
||||
)
|
||||
|
||||
login_token = auth_handler.generate_short_term_login_token(user_id)
|
||||
login_token = auth_handler.generate_short_term_login_token(registered_user_id)
|
||||
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
||||
login_token)
|
||||
request.redirect(redirect_url)
|
||||
|
@ -433,30 +381,39 @@ class CasTicketServlet(ClientV1RestServlet):
|
|||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not root[0].tag.endswith("authenticationSuccess"):
|
||||
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
attributes = {}
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in attribute tags
|
||||
# to the full URL of the namespace.
|
||||
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
|
||||
# We don't care about namespace here and it will always be encased in
|
||||
# curly braces, so we remove them.
|
||||
if "}" in attribute.tag:
|
||||
attributes[attribute.tag.split("}")[1]] = attribute.text
|
||||
else:
|
||||
attributes[attribute.tag] = attribute.text
|
||||
if user is None or attributes is None:
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return (user, attributes)
|
||||
user = None
|
||||
attributes = None
|
||||
try:
|
||||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
if child.tag.endswith("attributes"):
|
||||
attributes = {}
|
||||
for attribute in child:
|
||||
# ElementTree library expands the namespace in
|
||||
# attribute tags to the full URL of the namespace.
|
||||
# We don't care about namespace here and it will always
|
||||
# be encased in curly braces, so we remove them.
|
||||
tag = attribute.tag
|
||||
if "}" in tag:
|
||||
tag = tag.split("}")[1]
|
||||
attributes[tag] = attribute.text
|
||||
if user is None:
|
||||
raise Exception("CAS response does not contain user")
|
||||
if attributes is None:
|
||||
raise Exception("CAS response does not contain attributes")
|
||||
except Exception:
|
||||
logger.error("Error parsing CAS response", exc_info=1)
|
||||
raise LoginError(401, "Invalid CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(401, "Unsuccessful CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
return user, attributes
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
@ -466,5 +423,3 @@ def register_servlets(hs, http_server):
|
|||
if hs.config.cas_enabled:
|
||||
CasRedirectServlet(hs).register(http_server)
|
||||
CasTicketServlet(hs).register(http_server)
|
||||
CasRestServlet(hs).register(http_server)
|
||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
# we build up the full structure and then decide which bits of it
|
||||
# to send which means doing unnecessary work sometimes but is
|
||||
# is probably not going to make a whole lot of difference
|
||||
rawrules = yield self.store.get_push_rules_for_user(user_id)
|
||||
rules = yield self.store.get_push_rules_for_user(user_id)
|
||||
|
||||
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
|
||||
|
||||
rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
|
||||
rules = format_push_rules_for_user(requester.user, rules)
|
||||
|
||||
path = request.postpath[1:]
|
||||
|
||||
|
|
|
@ -17,7 +17,11 @@ from twisted.internet import defer
|
|||
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_string, RestServlet
|
||||
)
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from .base import ClientV1RestServlet, client_path_patterns
|
||||
|
||||
|
@ -136,6 +140,57 @@ class PushersSetRestServlet(ClientV1RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
class PushersRemoveRestServlet(RestServlet):
|
||||
"""
|
||||
To allow pusher to be delete by clicking a link (ie. GET request)
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/pushers/remove$")
|
||||
SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.notifier = hs.get_notifier()
|
||||
self.auth = hs.get_v1auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
|
||||
user = requester.user
|
||||
|
||||
app_id = parse_string(request, "app_id", required=True)
|
||||
pushkey = parse_string(request, "pushkey", required=True)
|
||||
|
||||
pusher_pool = self.hs.get_pusherpool()
|
||||
|
||||
try:
|
||||
yield pusher_pool.remove_pusher(
|
||||
app_id=app_id,
|
||||
pushkey=pushkey,
|
||||
user_id=user.to_string(),
|
||||
)
|
||||
except StoreError as se:
|
||||
if se.code != 404:
|
||||
# This is fine: they're already unsubscribed
|
||||
raise
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Server", self.hs.version_string)
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(PushersRemoveRestServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PushersRestServlet(hs).register(http_server)
|
||||
PushersSetRestServlet(hs).register(http_server)
|
||||
PushersRemoveRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRestServlet, self).__init__(hs)
|
||||
# sessions are stored as:
|
||||
# self.sessions = {
|
||||
|
@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
# TODO: persistent storage
|
||||
self.sessions = {}
|
||||
self.enable_registration = hs.config.enable_registration
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
def on_GET(self, request):
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
|
@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
user_localpart = register_json["user"].encode("utf-8")
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.appservice_register(
|
||||
user_id = yield handler.appservice_register(
|
||||
user_localpart, as_token
|
||||
)
|
||||
token = yield self.auth_handler.issue_access_token(user_id)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
|
@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
user = register_json["user"].encode("utf-8")
|
||||
password = register_json["password"].encode("utf-8")
|
||||
admin = register_json.get("admin", None)
|
||||
|
||||
# Its important to check as we use null bytes as HMAC field separators
|
||||
if "\x00" in user:
|
||||
raise SynapseError(400, "Invalid user")
|
||||
if "\x00" in password:
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
|
@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet):
|
|||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
password = register_json["password"].encode("utf-8")
|
||||
)
|
||||
want_mac.update(user)
|
||||
want_mac.update("\x00")
|
||||
want_mac.update(password)
|
||||
want_mac.update("\x00")
|
||||
want_mac.update("admin" if admin else "notadmin")
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.register(
|
||||
localpart=user,
|
||||
password=password,
|
||||
admin=bool(admin),
|
||||
)
|
||||
self._remove_session(session)
|
||||
defer.returnValue({
|
||||
|
@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet):
|
|||
raise SynapseError(400, "Failed to parse 'duration_seconds'")
|
||||
if duration_seconds > self.direct_user_creation_max_duration:
|
||||
duration_seconds = self.direct_user_creation_max_duration
|
||||
password_hash = user_json["password_hash"].encode("utf-8") \
|
||||
if user_json.get("password_hash") else None
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
user_id, token = yield handler.get_or_create_user(
|
||||
localpart=localpart,
|
||||
displayname=displayname,
|
||||
duration_seconds=duration_seconds
|
||||
duration_in_ms=(duration_seconds * 1000),
|
||||
password_hash=password_hash
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
|
|
|
@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns
|
|||
from synapse.api.errors import SynapseError, Codes, AuthError
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.types import UserID, RoomID, RoomAlias
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
import ujson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -72,8 +74,6 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||
|
||||
def get_room_config(self, request):
|
||||
user_supplied_config = parse_json_object_from_request(request)
|
||||
# default visibility
|
||||
user_supplied_config.setdefault("visibility", "public")
|
||||
return user_supplied_config
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
|
@ -279,8 +279,16 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
handler = self.handlers.room_list_handler
|
||||
data = yield handler.get_public_room_list()
|
||||
try:
|
||||
yield self.auth.get_user_by_req(request)
|
||||
except AuthError:
|
||||
# This endpoint isn't authed, but its useful to know who's hitting
|
||||
# it if they *do* supply an access token
|
||||
pass
|
||||
|
||||
handler = self.hs.get_room_list_handler()
|
||||
data = yield handler.get_aggregated_public_room_list()
|
||||
|
||||
defer.returnValue((200, data))
|
||||
|
||||
|
||||
|
@ -321,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
request, default_limit=10,
|
||||
)
|
||||
as_client_event = "raw" not in request.args
|
||||
filter_bytes = request.args.get("filter", None)
|
||||
if filter_bytes:
|
||||
filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
|
||||
event_filter = Filter(json.loads(filter_json))
|
||||
else:
|
||||
event_filter = None
|
||||
handler = self.handlers.message_handler
|
||||
msgs = yield handler.get_messages(
|
||||
room_id=room_id,
|
||||
requester=requester,
|
||||
pagin_config=pagination_config,
|
||||
as_client_event=as_client_event
|
||||
as_client_event=as_client_event,
|
||||
event_filter=event_filter,
|
||||
)
|
||||
|
||||
defer.returnValue((200, msgs))
|
||||
|
|
|
@ -25,7 +25,9 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_v2_patterns(path_regex, releases=(0,)):
|
||||
def client_v2_patterns(path_regex, releases=(0,),
|
||||
v2_alpha=True,
|
||||
unstable=True):
|
||||
"""Creates a regex compiled client path with the correct client path
|
||||
prefix.
|
||||
|
||||
|
@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)):
|
|||
Returns:
|
||||
SRE_Pattern
|
||||
"""
|
||||
patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
|
||||
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
|
||||
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
||||
patterns = []
|
||||
if v2_alpha:
|
||||
patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
|
||||
if unstable:
|
||||
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
|
||||
patterns.append(re.compile("^" + unstable_prefix + path_regex))
|
||||
for release in releases:
|
||||
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
|
||||
patterns.append(re.compile("^" + new_prefix + path_regex))
|
||||
|
|
|
@ -28,14 +28,46 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PasswordRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
if existingUid is None:
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class PasswordRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/password")
|
||||
PATTERNS = client_v2_patterns("/account/password$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PasswordRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -89,15 +121,90 @@ class PasswordRestServlet(RestServlet):
|
|||
return 200, {}
|
||||
|
||||
|
||||
class DeactivateAccountRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/deactivate$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
super(DeactivateAccountRestServlet, self).__init__()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
], body, self.hs.get_ip_from_request(request))
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
user_id = None
|
||||
requester = None
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
if user_id != result[LoginType.PASSWORD]:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
# FIXME: Theoretically there is a race here wherein user resets password
|
||||
# using threepid.
|
||||
yield self.store.user_delete_access_tokens(user_id)
|
||||
yield self.store.user_delete_threepids(user_id)
|
||||
yield self.store.user_set_password_hash(user_id, None)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class ThreepidRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
super(ThreepidRequestTokenRestServlet, self).__init__()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class ThreepidRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid")
|
||||
PATTERNS = client_v2_patterns("/account/3pid$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
|
@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet):
|
|||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
ThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
ThreepidRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -104,7 +104,7 @@ class AuthRestServlet(RestServlet):
|
|||
super(AuthRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
|
100
synapse/rest/client/v2_alpha/devices.py
Normal file
100
synapse/rest/client/v2_alpha/devices.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http import servlet
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DevicesRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(DevicesRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
devices = yield self.device_handler.get_devices_by_user(
|
||||
requester.user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {"devices": devices}))
|
||||
|
||||
|
||||
class DeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$",
|
||||
releases=[], v2_alpha=False)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(DeviceRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
device = yield self.device_handler.get_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
defer.returnValue((200, device))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, device_id):
|
||||
# XXX: it's not completely obvious we want to expose this endpoint.
|
||||
# It allows the client to delete access tokens, which feels like a
|
||||
# thing which merits extra auth. But if we want to do the interactive-
|
||||
# auth dance, we should really make it possible to delete more than one
|
||||
# device at a time.
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
body = servlet.parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
DevicesRestServlet(hs).register(http_server)
|
||||
DeviceRestServlet(hs).register(http_server)
|
|
@ -13,24 +13,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import simplejson as json
|
||||
from canonicaljson import encode_canonical_json
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.server
|
||||
import synapse.types
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import UserID
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class KeyUploadServlet(RestServlet):
|
||||
"""
|
||||
POST /keys/upload/<device_id> HTTP/1.1
|
||||
POST /keys/upload HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
|
@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet):
|
|||
},
|
||||
}
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$",
|
||||
releases=())
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(KeyUploadServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
# TODO: Check that the device_id matches that in the authentication
|
||||
# or derive the device_id from the authentication instead.
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if device_id is not None:
|
||||
# passing the device_id here is deprecated; however, we allow it
|
||||
# for now for compatibility with older clients.
|
||||
if (requester.device_id is not None and
|
||||
device_id != requester.device_id):
|
||||
logger.warning("Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id, device_id)
|
||||
else:
|
||||
device_id = requester.device_id
|
||||
|
||||
if device_id is None:
|
||||
raise synapse.api.errors.SynapseError(
|
||||
400,
|
||||
"To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet):
|
|||
user_id, device_id, time_now, key_list
|
||||
)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
# the device should have been registered already, but it may have been
|
||||
# deleted due to a race with a DELETE request. Or we may be using an
|
||||
# old access_token without an associated device_id. Either way, we
|
||||
# need to double-check the device is registered to avoid ending up with
|
||||
# keys without a corresponding device.
|
||||
self.device_handler.check_device_registered(user_id, device_id)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||
|
@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet):
|
|||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
super(KeyQueryServlet, self).__init__()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.is_mine = hs.is_mine
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id, device_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.handle_request(body)
|
||||
result = yield self.e2e_keys_handler.query_devices(body)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet):
|
|||
auth_user_id = requester.user.to_string()
|
||||
user_id = user_id if user_id else auth_user_id
|
||||
device_ids = [device_id] if device_id else []
|
||||
result = yield self.handle_request(
|
||||
result = yield self.e2e_keys_handler.query_devices(
|
||||
{"device_keys": {user_id: device_ids}}
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_request(self, body):
|
||||
local_query = []
|
||||
remote_queries = {}
|
||||
for user_id, device_ids in body.get("device_keys", {}).items():
|
||||
user = UserID.from_string(user_id)
|
||||
if self.is_mine(user):
|
||||
if not device_ids:
|
||||
local_query.append((user_id, None))
|
||||
else:
|
||||
for device_id in device_ids:
|
||||
local_query.append((user_id, device_id))
|
||||
else:
|
||||
remote_queries.setdefault(user.domain, {})[user_id] = list(
|
||||
device_ids
|
||||
)
|
||||
results = yield self.store.get_e2e_device_keys(local_query)
|
||||
|
||||
json_result = {}
|
||||
for user_id, device_keys in results.items():
|
||||
for device_id, json_bytes in device_keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||
json_bytes
|
||||
)
|
||||
|
||||
for destination, device_keys in remote_queries.items():
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
destination, {"device_keys": device_keys}
|
||||
)
|
||||
for user_id, keys in remote_result["device_keys"].items():
|
||||
if user_id in device_keys:
|
||||
json_result[user_id] = keys
|
||||
defer.returnValue((200, {"device_keys": json_result}))
|
||||
|
||||
|
||||
class OneTimeKeyServlet(RestServlet):
|
||||
"""
|
||||
|
|
|
@ -41,17 +41,59 @@ else:
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register")
|
||||
class RegisterRequestTokenRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register/email/requestToken$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRequestTokenRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/register$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RegisterRestServlet, self).__init__()
|
||||
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet):
|
|||
"Do not understand membership kind: %s" % (kind,)
|
||||
)
|
||||
|
||||
if '/register/email/requestToken' in request.path:
|
||||
ret = yield self.onEmailTokenRequest(request)
|
||||
defer.returnValue(ret)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
# we do basic sanity checks here because the auth layer will store these
|
||||
|
@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet):
|
|||
# Set the desired user according to the AS API (which uses the
|
||||
# 'user' key not 'username'). Since this is a new addition, we'll
|
||||
# fallback to 'username' if they gave one.
|
||||
if isinstance(body.get("user"), basestring):
|
||||
desired_username = body["user"]
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0]
|
||||
)
|
||||
desired_username = body.get("user", desired_username)
|
||||
|
||||
if isinstance(desired_username, basestring):
|
||||
result = yield self._do_appservice_registration(
|
||||
desired_username, request.args["access_token"][0], body
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
||||
|
@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet):
|
|||
# FIXME: Should we really be determining if this is shared secret
|
||||
# auth based purely on the 'mac' key?
|
||||
result = yield self._do_shared_secret_registration(
|
||||
desired_username, desired_password, body["mac"]
|
||||
desired_username, desired_password, body
|
||||
)
|
||||
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||
return
|
||||
|
@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet):
|
|||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
authed, result, params, session_id = yield self.auth_handler.check_auth(
|
||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
defer.returnValue((401, auth_result))
|
||||
return
|
||||
|
||||
if registered_user_id is not None:
|
||||
|
@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet):
|
|||
"Already registered user ID %r for this session",
|
||||
registered_user_id
|
||||
)
|
||||
access_token = yield self.auth_handler.issue_access_token(registered_user_id)
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(
|
||||
registered_user_id
|
||||
# don't re-register the email address
|
||||
add_email = False
|
||||
else:
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "Missing password.",
|
||||
Codes.MISSING_PARAM)
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
|
||||
(registered_user_id, _) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password,
|
||||
guest_access_token=guest_access_token,
|
||||
generate_token=False,
|
||||
)
|
||||
defer.returnValue((200, {
|
||||
"user_id": registered_user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"refresh_token": refresh_token,
|
||||
}))
|
||||
|
||||
# NB: This may be from the auth handler and NOT from the POST
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", registered_user_id
|
||||
)
|
||||
|
||||
desired_username = params.get("username", None)
|
||||
new_password = params.get("password", None)
|
||||
guest_access_token = params.get("guest_access_token", None)
|
||||
add_email = True
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password,
|
||||
guest_access_token=guest_access_token,
|
||||
return_dict = yield self._create_registration_details(
|
||||
registered_user_id, params
|
||||
)
|
||||
|
||||
# remember that we've now registered that user account, and with what
|
||||
# user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
session_id, "registered_user_id", user_id
|
||||
)
|
||||
if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result:
|
||||
threepid = auth_result[LoginType.EMAIL_IDENTITY]
|
||||
yield self._register_email_threepid(
|
||||
registered_user_id, threepid, return_dict["access_token"],
|
||||
params.get("bind_email")
|
||||
)
|
||||
|
||||
if result and LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
|
||||
for reqd in ['medium', 'address', 'validated_at']:
|
||||
if reqd not in threepid:
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
else:
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
# And we add an email pusher for them by default, but only
|
||||
# if email notifications are enabled (so people don't start
|
||||
# getting mail spam where they weren't before if email
|
||||
# notifs are set up on a home server)
|
||||
if (
|
||||
self.hs.config.email_enable_notifs and
|
||||
self.hs.config.email_notif_for_new_users
|
||||
):
|
||||
# Pull the ID of the access token back out of the db
|
||||
# It would really make more sense for this to be passed
|
||||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = yield self.store.get_user_by_access_token(
|
||||
token
|
||||
)
|
||||
|
||||
yield self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=user_tuple["token_id"],
|
||||
kind="email",
|
||||
app_id="m.email",
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
data={},
|
||||
)
|
||||
|
||||
if 'bind_email' in params and params['bind_email']:
|
||||
logger.info("bind_email specified: binding")
|
||||
|
||||
emailThreepid = result[LoginType.EMAIL_IDENTITY]
|
||||
threepid_creds = emailThreepid['threepid_creds']
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
emailThreepid, user_id
|
||||
))
|
||||
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
|
||||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
result = yield self._create_registration_details(user_id, token)
|
||||
defer.returnValue((200, result))
|
||||
defer.returnValue((200, return_dict))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_appservice_registration(self, username, as_token):
|
||||
(user_id, token) = yield self.registration_handler.appservice_register(
|
||||
def _do_appservice_registration(self, username, as_token, body):
|
||||
user_id = yield self.registration_handler.appservice_register(
|
||||
username, as_token
|
||||
)
|
||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
||||
defer.returnValue((yield self._create_registration_details(user_id, body)))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_shared_secret_registration(self, username, password, mac):
|
||||
def _do_shared_secret_registration(self, username, password, body):
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
|
@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
got_mac = str(mac)
|
||||
got_mac = str(body["mac"])
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
|
@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet):
|
|||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=username, password=password
|
||||
(user_id, _) = yield self.registration_handler.register(
|
||||
localpart=username, password=password, generate_token=False,
|
||||
)
|
||||
defer.returnValue((yield self._create_registration_details(user_id, token)))
|
||||
|
||||
result = yield self._create_registration_details(user_id, body)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_registration_details(self, user_id, token):
|
||||
refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
|
||||
def _register_email_threepid(self, user_id, threepid, token, bind_email):
|
||||
"""Add an email address as a 3pid identifier
|
||||
|
||||
Also adds an email pusher for the email address, if configured in the
|
||||
HS config
|
||||
|
||||
Also optionally binds emails to the given user_id on the identity server
|
||||
|
||||
Args:
|
||||
user_id (str): id of user
|
||||
threepid (object): m.login.email.identity auth response
|
||||
token (str): access_token for the user
|
||||
bind_email (bool): true if the client requested the email to be
|
||||
bound at the identity server
|
||||
Returns:
|
||||
defer.Deferred:
|
||||
"""
|
||||
reqd = ('medium', 'address', 'validated_at')
|
||||
if any(x not in threepid for x in reqd):
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
defer.returnValue()
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
# And we add an email pusher for them by default, but only
|
||||
# if email notifications are enabled (so people don't start
|
||||
# getting mail spam where they weren't before if email
|
||||
# notifs are set up on a home server)
|
||||
if (self.hs.config.email_enable_notifs and
|
||||
self.hs.config.email_notif_for_new_users):
|
||||
# Pull the ID of the access token back out of the db
|
||||
# It would really make more sense for this to be passed
|
||||
# up when the access token is saved, but that's quite an
|
||||
# invasive change I'd rather do separately.
|
||||
user_tuple = yield self.store.get_user_by_access_token(
|
||||
token
|
||||
)
|
||||
token_id = user_tuple["token_id"]
|
||||
|
||||
yield self.hs.get_pusherpool().add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="email",
|
||||
app_id="m.email",
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
data={},
|
||||
)
|
||||
|
||||
if bind_email:
|
||||
logger.info("bind_email specified: binding")
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
threepid, user_id
|
||||
))
|
||||
yield self.identity_handler.bind_threepid(
|
||||
threepid['threepid_creds'], user_id
|
||||
)
|
||||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_registration_details(self, user_id, params):
|
||||
"""Complete registration of newly-registered user
|
||||
|
||||
Allocates device_id if one was not given; also creates access_token
|
||||
and refresh_token.
|
||||
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) params: registration parameters, from which we pull
|
||||
device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (object) dictionary for response from /register
|
||||
"""
|
||||
device_id = yield self._register_device(user_id, params)
|
||||
|
||||
access_token, refresh_token = (
|
||||
yield self.auth_handler.get_login_tuple_for_user_id(
|
||||
user_id, device_id=device_id,
|
||||
initial_display_name=params.get("initial_device_display_name")
|
||||
)
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
"refresh_token": refresh_token,
|
||||
"device_id": device_id,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def onEmailTokenRequest(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
def _register_device(self, user_id, params):
|
||||
"""Register a device for a user.
|
||||
|
||||
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||
absent = []
|
||||
for k in required:
|
||||
if k not in body:
|
||||
absent.append(k)
|
||||
This is called after the user's credentials have been validated, but
|
||||
before the access token has been issued.
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
Args:
|
||||
(str) user_id: full canonical @user:id
|
||||
(object) params: registration parameters, from which we pull
|
||||
device_id and initial_device_name
|
||||
Returns:
|
||||
defer.Deferred: (str) device_id
|
||||
"""
|
||||
# register the user's device
|
||||
device_id = params.get("device_id")
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id = self.device_handler.check_device_registered(
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
return device_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self):
|
||||
|
@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet):
|
|||
generate_token=False,
|
||||
make_guest=True
|
||||
)
|
||||
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
||||
access_token = self.auth_handler.generate_access_token(
|
||||
user_id, ["guest = true"]
|
||||
)
|
||||
# XXX the "guest" caveat is not copied by /tokenrefresh. That's ok
|
||||
# so long as we don't return a refresh_token here.
|
||||
defer.returnValue((200, {
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
|
@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRequestTokenRestServlet(hs).register(http_server)
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -38,10 +38,14 @@ class TokenRefreshRestServlet(RestServlet):
|
|||
body = parse_json_object_from_request(request)
|
||||
try:
|
||||
old_refresh_token = body["refresh_token"]
|
||||
auth_handler = self.hs.get_handlers().auth_handler
|
||||
(user_id, new_refresh_token) = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token)
|
||||
new_access_token = yield auth_handler.issue_access_token(user_id)
|
||||
auth_handler = self.hs.get_auth_handler()
|
||||
refresh_result = yield self.store.exchange_refresh_token(
|
||||
old_refresh_token, auth_handler.generate_refresh_token
|
||||
)
|
||||
(user_id, new_refresh_token, device_id) = refresh_result
|
||||
new_access_token = yield auth_handler.issue_access_token(
|
||||
user_id, device_id
|
||||
)
|
||||
defer.returnValue((200, {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
|
|
|
@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet):
|
|||
|
||||
def on_GET(self, request):
|
||||
return (200, {
|
||||
"versions": ["r0.0.1"]
|
||||
"versions": [
|
||||
"r0.0.1",
|
||||
"r0.1.0",
|
||||
"r0.2.0",
|
||||
]
|
||||
})
|
||||
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
from synapse.http.server import request_handler, respond_with_json_bytes
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.crypto.keyring import KeyLookupError
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
@ -210,9 +211,10 @@ class RemoteKey(Resource):
|
|||
yield self.keyring.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except KeyLookupError as e:
|
||||
logger.info("Failed to fetch key: %s", e)
|
||||
except:
|
||||
logger.exception("Failed to get key for %r", server_name)
|
||||
pass
|
||||
yield self.query_keys(
|
||||
request, query, query_remote_on_cache_miss=False
|
||||
)
|
||||
|
|
|
@ -15,14 +15,12 @@
|
|||
|
||||
from synapse.http.server import respond_with_json_bytes, finish_request
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException, Codes, cs_error
|
||||
Codes, cs_error
|
||||
)
|
||||
|
||||
from twisted.protocols.basic import FileSender
|
||||
from twisted.web import server, resource
|
||||
from twisted.internet import defer
|
||||
|
||||
import base64
|
||||
import simplejson as json
|
||||
|
@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource):
|
|||
"""
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, directory, auth, external_addr):
|
||||
def __init__(self, hs, directory):
|
||||
resource.Resource.__init__(self)
|
||||
self.hs = hs
|
||||
self.directory = directory
|
||||
self.auth = auth
|
||||
self.external_addr = external_addr.rstrip('/')
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
|
||||
if not os.path.isdir(self.directory):
|
||||
os.mkdir(self.directory)
|
||||
logger.info("ContentRepoResource : Created %s directory.",
|
||||
self.directory)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def map_request_to_name(self, request):
|
||||
# auth the user
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# namespace all file uploads on the user
|
||||
prefix = base64.urlsafe_b64encode(
|
||||
requester.user.to_string()
|
||||
).replace('=', '')
|
||||
|
||||
# use a random string for the main portion
|
||||
main_part = random_string(24)
|
||||
|
||||
# suffix with a file extension if we can make one. This is nice to
|
||||
# provide a hint to clients on the file information. We will also reuse
|
||||
# this info to spit back the content type to the client.
|
||||
suffix = ""
|
||||
if request.requestHeaders.hasHeader("Content-Type"):
|
||||
content_type = request.requestHeaders.getRawHeaders(
|
||||
"Content-Type")[0]
|
||||
suffix = "." + base64.urlsafe_b64encode(content_type)
|
||||
if (content_type.split("/")[0].lower() in
|
||||
["image", "video", "audio"]):
|
||||
file_ext = content_type.split("/")[-1]
|
||||
# be a little paranoid and only allow a-z
|
||||
file_ext = re.sub("[^a-z]", "", file_ext)
|
||||
suffix += "." + file_ext
|
||||
|
||||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
logger.info("User %s is uploading a file to path %s",
|
||||
request.user.user_id.to_string(),
|
||||
file_path)
|
||||
|
||||
# keep trying to make a non-clashing file, with a sensible max attempts
|
||||
attempts = 0
|
||||
while os.path.exists(file_path):
|
||||
main_part = random_string(24)
|
||||
file_name = prefix + main_part + suffix
|
||||
file_path = os.path.join(self.directory, file_name)
|
||||
attempts += 1
|
||||
if attempts > 25: # really? Really?
|
||||
raise SynapseError(500, "Unable to create file.")
|
||||
|
||||
defer.returnValue(file_path)
|
||||
|
||||
def render_GET(self, request):
|
||||
# no auth here on purpose, to allow anyone to view, even across home
|
||||
|
@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource):
|
|||
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def render_POST(self, request):
|
||||
self._async_render(request)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
respond_with_json_bytes(request, 200, {}, send_cors=True)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
try:
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
|
||||
fname = yield self.map_request_to_name(request)
|
||||
|
||||
# TODO I have a suspicious feeling this is just going to block
|
||||
with open(fname, "wb") as f:
|
||||
f.write(request.content.read())
|
||||
|
||||
# FIXME (erikj): These should use constants.
|
||||
file_name = os.path.basename(fname)
|
||||
# FIXME: we can't assume what the repo's public mounted path is
|
||||
# ...plus self-signed SSL won't work to remote clients anyway
|
||||
# ...and we can't assume that it's SSL anyway, as we might want to
|
||||
# serve it via the non-SSL listener...
|
||||
url = "%s/_matrix/content/%s" % (
|
||||
self.external_addr, file_name
|
||||
)
|
||||
|
||||
respond_with_json_bytes(request, 200,
|
||||
json.dumps({"content_token": url}),
|
||||
send_cors=True)
|
||||
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json_bytes(request, e.code,
|
||||
json.dumps(cs_exception(e)))
|
||||
except Exception as e:
|
||||
logger.error("Failed to store file: %s" % e)
|
||||
respond_with_json_bytes(
|
||||
request,
|
||||
500,
|
||||
json.dumps({"error": "Internal server error"}),
|
||||
send_cors=True)
|
||||
|
|
|
@ -65,3 +65,9 @@ class MediaFilePaths(object):
|
|||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
)
|
||||
|
|
|
@ -26,14 +26,17 @@ from .thumbnailer import Thumbnailer
|
|||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
|
||||
import os
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import cgi
|
||||
import logging
|
||||
|
@ -42,8 +45,11 @@ import urlparse
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
||||
|
||||
|
||||
class MediaRepository(object):
|
||||
def __init__(self, hs, filepaths):
|
||||
def __init__(self, hs):
|
||||
self.auth = hs.get_auth()
|
||||
self.client = MatrixFederationHttpClient(hs)
|
||||
self.clock = hs.get_clock()
|
||||
|
@ -51,11 +57,28 @@ class MediaRepository(object):
|
|||
self.store = hs.get_datastore()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = filepaths
|
||||
self.downloads = {}
|
||||
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer()
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
self.clock.looping_call(
|
||||
self._update_recently_accessed_remotes,
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_recently_accessed_remotes(self):
|
||||
media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
yield self.store.update_cached_last_access_time(
|
||||
media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _makedirs(filepath):
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
@ -92,22 +115,12 @@ class MediaRepository(object):
|
|||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, server_name, media_id):
|
||||
key = (server_name, media_id)
|
||||
download = self.downloads.get(key)
|
||||
if download is None:
|
||||
download = self._get_remote_media_impl(server_name, media_id)
|
||||
download = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
self.downloads[key] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
return media_info
|
||||
return download.observe()
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
|
@ -118,6 +131,11 @@ class MediaRepository(object):
|
|||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id
|
||||
)
|
||||
else:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
yield self.store.update_cached_last_access_time(
|
||||
[(server_name, media_id)], self.clock.time_msec()
|
||||
)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -134,10 +152,15 @@ class MediaRepository(object):
|
|||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn("Failed to fetch remoted media %r", e)
|
||||
raise SynapseError(502, "Failed to fetch remoted media")
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
|
@ -410,6 +433,41 @@ class MediaRepository(object):
|
|||
"height": m_height,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_old_remote_media(self, before_ts):
|
||||
old_media = yield self.store.get_remote_media_before(before_ts)
|
||||
|
||||
deleted = 0
|
||||
|
||||
for media in old_media:
|
||||
origin = media["media_origin"]
|
||||
media_id = media["media_id"]
|
||||
file_id = media["filesystem_id"]
|
||||
key = (origin, media_id)
|
||||
|
||||
logger.info("Deleting: %r", key)
|
||||
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||
try:
|
||||
os.remove(full_path)
|
||||
except OSError as e:
|
||||
logger.warn("Failed to remove file: %r", full_path)
|
||||
if e.errno == errno.ENOENT:
|
||||
pass
|
||||
else:
|
||||
continue
|
||||
|
||||
thumbnail_dir = self.filepaths.remote_media_thumbnail_dir(
|
||||
origin, file_id
|
||||
)
|
||||
shutil.rmtree(thumbnail_dir, ignore_errors=True)
|
||||
|
||||
yield self.store.delete_remote_media(origin, media_id)
|
||||
deleted += 1
|
||||
|
||||
defer.returnValue({"deleted": deleted})
|
||||
|
||||
|
||||
class MediaRepositoryResource(Resource):
|
||||
"""File uploading and downloading.
|
||||
|
@ -458,9 +516,8 @@ class MediaRepositoryResource(Resource):
|
|||
|
||||
def __init__(self, hs):
|
||||
Resource.__init__(self)
|
||||
filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
|
||||
media_repo = MediaRepository(hs, filepaths)
|
||||
media_repo = hs.get_media_repository()
|
||||
|
||||
self.putChild("upload", UploadResource(hs, media_repo))
|
||||
self.putChild("download", DownloadResource(hs, media_repo))
|
||||
|
|
|
@ -29,6 +29,8 @@ from synapse.http.server import (
|
|||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
|
@ -252,7 +254,8 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
|
@ -279,7 +282,7 @@ class PreviewUrlResource(Resource):
|
|||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * int(i.attrib['width']) * int(i.attrib['height'])
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
|
@ -287,9 +290,9 @@ class PreviewUrlResource(Resource):
|
|||
og['og:image'] = images[0].attrib['src']
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url request
|
||||
# itself and benefit from the same caching etc. But for now we just rely on the
|
||||
# caching on the master request to speed things up.
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
image_info = yield self._download_url(
|
||||
self._rebase_url(og['og:image'], media_info['uri']), requester.user
|
||||
|
@ -328,20 +331,24 @@ class PreviewUrlResource(Resource):
|
|||
# ...or if they are within a <script/> or <style/> tag.
|
||||
# This is a very very very coarse approximation to a plain text
|
||||
# render of the page.
|
||||
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
|
||||
"ancestor::aside | ancestor::footer | "
|
||||
"ancestor::script | ancestor::style)]" +
|
||||
"[ancestor::body]")
|
||||
text = ''
|
||||
for text_node in text_nodes:
|
||||
if len(text) < 500:
|
||||
text += text_node + ' '
|
||||
else:
|
||||
break
|
||||
text = re.sub(r'[\t ]+', ' ', text)
|
||||
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
|
||||
text = text.strip()[:500]
|
||||
og['og:description'] = text if text else None
|
||||
|
||||
# We don't just use XPATH here as that is slow on some machines.
|
||||
|
||||
# We clone `tree` as we modify it.
|
||||
cloned_tree = deepcopy(tree.find("body"))
|
||||
|
||||
TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",)
|
||||
for el in cloned_tree.iter(TAGS_TO_REMOVE):
|
||||
el.getparent().remove(el)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el.text).strip()
|
||||
for el in cloned_tree.iter()
|
||||
if el.text and isinstance(el.tag, basestring) # Removes comments
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
|
@ -449,3 +456,56 @@ class PreviewUrlResource(Resource):
|
|||
content_type.startswith("application/xhtml")
|
||||
):
|
||||
return True
|
||||
|
||||
|
||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||
# Try to get a summary of between 200 and 500 words, respecting
|
||||
# first paragraph and then word boundaries.
|
||||
# TODO: Respect sentences?
|
||||
|
||||
description = ''
|
||||
|
||||
# Keep adding paragraphs until we get to the MIN_SIZE.
|
||||
for text_node in text_nodes:
|
||||
if len(description) < min_size:
|
||||
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
|
||||
description += text_node + '\n\n'
|
||||
else:
|
||||
break
|
||||
|
||||
description = description.strip()
|
||||
description = re.sub(r'[\t ]+', ' ', description)
|
||||
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
|
||||
|
||||
# If the concatenation of paragraphs to get above MIN_SIZE
|
||||
# took us over MAX_SIZE, then we need to truncate mid paragraph
|
||||
if len(description) > max_size:
|
||||
new_desc = ""
|
||||
|
||||
# This splits the paragraph into words, but keeping the
|
||||
# (preceeding) whitespace intact so we can easily concat
|
||||
# words back together.
|
||||
for match in re.finditer("\s*\S+", description):
|
||||
word = match.group()
|
||||
|
||||
# Keep adding words while the total length is less than
|
||||
# MAX_SIZE.
|
||||
if len(word) + len(new_desc) < max_size:
|
||||
new_desc += word
|
||||
else:
|
||||
# At this point the next word *will* take us over
|
||||
# MAX_SIZE, but we also want to ensure that its not
|
||||
# a huge word. If it is add it anyway and we'll
|
||||
# truncate later.
|
||||
if len(new_desc) < min_size:
|
||||
new_desc += word
|
||||
break
|
||||
|
||||
# Double check that we're not over the limit
|
||||
if len(new_desc) > max_size:
|
||||
new_desc = new_desc[:max_size]
|
||||
|
||||
# We always add an ellipsis because at the very least
|
||||
# we chopped mid paragraph.
|
||||
description = new_desc.strip() + "…"
|
||||
return description if description else None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue