mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-02 12:16:09 -04:00
Merge branch 'develop' into rav/saml2_client
This commit is contained in:
commit
a4daa899ec
478 changed files with 18927 additions and 11500 deletions
|
@ -66,6 +66,7 @@ class ClientRestResource(JsonResource):
|
|||
* /_matrix/client/unstable
|
||||
* etc
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs, canonical_json=False)
|
||||
self.register_servlets(self, hs)
|
||||
|
|
|
@ -61,7 +61,7 @@ def historical_admin_path_patterns(path_regex):
|
|||
"^/_synapse/admin/v1",
|
||||
"^/_matrix/client/api/v1/admin",
|
||||
"^/_matrix/client/unstable/admin",
|
||||
"^/_matrix/client/r0/admin"
|
||||
"^/_matrix/client/r0/admin",
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -88,12 +88,12 @@ class UsersRestServlet(RestServlet):
|
|||
|
||||
|
||||
class VersionServlet(RestServlet):
|
||||
PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), )
|
||||
PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
|
||||
|
||||
def __init__(self, hs):
|
||||
self.res = {
|
||||
'server_version': get_version_string(synapse),
|
||||
'python_version': platform.python_version(),
|
||||
"server_version": get_version_string(synapse),
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
def on_GET(self, request):
|
||||
|
@ -107,6 +107,7 @@ class UserRegisterServlet(RestServlet):
|
|||
nonces (dict[str, int]): The nonces that we will accept. A dict of
|
||||
nonce to the time it was generated, in int seconds.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/register")
|
||||
NONCE_TIMEOUT = 60
|
||||
|
||||
|
@ -146,28 +147,24 @@ class UserRegisterServlet(RestServlet):
|
|||
body = parse_json_object_from_request(request)
|
||||
|
||||
if "nonce" not in body:
|
||||
raise SynapseError(
|
||||
400, "nonce must be specified", errcode=Codes.BAD_JSON,
|
||||
)
|
||||
raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
|
||||
|
||||
nonce = body["nonce"]
|
||||
|
||||
if nonce not in self.nonces:
|
||||
raise SynapseError(
|
||||
400, "unrecognised nonce",
|
||||
)
|
||||
raise SynapseError(400, "unrecognised nonce")
|
||||
|
||||
# Delete the nonce, so it can't be reused, even if it's invalid
|
||||
del self.nonces[nonce]
|
||||
|
||||
if "username" not in body:
|
||||
raise SynapseError(
|
||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||
400, "username must be specified", errcode=Codes.BAD_JSON
|
||||
)
|
||||
else:
|
||||
if (
|
||||
not isinstance(body['username'], text_type)
|
||||
or len(body['username']) > 512
|
||||
not isinstance(body["username"], text_type)
|
||||
or len(body["username"]) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
|
||||
|
@ -177,12 +174,12 @@ class UserRegisterServlet(RestServlet):
|
|||
|
||||
if "password" not in body:
|
||||
raise SynapseError(
|
||||
400, "password must be specified", errcode=Codes.BAD_JSON,
|
||||
400, "password must be specified", errcode=Codes.BAD_JSON
|
||||
)
|
||||
else:
|
||||
if (
|
||||
not isinstance(body['password'], text_type)
|
||||
or len(body['password']) > 512
|
||||
not isinstance(body["password"], text_type)
|
||||
or len(body["password"]) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
|
||||
|
@ -202,7 +199,7 @@ class UserRegisterServlet(RestServlet):
|
|||
key=self.hs.config.registration_shared_secret.encode(),
|
||||
digestmod=hashlib.sha1,
|
||||
)
|
||||
want_mac.update(nonce.encode('utf8'))
|
||||
want_mac.update(nonce.encode("utf8"))
|
||||
want_mac.update(b"\x00")
|
||||
want_mac.update(username)
|
||||
want_mac.update(b"\x00")
|
||||
|
@ -211,13 +208,10 @@ class UserRegisterServlet(RestServlet):
|
|||
want_mac.update(b"admin" if admin else b"notadmin")
|
||||
if user_type:
|
||||
want_mac.update(b"\x00")
|
||||
want_mac.update(user_type.encode('utf8'))
|
||||
want_mac.update(user_type.encode("utf8"))
|
||||
want_mac = want_mac.hexdigest()
|
||||
|
||||
if not hmac.compare_digest(
|
||||
want_mac.encode('ascii'),
|
||||
got_mac.encode('ascii')
|
||||
):
|
||||
if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
|
||||
raise SynapseError(403, "HMAC incorrect")
|
||||
|
||||
# Reuse the parts of RegisterRestServlet to reduce code duplication
|
||||
|
@ -226,7 +220,7 @@ class UserRegisterServlet(RestServlet):
|
|||
register = RegisterRestServlet(self.hs)
|
||||
|
||||
(user_id, _) = yield register.registration_handler.register(
|
||||
localpart=body['username'].lower(),
|
||||
localpart=body["username"].lower(),
|
||||
password=body["password"],
|
||||
admin=bool(admin),
|
||||
generate_token=False,
|
||||
|
@ -308,7 +302,7 @@ class PurgeHistoryRestServlet(RestServlet):
|
|||
# user can provide an event_id in the URL or the request body, or can
|
||||
# provide a timestamp in the request body.
|
||||
if event_id is None:
|
||||
event_id = body.get('purge_up_to_event_id')
|
||||
event_id = body.get("purge_up_to_event_id")
|
||||
|
||||
if event_id is not None:
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
@ -318,44 +312,39 @@ class PurgeHistoryRestServlet(RestServlet):
|
|||
|
||||
token = yield self.store.get_topological_token_for_event(event_id)
|
||||
|
||||
logger.info(
|
||||
"[purge] purging up to token %s (event_id %s)",
|
||||
token, event_id,
|
||||
)
|
||||
elif 'purge_up_to_ts' in body:
|
||||
ts = body['purge_up_to_ts']
|
||||
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
|
||||
elif "purge_up_to_ts" in body:
|
||||
ts = body["purge_up_to_ts"]
|
||||
if not isinstance(ts, int):
|
||||
raise SynapseError(
|
||||
400, "purge_up_to_ts must be an int",
|
||||
errcode=Codes.BAD_JSON,
|
||||
400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
|
||||
)
|
||||
|
||||
stream_ordering = (
|
||||
yield self.store.find_first_stream_ordering_after_ts(ts)
|
||||
)
|
||||
stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
|
||||
|
||||
r = (
|
||||
yield self.store.get_room_event_after_stream_ordering(
|
||||
room_id, stream_ordering,
|
||||
room_id, stream_ordering
|
||||
)
|
||||
)
|
||||
if not r:
|
||||
logger.warn(
|
||||
"[purge] purging events not possible: No event found "
|
||||
"(received_ts %i => stream_ordering %i)",
|
||||
ts, stream_ordering,
|
||||
ts,
|
||||
stream_ordering,
|
||||
)
|
||||
raise SynapseError(
|
||||
404,
|
||||
"there is no event to be purged",
|
||||
errcode=Codes.NOT_FOUND,
|
||||
404, "there is no event to be purged", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
(stream, topo, _event_id) = r
|
||||
token = "t%d-%d" % (topo, stream)
|
||||
logger.info(
|
||||
"[purge] purging up to token %s (received_ts %i => "
|
||||
"stream_ordering %i)",
|
||||
token, ts, stream_ordering,
|
||||
token,
|
||||
ts,
|
||||
stream_ordering,
|
||||
)
|
||||
else:
|
||||
raise SynapseError(
|
||||
|
@ -365,13 +354,10 @@ class PurgeHistoryRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
purge_id = yield self.pagination_handler.start_purge_history(
|
||||
room_id, token,
|
||||
delete_local_events=delete_local_events,
|
||||
room_id, token, delete_local_events=delete_local_events
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"purge_id": purge_id,
|
||||
}))
|
||||
defer.returnValue((200, {"purge_id": purge_id}))
|
||||
|
||||
|
||||
class PurgeHistoryStatusRestServlet(RestServlet):
|
||||
|
@ -421,16 +407,14 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
UserID.from_string(target_user_id)
|
||||
|
||||
result = yield self._deactivate_account_handler.deactivate_account(
|
||||
target_user_id, erase,
|
||||
target_user_id, erase
|
||||
)
|
||||
if result:
|
||||
id_server_unbind_result = "success"
|
||||
else:
|
||||
id_server_unbind_result = "no-support"
|
||||
|
||||
defer.returnValue((200, {
|
||||
"id_server_unbind_result": id_server_unbind_result,
|
||||
}))
|
||||
defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
|
||||
|
||||
|
||||
class ShutdownRoomRestServlet(RestServlet):
|
||||
|
@ -439,6 +423,7 @@ class ShutdownRoomRestServlet(RestServlet):
|
|||
to a new room created by `new_room_user_id` and kicked users will be auto
|
||||
joined to the new room.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
|
||||
|
||||
DEFAULT_MESSAGE = (
|
||||
|
@ -474,9 +459,7 @@ class ShutdownRoomRestServlet(RestServlet):
|
|||
config={
|
||||
"preset": "public_chat",
|
||||
"name": room_name,
|
||||
"power_level_content_override": {
|
||||
"users_default": -10,
|
||||
},
|
||||
"power_level_content_override": {"users_default": -10},
|
||||
},
|
||||
ratelimit=False,
|
||||
)
|
||||
|
@ -485,8 +468,7 @@ class ShutdownRoomRestServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
logger.info(
|
||||
"Shutting down room %r, joining to new room: %r",
|
||||
room_id, new_room_id,
|
||||
"Shutting down room %r, joining to new room: %r", room_id, new_room_id
|
||||
)
|
||||
|
||||
# This will work even if the room is already blocked, but that is
|
||||
|
@ -529,7 +511,7 @@ class ShutdownRoomRestServlet(RestServlet):
|
|||
kicked_users.append(user_id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to leave old room and join new room for %r", user_id,
|
||||
"Failed to leave old room and join new room for %r", user_id
|
||||
)
|
||||
failed_to_kick_users.append(user_id)
|
||||
|
||||
|
@ -550,18 +532,24 @@ class ShutdownRoomRestServlet(RestServlet):
|
|||
room_id, new_room_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"kicked_users": kicked_users,
|
||||
"failed_to_kick_users": failed_to_kick_users,
|
||||
"local_aliases": aliases_for_room,
|
||||
"new_room_id": new_room_id,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"kicked_users": kicked_users,
|
||||
"failed_to_kick_users": failed_to_kick_users,
|
||||
"local_aliases": aliases_for_room,
|
||||
"new_room_id": new_room_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class QuarantineMediaInRoom(RestServlet):
|
||||
"""Quarantines all media in a room so that no one can download it via
|
||||
this server.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -574,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet):
|
|||
yield assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
num_quarantined = yield self.store.quarantine_media_ids_in_room(
|
||||
room_id, requester.user.to_string(),
|
||||
room_id, requester.user.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"num_quarantined": num_quarantined}))
|
||||
|
@ -583,6 +571,7 @@ class QuarantineMediaInRoom(RestServlet):
|
|||
class ListMediaInRoom(RestServlet):
|
||||
"""Lists all of the media in a given room.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -613,7 +602,10 @@ class ResetPasswordRestServlet(RestServlet):
|
|||
Returns:
|
||||
200 OK with empty object if success otherwise an error.
|
||||
"""
|
||||
PATTERNS = historical_admin_path_patterns("/reset_password/(?P<target_user_id>[^/]*)")
|
||||
|
||||
PATTERNS = historical_admin_path_patterns(
|
||||
"/reset_password/(?P<target_user_id>[^/]*)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -633,7 +625,7 @@ class ResetPasswordRestServlet(RestServlet):
|
|||
|
||||
params = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params['new_password']
|
||||
new_password = params["new_password"]
|
||||
|
||||
yield self._set_password_handler.set_password(
|
||||
target_user_id, new_password, requester
|
||||
|
@ -650,7 +642,10 @@ class GetUsersPaginatedRestServlet(RestServlet):
|
|||
Returns:
|
||||
200 OK with json object {list[dict[str, Any]], count} or empty object.
|
||||
"""
|
||||
PATTERNS = historical_admin_path_patterns("/users_paginate/(?P<target_user_id>[^/]*)")
|
||||
|
||||
PATTERNS = historical_admin_path_patterns(
|
||||
"/users_paginate/(?P<target_user_id>[^/]*)"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
@ -676,9 +671,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
|
|||
|
||||
logger.info("limit: %s, start: %s", limit, start)
|
||||
|
||||
ret = yield self.handlers.admin_handler.get_users_paginate(
|
||||
order, start, limit
|
||||
)
|
||||
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -702,13 +695,11 @@ class GetUsersPaginatedRestServlet(RestServlet):
|
|||
order = "name" # order by name in user table
|
||||
params = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(params, ["limit", "start"])
|
||||
limit = params['limit']
|
||||
start = params['start']
|
||||
limit = params["limit"]
|
||||
start = params["start"]
|
||||
logger.info("limit: %s, start: %s", limit, start)
|
||||
|
||||
ret = yield self.handlers.admin_handler.get_users_paginate(
|
||||
order, start, limit
|
||||
)
|
||||
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
|
@ -722,6 +713,7 @@ class SearchUsersRestServlet(RestServlet):
|
|||
Returns:
|
||||
200 OK with json object {list[dict[str, Any]], count} or empty object.
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -750,15 +742,14 @@ class SearchUsersRestServlet(RestServlet):
|
|||
term = parse_string(request, "term", required=True)
|
||||
logger.info("term: %s ", term)
|
||||
|
||||
ret = yield self.handlers.admin_handler.search_users(
|
||||
term
|
||||
)
|
||||
ret = yield self.handlers.admin_handler.search_users(term)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class DeleteGroupAdminRestServlet(RestServlet):
|
||||
"""Allows deleting of local groups
|
||||
"""
|
||||
|
||||
PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -800,15 +791,15 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
raise SynapseError(400, "Missing property 'user_id' in the request body")
|
||||
|
||||
expiration_ts = yield self.account_activity_handler.renew_account_for_user(
|
||||
body["user_id"], body.get("expiration_ts"),
|
||||
body["user_id"],
|
||||
body.get("expiration_ts"),
|
||||
not body.get("enable_renewal_emails", True),
|
||||
)
|
||||
|
||||
res = {
|
||||
"expiration_ts": expiration_ts,
|
||||
}
|
||||
res = {"expiration_ts": expiration_ts}
|
||||
defer.returnValue((200, res))
|
||||
|
||||
|
||||
########################################################################################
|
||||
#
|
||||
# please don't add more servlets here: this file is already long and unwieldy. Put
|
||||
|
|
|
@ -46,6 +46,7 @@ class SendServerNoticeServlet(RestServlet):
|
|||
"event_id": "$1895723857jgskldgujpious"
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
|
@ -58,15 +59,9 @@ class SendServerNoticeServlet(RestServlet):
|
|||
|
||||
def register(self, json_resource):
|
||||
PATTERN = "^/_synapse/admin/v1/send_server_notice"
|
||||
json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST)
|
||||
json_resource.register_paths(
|
||||
"POST",
|
||||
(re.compile(PATTERN + "$"), ),
|
||||
self.on_POST,
|
||||
)
|
||||
json_resource.register_paths(
|
||||
"PUT",
|
||||
(re.compile(PATTERN + "/(?P<txn_id>[^/]*)$",), ),
|
||||
self.on_PUT,
|
||||
"PUT", (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), self.on_PUT
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -96,5 +91,5 @@ class SendServerNoticeServlet(RestServlet):
|
|||
|
||||
def on_PUT(self, request, txn_id):
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request, txn_id,
|
||||
request, self.on_POST, request, txn_id
|
||||
)
|
||||
|
|
|
@ -26,7 +26,6 @@ CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
|
|||
|
||||
|
||||
class HttpTransactionCache(object):
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = self.hs.get_auth()
|
||||
|
@ -53,7 +52,7 @@ class HttpTransactionCache(object):
|
|||
str: A transaction key
|
||||
"""
|
||||
token = self.auth.get_access_token_from_request(request)
|
||||
return request.path.decode('utf8') + "/" + token
|
||||
return request.path.decode("utf8") + "/" + token
|
||||
|
||||
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
|
||||
"""A helper function for fetch_or_execute which extracts
|
||||
|
|
|
@ -56,8 +56,9 @@ class ClientDirectoryServer(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
if "room_id" not in content:
|
||||
raise SynapseError(400, 'Missing params: ["room_id"]',
|
||||
errcode=Codes.BAD_JSON)
|
||||
raise SynapseError(
|
||||
400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON
|
||||
)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
logger.debug("Got room name: %s", room_alias.to_string())
|
||||
|
@ -89,13 +90,11 @@ class ClientDirectoryServer(RestServlet):
|
|||
try:
|
||||
service = yield self.auth.get_appservice_by_req(request)
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
yield dir_handler.delete_appservice_association(
|
||||
service, room_alias
|
||||
)
|
||||
yield dir_handler.delete_appservice_association(service, room_alias)
|
||||
logger.info(
|
||||
"Application service at %s deleted alias %s",
|
||||
service.url,
|
||||
room_alias.to_string()
|
||||
room_alias.to_string(),
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
except AuthError:
|
||||
|
@ -107,14 +106,10 @@ class ClientDirectoryServer(RestServlet):
|
|||
|
||||
room_alias = RoomAlias.from_string(room_alias)
|
||||
|
||||
yield dir_handler.delete_association(
|
||||
requester, room_alias
|
||||
)
|
||||
yield dir_handler.delete_association(requester, room_alias)
|
||||
|
||||
logger.info(
|
||||
"User %s deleted alias %s",
|
||||
user.to_string(),
|
||||
room_alias.to_string()
|
||||
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -135,9 +130,9 @@ class ClientDirectoryListServer(RestServlet):
|
|||
if room is None:
|
||||
raise NotFoundError("Unknown room")
|
||||
|
||||
defer.returnValue((200, {
|
||||
"visibility": "public" if room["is_public"] else "private"
|
||||
}))
|
||||
defer.returnValue(
|
||||
(200, {"visibility": "public" if room["is_public"] else "private"})
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id):
|
||||
|
@ -147,7 +142,7 @@ class ClientDirectoryListServer(RestServlet):
|
|||
visibility = content.get("visibility", "public")
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_room_list(
|
||||
requester, room_id, visibility,
|
||||
requester, room_id, visibility
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -157,7 +152,7 @@ class ClientDirectoryListServer(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_room_list(
|
||||
requester, room_id, "private",
|
||||
requester, room_id, "private"
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -191,7 +186,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
|
|||
)
|
||||
|
||||
yield self.handlers.directory_handler.edit_published_appservice_room_list(
|
||||
requester.app_service.id, network_id, room_id, visibility,
|
||||
requester.app_service.id, network_id, room_id, visibility
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -38,17 +38,14 @@ class EventStreamRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
is_guest = requester.is_guest
|
||||
room_id = None
|
||||
if is_guest:
|
||||
if b"room_id" not in request.args:
|
||||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
if b"room_id" in request.args:
|
||||
room_id = request.args[b"room_id"][0].decode('ascii')
|
||||
room_id = request.args[b"room_id"][0].decode("ascii")
|
||||
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
|
||||
|
|
|
@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
|
|||
to a typed object.
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.user",
|
||||
"user": submission["user"],
|
||||
}
|
||||
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
|
||||
del submission["user"]
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
|
@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
|
|||
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
|
||||
|
||||
return {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "msisdn",
|
||||
"address": msisdn,
|
||||
}
|
||||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
|
@ -124,9 +117,9 @@ class LoginRestServlet(RestServlet):
|
|||
# login flow types returned.
|
||||
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
|
||||
|
||||
flows.extend((
|
||||
{"type": t} for t in self.auth_handler.get_supported_login_types()
|
||||
))
|
||||
flows.extend(
|
||||
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||
)
|
||||
|
||||
return (200, {"flows": flows})
|
||||
|
||||
|
@ -136,7 +129,8 @@ class LoginRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
self._address_ratelimiter.ratelimit(
|
||||
request.getClientIP(), time_now_s=self.hs.clock.time(),
|
||||
request.getClientIP(),
|
||||
time_now_s=self.hs.clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_address.per_second,
|
||||
burst_count=self.hs.config.rc_login_address.burst_count,
|
||||
update=True,
|
||||
|
@ -144,8 +138,9 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
try:
|
||||
if self.jwt_enabled and (login_submission["type"] ==
|
||||
LoginRestServlet.JWT_TYPE):
|
||||
if self.jwt_enabled and (
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
):
|
||||
result = yield self.do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
result = yield self.do_token_login(login_submission)
|
||||
|
@ -174,10 +169,10 @@ class LoginRestServlet(RestServlet):
|
|||
# field)
|
||||
logger.info(
|
||||
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
|
||||
login_submission.get('identifier'),
|
||||
login_submission.get('medium'),
|
||||
login_submission.get('address'),
|
||||
login_submission.get('user'),
|
||||
login_submission.get("identifier"),
|
||||
login_submission.get("medium"),
|
||||
login_submission.get("address"),
|
||||
login_submission.get("user"),
|
||||
)
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
|
@ -194,13 +189,13 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get('address')
|
||||
medium = identifier.get('medium')
|
||||
address = identifier.get("address")
|
||||
medium = identifier.get("medium")
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
if medium == 'email':
|
||||
if medium == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
|
@ -209,34 +204,28 @@ class LoginRestServlet(RestServlet):
|
|||
# Check for login providers that support 3pid login types
|
||||
canonical_user_id, callback_3pid = (
|
||||
yield self.auth_handler.check_password_provider_3pid(
|
||||
medium,
|
||||
address,
|
||||
login_submission["password"],
|
||||
medium, address, login_submission["password"]
|
||||
)
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
result = yield self._register_device_with_callback(
|
||||
canonical_user_id, login_submission, callback_3pid,
|
||||
canonical_user_id, login_submission, callback_3pid
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
# No password providers were able to handle this 3pid
|
||||
# Check local store
|
||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address,
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warn(
|
||||
"unknown 3pid identifier medium %s, address %r",
|
||||
medium, address,
|
||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier = {
|
||||
"type": "m.id.user",
|
||||
"user": user_id,
|
||||
}
|
||||
identifier = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
|
@ -246,22 +235,16 @@ class LoginRestServlet(RestServlet):
|
|||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
canonical_user_id, callback = yield self.auth_handler.validate_login(
|
||||
identifier["user"],
|
||||
login_submission,
|
||||
identifier["user"], login_submission
|
||||
)
|
||||
|
||||
result = yield self._register_device_with_callback(
|
||||
canonical_user_id, login_submission, callback,
|
||||
canonical_user_id, login_submission, callback
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _register_device_with_callback(
|
||||
self,
|
||||
user_id,
|
||||
login_submission,
|
||||
callback=None,
|
||||
):
|
||||
def _register_device_with_callback(self, user_id, login_submission, callback=None):
|
||||
""" Registers a device with a given user_id. Optionally run a callback
|
||||
function after registration has completed.
|
||||
|
||||
|
@ -277,7 +260,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -294,7 +277,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def do_token_login(self, login_submission):
|
||||
token = login_submission['token']
|
||||
token = login_submission["token"]
|
||||
auth_handler = self.auth_handler
|
||||
user_id = (
|
||||
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
|
||||
|
@ -303,7 +286,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name,
|
||||
user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -320,15 +303,16 @@ class LoginRestServlet(RestServlet):
|
|||
token = login_submission.get("token", None)
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
401, "Token field for JWT is missing",
|
||||
errcode=Codes.UNAUTHORIZED
|
||||
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
import jwt
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
|
||||
payload = jwt.decode(
|
||||
token, self.jwt_secret, algorithms=[self.jwt_algorithm]
|
||||
)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
|
||||
except InvalidTokenError:
|
||||
|
@ -346,7 +330,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
registered_user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -362,7 +346,7 @@ class LoginRestServlet(RestServlet):
|
|||
device_id = login_submission.get("device_id")
|
||||
initial_display_name = login_submission.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
registered_user_id, device_id, initial_display_name,
|
||||
registered_user_id, device_id, initial_display_name
|
||||
)
|
||||
|
||||
result = {
|
||||
|
@ -376,6 +360,7 @@ class LoginRestServlet(RestServlet):
|
|||
|
||||
class BaseSsoRedirectServlet(RestServlet):
|
||||
"""Common base class for /login/sso/redirect impls"""
|
||||
|
||||
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
|
||||
|
||||
def on_GET(self, request):
|
||||
|
@ -401,21 +386,20 @@ class BaseSsoRedirectServlet(RestServlet):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CasRedirectServlet(RestServlet):
|
||||
class CasRedirectServlet(BaseSsoRedirectServlet):
|
||||
def __init__(self, hs):
|
||||
super(CasRedirectServlet, self).__init__()
|
||||
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
|
||||
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
|
||||
self.cas_server_url = hs.config.cas_server_url.encode("ascii")
|
||||
self.cas_service_url = hs.config.cas_service_url.encode("ascii")
|
||||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
client_redirect_url_param = urllib.parse.urlencode({
|
||||
b"redirectUrl": client_redirect_url
|
||||
}).encode('ascii')
|
||||
hs_redirect_url = (self.cas_service_url +
|
||||
b"/_matrix/client/r0/login/cas/ticket")
|
||||
service_param = urllib.parse.urlencode({
|
||||
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||
}).encode('ascii')
|
||||
client_redirect_url_param = urllib.parse.urlencode(
|
||||
{b"redirectUrl": client_redirect_url}
|
||||
).encode("ascii")
|
||||
hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
|
||||
service_param = urllib.parse.urlencode(
|
||||
{b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
|
||||
).encode("ascii")
|
||||
return b"%s/login?%s" % (self.cas_server_url, service_param)
|
||||
|
||||
|
||||
|
@ -436,7 +420,7 @@ class CasTicketServlet(RestServlet):
|
|||
uri = self.cas_server_url + "/proxyValidate"
|
||||
args = {
|
||||
"ticket": parse_string(request, "ticket", required=True),
|
||||
"service": self.cas_service_url
|
||||
"service": self.cas_service_url,
|
||||
}
|
||||
try:
|
||||
body = yield self._http_client.get_raw(uri, args)
|
||||
|
@ -463,7 +447,7 @@ class CasTicketServlet(RestServlet):
|
|||
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
user, request, client_redirect_url,
|
||||
user, request, client_redirect_url
|
||||
)
|
||||
|
||||
def parse_cas_response(self, cas_response_body):
|
||||
|
@ -473,7 +457,7 @@ class CasTicketServlet(RestServlet):
|
|||
root = ET.fromstring(cas_response_body)
|
||||
if not root.tag.endswith("serviceResponse"):
|
||||
raise Exception("root of CAS response is not serviceResponse")
|
||||
success = (root[0].tag.endswith("authenticationSuccess"))
|
||||
success = root[0].tag.endswith("authenticationSuccess")
|
||||
for child in root[0]:
|
||||
if child.tag.endswith("user"):
|
||||
user = child.text
|
||||
|
@ -491,11 +475,11 @@ class CasTicketServlet(RestServlet):
|
|||
raise Exception("CAS response does not contain user")
|
||||
except Exception:
|
||||
logger.error("Error parsing CAS response", exc_info=1)
|
||||
raise LoginError(401, "Invalid CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
||||
if not success:
|
||||
raise LoginError(401, "Unsuccessful CAS response",
|
||||
errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(
|
||||
401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
||||
)
|
||||
return user, attributes
|
||||
|
||||
|
||||
|
@ -507,11 +491,11 @@ class SAMLRedirectServlet(BaseSsoRedirectServlet):
|
|||
|
||||
def get_sso_url(self, client_redirect_url):
|
||||
reqid, info = self._saml_client.prepare_for_authenticate(
|
||||
relay_state=client_redirect_url,
|
||||
relay_state=client_redirect_url
|
||||
)
|
||||
|
||||
for key, value in info['headers']:
|
||||
if key == 'Location':
|
||||
for key, value in info["headers"]:
|
||||
if key == "Location":
|
||||
return value
|
||||
|
||||
# this shouldn't happen!
|
||||
|
@ -526,6 +510,7 @@ class SSOAuthHandler(object):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer)
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
@ -534,8 +519,7 @@ class SSOAuthHandler(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_successful_auth(
|
||||
self, username, request, client_redirect_url,
|
||||
user_display_name=None,
|
||||
self, username, request, client_redirect_url, user_display_name=None
|
||||
):
|
||||
"""Called once the user has successfully authenticated with the SSO.
|
||||
|
||||
|
|
|
@ -46,7 +46,8 @@ class LogoutRestServlet(RestServlet):
|
|||
yield self._auth_handler.delete_access_token(access_token)
|
||||
else:
|
||||
yield self._device_handler.delete_device(
|
||||
requester.user.to_string(), requester.device_id)
|
||||
requester.user.to_string(), requester.device_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class PresenceStatusRestServlet(RestServlet):
|
|||
|
||||
if requester.user != user:
|
||||
allowed = yield self.presence_handler.is_visible(
|
||||
observed_user=user, observer_user=requester.user,
|
||||
observed_user=user, observer_user=requester.user
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
|
|
|
@ -63,8 +63,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
|
|||
except Exception:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, new_name, is_admin)
|
||||
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -113,8 +112,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
|
|||
except Exception:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.profile_handler.set_avatar_url(
|
||||
user, requester, new_name, is_admin)
|
||||
yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -21,7 +21,11 @@ from synapse.api.errors import (
|
|||
SynapseError,
|
||||
UnrecognizedRequestError,
|
||||
)
|
||||
from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_json_value_from_request,
|
||||
parse_string,
|
||||
)
|
||||
from synapse.push.baserules import BASE_RULE_IDS
|
||||
from synapse.push.clientformat import format_push_rules_for_user
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
|
||||
|
@ -32,7 +36,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc
|
|||
class PushRuleRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
|
||||
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
|
||||
"Unrecognised request: You probably wanted a trailing slash")
|
||||
"Unrecognised request: You probably wanted a trailing slash"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PushRuleRestServlet, self).__init__()
|
||||
|
@ -54,27 +59,25 @@ class PushRuleRestServlet(RestServlet):
|
|||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
||||
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
|
||||
raise SynapseError(400, "rule_id may not contain slashes")
|
||||
|
||||
content = parse_json_value_from_request(request)
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if 'attr' in spec:
|
||||
if "attr" in spec:
|
||||
yield self.set_rule_attr(user_id, spec, content)
|
||||
self.notify_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
if spec['rule_id'].startswith('.'):
|
||||
if spec["rule_id"].startswith("."):
|
||||
# Rule ids starting with '.' are reserved for server default rules.
|
||||
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
|
||||
|
||||
try:
|
||||
(conditions, actions) = _rule_tuple_from_request_object(
|
||||
spec['template'],
|
||||
spec['rule_id'],
|
||||
content,
|
||||
spec["template"], spec["rule_id"], content
|
||||
)
|
||||
except InvalidRuleException as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
@ -95,7 +98,7 @@ class PushRuleRestServlet(RestServlet):
|
|||
conditions=conditions,
|
||||
actions=actions,
|
||||
before=before,
|
||||
after=after
|
||||
after=after,
|
||||
)
|
||||
self.notify_user(user_id)
|
||||
except InconsistentRuleException as e:
|
||||
|
@ -118,9 +121,7 @@ class PushRuleRestServlet(RestServlet):
|
|||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
|
||||
try:
|
||||
yield self.store.delete_push_rule(
|
||||
user_id, namespaced_rule_id
|
||||
)
|
||||
yield self.store.delete_push_rule(user_id, namespaced_rule_id)
|
||||
self.notify_user(user_id)
|
||||
defer.returnValue((200, {}))
|
||||
except StoreError as e:
|
||||
|
@ -149,10 +150,10 @@ class PushRuleRestServlet(RestServlet):
|
|||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||
)
|
||||
|
||||
if path[0] == '':
|
||||
if path[0] == "":
|
||||
defer.returnValue((200, rules))
|
||||
elif path[0] == 'global':
|
||||
result = _filter_ruleset_with_path(rules['global'], path[1:])
|
||||
elif path[0] == "global":
|
||||
result = _filter_ruleset_with_path(rules["global"], path[1:])
|
||||
defer.returnValue((200, result))
|
||||
else:
|
||||
raise UnrecognizedRequestError()
|
||||
|
@ -162,12 +163,10 @@ class PushRuleRestServlet(RestServlet):
|
|||
|
||||
def notify_user(self, user_id):
|
||||
stream_id, _ = self.store.get_push_rules_stream_token()
|
||||
self.notifier.on_new_event(
|
||||
"push_rules_key", stream_id, users=[user_id]
|
||||
)
|
||||
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
|
||||
def set_rule_attr(self, user_id, spec, val):
|
||||
if spec['attr'] == 'enabled':
|
||||
if spec["attr"] == "enabled":
|
||||
if isinstance(val, dict) and "enabled" in val:
|
||||
val = val["enabled"]
|
||||
if not isinstance(val, bool):
|
||||
|
@ -176,14 +175,12 @@ class PushRuleRestServlet(RestServlet):
|
|||
# bools directly, so let's not break them.
|
||||
raise SynapseError(400, "Value for 'enabled' must be boolean")
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
return self.store.set_push_rule_enabled(
|
||||
user_id, namespaced_rule_id, val
|
||||
)
|
||||
elif spec['attr'] == 'actions':
|
||||
actions = val.get('actions')
|
||||
return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
|
||||
elif spec["attr"] == "actions":
|
||||
actions = val.get("actions")
|
||||
_check_actions(actions)
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
rule_id = spec['rule_id']
|
||||
rule_id = spec["rule_id"]
|
||||
is_default_rule = rule_id.startswith(".")
|
||||
if is_default_rule:
|
||||
if namespaced_rule_id not in BASE_RULE_IDS:
|
||||
|
@ -210,12 +207,12 @@ def _rule_spec_from_path(path):
|
|||
"""
|
||||
if len(path) < 2:
|
||||
raise UnrecognizedRequestError()
|
||||
if path[0] != 'pushrules':
|
||||
if path[0] != "pushrules":
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
scope = path[1]
|
||||
path = path[2:]
|
||||
if scope != 'global':
|
||||
if scope != "global":
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
if len(path) == 0:
|
||||
|
@ -229,56 +226,40 @@ def _rule_spec_from_path(path):
|
|||
|
||||
rule_id = path[0]
|
||||
|
||||
spec = {
|
||||
'scope': scope,
|
||||
'template': template,
|
||||
'rule_id': rule_id
|
||||
}
|
||||
spec = {"scope": scope, "template": template, "rule_id": rule_id}
|
||||
|
||||
path = path[1:]
|
||||
|
||||
if len(path) > 0 and len(path[0]) > 0:
|
||||
spec['attr'] = path[0]
|
||||
spec["attr"] = path[0]
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
|
||||
if rule_template in ['override', 'underride']:
|
||||
if 'conditions' not in req_obj:
|
||||
if rule_template in ["override", "underride"]:
|
||||
if "conditions" not in req_obj:
|
||||
raise InvalidRuleException("Missing 'conditions'")
|
||||
conditions = req_obj['conditions']
|
||||
conditions = req_obj["conditions"]
|
||||
for c in conditions:
|
||||
if 'kind' not in c:
|
||||
if "kind" not in c:
|
||||
raise InvalidRuleException("Condition without 'kind'")
|
||||
elif rule_template == 'room':
|
||||
conditions = [{
|
||||
'kind': 'event_match',
|
||||
'key': 'room_id',
|
||||
'pattern': rule_id
|
||||
}]
|
||||
elif rule_template == 'sender':
|
||||
conditions = [{
|
||||
'kind': 'event_match',
|
||||
'key': 'user_id',
|
||||
'pattern': rule_id
|
||||
}]
|
||||
elif rule_template == 'content':
|
||||
if 'pattern' not in req_obj:
|
||||
elif rule_template == "room":
|
||||
conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}]
|
||||
elif rule_template == "sender":
|
||||
conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}]
|
||||
elif rule_template == "content":
|
||||
if "pattern" not in req_obj:
|
||||
raise InvalidRuleException("Content rule missing 'pattern'")
|
||||
pat = req_obj['pattern']
|
||||
pat = req_obj["pattern"]
|
||||
|
||||
conditions = [{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.body',
|
||||
'pattern': pat
|
||||
}]
|
||||
conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}]
|
||||
else:
|
||||
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
|
||||
|
||||
if 'actions' not in req_obj:
|
||||
if "actions" not in req_obj:
|
||||
raise InvalidRuleException("No actions found")
|
||||
actions = req_obj['actions']
|
||||
actions = req_obj["actions"]
|
||||
|
||||
_check_actions(actions)
|
||||
|
||||
|
@ -290,9 +271,9 @@ def _check_actions(actions):
|
|||
raise InvalidRuleException("No actions found")
|
||||
|
||||
for a in actions:
|
||||
if a in ['notify', 'dont_notify', 'coalesce']:
|
||||
if a in ["notify", "dont_notify", "coalesce"]:
|
||||
pass
|
||||
elif isinstance(a, dict) and 'set_tweak' in a:
|
||||
elif isinstance(a, dict) and "set_tweak" in a:
|
||||
pass
|
||||
else:
|
||||
raise InvalidRuleException("Unrecognised action")
|
||||
|
@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||
)
|
||||
|
||||
if path[0] == '':
|
||||
if path[0] == "":
|
||||
return ruleset
|
||||
template_kind = path[0]
|
||||
if template_kind not in ruleset:
|
||||
|
@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
raise UnrecognizedRequestError(
|
||||
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
|
||||
)
|
||||
if path[0] == '':
|
||||
if path[0] == "":
|
||||
return ruleset[template_kind]
|
||||
rule_id = path[0]
|
||||
|
||||
the_rule = None
|
||||
for r in ruleset[template_kind]:
|
||||
if r['rule_id'] == rule_id:
|
||||
if r["rule_id"] == rule_id:
|
||||
the_rule = r
|
||||
if the_rule is None:
|
||||
raise NotFoundError
|
||||
|
@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path):
|
|||
|
||||
|
||||
def _priority_class_from_spec(spec):
|
||||
if spec['template'] not in PRIORITY_CLASS_MAP.keys():
|
||||
raise InvalidRuleException("Unknown template: %s" % (spec['template']))
|
||||
pc = PRIORITY_CLASS_MAP[spec['template']]
|
||||
if spec["template"] not in PRIORITY_CLASS_MAP.keys():
|
||||
raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
|
||||
pc = PRIORITY_CLASS_MAP[spec["template"]]
|
||||
|
||||
return pc
|
||||
|
||||
|
||||
def _namespaced_rule_id_from_spec(spec):
|
||||
return _namespaced_rule_id(spec, spec['rule_id'])
|
||||
return _namespaced_rule_id(spec, spec["rule_id"])
|
||||
|
||||
|
||||
def _namespaced_rule_id(spec, rule_id):
|
||||
return "global/%s/%s" % (spec['template'], rule_id)
|
||||
return "global/%s/%s" % (spec["template"], rule_id)
|
||||
|
||||
|
||||
class InvalidRuleException(Exception):
|
||||
|
|
|
@ -44,9 +44,7 @@ class PushersRestServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = requester.user
|
||||
|
||||
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
|
||||
user.to_string()
|
||||
)
|
||||
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
|
||||
|
||||
allowed_keys = [
|
||||
"app_display_name",
|
||||
|
@ -87,50 +85,61 @@ class PushersSetRestServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
if ('pushkey' in content and 'app_id' in content
|
||||
and 'kind' in content and
|
||||
content['kind'] is None):
|
||||
if (
|
||||
"pushkey" in content
|
||||
and "app_id" in content
|
||||
and "kind" in content
|
||||
and content["kind"] is None
|
||||
):
|
||||
yield self.pusher_pool.remove_pusher(
|
||||
content['app_id'], content['pushkey'], user_id=user.to_string()
|
||||
content["app_id"], content["pushkey"], user_id=user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
assert_params_in_dict(
|
||||
content,
|
||||
['kind', 'app_id', 'app_display_name',
|
||||
'device_display_name', 'pushkey', 'lang', 'data']
|
||||
[
|
||||
"kind",
|
||||
"app_id",
|
||||
"app_display_name",
|
||||
"device_display_name",
|
||||
"pushkey",
|
||||
"lang",
|
||||
"data",
|
||||
],
|
||||
)
|
||||
|
||||
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
|
||||
logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"])
|
||||
logger.debug("Got pushers request with body: %r", content)
|
||||
|
||||
append = False
|
||||
if 'append' in content:
|
||||
append = content['append']
|
||||
if "append" in content:
|
||||
append = content["append"]
|
||||
|
||||
if not append:
|
||||
yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
app_id=content['app_id'],
|
||||
pushkey=content['pushkey'],
|
||||
not_user_id=user.to_string()
|
||||
app_id=content["app_id"],
|
||||
pushkey=content["pushkey"],
|
||||
not_user_id=user.to_string(),
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.pusher_pool.add_pusher(
|
||||
user_id=user.to_string(),
|
||||
access_token=requester.access_token_id,
|
||||
kind=content['kind'],
|
||||
app_id=content['app_id'],
|
||||
app_display_name=content['app_display_name'],
|
||||
device_display_name=content['device_display_name'],
|
||||
pushkey=content['pushkey'],
|
||||
lang=content['lang'],
|
||||
data=content['data'],
|
||||
profile_tag=content.get('profile_tag', ""),
|
||||
kind=content["kind"],
|
||||
app_id=content["app_id"],
|
||||
app_display_name=content["app_display_name"],
|
||||
device_display_name=content["device_display_name"],
|
||||
pushkey=content["pushkey"],
|
||||
lang=content["lang"],
|
||||
data=content["data"],
|
||||
profile_tag=content.get("profile_tag", ""),
|
||||
)
|
||||
except PusherConfigException as pce:
|
||||
raise SynapseError(400, "Config Error: " + str(pce),
|
||||
errcode=Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
|
@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
"""
|
||||
To allow pusher to be delete by clicking a link (ie. GET request)
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/pushers/remove$", v1=True)
|
||||
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
|
||||
|
||||
|
@ -164,9 +174,7 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
|
||||
try:
|
||||
yield self.pusher_pool.remove_pusher(
|
||||
app_id=app_id,
|
||||
pushkey=pushkey,
|
||||
user_id=user.to_string(),
|
||||
app_id=app_id, pushkey=pushkey, user_id=user.to_string()
|
||||
)
|
||||
except StoreError as se:
|
||||
if se.code != 404:
|
||||
|
@ -177,9 +185,9 @@ class PushersRemoveRestServlet(RestServlet):
|
|||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(PushersRemoveRestServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),)
|
||||
)
|
||||
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
|
|
@ -61,18 +61,16 @@ class RoomCreateRestServlet(TransactionRestServlet):
|
|||
PATTERNS = "/createRoom"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
|
||||
http_server.register_paths("OPTIONS",
|
||||
client_patterns("/rooms(?:/.*)?$", v1=True),
|
||||
self.on_OPTIONS)
|
||||
http_server.register_paths(
|
||||
"OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS
|
||||
)
|
||||
# define CORS for /createRoom[/txnid]
|
||||
http_server.register_paths("OPTIONS",
|
||||
client_patterns("/createRoom(?:/.*)?$", v1=True),
|
||||
self.on_OPTIONS)
|
||||
http_server.register_paths(
|
||||
"OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS
|
||||
)
|
||||
|
||||
def on_PUT(self, request, txn_id):
|
||||
return self.txns.fetch_or_execute_request(
|
||||
request, self.on_POST, request
|
||||
)
|
||||
return self.txns.fetch_or_execute_request(request, self.on_POST, request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
@ -107,21 +105,23 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
|
||||
|
||||
# /room/$roomid/state/$eventtype/$statekey
|
||||
state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
|
||||
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
|
||||
state_key = (
|
||||
"/rooms/(?P<room_id>[^/]*)/state/"
|
||||
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$"
|
||||
)
|
||||
|
||||
http_server.register_paths("GET",
|
||||
client_patterns(state_key, v1=True),
|
||||
self.on_GET)
|
||||
http_server.register_paths("PUT",
|
||||
client_patterns(state_key, v1=True),
|
||||
self.on_PUT)
|
||||
http_server.register_paths("GET",
|
||||
client_patterns(no_state_key, v1=True),
|
||||
self.on_GET_no_state_key)
|
||||
http_server.register_paths("PUT",
|
||||
client_patterns(no_state_key, v1=True),
|
||||
self.on_PUT_no_state_key)
|
||||
http_server.register_paths(
|
||||
"GET", client_patterns(state_key, v1=True), self.on_GET
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT", client_patterns(state_key, v1=True), self.on_PUT
|
||||
)
|
||||
http_server.register_paths(
|
||||
"GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key
|
||||
)
|
||||
|
||||
def on_GET_no_state_key(self, request, room_id, event_type):
|
||||
return self.on_GET(request, room_id, event_type, "")
|
||||
|
@ -132,8 +132,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_type, state_key):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
format = parse_string(request, "format", default="content",
|
||||
allowed_values=["content", "event"])
|
||||
format = parse_string(
|
||||
request, "format", default="content", allowed_values=["content", "event"]
|
||||
)
|
||||
|
||||
msg_handler = self.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
|
@ -145,9 +146,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
)
|
||||
|
||||
if not data:
|
||||
raise SynapseError(
|
||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
||||
if format == "event":
|
||||
event = format_event_for_client_v2(data.get_dict())
|
||||
|
@ -182,9 +181,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
)
|
||||
else:
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
|
||||
ret = {}
|
||||
|
@ -195,7 +192,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
|||
|
||||
# TODO: Needs unit testing for generic events + feedback
|
||||
class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
@ -203,7 +199,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server, with_get=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -218,13 +214,11 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
|||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
if b'ts' in request.args and requester.app_service:
|
||||
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||
if b"ts" in request.args and requester.app_service:
|
||||
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
|
||||
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
txn_id=txn_id,
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"event_id": event.event_id}))
|
||||
|
@ -247,15 +241,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
|
||||
def register(self, http_server):
|
||||
# /join/$room_identifier[/$txn_id]
|
||||
PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
|
||||
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_identifier, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
try:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
@ -268,7 +259,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
room_id = room_identifier
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
x.decode('ascii') for x in request.args[b"server_name"]
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
|
@ -278,9 +269,9 @@ class JoinRoomAliasServlet(TransactionRestServlet):
|
|||
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(400, "%s was not legal room ID or room alias" % (
|
||||
room_identifier,
|
||||
))
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
|
@ -320,7 +311,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
# Option to allow servers to require auth when accessing
|
||||
# /publicRooms via CS API. This is especially helpful in private
|
||||
# federations.
|
||||
if self.hs.config.restrict_public_rooms_to_local_users:
|
||||
if not self.hs.config.allow_public_rooms_without_auth:
|
||||
raise
|
||||
|
||||
# We allow people to not be authed if they're just looking at our
|
||||
|
@ -339,14 +330,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
|
|||
handler = self.hs.get_room_list_handler()
|
||||
if server:
|
||||
data = yield handler.get_remote_public_room_list(
|
||||
server,
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
server, limit=limit, since_token=since_token
|
||||
)
|
||||
else:
|
||||
data = yield handler.get_local_public_room_list(
|
||||
limit=limit,
|
||||
since_token=since_token,
|
||||
limit=limit, since_token=since_token
|
||||
)
|
||||
|
||||
defer.returnValue((200, data))
|
||||
|
@ -439,16 +427,13 @@ class RoomMemberListRestServlet(RestServlet):
|
|||
chunk = []
|
||||
|
||||
for event in events:
|
||||
if (
|
||||
(membership and event['content'].get("membership") != membership) or
|
||||
(not_membership and event['content'].get("membership") == not_membership)
|
||||
if (membership and event["content"].get("membership") != membership) or (
|
||||
not_membership and event["content"].get("membership") == not_membership
|
||||
):
|
||||
continue
|
||||
chunk.append(event)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"chunk": chunk
|
||||
}))
|
||||
defer.returnValue((200, {"chunk": chunk}))
|
||||
|
||||
|
||||
# deprecated in favour of /members?membership=join?
|
||||
|
@ -466,12 +451,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
users_with_profile = yield self.message_handler.get_joined_members(
|
||||
requester, room_id,
|
||||
requester, room_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"joined": users_with_profile,
|
||||
}))
|
||||
defer.returnValue((200, {"joined": users_with_profile}))
|
||||
|
||||
|
||||
# TODO: Needs better unit testing
|
||||
|
@ -486,9 +469,7 @@ class RoomMessageListRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10,
|
||||
)
|
||||
pagination_config = PaginationConfig.from_request(request, default_limit=10)
|
||||
as_client_event = b"raw" not in request.args
|
||||
filter_bytes = parse_string(request, b"filter", encoding=None)
|
||||
if filter_bytes:
|
||||
|
@ -544,9 +525,7 @@ class RoomInitialSyncRestServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
content = yield self.initial_sync_handler.room_initial_sync(
|
||||
room_id=room_id,
|
||||
requester=requester,
|
||||
pagin_config=pagination_config,
|
||||
room_id=room_id, requester=requester, pagin_config=pagination_config
|
||||
)
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
@ -603,30 +582,24 @@ class RoomEventContextServlet(RestServlet):
|
|||
event_filter = None
|
||||
|
||||
results = yield self.room_context_handler.get_event_context(
|
||||
requester.user,
|
||||
room_id,
|
||||
event_id,
|
||||
limit,
|
||||
event_filter,
|
||||
requester.user, room_id, event_id, limit, event_filter
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise SynapseError(
|
||||
404, "Event not found.", errcode=Codes.NOT_FOUND
|
||||
)
|
||||
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
results["events_before"] = yield self._event_serializer.serialize_events(
|
||||
results["events_before"], time_now,
|
||||
results["events_before"], time_now
|
||||
)
|
||||
results["event"] = yield self._event_serializer.serialize_event(
|
||||
results["event"], time_now,
|
||||
results["event"], time_now
|
||||
)
|
||||
results["events_after"] = yield self._event_serializer.serialize_events(
|
||||
results["events_after"], time_now,
|
||||
results["events_after"], time_now
|
||||
)
|
||||
results["state"] = yield self._event_serializer.serialize_events(
|
||||
results["state"], time_now,
|
||||
results["state"], time_now
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
@ -639,20 +612,14 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=False,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
|
||||
yield self.room_member_handler.forget(
|
||||
user=requester.user,
|
||||
room_id=room_id,
|
||||
)
|
||||
yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -664,7 +631,6 @@ class RoomForgetRestServlet(TransactionRestServlet):
|
|||
|
||||
# TODO: Needs unit testing
|
||||
class RoomMembershipRestServlet(TransactionRestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomMembershipRestServlet, self).__init__(hs)
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
|
@ -672,20 +638,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/[invite|join|leave]
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
|
||||
"(?P<membership_action>join|invite|leave|ban|unban|kick)")
|
||||
PATTERNS = (
|
||||
"/rooms/(?P<room_id>[^/]*)/"
|
||||
"(?P<membership_action>join|invite|leave|ban|unban|kick)"
|
||||
)
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True,
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if requester.is_guest and membership_action not in {
|
||||
Membership.JOIN,
|
||||
Membership.LEAVE
|
||||
Membership.LEAVE,
|
||||
}:
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
|
@ -704,7 +669,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
content["address"],
|
||||
content["id_server"],
|
||||
requester,
|
||||
txn_id
|
||||
txn_id,
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
return
|
||||
|
@ -715,8 +680,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
|||
target = UserID.from_string(content["user_id"])
|
||||
|
||||
event_content = None
|
||||
if 'reason' in content and membership_action in ['kick', 'ban']:
|
||||
event_content = {'reason': content['reason']}
|
||||
if "reason" in content and membership_action in ["kick", "ban"]:
|
||||
event_content = {"reason": content["reason"]}
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
|
@ -755,7 +720,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
|
||||
register_txn_path(self, PATTERNS, http_server)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -817,9 +782,7 @@ class RoomTypingRestServlet(RestServlet):
|
|||
)
|
||||
else:
|
||||
yield self.typing_handler.stopped_typing(
|
||||
target_user=target_user,
|
||||
auth_user=requester.user,
|
||||
room_id=room_id,
|
||||
target_user=target_user, auth_user=requester.user, room_id=room_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -841,9 +804,7 @@ class SearchRestServlet(RestServlet):
|
|||
|
||||
batch = parse_string(request, "next_batch")
|
||||
results = yield self.handlers.search_handler.search(
|
||||
requester.user,
|
||||
content,
|
||||
batch,
|
||||
requester.user, content, batch
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
@ -879,20 +840,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
|
|||
with_get: True to also register respective GET paths for the PUTs.
|
||||
"""
|
||||
http_server.register_paths(
|
||||
"POST",
|
||||
client_patterns(regex_string + "$", v1=True),
|
||||
servlet.on_POST
|
||||
"POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST
|
||||
)
|
||||
http_server.register_paths(
|
||||
"PUT",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
servlet.on_PUT
|
||||
servlet.on_PUT,
|
||||
)
|
||||
if with_get:
|
||||
http_server.register_paths(
|
||||
"GET",
|
||||
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
|
||||
servlet.on_GET
|
||||
servlet.on_GET,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -34,8 +34,7 @@ class VoipRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
self.hs.config.turn_allow_guests
|
||||
request, self.hs.config.turn_allow_guests
|
||||
)
|
||||
|
||||
turnUris = self.hs.config.turn_uris
|
||||
|
@ -49,9 +48,7 @@ class VoipRestServlet(RestServlet):
|
|||
username = "%d:%s" % (expiry, requester.user.to_string())
|
||||
|
||||
mac = hmac.new(
|
||||
turnSecret.encode(),
|
||||
msg=username.encode(),
|
||||
digestmod=hashlib.sha1
|
||||
turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1
|
||||
)
|
||||
# We need to use standard padded base64 encoding here
|
||||
# encode_base64 because we need to add the standard padding to get the
|
||||
|
@ -65,12 +62,17 @@ class VoipRestServlet(RestServlet):
|
|||
else:
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
defer.returnValue((200, {
|
||||
'username': username,
|
||||
'password': password,
|
||||
'ttl': userLifetime / 1000,
|
||||
'uris': turnUris,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"username": username,
|
||||
"password": password,
|
||||
"ttl": userLifetime / 1000,
|
||||
"uris": turnUris,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
|
|
@ -52,11 +52,11 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
|
|||
def set_timeline_upper_limit(filter_json, filter_timeline_limit):
|
||||
if filter_timeline_limit < 0:
|
||||
return # no upper limits
|
||||
timeline = filter_json.get('room', {}).get('timeline', {})
|
||||
if 'limit' in timeline:
|
||||
filter_json['room']['timeline']["limit"] = min(
|
||||
filter_json['room']['timeline']['limit'],
|
||||
filter_timeline_limit)
|
||||
timeline = filter_json.get("room", {}).get("timeline", {})
|
||||
if "limit" in timeline:
|
||||
filter_json["room"]["timeline"]["limit"] = min(
|
||||
filter_json["room"]["timeline"]["limit"], filter_timeline_limit
|
||||
)
|
||||
|
||||
|
||||
def interactive_auth_handler(orig):
|
||||
|
@ -74,10 +74,12 @@ def interactive_auth_handler(orig):
|
|||
# ...
|
||||
yield self.auth_handler.check_auth
|
||||
"""
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
res = defer.maybeDeferred(orig, *args, **kwargs)
|
||||
res.addErrback(_catch_incomplete_interactive_auth)
|
||||
return res
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
|
||||
from six.moves import http_client
|
||||
|
||||
|
@ -53,6 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
|
||||
if self.config.email_password_reset_behaviour == "local":
|
||||
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||
|
||||
templates = load_jinja2_templates(
|
||||
config=hs.config,
|
||||
template_html_name=hs.config.email_password_reset_template_html,
|
||||
|
@ -68,13 +68,17 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
if self.config.email_password_reset_behaviour == "off":
|
||||
raise SynapseError(400, "Password resets have been disabled on this server")
|
||||
if self.config.password_resets_were_disabled_due_to_email_config:
|
||||
logger.warn(
|
||||
"User password resets have been disabled due to lack of email config"
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "Email-based password resets have been disabled on this server"
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(body, [
|
||||
'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
|
||||
|
||||
# Extract params from body
|
||||
client_secret = body["client_secret"]
|
||||
|
@ -90,24 +94,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', email,
|
||||
"email", email
|
||||
)
|
||||
|
||||
if existingUid is None:
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
||||
if self.config.email_password_reset_behaviour == "remote":
|
||||
if 'id_server' not in body:
|
||||
if "id_server" not in body:
|
||||
raise SynapseError(400, "Missing 'id_server' param in body")
|
||||
|
||||
# Have the identity server handle the password reset flow
|
||||
ret = yield self.identity_handler.requestEmailToken(
|
||||
body["id_server"], email, client_secret, send_attempt, next_link,
|
||||
body["id_server"], email, client_secret, send_attempt, next_link
|
||||
)
|
||||
else:
|
||||
# Send password reset emails from Synapse
|
||||
sid = yield self.send_password_reset(
|
||||
email, client_secret, send_attempt, next_link,
|
||||
email, client_secret, send_attempt, next_link
|
||||
)
|
||||
|
||||
# Wrap the session id in a JSON object
|
||||
|
@ -116,13 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
defer.returnValue((200, ret))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_password_reset(
|
||||
self,
|
||||
email,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
next_link=None,
|
||||
):
|
||||
def send_password_reset(self, email, client_secret, send_attempt, next_link=None):
|
||||
"""Send a password reset email
|
||||
|
||||
Args:
|
||||
|
@ -139,14 +137,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
# Check that this email/client_secret/send_attempt combo is new or
|
||||
# greater than what we've seen previously
|
||||
session = yield self.datastore.get_threepid_validation_session(
|
||||
"email", client_secret, address=email, validated=False,
|
||||
"email", client_secret, address=email, validated=False
|
||||
)
|
||||
|
||||
# Check to see if a session already exists and that it is not yet
|
||||
# marked as validated
|
||||
if session and session.get("validated_at") is None:
|
||||
session_id = session['session_id']
|
||||
last_send_attempt = session['last_send_attempt']
|
||||
session_id = session["session_id"]
|
||||
last_send_attempt = session["last_send_attempt"]
|
||||
|
||||
# Check that the send_attempt is higher than previous attempts
|
||||
if send_attempt <= last_send_attempt:
|
||||
|
@ -164,22 +162,27 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
|||
# and session_id
|
||||
try:
|
||||
yield self.mailer.send_password_reset_mail(
|
||||
email, token, client_secret, session_id,
|
||||
email, token, client_secret, session_id
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error sending a password reset email to %s", email,
|
||||
)
|
||||
logger.exception("Error sending a password reset email to %s", email)
|
||||
raise SynapseError(
|
||||
500, "An error was encountered when sending the password reset email"
|
||||
)
|
||||
|
||||
token_expires = (self.hs.clock.time_msec() +
|
||||
self.config.email_validation_token_lifetime)
|
||||
token_expires = (
|
||||
self.hs.clock.time_msec() + self.config.email_validation_token_lifetime
|
||||
)
|
||||
|
||||
yield self.datastore.start_or_continue_validation_session(
|
||||
"email", email, session_id, client_secret, send_attempt,
|
||||
next_link, token, token_expires,
|
||||
"email",
|
||||
email,
|
||||
session_id,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
next_link,
|
||||
token,
|
||||
token_expires,
|
||||
)
|
||||
|
||||
defer.returnValue(session_id)
|
||||
|
@ -196,17 +199,14 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
if not self.config.email_password_reset_behaviour == "off":
|
||||
raise SynapseError(400, "Password resets have been disabled on this server")
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number', 'send_attempt',
|
||||
])
|
||||
assert_params_in_dict(
|
||||
body,
|
||||
["id_server", "client_secret", "country", "phone_number", "send_attempt"],
|
||||
)
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
|
@ -215,9 +215,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
|
||||
|
||||
if existingUid is None:
|
||||
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
@ -228,9 +226,10 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
|||
|
||||
class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
"""Handles 3PID validation token submission"""
|
||||
PATTERNS = [
|
||||
re.compile("^/_synapse/password_reset/(?P<medium>[^/]*)/submit_token/*$"),
|
||||
]
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -248,8 +247,15 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
|||
def on_GET(self, request, medium):
|
||||
if medium != "email":
|
||||
raise SynapseError(
|
||||
400,
|
||||
"This medium is currently not supported for password resets",
|
||||
400, "This medium is currently not supported for password resets"
|
||||
)
|
||||
if self.config.email_password_reset_behaviour == "off":
|
||||
if self.config.password_resets_were_disabled_due_to_email_config:
|
||||
logger.warn(
|
||||
"User password resets have been disabled due to lack of email config"
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "Email-based password resets have been disabled on this server"
|
||||
)
|
||||
|
||||
sid = parse_string(request, "sid")
|
||||
|
@ -260,10 +266,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
|||
try:
|
||||
# Mark the session as valid
|
||||
next_link = yield self.datastore.validate_threepid_session(
|
||||
sid,
|
||||
client_secret,
|
||||
token,
|
||||
self.clock.time_msec(),
|
||||
sid, client_secret, token, self.clock.time_msec()
|
||||
)
|
||||
|
||||
# Perform a 302 redirect if next_link is set
|
||||
|
@ -286,13 +289,11 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
|||
html = self.load_jinja2_template(
|
||||
self.config.email_template_dir,
|
||||
self.config.email_password_reset_failure_template,
|
||||
template_vars={
|
||||
"failure_reason": e.msg,
|
||||
}
|
||||
template_vars={"failure_reason": e.msg},
|
||||
)
|
||||
request.setResponseCode(e.code)
|
||||
|
||||
request.write(html.encode('utf-8'))
|
||||
request.write(html.encode("utf-8"))
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
||||
|
@ -318,20 +319,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
|||
def on_POST(self, request, medium):
|
||||
if medium != "email":
|
||||
raise SynapseError(
|
||||
400,
|
||||
"This medium is currently not supported for password resets",
|
||||
400, "This medium is currently not supported for password resets"
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, [
|
||||
'sid', 'client_secret', 'token',
|
||||
])
|
||||
assert_params_in_dict(body, ["sid", "client_secret", "token"])
|
||||
|
||||
valid, _ = yield self.datastore.validate_threepid_validation_token(
|
||||
body['sid'],
|
||||
body['client_secret'],
|
||||
body['token'],
|
||||
self.clock.time_msec(),
|
||||
body["sid"], body["client_secret"], body["token"], self.clock.time_msec()
|
||||
)
|
||||
response_code = 200 if valid else 400
|
||||
|
||||
|
@ -367,29 +362,30 @@ class PasswordRestServlet(RestServlet):
|
|||
if self.auth.has_access_token(request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
params = yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
user_id = requester.user.to_string()
|
||||
else:
|
||||
requester = None
|
||||
result, params, _ = yield self.auth_handler.check_auth(
|
||||
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
|
||||
body, self.hs.get_ip_from_request(request),
|
||||
body,
|
||||
self.hs.get_ip_from_request(request),
|
||||
password_servlet=True,
|
||||
)
|
||||
|
||||
if LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
if 'medium' not in threepid or 'address' not in threepid:
|
||||
if "medium" not in threepid or "address" not in threepid:
|
||||
raise SynapseError(500, "Malformed threepid")
|
||||
if threepid['medium'] == 'email':
|
||||
if threepid["medium"] == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
threepid['address'] = threepid['address'].lower()
|
||||
threepid["address"] = threepid["address"].lower()
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
||||
threepid['medium'], threepid['address']
|
||||
threepid["medium"], threepid["address"]
|
||||
)
|
||||
if not threepid_user_id:
|
||||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||
|
@ -399,11 +395,9 @@ class PasswordRestServlet(RestServlet):
|
|||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params['new_password']
|
||||
new_password = params["new_password"]
|
||||
|
||||
yield self._set_password_handler.set_password(
|
||||
user_id, new_password, requester
|
||||
)
|
||||
yield self._set_password_handler.set_password(user_id, new_password, requester)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -438,25 +432,22 @@ class DeactivateAccountRestServlet(RestServlet):
|
|||
# allow ASes to dectivate their own users
|
||||
if requester.app_service:
|
||||
yield self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(), erase,
|
||||
requester.user.to_string(), erase
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
result = yield self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(), erase,
|
||||
id_server=body.get("id_server"),
|
||||
requester.user.to_string(), erase, id_server=body.get("id_server")
|
||||
)
|
||||
if result:
|
||||
id_server_unbind_result = "success"
|
||||
else:
|
||||
id_server_unbind_result = "no-support"
|
||||
|
||||
defer.returnValue((200, {
|
||||
"id_server_unbind_result": id_server_unbind_result,
|
||||
}))
|
||||
defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
|
||||
|
||||
|
||||
class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
|
@ -472,11 +463,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(
|
||||
body,
|
||||
['id_server', 'client_secret', 'email', 'send_attempt'],
|
||||
body, ["id_server", "client_secret", "email", "send_attempt"]
|
||||
)
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
if not check_3pid_allowed(self.hs, "email", body["email"]):
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Your email domain is not authorized on this server",
|
||||
|
@ -484,7 +474,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
"email", body["email"]
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
|
@ -506,12 +496,12 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number', 'send_attempt',
|
||||
])
|
||||
assert_params_in_dict(
|
||||
body,
|
||||
["id_server", "client_secret", "country", "phone_number", "send_attempt"],
|
||||
)
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
|
@ -520,9 +510,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
|||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
|
||||
|
||||
if existingUid is not None:
|
||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||
|
@ -546,18 +534,16 @@ class ThreepidRestServlet(RestServlet):
|
|||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepids = yield self.datastore.user_get_threepids(
|
||||
requester.user.to_string()
|
||||
)
|
||||
threepids = yield self.datastore.user_get_threepids(requester.user.to_string())
|
||||
|
||||
defer.returnValue((200, {'threepids': threepids}))
|
||||
defer.returnValue((200, {"threepids": threepids}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
threePidCreds = body.get('threePidCreds')
|
||||
threePidCreds = body.get('three_pid_creds', threePidCreds)
|
||||
threePidCreds = body.get("threePidCreds")
|
||||
threePidCreds = body.get("three_pid_creds", threePidCreds)
|
||||
if threePidCreds is None:
|
||||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||
|
||||
|
@ -567,30 +553,20 @@ class ThreepidRestServlet(RestServlet):
|
|||
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
||||
|
||||
if not threepid:
|
||||
raise SynapseError(
|
||||
400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
|
||||
)
|
||||
raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
|
||||
|
||||
for reqd in ['medium', 'address', 'validated_at']:
|
||||
for reqd in ["medium", "address", "validated_at"]:
|
||||
if reqd not in threepid:
|
||||
logger.warn("Couldn't add 3pid: invalid response from ID server")
|
||||
raise SynapseError(500, "Invalid response from ID Server")
|
||||
|
||||
yield self.auth_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
|
||||
)
|
||||
|
||||
if 'bind' in body and body['bind']:
|
||||
logger.debug(
|
||||
"Binding threepid %s to %s",
|
||||
threepid, user_id
|
||||
)
|
||||
yield self.identity_handler.bind_threepid(
|
||||
threePidCreds, user_id
|
||||
)
|
||||
if "bind" in body and body["bind"]:
|
||||
logger.debug("Binding threepid %s to %s", threepid, user_id)
|
||||
yield self.identity_handler.bind_threepid(threePidCreds, user_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -606,14 +582,14 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ['medium', 'address'])
|
||||
assert_params_in_dict(body, ["medium", "address"])
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
try:
|
||||
ret = yield self.auth_handler.delete_threepid(
|
||||
user_id, body['medium'], body['address'], body.get("id_server"),
|
||||
user_id, body["medium"], body["address"], body.get("id_server")
|
||||
)
|
||||
except Exception:
|
||||
# NB. This endpoint should succeed if there is nothing to
|
||||
|
@ -627,9 +603,7 @@ class ThreepidDeleteRestServlet(RestServlet):
|
|||
else:
|
||||
id_server_unbind_result = "no-support"
|
||||
|
||||
defer.returnValue((200, {
|
||||
"id_server_unbind_result": id_server_unbind_result,
|
||||
}))
|
||||
defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
|
||||
|
||||
|
||||
class WhoamiRestServlet(RestServlet):
|
||||
|
@ -643,7 +617,7 @@ class WhoamiRestServlet(RestServlet):
|
|||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
defer.returnValue((200, {'user_id': requester.user.to_string()}))
|
||||
defer.returnValue((200, {"user_id": requester.user.to_string()}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -30,6 +30,7 @@ class AccountDataServlet(RestServlet):
|
|||
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
|
||||
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
|
||||
)
|
||||
|
@ -52,9 +53,7 @@ class AccountDataServlet(RestServlet):
|
|||
user_id, account_data_type, body
|
||||
)
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"account_data_key", max_id, users=[user_id]
|
||||
)
|
||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -65,7 +64,7 @@ class AccountDataServlet(RestServlet):
|
|||
raise AuthError(403, "Cannot get account data for other users.")
|
||||
|
||||
event = yield self.store.get_global_account_data_by_type_for_user(
|
||||
account_data_type, user_id,
|
||||
account_data_type, user_id
|
||||
)
|
||||
|
||||
if event is None:
|
||||
|
@ -79,6 +78,7 @@ class RoomAccountDataServlet(RestServlet):
|
|||
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
|
||||
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/user/(?P<user_id>[^/]*)"
|
||||
"/rooms/(?P<room_id>[^/]*)"
|
||||
|
@ -103,16 +103,14 @@ class RoomAccountDataServlet(RestServlet):
|
|||
raise SynapseError(
|
||||
405,
|
||||
"Cannot set m.fully_read through this API."
|
||||
" Use /rooms/!roomId:server.name/read_markers"
|
||||
" Use /rooms/!roomId:server.name/read_markers",
|
||||
)
|
||||
|
||||
max_id = yield self.store.add_account_data_to_room(
|
||||
user_id, room_id, account_data_type, body
|
||||
)
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"account_data_key", max_id, users=[user_id]
|
||||
)
|
||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -123,7 +121,7 @@ class RoomAccountDataServlet(RestServlet):
|
|||
raise AuthError(403, "Cannot get account data for other users.")
|
||||
|
||||
event = yield self.store.get_account_data_for_room_and_type(
|
||||
user_id, room_id, account_data_type,
|
||||
user_id, room_id, account_data_type
|
||||
)
|
||||
|
||||
if event is None:
|
||||
|
|
|
@ -28,7 +28,9 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class AccountValidityRenewServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/account_validity/renew$")
|
||||
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
|
||||
SUCCESS_HTML = (
|
||||
b"<html><body>Your account has been successfully renewed.</body><html>"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -47,13 +49,13 @@ class AccountValidityRenewServlet(RestServlet):
|
|||
raise SynapseError(400, "Missing renewal token")
|
||||
renewal_token = request.args[b"token"][0]
|
||||
|
||||
yield self.account_activity_handler.renew_account(renewal_token.decode('utf8'))
|
||||
yield self.account_activity_handler.renew_account(renewal_token.decode("utf8"))
|
||||
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Content-Length", b"%d" % (
|
||||
len(AccountValidityRenewServlet.SUCCESS_HTML),
|
||||
))
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),)
|
||||
)
|
||||
request.write(AccountValidityRenewServlet.SUCCESS_HTML)
|
||||
finish_request(request)
|
||||
defer.returnValue(None)
|
||||
|
@ -77,7 +79,9 @@ class AccountValiditySendMailServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
if not self.account_validity.renew_by_email_enabled:
|
||||
raise AuthError(403, "Account renewal via email is disabled on this server.")
|
||||
raise AuthError(
|
||||
403, "Account renewal via email is disabled on this server."
|
||||
)
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request, allow_expired=True)
|
||||
user_id = requester.user.to_string()
|
||||
|
|
|
@ -122,6 +122,7 @@ class AuthRestServlet(RestServlet):
|
|||
cannot be handled in the normal flow (with requests to the same endpoint).
|
||||
Current use is for web fallback auth.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -138,11 +139,10 @@ class AuthRestServlet(RestServlet):
|
|||
|
||||
if stagetype == LoginType.RECAPTCHA:
|
||||
html = RECAPTCHA_TEMPLATE % {
|
||||
'session': session,
|
||||
'myurl': "%s/r0/auth/%s/fallback/web" % (
|
||||
CLIENT_API_PREFIX, LoginType.RECAPTCHA
|
||||
),
|
||||
'sitekey': self.hs.config.recaptcha_public_key,
|
||||
"session": session,
|
||||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
|
||||
"sitekey": self.hs.config.recaptcha_public_key,
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
|
@ -154,14 +154,11 @@ class AuthRestServlet(RestServlet):
|
|||
return None
|
||||
elif stagetype == LoginType.TERMS:
|
||||
html = TERMS_TEMPLATE % {
|
||||
'session': session,
|
||||
'terms_url': "%s_matrix/consent?v=%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
self.hs.config.user_consent_version,
|
||||
),
|
||||
'myurl': "%s/r0/auth/%s/fallback/web" % (
|
||||
CLIENT_API_PREFIX, LoginType.TERMS
|
||||
),
|
||||
"session": session,
|
||||
"terms_url": "%s_matrix/consent?v=%s"
|
||||
% (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
|
||||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
|
@ -187,26 +184,20 @@ class AuthRestServlet(RestServlet):
|
|||
if not response:
|
||||
raise SynapseError(400, "No captcha response supplied")
|
||||
|
||||
authdict = {
|
||||
'response': response,
|
||||
'session': session,
|
||||
}
|
||||
authdict = {"response": response, "session": session}
|
||||
|
||||
success = yield self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA,
|
||||
authdict,
|
||||
self.hs.get_ip_from_request(request)
|
||||
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if success:
|
||||
html = SUCCESS_TEMPLATE
|
||||
else:
|
||||
html = RECAPTCHA_TEMPLATE % {
|
||||
'session': session,
|
||||
'myurl': "%s/r0/auth/%s/fallback/web" % (
|
||||
CLIENT_API_PREFIX, LoginType.RECAPTCHA
|
||||
),
|
||||
'sitekey': self.hs.config.recaptcha_public_key,
|
||||
"session": session,
|
||||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
|
||||
"sitekey": self.hs.config.recaptcha_public_key,
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
|
@ -218,31 +209,28 @@ class AuthRestServlet(RestServlet):
|
|||
|
||||
defer.returnValue(None)
|
||||
elif stagetype == LoginType.TERMS:
|
||||
if ('session' not in request.args or
|
||||
len(request.args['session'])) == 0:
|
||||
if ("session" not in request.args or len(request.args["session"])) == 0:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
||||
session = request.args['session'][0]
|
||||
authdict = {'session': session}
|
||||
session = request.args["session"][0]
|
||||
authdict = {"session": session}
|
||||
|
||||
success = yield self.auth_handler.add_oob_auth(
|
||||
LoginType.TERMS,
|
||||
authdict,
|
||||
self.hs.get_ip_from_request(request)
|
||||
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if success:
|
||||
html = SUCCESS_TEMPLATE
|
||||
else:
|
||||
html = TERMS_TEMPLATE % {
|
||||
'session': session,
|
||||
'terms_url': "%s_matrix/consent?v=%s" % (
|
||||
"session": session,
|
||||
"terms_url": "%s_matrix/consent?v=%s"
|
||||
% (
|
||||
self.hs.config.public_baseurl,
|
||||
self.hs.config.user_consent_version,
|
||||
),
|
||||
'myurl': "%s/r0/auth/%s/fallback/web" % (
|
||||
CLIENT_API_PREFIX, LoginType.TERMS
|
||||
),
|
||||
"myurl": "%s/r0/auth/%s/fallback/web"
|
||||
% (CLIENT_API_PREFIX, LoginType.TERMS),
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
|
|
|
@ -56,6 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
API for bulk deletion of devices. Accepts a JSON object with a devices
|
||||
key which lists the device_ids to delete. Requires user interactive auth.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/delete_devices")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -84,12 +85,11 @@ class DeleteDevicesRestServlet(RestServlet):
|
|||
assert_params_in_dict(body, ["devices"])
|
||||
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
yield self.device_handler.delete_devices(
|
||||
requester.user.to_string(),
|
||||
body['devices'],
|
||||
requester.user.to_string(), body["devices"]
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -112,8 +112,7 @@ class DeviceRestServlet(RestServlet):
|
|||
def on_GET(self, request, device_id):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
device = yield self.device_handler.get_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
requester.user.to_string(), device_id
|
||||
)
|
||||
defer.returnValue((200, device))
|
||||
|
||||
|
@ -134,12 +133,10 @@ class DeviceRestServlet(RestServlet):
|
|||
raise
|
||||
|
||||
yield self.auth_handler.validate_user_via_ui_auth(
|
||||
requester, body, self.hs.get_ip_from_request(request),
|
||||
requester, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
yield self.device_handler.delete_device(
|
||||
requester.user.to_string(), device_id,
|
||||
)
|
||||
yield self.device_handler.delete_device(requester.user.to_string(), device_id)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -148,9 +145,7 @@ class DeviceRestServlet(RestServlet):
|
|||
|
||||
body = parse_json_object_from_request(request)
|
||||
yield self.device_handler.update_device(
|
||||
requester.user.to_string(),
|
||||
device_id,
|
||||
body
|
||||
requester.user.to_string(), device_id, body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -53,8 +53,7 @@ class GetFilterRestServlet(RestServlet):
|
|||
|
||||
try:
|
||||
filter = yield self.filtering.get_user_filter(
|
||||
user_localpart=target_user.localpart,
|
||||
filter_id=filter_id,
|
||||
user_localpart=target_user.localpart, filter_id=filter_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, filter.get_filter_json()))
|
||||
|
@ -84,14 +83,10 @@ class CreateFilterRestServlet(RestServlet):
|
|||
raise AuthError(403, "Can only create filters for local users")
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
set_timeline_upper_limit(
|
||||
content,
|
||||
self.hs.config.filter_timeline_limit
|
||||
)
|
||||
set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
|
||||
|
||||
filter_id = yield self.filtering.add_user_filter(
|
||||
user_localpart=target_user.localpart,
|
||||
user_filter=content,
|
||||
user_localpart=target_user.localpart, user_filter=content
|
||||
)
|
||||
|
||||
defer.returnValue((200, {"filter_id": str(filter_id)}))
|
||||
|
|
|
@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
|
|||
class GroupServlet(RestServlet):
|
||||
"""Get the group profile
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -43,8 +44,7 @@ class GroupServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
group_description = yield self.groups_handler.get_group_profile(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, group_description))
|
||||
|
@ -56,7 +56,7 @@ class GroupServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
yield self.groups_handler.update_group_profile(
|
||||
group_id, requester_user_id, content,
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
@ -65,6 +65,7 @@ class GroupServlet(RestServlet):
|
|||
class GroupSummaryServlet(RestServlet):
|
||||
"""Get the full group summary
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -79,8 +80,7 @@ class GroupSummaryServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
get_group_summary = yield self.groups_handler.get_group_summary(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, get_group_summary))
|
||||
|
@ -93,6 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
- /groups/:group/summary/rooms/:room_id
|
||||
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/categories/(?P<category_id>[^/]+))?"
|
||||
|
@ -112,7 +113,8 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
group_id,
|
||||
requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
|
@ -126,9 +128,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
group_id, requester_user_id, room_id=room_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -137,6 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
|||
class GroupCategoryServlet(RestServlet):
|
||||
"""Get/add/update/delete a group category
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||
)
|
||||
|
@ -153,8 +154,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_category(
|
||||
group_id, requester_user_id,
|
||||
category_id=category_id,
|
||||
group_id, requester_user_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
@ -166,9 +166,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_category(
|
||||
group_id, requester_user_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
group_id, requester_user_id, category_id=category_id, content=content
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -179,8 +177,7 @@ class GroupCategoryServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_category(
|
||||
group_id, requester_user_id,
|
||||
category_id=category_id,
|
||||
group_id, requester_user_id, category_id=category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -189,9 +186,8 @@ class GroupCategoryServlet(RestServlet):
|
|||
class GroupCategoriesServlet(RestServlet):
|
||||
"""Get all group categories
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCategoriesServlet, self).__init__()
|
||||
|
@ -205,7 +201,7 @@ class GroupCategoriesServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_categories(
|
||||
group_id, requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
@ -214,9 +210,8 @@ class GroupCategoriesServlet(RestServlet):
|
|||
class GroupRoleServlet(RestServlet):
|
||||
"""Get/add/update/delete a group role
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRoleServlet, self).__init__()
|
||||
|
@ -230,8 +225,7 @@ class GroupRoleServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_role(
|
||||
group_id, requester_user_id,
|
||||
role_id=role_id,
|
||||
group_id, requester_user_id, role_id=role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
@ -243,9 +237,7 @@ class GroupRoleServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_role(
|
||||
group_id, requester_user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
group_id, requester_user_id, role_id=role_id, content=content
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -256,8 +248,7 @@ class GroupRoleServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_role(
|
||||
group_id, requester_user_id,
|
||||
role_id=role_id,
|
||||
group_id, requester_user_id, role_id=role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -266,9 +257,8 @@ class GroupRoleServlet(RestServlet):
|
|||
class GroupRolesServlet(RestServlet):
|
||||
"""Get all group roles
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRolesServlet, self).__init__()
|
||||
|
@ -282,7 +272,7 @@ class GroupRolesServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_roles(
|
||||
group_id, requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
@ -295,6 +285,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
|||
- /groups/:group/summary/users/:room_id
|
||||
- /groups/:group/summary/roles/:role/users/:user_id
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/roles/(?P<role_id>[^/]+))?"
|
||||
|
@ -314,7 +305,8 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
group_id,
|
||||
requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
|
@ -328,9 +320,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
group_id, requester_user_id, user_id=user_id, role_id=role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
@ -339,6 +329,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
|||
class GroupRoomServlet(RestServlet):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -352,7 +343,9 @@ class GroupRoomServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
|
||||
result = yield self.groups_handler.get_rooms_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -360,6 +353,7 @@ class GroupRoomServlet(RestServlet):
|
|||
class GroupUsersServlet(RestServlet):
|
||||
"""Get all users in a group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -373,7 +367,9 @@ class GroupUsersServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
|
||||
result = yield self.groups_handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -381,6 +377,7 @@ class GroupUsersServlet(RestServlet):
|
|||
class GroupInvitedUsersServlet(RestServlet):
|
||||
"""Get users invited to a group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -395,8 +392,7 @@ class GroupInvitedUsersServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_invited_users_in_group(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -405,6 +401,7 @@ class GroupInvitedUsersServlet(RestServlet):
|
|||
class GroupSettingJoinPolicyServlet(RestServlet):
|
||||
"""Set group join policy
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -420,9 +417,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
|||
content = parse_json_object_from_request(request)
|
||||
|
||||
result = yield self.groups_handler.set_group_join_policy(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
content,
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -431,6 +426,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
|||
class GroupCreateServlet(RestServlet):
|
||||
"""Create a group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/create_group$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -451,9 +447,7 @@ class GroupCreateServlet(RestServlet):
|
|||
group_id = GroupID(localpart, self.server_name).to_string()
|
||||
|
||||
result = yield self.groups_handler.create_group(
|
||||
group_id,
|
||||
requester_user_id,
|
||||
content,
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -462,6 +456,7 @@ class GroupCreateServlet(RestServlet):
|
|||
class GroupAdminRoomsServlet(RestServlet):
|
||||
"""Add a room to the group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
@ -479,7 +474,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.add_room_to_group(
|
||||
group_id, requester_user_id, room_id, content,
|
||||
group_id, requester_user_id, room_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -490,7 +485,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
requester_user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.remove_room_from_group(
|
||||
group_id, requester_user_id, room_id,
|
||||
group_id, requester_user_id, room_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -499,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
|||
class GroupAdminRoomsConfigServlet(RestServlet):
|
||||
"""Update the config of a room in a group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
|
||||
"/config/(?P<config_key>[^/]*)$"
|
||||
|
@ -517,7 +513,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.update_room_in_group(
|
||||
group_id, requester_user_id, room_id, config_key, content,
|
||||
group_id, requester_user_id, room_id, config_key, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -526,6 +522,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
|||
class GroupAdminUsersInviteServlet(RestServlet):
|
||||
"""Invite a user to the group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
@ -546,7 +543,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
|||
content = parse_json_object_from_request(request)
|
||||
config = content.get("config", {})
|
||||
result = yield self.groups_handler.invite(
|
||||
group_id, user_id, requester_user_id, config,
|
||||
group_id, user_id, requester_user_id, config
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -555,6 +552,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
|||
class GroupAdminUsersKickServlet(RestServlet):
|
||||
"""Kick a user from the group
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
@ -572,7 +570,7 @@ class GroupAdminUsersKickServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
group_id, user_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -581,9 +579,8 @@ class GroupAdminUsersKickServlet(RestServlet):
|
|||
class GroupSelfLeaveServlet(RestServlet):
|
||||
"""Leave a joined group
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/leave$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfLeaveServlet, self).__init__()
|
||||
|
@ -598,7 +595,7 @@ class GroupSelfLeaveServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.remove_user_from_group(
|
||||
group_id, requester_user_id, requester_user_id, content,
|
||||
group_id, requester_user_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -607,9 +604,8 @@ class GroupSelfLeaveServlet(RestServlet):
|
|||
class GroupSelfJoinServlet(RestServlet):
|
||||
"""Attempt to join a group, or knock
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/join$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfJoinServlet, self).__init__()
|
||||
|
@ -624,7 +620,7 @@ class GroupSelfJoinServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.join_group(
|
||||
group_id, requester_user_id, content,
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -633,9 +629,8 @@ class GroupSelfJoinServlet(RestServlet):
|
|||
class GroupSelfAcceptInviteServlet(RestServlet):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfAcceptInviteServlet, self).__init__()
|
||||
|
@ -650,7 +645,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.accept_invite(
|
||||
group_id, requester_user_id, content,
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
@ -659,9 +654,8 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
|||
class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||
"""Update whether we publicise a users membership of a group
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfUpdatePublicityServlet, self).__init__()
|
||||
|
@ -676,9 +670,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
|||
|
||||
content = parse_json_object_from_request(request)
|
||||
publicise = content["publicise"]
|
||||
yield self.store.update_group_publicity(
|
||||
group_id, requester_user_id, publicise,
|
||||
)
|
||||
yield self.store.update_group_publicity(group_id, requester_user_id, publicise)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -686,9 +678,8 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
|||
class PublicisedGroupsForUserServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/publicised_groups/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUserServlet, self).__init__()
|
||||
|
@ -701,9 +692,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
|||
def on_GET(self, request, user_id):
|
||||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||
user_id
|
||||
)
|
||||
result = yield self.groups_handler.get_publicised_groups_for_user(user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -711,9 +700,8 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
|||
class PublicisedGroupsForUsersServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/publicised_groups$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/publicised_groups$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUsersServlet, self).__init__()
|
||||
|
@ -729,9 +717,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
|||
content = parse_json_object_from_request(request)
|
||||
user_ids = content["user_ids"]
|
||||
|
||||
result = yield self.groups_handler.bulk_get_publicised_groups(
|
||||
user_ids
|
||||
)
|
||||
result = yield self.groups_handler.bulk_get_publicised_groups(user_ids)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
@ -739,9 +725,8 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
|||
class GroupsForUserServlet(RestServlet):
|
||||
"""Get all groups the logged in user is joined to
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/joined_groups$"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/joined_groups$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupsForUserServlet, self).__init__()
|
||||
|
|
|
@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet):
|
|||
},
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -76,18 +77,19 @@ class KeyUploadServlet(RestServlet):
|
|||
if device_id is not None:
|
||||
# passing the device_id here is deprecated; however, we allow it
|
||||
# for now for compatibility with older clients.
|
||||
if (requester.device_id is not None and
|
||||
device_id != requester.device_id):
|
||||
logger.warning("Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id, device_id)
|
||||
if requester.device_id is not None and device_id != requester.device_id:
|
||||
logger.warning(
|
||||
"Client uploading keys for a different device "
|
||||
"(logged in as %s, uploading for %s)",
|
||||
requester.device_id,
|
||||
device_id,
|
||||
)
|
||||
else:
|
||||
device_id = requester.device_id
|
||||
|
||||
if device_id is None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"To upload keys, you must pass device_id when authenticating"
|
||||
400, "To upload keys, you must pass device_id when authenticating"
|
||||
)
|
||||
|
||||
result = yield self.e2e_keys_handler.upload_keys_for_user(
|
||||
|
@ -159,6 +161,7 @@ class KeyChangesServlet(RestServlet):
|
|||
200 OK
|
||||
{ "changed": ["@foo:example.com"] }
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/keys/changes$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -184,9 +187,7 @@ class KeyChangesServlet(RestServlet):
|
|||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
results = yield self.device_handler.get_user_ids_changed(
|
||||
user_id, from_token,
|
||||
)
|
||||
results = yield self.device_handler.get_user_ids_changed(user_id, from_token)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
||||
|
@ -209,6 +210,7 @@ class OneTimeKeyServlet(RestServlet):
|
|||
} } } }
|
||||
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/keys/claim$")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -221,10 +223,7 @@ class OneTimeKeyServlet(RestServlet):
|
|||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
timeout = parse_integer(request, "timeout", 10 * 1000)
|
||||
body = parse_json_object_from_request(request)
|
||||
result = yield self.e2e_keys_handler.claim_one_time_keys(
|
||||
body,
|
||||
timeout,
|
||||
)
|
||||
result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ class NotificationsServlet(RestServlet):
|
|||
)
|
||||
|
||||
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
|
||||
user_id, 'm.read'
|
||||
user_id, "m.read"
|
||||
)
|
||||
|
||||
notif_event_ids = [pa["event_id"] for pa in push_actions]
|
||||
|
@ -67,11 +67,13 @@ class NotificationsServlet(RestServlet):
|
|||
"profile_tag": pa["profile_tag"],
|
||||
"actions": pa["actions"],
|
||||
"ts": pa["received_ts"],
|
||||
"event": (yield self._event_serializer.serialize_event(
|
||||
notif_events[pa["event_id"]],
|
||||
self.clock.time_msec(),
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
)),
|
||||
"event": (
|
||||
yield self._event_serializer.serialize_event(
|
||||
notif_events[pa["event_id"]],
|
||||
self.clock.time_msec(),
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
if pa["room_id"] not in receipts_by_room:
|
||||
|
@ -80,17 +82,15 @@ class NotificationsServlet(RestServlet):
|
|||
receipt = receipts_by_room[pa["room_id"]]
|
||||
|
||||
returned_pa["read"] = (
|
||||
receipt["topological_ordering"], receipt["stream_ordering"]
|
||||
) >= (
|
||||
pa["topological_ordering"], pa["stream_ordering"]
|
||||
)
|
||||
receipt["topological_ordering"],
|
||||
receipt["stream_ordering"],
|
||||
) >= (pa["topological_ordering"], pa["stream_ordering"])
|
||||
returned_push_actions.append(returned_pa)
|
||||
next_token = str(pa["stream_ordering"])
|
||||
|
||||
defer.returnValue((200, {
|
||||
"notifications": returned_push_actions,
|
||||
"next_token": next_token,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(200, {"notifications": returned_push_actions, "next_token": next_token})
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -56,9 +56,8 @@ class IdTokenServlet(RestServlet):
|
|||
"expires_in": 3600,
|
||||
}
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/user/(?P<user_id>[^/]*)/openid/request_token"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/openid/request_token")
|
||||
|
||||
EXPIRES_MS = 3600 * 1000
|
||||
|
||||
|
@ -84,12 +83,17 @@ class IdTokenServlet(RestServlet):
|
|||
|
||||
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"access_token": token,
|
||||
"token_type": "Bearer",
|
||||
"matrix_server_name": self.server_name,
|
||||
"expires_in": self.EXPIRES_MS / 1000,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"access_token": token,
|
||||
"token_type": "Bearer",
|
||||
"matrix_server_name": self.server_name,
|
||||
"expires_in": self.EXPIRES_MS / 1000,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
|
|||
room_id,
|
||||
"m.read",
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=read_event_id
|
||||
event_id=read_event_id,
|
||||
)
|
||||
|
||||
read_marker_event_id = body.get("m.fully_read", None)
|
||||
|
@ -56,7 +56,7 @@ class ReadMarkerRestServlet(RestServlet):
|
|||
yield self.read_marker_handler.received_client_read_marker(
|
||||
room_id,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=read_marker_event_id
|
||||
event_id=read_marker_event_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -49,10 +49,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
yield self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
||||
yield self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=event_id
|
||||
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
|
|
@ -52,6 +52,7 @@ from ._base import client_patterns, interactive_auth_handler
|
|||
if hasattr(hmac, "compare_digest"):
|
||||
compare_digest = hmac.compare_digest
|
||||
else:
|
||||
|
||||
def compare_digest(a, b):
|
||||
return a == b
|
||||
|
||||
|
@ -75,11 +76,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
assert_params_in_dict(
|
||||
body, ["id_server", "client_secret", "email", "send_attempt"]
|
||||
)
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
if not check_3pid_allowed(self.hs, "email", body["email"]):
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Your email domain is not authorized to register on this server",
|
||||
|
@ -87,7 +88,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
"email", body["email"]
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
|
@ -113,13 +114,12 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
assert_params_in_dict(body, [
|
||||
'id_server', 'client_secret',
|
||||
'country', 'phone_number',
|
||||
'send_attempt',
|
||||
])
|
||||
assert_params_in_dict(
|
||||
body,
|
||||
["id_server", "client_secret", "country", "phone_number", "send_attempt"],
|
||||
)
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
|
@ -129,7 +129,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
"msisdn", msisdn
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
|
@ -165,7 +165,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
|||
reject_limit=1,
|
||||
# Allow 1 request at a time
|
||||
concurrent_requests=1,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -212,7 +212,8 @@ class RegisterRestServlet(RestServlet):
|
|||
time_now = self.clock.time()
|
||||
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
client_addr, time_now_s=time_now,
|
||||
client_addr,
|
||||
time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration.per_second,
|
||||
burst_count=self.hs.config.rc_registration.burst_count,
|
||||
update=False,
|
||||
|
@ -220,7 +221,7 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now)),
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now))
|
||||
)
|
||||
|
||||
kind = b"user"
|
||||
|
@ -239,18 +240,22 @@ class RegisterRestServlet(RestServlet):
|
|||
# we do basic sanity checks here because the auth layer will store these
|
||||
# in sessions. Pull out the username/password provided to us.
|
||||
desired_password = None
|
||||
if 'password' in body:
|
||||
if (not isinstance(body['password'], string_types) or
|
||||
len(body['password']) > 512):
|
||||
if "password" in body:
|
||||
if (
|
||||
not isinstance(body["password"], string_types)
|
||||
or len(body["password"]) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid password")
|
||||
desired_password = body["password"]
|
||||
|
||||
desired_username = None
|
||||
if 'username' in body:
|
||||
if (not isinstance(body['username'], string_types) or
|
||||
len(body['username']) > 512):
|
||||
if "username" in body:
|
||||
if (
|
||||
not isinstance(body["username"], string_types)
|
||||
or len(body["username"]) > 512
|
||||
):
|
||||
raise SynapseError(400, "Invalid username")
|
||||
desired_username = body['username']
|
||||
desired_username = body["username"]
|
||||
|
||||
appservice = None
|
||||
if self.auth.has_access_token(request):
|
||||
|
@ -290,7 +295,7 @@ class RegisterRestServlet(RestServlet):
|
|||
desired_username = desired_username.lower()
|
||||
|
||||
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||
if 'mac' in body:
|
||||
if "mac" in body:
|
||||
# FIXME: Should we really be determining if this is shared secret
|
||||
# auth based purely on the 'mac' key?
|
||||
result = yield self._do_shared_secret_registration(
|
||||
|
@ -305,16 +310,13 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
guest_access_token = body.get("guest_access_token", None)
|
||||
|
||||
if (
|
||||
'initial_device_display_name' in body and
|
||||
'password' not in body
|
||||
):
|
||||
if "initial_device_display_name" in body and "password" not in body:
|
||||
# ignore 'initial_device_display_name' if sent without
|
||||
# a password to work around a client bug where it sent
|
||||
# the 'initial_device_display_name' param alone, wiping out
|
||||
# the original registration params
|
||||
logger.warn("Ignoring initial_device_display_name without password")
|
||||
del body['initial_device_display_name']
|
||||
del body["initial_device_display_name"]
|
||||
|
||||
session_id = self.auth_handler.get_session_id(body)
|
||||
registered_user_id = None
|
||||
|
@ -336,8 +338,8 @@ class RegisterRestServlet(RestServlet):
|
|||
|
||||
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||
# where we required 3PID for registration but the user didn't give one
|
||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||
require_email = "email" in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid
|
||||
|
||||
show_msisdn = True
|
||||
if self.hs.config.disable_msisdn_registration:
|
||||
|
@ -362,9 +364,9 @@ class RegisterRestServlet(RestServlet):
|
|||
if not require_email:
|
||||
flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
|
||||
# always let users provide both MSISDN & email
|
||||
flows.extend([
|
||||
[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||
])
|
||||
flows.extend(
|
||||
[[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]
|
||||
)
|
||||
else:
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
|
@ -378,9 +380,7 @@ class RegisterRestServlet(RestServlet):
|
|||
if not require_email or require_msisdn:
|
||||
flows.extend([[LoginType.MSISDN]])
|
||||
# always let users provide both MSISDN & email
|
||||
flows.extend([
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
|
||||
])
|
||||
flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]])
|
||||
|
||||
# Append m.login.terms to all flows if we're requiring consent
|
||||
if self.hs.config.user_consent_at_registration:
|
||||
|
@ -410,21 +410,20 @@ class RegisterRestServlet(RestServlet):
|
|||
if auth_result:
|
||||
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||
if login_type in auth_result:
|
||||
medium = auth_result[login_type]['medium']
|
||||
address = auth_result[login_type]['address']
|
||||
medium = auth_result[login_type]["medium"]
|
||||
address = auth_result[login_type]["address"]
|
||||
|
||||
if not check_3pid_allowed(self.hs, medium, address):
|
||||
raise SynapseError(
|
||||
403,
|
||||
"Third party identifiers (email/phone numbers)" +
|
||||
" are not authorized on this server",
|
||||
"Third party identifiers (email/phone numbers)"
|
||||
+ " are not authorized on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info(
|
||||
"Already registered user ID %r for this session",
|
||||
registered_user_id
|
||||
"Already registered user ID %r for this session", registered_user_id
|
||||
)
|
||||
# don't re-register the threepids
|
||||
registered = False
|
||||
|
@ -451,11 +450,11 @@ class RegisterRestServlet(RestServlet):
|
|||
# the two activation emails, they would register the same 3pid twice.
|
||||
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||
if login_type in auth_result:
|
||||
medium = auth_result[login_type]['medium']
|
||||
address = auth_result[login_type]['address']
|
||||
medium = auth_result[login_type]["medium"]
|
||||
address = auth_result[login_type]["address"]
|
||||
|
||||
existingUid = yield self.store.get_user_id_by_threepid(
|
||||
medium, address,
|
||||
medium, address
|
||||
)
|
||||
|
||||
if existingUid is not None:
|
||||
|
@ -520,7 +519,7 @@ class RegisterRestServlet(RestServlet):
|
|||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
if not username:
|
||||
raise SynapseError(
|
||||
400, "username must be specified", errcode=Codes.BAD_JSON,
|
||||
400, "username must be specified", errcode=Codes.BAD_JSON
|
||||
)
|
||||
|
||||
# use the username from the original request rather than the
|
||||
|
@ -541,12 +540,10 @@ class RegisterRestServlet(RestServlet):
|
|||
).hexdigest()
|
||||
|
||||
if not compare_digest(want_mac, got_mac):
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
raise SynapseError(403, "HMAC incorrect")
|
||||
|
||||
(user_id, _) = yield self.registration_handler.register(
|
||||
localpart=username, password=password, generate_token=False,
|
||||
localpart=username, password=password, generate_token=False
|
||||
)
|
||||
|
||||
result = yield self._create_registration_details(user_id, body)
|
||||
|
@ -565,21 +562,15 @@ class RegisterRestServlet(RestServlet):
|
|||
Returns:
|
||||
defer.Deferred: (object) dictionary for response from /register
|
||||
"""
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
result = {"user_id": user_id, "home_server": self.hs.hostname}
|
||||
if not params.get("inhibit_login", False):
|
||||
device_id = params.get("device_id")
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name, is_guest=False,
|
||||
user_id, device_id, initial_display_name, is_guest=False
|
||||
)
|
||||
|
||||
result.update({
|
||||
"access_token": access_token,
|
||||
"device_id": device_id,
|
||||
})
|
||||
result.update({"access_token": access_token, "device_id": device_id})
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -587,9 +578,7 @@ class RegisterRestServlet(RestServlet):
|
|||
if not self.hs.config.allow_guest_access:
|
||||
raise SynapseError(403, "Guest access is disabled")
|
||||
user_id, _ = yield self.registration_handler.register(
|
||||
generate_token=False,
|
||||
make_guest=True,
|
||||
address=address,
|
||||
generate_token=False, make_guest=True, address=address
|
||||
)
|
||||
|
||||
# we don't allow guests to specify their own device_id, because
|
||||
|
@ -597,15 +586,20 @@ class RegisterRestServlet(RestServlet):
|
|||
device_id = synapse.api.auth.GUEST_DEVICE_ID
|
||||
initial_display_name = params.get("initial_device_display_name")
|
||||
device_id, access_token = yield self.registration_handler.register_device(
|
||||
user_id, device_id, initial_display_name, is_guest=True,
|
||||
user_id, device_id, initial_display_name, is_guest=True
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}))
|
||||
defer.returnValue(
|
||||
(
|
||||
200,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
|
|
|
@ -32,7 +32,10 @@ from synapse.http.servlet import (
|
|||
parse_string,
|
||||
)
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
|
||||
from synapse.storage.relations import (
|
||||
AggregationPaginationToken,
|
||||
RelationPaginationToken,
|
||||
)
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
|
|
|
@ -33,9 +33,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ReportEventRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReportEventRestServlet, self).__init__()
|
||||
|
|
|
@ -129,22 +129,12 @@ class RoomKeysServlet(RestServlet):
|
|||
version = parse_string(request, "version")
|
||||
|
||||
if session_id:
|
||||
body = {
|
||||
"sessions": {
|
||||
session_id: body
|
||||
}
|
||||
}
|
||||
body = {"sessions": {session_id: body}}
|
||||
|
||||
if room_id:
|
||||
body = {
|
||||
"rooms": {
|
||||
room_id: body
|
||||
}
|
||||
}
|
||||
body = {"rooms": {room_id: body}}
|
||||
|
||||
yield self.e2e_room_keys_handler.upload_room_keys(
|
||||
user_id, version, body
|
||||
)
|
||||
yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -212,10 +202,10 @@ class RoomKeysServlet(RestServlet):
|
|||
if session_id:
|
||||
# If the client requests a specific session, but that session was
|
||||
# not backed up, then return an M_NOT_FOUND.
|
||||
if room_keys['rooms'] == {}:
|
||||
if room_keys["rooms"] == {}:
|
||||
raise NotFoundError("No room_keys found")
|
||||
else:
|
||||
room_keys = room_keys['rooms'][room_id]['sessions'][session_id]
|
||||
room_keys = room_keys["rooms"][room_id]["sessions"][session_id]
|
||||
elif room_id:
|
||||
# If the client requests all sessions from a room, but no sessions
|
||||
# are found, then return an empty result rather than an error, so
|
||||
|
@ -223,10 +213,10 @@ class RoomKeysServlet(RestServlet):
|
|||
# empty result is valid. (Similarly if the client requests all
|
||||
# sessions from the backup, but in that case, room_keys is already
|
||||
# in the right format, so we don't need to do anything about it.)
|
||||
if room_keys['rooms'] == {}:
|
||||
room_keys = {'sessions': {}}
|
||||
if room_keys["rooms"] == {}:
|
||||
room_keys = {"sessions": {}}
|
||||
else:
|
||||
room_keys = room_keys['rooms'][room_id]
|
||||
room_keys = room_keys["rooms"][room_id]
|
||||
|
||||
defer.returnValue((200, room_keys))
|
||||
|
||||
|
@ -256,9 +246,7 @@ class RoomKeysServlet(RestServlet):
|
|||
|
||||
|
||||
class RoomKeysNewVersionServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/room_keys/version$"
|
||||
)
|
||||
PATTERNS = client_patterns("/room_keys/version$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -304,9 +292,7 @@ class RoomKeysNewVersionServlet(RestServlet):
|
|||
user_id = requester.user.to_string()
|
||||
info = parse_json_object_from_request(request)
|
||||
|
||||
new_version = yield self.e2e_room_keys_handler.create_version(
|
||||
user_id, info
|
||||
)
|
||||
new_version = yield self.e2e_room_keys_handler.create_version(user_id, info)
|
||||
defer.returnValue((200, {"version": new_version}))
|
||||
|
||||
# we deliberately don't have a PUT /version, as these things really should
|
||||
|
@ -314,9 +300,7 @@ class RoomKeysNewVersionServlet(RestServlet):
|
|||
|
||||
|
||||
class RoomKeysVersionServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/room_keys/version(/(?P<version>[^/]+))?$"
|
||||
)
|
||||
PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
|
@ -350,9 +334,7 @@ class RoomKeysVersionServlet(RestServlet):
|
|||
user_id = requester.user.to_string()
|
||||
|
||||
try:
|
||||
info = yield self.e2e_room_keys_handler.get_version_info(
|
||||
user_id, version
|
||||
)
|
||||
info = yield self.e2e_room_keys_handler.get_version_info(user_id, version)
|
||||
except SynapseError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
|
||||
|
@ -375,9 +357,7 @@ class RoomKeysVersionServlet(RestServlet):
|
|||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
yield self.e2e_room_keys_handler.delete_version(
|
||||
user_id, version
|
||||
)
|
||||
yield self.e2e_room_keys_handler.delete_version(user_id, version)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -407,11 +387,11 @@ class RoomKeysVersionServlet(RestServlet):
|
|||
info = parse_json_object_from_request(request)
|
||||
|
||||
if version is None:
|
||||
raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
400, "No version specified to update", Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
yield self.e2e_room_keys_handler.update_version(
|
||||
user_id, version, info
|
||||
)
|
||||
yield self.e2e_room_keys_handler.update_version(user_id, version, info)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
|
|
|
@ -47,9 +47,10 @@ class RoomUpgradeRestServlet(RestServlet):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
# /rooms/$roomid/upgrade
|
||||
"/rooms/(?P<room_id>[^/]*)/upgrade$",
|
||||
"/rooms/(?P<room_id>[^/]*)/upgrade$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
@ -63,7 +64,7 @@ class RoomUpgradeRestServlet(RestServlet):
|
|||
requester = yield self._auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(content, ("new_version", ))
|
||||
assert_params_in_dict(content, ("new_version",))
|
||||
new_version = content["new_version"]
|
||||
|
||||
if new_version not in KNOWN_ROOM_VERSIONS:
|
||||
|
@ -77,9 +78,7 @@ class RoomUpgradeRestServlet(RestServlet):
|
|||
requester, room_id, new_version
|
||||
)
|
||||
|
||||
ret = {
|
||||
"replacement_room": new_room_id,
|
||||
}
|
||||
ret = {"replacement_room": new_room_id}
|
||||
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
class SendToDeviceRestServlet(servlet.RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
|
||||
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
|
|
|
@ -96,44 +96,42 @@ class SyncRestServlet(RestServlet):
|
|||
400, "'from' is not a valid query parameter. Did you mean 'since'?"
|
||||
)
|
||||
|
||||
requester = yield self.auth.get_user_by_req(
|
||||
request, allow_guest=True
|
||||
)
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user = requester.user
|
||||
device_id = requester.device_id
|
||||
|
||||
timeout = parse_integer(request, "timeout", default=0)
|
||||
since = parse_string(request, "since")
|
||||
set_presence = parse_string(
|
||||
request, "set_presence", default="online",
|
||||
allowed_values=self.ALLOWED_PRESENCE
|
||||
request,
|
||||
"set_presence",
|
||||
default="online",
|
||||
allowed_values=self.ALLOWED_PRESENCE,
|
||||
)
|
||||
filter_id = parse_string(request, "filter", default=None)
|
||||
full_state = parse_boolean(request, "full_state", default=False)
|
||||
|
||||
logger.debug(
|
||||
"/sync: user=%r, timeout=%r, since=%r,"
|
||||
" set_presence=%r, filter_id=%r, device_id=%r" % (
|
||||
user, timeout, since, set_presence, filter_id, device_id
|
||||
)
|
||||
" set_presence=%r, filter_id=%r, device_id=%r"
|
||||
% (user, timeout, since, set_presence, filter_id, device_id)
|
||||
)
|
||||
|
||||
request_key = (user, timeout, since, filter_id, full_state, device_id)
|
||||
|
||||
if filter_id:
|
||||
if filter_id.startswith('{'):
|
||||
if filter_id.startswith("{"):
|
||||
try:
|
||||
filter_object = json.loads(filter_id)
|
||||
set_timeline_upper_limit(filter_object,
|
||||
self.hs.config.filter_timeline_limit)
|
||||
set_timeline_upper_limit(
|
||||
filter_object, self.hs.config.filter_timeline_limit
|
||||
)
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid filter JSON")
|
||||
self.filtering.check_valid_filter(filter_object)
|
||||
filter = FilterCollection(filter_object)
|
||||
else:
|
||||
filter = yield self.filtering.get_user_filter(
|
||||
user.localpart, filter_id
|
||||
)
|
||||
filter = yield self.filtering.get_user_filter(user.localpart, filter_id)
|
||||
else:
|
||||
filter = DEFAULT_FILTER_COLLECTION
|
||||
|
||||
|
@ -156,15 +154,19 @@ class SyncRestServlet(RestServlet):
|
|||
affect_presence = set_presence != PresenceState.OFFLINE
|
||||
|
||||
if affect_presence:
|
||||
yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
|
||||
yield self.presence_handler.set_state(
|
||||
user, {"presence": set_presence}, True
|
||||
)
|
||||
|
||||
context = yield self.presence_handler.user_syncing(
|
||||
user.to_string(), affect_presence=affect_presence,
|
||||
user.to_string(), affect_presence=affect_presence
|
||||
)
|
||||
with context:
|
||||
sync_result = yield self.sync_handler.wait_for_sync_for_user(
|
||||
sync_config, since_token=since_token, timeout=timeout,
|
||||
full_state=full_state
|
||||
sync_config,
|
||||
since_token=since_token,
|
||||
timeout=timeout,
|
||||
full_state=full_state,
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
@ -176,53 +178,54 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def encode_response(self, time_now, sync_result, access_token_id, filter):
|
||||
if filter.event_format == 'client':
|
||||
if filter.event_format == "client":
|
||||
event_formatter = format_event_for_client_v2_without_room_id
|
||||
elif filter.event_format == 'federation':
|
||||
elif filter.event_format == "federation":
|
||||
event_formatter = format_event_raw
|
||||
else:
|
||||
raise Exception("Unknown event format %s" % (filter.event_format, ))
|
||||
raise Exception("Unknown event format %s" % (filter.event_format,))
|
||||
|
||||
joined = yield self.encode_joined(
|
||||
sync_result.joined, time_now, access_token_id,
|
||||
sync_result.joined,
|
||||
time_now,
|
||||
access_token_id,
|
||||
filter.event_fields,
|
||||
event_formatter,
|
||||
)
|
||||
|
||||
invited = yield self.encode_invited(
|
||||
sync_result.invited, time_now, access_token_id,
|
||||
event_formatter,
|
||||
sync_result.invited, time_now, access_token_id, event_formatter
|
||||
)
|
||||
|
||||
archived = yield self.encode_archived(
|
||||
sync_result.archived, time_now, access_token_id,
|
||||
sync_result.archived,
|
||||
time_now,
|
||||
access_token_id,
|
||||
filter.event_fields,
|
||||
event_formatter,
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"account_data": {"events": sync_result.account_data},
|
||||
"to_device": {"events": sync_result.to_device},
|
||||
"device_lists": {
|
||||
"changed": list(sync_result.device_lists.changed),
|
||||
"left": list(sync_result.device_lists.left),
|
||||
},
|
||||
"presence": SyncRestServlet.encode_presence(
|
||||
sync_result.presence, time_now
|
||||
),
|
||||
"rooms": {
|
||||
"join": joined,
|
||||
"invite": invited,
|
||||
"leave": archived,
|
||||
},
|
||||
"groups": {
|
||||
"join": sync_result.groups.join,
|
||||
"invite": sync_result.groups.invite,
|
||||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"next_batch": sync_result.next_batch.to_string(),
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"account_data": {"events": sync_result.account_data},
|
||||
"to_device": {"events": sync_result.to_device},
|
||||
"device_lists": {
|
||||
"changed": list(sync_result.device_lists.changed),
|
||||
"left": list(sync_result.device_lists.left),
|
||||
},
|
||||
"presence": SyncRestServlet.encode_presence(
|
||||
sync_result.presence, time_now
|
||||
),
|
||||
"rooms": {"join": joined, "invite": invited, "leave": archived},
|
||||
"groups": {
|
||||
"join": sync_result.groups.join,
|
||||
"invite": sync_result.groups.invite,
|
||||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"next_batch": sync_result.next_batch.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def encode_presence(events, time_now):
|
||||
|
@ -262,7 +265,11 @@ class SyncRestServlet(RestServlet):
|
|||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = yield self.encode_room(
|
||||
room, time_now, token_id, joined=True, only_fields=event_fields,
|
||||
room,
|
||||
time_now,
|
||||
token_id,
|
||||
joined=True,
|
||||
only_fields=event_fields,
|
||||
event_formatter=event_formatter,
|
||||
)
|
||||
|
||||
|
@ -290,7 +297,9 @@ class SyncRestServlet(RestServlet):
|
|||
invited = {}
|
||||
for room in rooms:
|
||||
invite = yield self._event_serializer.serialize_event(
|
||||
room.invite, time_now, token_id=token_id,
|
||||
room.invite,
|
||||
time_now,
|
||||
token_id=token_id,
|
||||
event_format=event_formatter,
|
||||
is_invite=True,
|
||||
)
|
||||
|
@ -298,9 +307,7 @@ class SyncRestServlet(RestServlet):
|
|||
invite["unsigned"] = unsigned
|
||||
invited_state = list(unsigned.pop("invite_room_state", []))
|
||||
invited_state.append(invite)
|
||||
invited[room.room_id] = {
|
||||
"invite_state": {"events": invited_state}
|
||||
}
|
||||
invited[room.room_id] = {"invite_state": {"events": invited_state}}
|
||||
|
||||
defer.returnValue(invited)
|
||||
|
||||
|
@ -327,7 +334,10 @@ class SyncRestServlet(RestServlet):
|
|||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = yield self.encode_room(
|
||||
room, time_now, token_id, joined=False,
|
||||
room,
|
||||
time_now,
|
||||
token_id,
|
||||
joined=False,
|
||||
only_fields=event_fields,
|
||||
event_formatter=event_formatter,
|
||||
)
|
||||
|
@ -336,8 +346,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def encode_room(
|
||||
self, room, time_now, token_id, joined,
|
||||
only_fields, event_formatter,
|
||||
self, room, time_now, token_id, joined, only_fields, event_formatter
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
|
@ -355,9 +364,11 @@ class SyncRestServlet(RestServlet):
|
|||
Returns:
|
||||
dict[str, object]: the room, encoded in our response format
|
||||
"""
|
||||
|
||||
def serialize(events):
|
||||
return self._event_serializer.serialize_events(
|
||||
events, time_now=time_now,
|
||||
events,
|
||||
time_now=time_now,
|
||||
# We don't bundle "live" events, as otherwise clients
|
||||
# will end up double counting annotations.
|
||||
bundle_aggregations=False,
|
||||
|
@ -377,7 +388,9 @@ class SyncRestServlet(RestServlet):
|
|||
if event.room_id != room.room_id:
|
||||
logger.warn(
|
||||
"Event %r is under room %r instead of %r",
|
||||
event.event_id, room.room_id, event.room_id,
|
||||
event.event_id,
|
||||
room.room_id,
|
||||
event.room_id,
|
||||
)
|
||||
|
||||
serialized_state = yield serialize(state_events)
|
||||
|
|
|
@ -29,9 +29,8 @@ class TagListServlet(RestServlet):
|
|||
"""
|
||||
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
|
||||
"""
|
||||
PATTERNS = client_patterns(
|
||||
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
|
||||
)
|
||||
|
||||
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(TagListServlet, self).__init__()
|
||||
|
@ -54,6 +53,7 @@ class TagServlet(RestServlet):
|
|||
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
|
||||
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
|
||||
)
|
||||
|
@ -74,9 +74,7 @@ class TagServlet(RestServlet):
|
|||
|
||||
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"account_data_key", max_id, users=[user_id]
|
||||
)
|
||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
@ -88,9 +86,7 @@ class TagServlet(RestServlet):
|
|||
|
||||
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
|
||||
|
||||
self.notifier.on_new_event(
|
||||
"account_data_key", max_id, users=[user_id]
|
||||
)
|
||||
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ class ThirdPartyProtocolServlet(RestServlet):
|
|||
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
protocols = yield self.appservice_handler.get_3pe_protocols(
|
||||
only_protocol=protocol,
|
||||
only_protocol=protocol
|
||||
)
|
||||
if protocol in protocols:
|
||||
defer.returnValue((200, protocols[protocol]))
|
||||
|
|
|
@ -26,6 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
|
|||
Exchanges refresh tokens for a pair of an access token and a new refresh
|
||||
token.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/tokenrefresh")
|
||||
|
||||
def __init__(self, hs):
|
||||
|
|
|
@ -60,10 +60,7 @@ class UserDirectorySearchRestServlet(RestServlet):
|
|||
user_id = requester.user.to_string()
|
||||
|
||||
if not self.hs.config.user_directory_search_enabled:
|
||||
defer.returnValue((200, {
|
||||
"limited": False,
|
||||
"results": [],
|
||||
}))
|
||||
defer.returnValue((200, {"limited": False, "results": []}))
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
|
@ -76,7 +73,7 @@ class UserDirectorySearchRestServlet(RestServlet):
|
|||
raise SynapseError(400, "`search_term` is required field")
|
||||
|
||||
results = yield self.user_directory_handler.search_users(
|
||||
user_id, search_term, limit,
|
||||
user_id, search_term, limit
|
||||
)
|
||||
|
||||
defer.returnValue((200, results))
|
||||
|
|
|
@ -25,27 +25,28 @@ class VersionsRestServlet(RestServlet):
|
|||
PATTERNS = [re.compile("^/_matrix/client/versions$")]
|
||||
|
||||
def on_GET(self, request):
|
||||
return (200, {
|
||||
"versions": [
|
||||
# XXX: at some point we need to decide whether we need to include
|
||||
# the previous version numbers, given we've defined r0.3.0 to be
|
||||
# backwards compatible with r0.2.0. But need to check how
|
||||
# conscientious we've been in compatibility, and decide whether the
|
||||
# middle number is the major revision when at 0.X.Y (as opposed to
|
||||
# X.Y.Z). And we need to decide whether it's fair to make clients
|
||||
# parse the version string to figure out what's going on.
|
||||
"r0.0.1",
|
||||
"r0.1.0",
|
||||
"r0.2.0",
|
||||
"r0.3.0",
|
||||
"r0.4.0",
|
||||
"r0.5.0",
|
||||
],
|
||||
# as per MSC1497:
|
||||
"unstable_features": {
|
||||
"m.lazy_load_members": True,
|
||||
}
|
||||
})
|
||||
return (
|
||||
200,
|
||||
{
|
||||
"versions": [
|
||||
# XXX: at some point we need to decide whether we need to include
|
||||
# the previous version numbers, given we've defined r0.3.0 to be
|
||||
# backwards compatible with r0.2.0. But need to check how
|
||||
# conscientious we've been in compatibility, and decide whether the
|
||||
# middle number is the major revision when at 0.X.Y (as opposed to
|
||||
# X.Y.Z). And we need to decide whether it's fair to make clients
|
||||
# parse the version string to figure out what's going on.
|
||||
"r0.0.1",
|
||||
"r0.1.0",
|
||||
"r0.2.0",
|
||||
"r0.3.0",
|
||||
"r0.4.0",
|
||||
"r0.5.0",
|
||||
],
|
||||
# as per MSC1497:
|
||||
"unstable_features": {"m.lazy_load_members": True},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(http_server):
|
||||
|
|
|
@ -42,6 +42,7 @@ logger = logging.getLogger(__name__)
|
|||
if hasattr(hmac, "compare_digest"):
|
||||
compare_digest = hmac.compare_digest
|
||||
else:
|
||||
|
||||
def compare_digest(a, b):
|
||||
return a == b
|
||||
|
||||
|
@ -80,6 +81,7 @@ class ConsentResource(Resource):
|
|||
For POST: required; gives the value to be recorded in the database
|
||||
against the user.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
|
@ -98,21 +100,20 @@ class ConsentResource(Resource):
|
|||
if self._default_consent_version is None:
|
||||
raise ConfigError(
|
||||
"Consent resource is enabled but user_consent section is "
|
||||
"missing in config file.",
|
||||
"missing in config file."
|
||||
)
|
||||
|
||||
consent_template_directory = hs.config.user_consent_template_dir
|
||||
|
||||
loader = jinja2.FileSystemLoader(consent_template_directory)
|
||||
self._jinja_env = jinja2.Environment(
|
||||
loader=loader,
|
||||
autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']),
|
||||
loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
|
||||
)
|
||||
|
||||
if hs.config.form_secret is None:
|
||||
raise ConfigError(
|
||||
"Consent resource is enabled but form_secret is not set in "
|
||||
"config file. It should be set to an arbitrary secret string.",
|
||||
"config file. It should be set to an arbitrary secret string."
|
||||
)
|
||||
|
||||
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
||||
|
@ -139,7 +140,7 @@ class ConsentResource(Resource):
|
|||
|
||||
self._check_hash(username, userhmac_bytes)
|
||||
|
||||
if username.startswith('@'):
|
||||
if username.startswith("@"):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
@ -153,7 +154,8 @@ class ConsentResource(Resource):
|
|||
|
||||
try:
|
||||
self._render_template(
|
||||
request, "%s.html" % (version,),
|
||||
request,
|
||||
"%s.html" % (version,),
|
||||
user=username,
|
||||
userhmac=userhmac,
|
||||
version=version,
|
||||
|
@ -180,7 +182,7 @@ class ConsentResource(Resource):
|
|||
|
||||
self._check_hash(username, userhmac)
|
||||
|
||||
if username.startswith('@'):
|
||||
if username.startswith("@"):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
@ -221,11 +223,13 @@ class ConsentResource(Resource):
|
|||
SynapseError if the hash doesn't match
|
||||
|
||||
"""
|
||||
want_mac = hmac.new(
|
||||
key=self._hmac_secret,
|
||||
msg=userid.encode('utf-8'),
|
||||
digestmod=sha256,
|
||||
).hexdigest().encode('ascii')
|
||||
want_mac = (
|
||||
hmac.new(
|
||||
key=self._hmac_secret, msg=userid.encode("utf-8"), digestmod=sha256
|
||||
)
|
||||
.hexdigest()
|
||||
.encode("ascii")
|
||||
)
|
||||
|
||||
if not compare_digest(want_mac, userhmac):
|
||||
raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
|
||||
|
|
|
@ -80,33 +80,27 @@ class LocalKey(Resource):
|
|||
for key in self.config.signing_key:
|
||||
verify_key_bytes = key.verify_key.encode()
|
||||
key_id = "%s:%s" % (key.alg, key.version)
|
||||
verify_keys[key_id] = {
|
||||
u"key": encode_base64(verify_key_bytes)
|
||||
}
|
||||
verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)}
|
||||
|
||||
old_verify_keys = {}
|
||||
for key_id, key in self.config.old_signing_keys.items():
|
||||
verify_key_bytes = key.encode()
|
||||
old_verify_keys[key_id] = {
|
||||
u"key": encode_base64(verify_key_bytes),
|
||||
u"expired_ts": key.expired_ts,
|
||||
"key": encode_base64(verify_key_bytes),
|
||||
"expired_ts": key.expired_ts,
|
||||
}
|
||||
|
||||
tls_fingerprints = self.config.tls_fingerprints
|
||||
|
||||
json_object = {
|
||||
u"valid_until_ts": self.valid_until_ts,
|
||||
u"server_name": self.config.server_name,
|
||||
u"verify_keys": verify_keys,
|
||||
u"old_verify_keys": old_verify_keys,
|
||||
u"tls_fingerprints": tls_fingerprints,
|
||||
"valid_until_ts": self.valid_until_ts,
|
||||
"server_name": self.config.server_name,
|
||||
"verify_keys": verify_keys,
|
||||
"old_verify_keys": old_verify_keys,
|
||||
"tls_fingerprints": tls_fingerprints,
|
||||
}
|
||||
for key in self.config.signing_key:
|
||||
json_object = sign_json(
|
||||
json_object,
|
||||
self.config.server_name,
|
||||
key,
|
||||
)
|
||||
json_object = sign_json(json_object, self.config.server_name, key)
|
||||
return json_object
|
||||
|
||||
def render_GET(self, request):
|
||||
|
@ -114,6 +108,4 @@ class LocalKey(Resource):
|
|||
# Update the expiry time if less than half the interval remains.
|
||||
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
|
||||
self.update_response_body(time_now)
|
||||
return respond_with_json_bytes(
|
||||
request, 200, self.response_body,
|
||||
)
|
||||
return respond_with_json_bytes(request, 200, self.response_body)
|
||||
|
|
|
@ -103,20 +103,16 @@ class RemoteKey(Resource):
|
|||
def async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
server, = request.postpath
|
||||
query = {server.decode('ascii'): {}}
|
||||
query = {server.decode("ascii"): {}}
|
||||
elif len(request.postpath) == 2:
|
||||
server, key_id = request.postpath
|
||||
minimum_valid_until_ts = parse_integer(
|
||||
request, "minimum_valid_until_ts"
|
||||
)
|
||||
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
|
||||
arguments = {}
|
||||
if minimum_valid_until_ts is not None:
|
||||
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
|
||||
query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
|
||||
query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
|
||||
else:
|
||||
raise SynapseError(
|
||||
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
||||
)
|
||||
raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
|
||||
|
||||
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
|
@ -140,8 +136,8 @@ class RemoteKey(Resource):
|
|||
store_queries = []
|
||||
for server_name, key_ids in query.items():
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
self.federation_domain_whitelist is not None
|
||||
and server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
logger.debug("Federation denied with %s", server_name)
|
||||
continue
|
||||
|
@ -159,9 +155,7 @@ class RemoteKey(Resource):
|
|||
|
||||
cache_misses = dict()
|
||||
for (server_name, key_id, from_server), results in cached.items():
|
||||
results = [
|
||||
(result["ts_added_ms"], result) for result in results
|
||||
]
|
||||
results = [(result["ts_added_ms"], result) for result in results]
|
||||
|
||||
if not results and key_id is not None:
|
||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||
|
@ -178,23 +172,30 @@ class RemoteKey(Resource):
|
|||
logger.debug(
|
||||
"Cached response for %r/%r is older than requested"
|
||||
": valid_until (%r) < minimum_valid_until (%r)",
|
||||
server_name, key_id,
|
||||
ts_valid_until_ms, req_valid_until
|
||||
server_name,
|
||||
key_id,
|
||||
ts_valid_until_ms,
|
||||
req_valid_until,
|
||||
)
|
||||
miss = True
|
||||
else:
|
||||
logger.debug(
|
||||
"Cached response for %r/%r is newer than requested"
|
||||
": valid_until (%r) >= minimum_valid_until (%r)",
|
||||
server_name, key_id,
|
||||
ts_valid_until_ms, req_valid_until
|
||||
server_name,
|
||||
key_id,
|
||||
ts_valid_until_ms,
|
||||
req_valid_until,
|
||||
)
|
||||
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
|
||||
logger.debug(
|
||||
"Cached response for %r/%r is too old"
|
||||
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
|
||||
server_name, key_id,
|
||||
ts_added_ms, ts_valid_until_ms, time_now_ms
|
||||
server_name,
|
||||
key_id,
|
||||
ts_added_ms,
|
||||
ts_valid_until_ms,
|
||||
time_now_ms,
|
||||
)
|
||||
# We more than half way through the lifetime of the
|
||||
# response. We should fetch a fresh copy.
|
||||
|
@ -203,8 +204,11 @@ class RemoteKey(Resource):
|
|||
logger.debug(
|
||||
"Cached response for %r/%r is still valid"
|
||||
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
|
||||
server_name, key_id,
|
||||
ts_added_ms, ts_valid_until_ms, time_now_ms
|
||||
server_name,
|
||||
key_id,
|
||||
ts_added_ms,
|
||||
ts_valid_until_ms,
|
||||
time_now_ms,
|
||||
)
|
||||
|
||||
if miss:
|
||||
|
@ -216,12 +220,10 @@ class RemoteKey(Resource):
|
|||
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
yield self.fetcher.get_keys(cache_misses)
|
||||
yield self.query_keys(
|
||||
request, query, query_remote_on_cache_miss=False
|
||||
)
|
||||
yield self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||
else:
|
||||
result_io = BytesIO()
|
||||
result_io.write(b"{\"server_keys\":")
|
||||
result_io.write(b'{"server_keys":')
|
||||
sep = b"["
|
||||
for json_bytes in json_results:
|
||||
result_io.write(sep)
|
||||
|
@ -231,6 +233,4 @@ class RemoteKey(Resource):
|
|||
result_io.write(sep)
|
||||
result_io.write(b"]}")
|
||||
|
||||
respond_with_json_bytes(
|
||||
request, 200, result_io.getvalue(),
|
||||
)
|
||||
respond_with_json_bytes(request, 200, result_io.getvalue())
|
||||
|
|
|
@ -44,6 +44,7 @@ class ContentRepoResource(resource.Resource):
|
|||
- Content type base64d (so we can return it when clients GET it)
|
||||
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, directory):
|
||||
|
@ -56,7 +57,7 @@ class ContentRepoResource(resource.Resource):
|
|||
# servers.
|
||||
|
||||
# TODO: A little crude here, we could do this better.
|
||||
filename = request.path.decode('ascii').split('/')[-1]
|
||||
filename = request.path.decode("ascii").split("/")[-1]
|
||||
# be paranoid
|
||||
filename = re.sub("[^0-9A-z.-_]", "", filename)
|
||||
|
||||
|
@ -69,17 +70,15 @@ class ContentRepoResource(resource.Resource):
|
|||
base64_contentype = filename.split(".")[1]
|
||||
content_type = base64.urlsafe_b64decode(base64_contentype)
|
||||
logger.info("Sending file %s", file_path)
|
||||
f = open(file_path, 'rb')
|
||||
request.setHeader('Content-Type', content_type)
|
||||
f = open(file_path, "rb")
|
||||
request.setHeader("Content-Type", content_type)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our matrix
|
||||
# clients are smart enough to be happy with Cache-Control (right?)
|
||||
request.setHeader(
|
||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
||||
|
||||
d = FileSender().beginFileTransfer(f, request)
|
||||
|
||||
|
@ -87,13 +86,15 @@ class ContentRepoResource(resource.Resource):
|
|||
def cbFinished(ignored):
|
||||
f.close()
|
||||
finish_request(request)
|
||||
|
||||
d.addCallback(cbFinished)
|
||||
else:
|
||||
respond_with_json_bytes(
|
||||
request,
|
||||
404,
|
||||
json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)),
|
||||
send_cors=True)
|
||||
send_cors=True,
|
||||
)
|
||||
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
|
|
|
@ -38,8 +38,8 @@ def parse_media_id(request):
|
|||
server_name, media_id = request.postpath[:2]
|
||||
|
||||
if isinstance(server_name, bytes):
|
||||
server_name = server_name.decode('utf-8')
|
||||
media_id = media_id.decode('utf8')
|
||||
server_name = server_name.decode("utf-8")
|
||||
media_id = media_id.decode("utf8")
|
||||
|
||||
file_name = None
|
||||
if len(request.postpath) > 2:
|
||||
|
@ -120,11 +120,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
|||
# correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
|
||||
# may as well just do the filename* version.
|
||||
if _can_encode_filename_as_token(upload_name):
|
||||
disposition = 'inline; filename=%s' % (upload_name, )
|
||||
disposition = "inline; filename=%s" % (upload_name,)
|
||||
else:
|
||||
disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name), )
|
||||
disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
|
||||
|
||||
request.setHeader(b"Content-Disposition", disposition.encode('ascii'))
|
||||
request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
|
@ -137,10 +137,27 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
|||
|
||||
# separators as defined in RFC2616. SP and HT are handled separately.
|
||||
# see _can_encode_filename_as_token.
|
||||
_FILENAME_SEPARATOR_CHARS = set((
|
||||
"(", ")", "<", ">", "@", ",", ";", ":", "\\", '"',
|
||||
"/", "[", "]", "?", "=", "{", "}",
|
||||
))
|
||||
_FILENAME_SEPARATOR_CHARS = set(
|
||||
(
|
||||
"(",
|
||||
")",
|
||||
"<",
|
||||
">",
|
||||
"@",
|
||||
",",
|
||||
";",
|
||||
":",
|
||||
"\\",
|
||||
'"',
|
||||
"/",
|
||||
"[",
|
||||
"]",
|
||||
"?",
|
||||
"=",
|
||||
"{",
|
||||
"}",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _can_encode_filename_as_token(x):
|
||||
|
@ -271,7 +288,7 @@ def get_filename_from_headers(headers):
|
|||
Returns:
|
||||
A Unicode string of the filename, or None.
|
||||
"""
|
||||
content_disposition = headers.get(b"Content-Disposition", [b''])
|
||||
content_disposition = headers.get(b"Content-Disposition", [b""])
|
||||
|
||||
# No header, bail out.
|
||||
if not content_disposition[0]:
|
||||
|
@ -293,7 +310,7 @@ def get_filename_from_headers(headers):
|
|||
# Once it is decoded, we can then unquote the %-encoded
|
||||
# parts strictly into a unicode string.
|
||||
upload_name = urllib.parse.unquote(
|
||||
upload_name_utf8.decode('ascii'), errors="strict"
|
||||
upload_name_utf8.decode("ascii"), errors="strict"
|
||||
)
|
||||
except UnicodeDecodeError:
|
||||
# Incorrect UTF-8.
|
||||
|
@ -302,7 +319,7 @@ def get_filename_from_headers(headers):
|
|||
# On Python 2, we first unquote the %-encoded parts and then
|
||||
# decode it strictly using UTF-8.
|
||||
try:
|
||||
upload_name = urllib.parse.unquote(upload_name_utf8).decode('utf8')
|
||||
upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
|
||||
|
@ -310,7 +327,7 @@ def get_filename_from_headers(headers):
|
|||
if not upload_name:
|
||||
upload_name_ascii = params.get(b"filename", None)
|
||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||
upload_name = upload_name_ascii.decode('ascii')
|
||||
upload_name = upload_name_ascii.decode("ascii")
|
||||
|
||||
# This may be None here, indicating we did not find a matching name.
|
||||
return upload_name
|
||||
|
@ -328,19 +345,19 @@ def _parse_header(line):
|
|||
Tuple[bytes, dict[bytes, bytes]]:
|
||||
the main content-type, followed by the parameter dictionary
|
||||
"""
|
||||
parts = _parseparam(b';' + line)
|
||||
parts = _parseparam(b";" + line)
|
||||
key = next(parts)
|
||||
pdict = {}
|
||||
for p in parts:
|
||||
i = p.find(b'=')
|
||||
i = p.find(b"=")
|
||||
if i >= 0:
|
||||
name = p[:i].strip().lower()
|
||||
value = p[i + 1:].strip()
|
||||
value = p[i + 1 :].strip()
|
||||
|
||||
# strip double-quotes
|
||||
if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
|
||||
value = value[1:-1]
|
||||
value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"')
|
||||
value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
|
||||
pdict[name] = value
|
||||
|
||||
return key, pdict
|
||||
|
@ -357,16 +374,16 @@ def _parseparam(s):
|
|||
Returns:
|
||||
Iterable[bytes]: the split input
|
||||
"""
|
||||
while s[:1] == b';':
|
||||
while s[:1] == b";":
|
||||
s = s[1:]
|
||||
|
||||
# look for the next ;
|
||||
end = s.find(b';')
|
||||
end = s.find(b";")
|
||||
|
||||
# if there is an odd number of " marks between here and the next ;, skip to the
|
||||
# next ; instead
|
||||
while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
|
||||
end = s.find(b';', end + 1)
|
||||
end = s.find(b";", end + 1)
|
||||
|
||||
if end < 0:
|
||||
end = len(s)
|
||||
|
|
|
@ -29,9 +29,7 @@ class MediaConfigResource(Resource):
|
|||
config = hs.get_config()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {
|
||||
"m.upload.size": config.max_upload_size,
|
||||
}
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
|
|
|
@ -54,18 +54,20 @@ class DownloadResource(Resource):
|
|||
b" plugin-types application/pdf;"
|
||||
b" style-src 'unsafe-inline';"
|
||||
b" media-src 'self';"
|
||||
b" object-src 'self';"
|
||||
b" object-src 'self';",
|
||||
)
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self.media_repo.get_local_media(request, media_id, name)
|
||||
else:
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True)
|
||||
request, "allow_remote", default=True
|
||||
)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
server_name, media_id,
|
||||
server_name,
|
||||
media_id,
|
||||
)
|
||||
respond_404(request)
|
||||
return
|
||||
|
|
|
@ -24,6 +24,7 @@ def _wrap_in_base_path(func):
|
|||
"""Takes a function that returns a relative path and turns it into an
|
||||
absolute path based on the location of the primary media store
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def _wrapped(self, *args, **kwargs):
|
||||
path = func(self, *args, **kwargs)
|
||||
|
@ -43,125 +44,102 @@ class MediaFilePaths(object):
|
|||
def __init__(self, primary_base_path):
|
||||
self.base_path = primary_base_path
|
||||
|
||||
def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
|
||||
height, content_type, method):
|
||||
def default_thumbnail_rel(
|
||||
self, default_top_level, default_sub_type, width, height, content_type, method
|
||||
):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
"default_thumbnails", default_top_level,
|
||||
default_sub_type, file_name
|
||||
"default_thumbnails", default_top_level, default_sub_type, file_name
|
||||
)
|
||||
|
||||
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||
|
||||
def local_media_filepath_rel(self, media_id):
|
||||
return os.path.join(
|
||||
"local_content",
|
||||
media_id[0:2], media_id[2:4], media_id[4:]
|
||||
)
|
||||
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
|
||||
|
||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||
|
||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type,
|
||||
method):
|
||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
"local_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
"local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
|
||||
)
|
||||
|
||||
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
|
||||
|
||||
def remote_media_filepath_rel(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
"remote_content", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:]
|
||||
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
|
||||
)
|
||||
|
||||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||
|
||||
def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
|
||||
content_type, method):
|
||||
def remote_media_thumbnail_rel(
|
||||
self, server_name, file_id, width, height, content_type, method
|
||||
):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
return os.path.join(
|
||||
"remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
"remote_thumbnail",
|
||||
server_name,
|
||||
file_id[0:2],
|
||||
file_id[2:4],
|
||||
file_id[4:],
|
||||
file_name,
|
||||
)
|
||||
|
||||
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
self.base_path,
|
||||
"remote_thumbnail",
|
||||
server_name,
|
||||
file_id[0:2],
|
||||
file_id[2:4],
|
||||
file_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_filepath_rel(self, media_id):
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
media_id[:10], media_id[11:]
|
||||
)
|
||||
return os.path.join("url_cache", media_id[:10], media_id[11:])
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
)
|
||||
return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
|
||||
|
||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id):
|
||||
"The dirs to try and remove if we delete the media_id file"
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[:10],
|
||||
),
|
||||
]
|
||||
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2], media_id[2:4],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2],
|
||||
),
|
||||
os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
|
||||
os.path.join(self.base_path, "url_cache", media_id[0:2]),
|
||||
]
|
||||
|
||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
|
||||
method):
|
||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
file_name
|
||||
"url_cache_thumbnails", media_id[:10], media_id[11:], file_name
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
file_name,
|
||||
)
|
||||
|
||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||
|
@ -172,13 +150,15 @@ class MediaFilePaths(object):
|
|||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
||||
|
@ -188,26 +168,21 @@ class MediaFilePaths(object):
|
|||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10],
|
||||
self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
|
||||
),
|
||||
os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
self.base_path,
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
media_id[2:4],
|
||||
media_id[4:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
|
||||
),
|
||||
os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
|
||||
]
|
||||
|
|
|
@ -100,17 +100,16 @@ class MediaRepository(object):
|
|||
storage_providers.append(provider)
|
||||
|
||||
self.media_storage = MediaStorage(
|
||||
self.hs, self.primary_base_path, self.filepaths, storage_providers,
|
||||
self.hs, self.primary_base_path, self.filepaths, storage_providers
|
||||
)
|
||||
|
||||
self.clock.looping_call(
|
||||
self._start_update_recently_accessed,
|
||||
UPDATE_RECENTLY_ACCESSED_TS,
|
||||
self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
|
||||
)
|
||||
|
||||
def _start_update_recently_accessed(self):
|
||||
return run_as_background_process(
|
||||
"update_recently_accessed_media", self._update_recently_accessed,
|
||||
"update_recently_accessed_media", self._update_recently_accessed
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -138,8 +137,9 @@ class MediaRepository(object):
|
|||
self.recently_accessed_locals.add(media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_content(self, media_type, upload_name, content, content_length,
|
||||
auth_user):
|
||||
def create_content(
|
||||
self, media_type, upload_name, content, content_length, auth_user
|
||||
):
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
|
||||
Args:
|
||||
|
@ -154,10 +154,7 @@ class MediaRepository(object):
|
|||
"""
|
||||
media_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
)
|
||||
file_info = FileInfo(server_name=None, file_id=media_id)
|
||||
|
||||
fname = yield self.media_storage.store_file(content, file_info)
|
||||
|
||||
|
@ -172,9 +169,7 @@ class MediaRepository(object):
|
|||
user_id=auth_user,
|
||||
)
|
||||
|
||||
yield self._generate_thumbnails(
|
||||
None, media_id, media_id, media_type,
|
||||
)
|
||||
yield self._generate_thumbnails(None, media_id, media_id, media_type)
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
|
@ -205,14 +200,11 @@ class MediaRepository(object):
|
|||
upload_name = name if name else media_info["upload_name"]
|
||||
url_cache = media_info["url_cache"]
|
||||
|
||||
file_info = FileInfo(
|
||||
None, media_id,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
file_info = FileInfo(None, media_id, url_cache=url_cache)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name,
|
||||
request, responder, media_type, media_length, upload_name
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -232,8 +224,8 @@ class MediaRepository(object):
|
|||
to request
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
self.federation_domain_whitelist is not None
|
||||
and server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(server_name)
|
||||
|
||||
|
@ -244,7 +236,7 @@ class MediaRepository(object):
|
|||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
server_name, media_id,
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
# We deliberately stream the file outside the lock
|
||||
|
@ -253,7 +245,7 @@ class MediaRepository(object):
|
|||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
yield respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name,
|
||||
request, responder, media_type, media_length, upload_name
|
||||
)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
@ -272,8 +264,8 @@ class MediaRepository(object):
|
|||
Deferred[dict]: The media_info of the file
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
self.federation_domain_whitelist is not None
|
||||
and server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(server_name)
|
||||
|
||||
|
@ -282,7 +274,7 @@ class MediaRepository(object):
|
|||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
server_name, media_id,
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
# Ensure we actually use the responder so that it releases resources
|
||||
|
@ -305,9 +297,7 @@ class MediaRepository(object):
|
|||
Returns:
|
||||
Deferred[(Responder, media_info)]
|
||||
"""
|
||||
media_info = yield self.store.get_cached_remote_media(
|
||||
server_name, media_id
|
||||
)
|
||||
media_info = yield self.store.get_cached_remote_media(server_name, media_id)
|
||||
|
||||
# file_id is the ID we use to track the file locally. If we've already
|
||||
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||
|
@ -331,9 +321,7 @@ class MediaRepository(object):
|
|||
|
||||
# Failed to find the file anywhere, lets download it.
|
||||
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id, file_id
|
||||
)
|
||||
media_info = yield self._download_remote_file(server_name, media_id, file_id)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
defer.returnValue((responder, media_info))
|
||||
|
@ -354,52 +342,60 @@ class MediaRepository(object):
|
|||
Deferred[MediaInfo]
|
||||
"""
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
)
|
||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
request_path = "/".join(
|
||||
("/_matrix/media/v1/download", server_name, media_id)
|
||||
)
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size, args={
|
||||
server_name,
|
||||
request_path,
|
||||
output_stream=f,
|
||||
max_size=self.max_upload_size,
|
||||
args={
|
||||
# tell the remote server to 404 if it doesn't
|
||||
# recognise the server_name, to make sure we don't
|
||||
# end up with a routing loop.
|
||||
"allow_remote": "false",
|
||||
}
|
||||
"allow_remote": "false"
|
||||
},
|
||||
)
|
||||
except RequestSendFailed as e:
|
||||
logger.warn("Request failed fetching remote media %s/%s: %r",
|
||||
server_name, media_id, e)
|
||||
logger.warn(
|
||||
"Request failed fetching remote media %s/%s: %r",
|
||||
server_name,
|
||||
media_id,
|
||||
e,
|
||||
)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
logger.warn(
|
||||
"HTTP error fetching remote media %s/%s: %s",
|
||||
server_name,
|
||||
media_id,
|
||||
e.response,
|
||||
)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise e.to_synapse_error()
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except SynapseError:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
logger.warn("Failed to fetch remote media %s/%s", server_name, media_id)
|
||||
raise
|
||||
except NotRetryingDestination:
|
||||
logger.warn("Not retrying destination %r", server_name)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
logger.exception(
|
||||
"Failed to fetch remote media %s/%s", server_name, media_id
|
||||
)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield finish()
|
||||
|
||||
media_type = headers[b"Content-Type"][0].decode('ascii')
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
upload_name = get_filename_from_headers(headers)
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
|
@ -423,24 +419,23 @@ class MediaRepository(object):
|
|||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
yield self._generate_thumbnails(
|
||||
server_name, media_id, file_id, media_type,
|
||||
)
|
||||
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
|
||||
|
||||
defer.returnValue(media_info)
|
||||
|
||||
def _get_thumbnail_requirements(self, media_type):
|
||||
return self.thumbnail_requirements.get(media_type, ())
|
||||
|
||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height,
|
||||
t_method, t_type):
|
||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
m_width,
|
||||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
return
|
||||
|
||||
|
@ -460,17 +455,22 @@ class MediaRepository(object):
|
|||
return t_byte_source
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||
t_method, t_type, url_cache):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
None, media_id, url_cache=url_cache,
|
||||
))
|
||||
def generate_local_exact_thumbnail(
|
||||
self, media_id, t_width, t_height, t_method, t_type, url_cache
|
||||
):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(None, media_id, url_cache=url_cache)
|
||||
)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer, t_width, t_height, t_method, t_type
|
||||
thumbnailer,
|
||||
t_width,
|
||||
t_height,
|
||||
t_method,
|
||||
t_type,
|
||||
)
|
||||
|
||||
if t_byte_source:
|
||||
|
@ -487,7 +487,7 @@ class MediaRepository(object):
|
|||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
t_byte_source, file_info
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
@ -503,17 +503,22 @@ class MediaRepository(object):
|
|||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||
t_width, t_height, t_method, t_type):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
server_name, file_id, url_cache=False,
|
||||
))
|
||||
def generate_remote_exact_thumbnail(
|
||||
self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
|
||||
):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=False)
|
||||
)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer, t_width, t_height, t_method, t_type
|
||||
thumbnailer,
|
||||
t_width,
|
||||
t_height,
|
||||
t_method,
|
||||
t_type,
|
||||
)
|
||||
|
||||
if t_byte_source:
|
||||
|
@ -529,7 +534,7 @@ class MediaRepository(object):
|
|||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
t_byte_source, file_info
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
@ -539,15 +544,22 @@ class MediaRepository(object):
|
|||
t_len = os.path.getsize(output_path)
|
||||
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
server_name,
|
||||
media_id,
|
||||
file_id,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
t_method,
|
||||
t_len,
|
||||
)
|
||||
|
||||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
|
||||
url_cache=False):
|
||||
def _generate_thumbnails(
|
||||
self, server_name, media_id, file_id, media_type, url_cache=False
|
||||
):
|
||||
"""Generate and store thumbnails for an image.
|
||||
|
||||
Args:
|
||||
|
@ -566,9 +578,9 @@ class MediaRepository(object):
|
|||
if not requirements:
|
||||
return
|
||||
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
server_name, file_id, url_cache=url_cache,
|
||||
))
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(
|
||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||
)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
|
@ -577,14 +589,15 @@ class MediaRepository(object):
|
|||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
m_width,
|
||||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
return
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = yield logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
thumbnailer.transpose
|
||||
self.hs.get_reactor(), thumbnailer.transpose
|
||||
)
|
||||
|
||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
|
@ -604,15 +617,11 @@ class MediaRepository(object):
|
|||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = yield logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
thumbnailer.crop,
|
||||
t_width, t_height, t_type,
|
||||
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
|
||||
)
|
||||
elif t_method == "scale":
|
||||
t_byte_source = yield logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
thumbnailer.scale,
|
||||
t_width, t_height, t_type,
|
||||
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
|
||||
)
|
||||
else:
|
||||
logger.error("Unrecognized method: %r", t_method)
|
||||
|
@ -634,7 +643,7 @@ class MediaRepository(object):
|
|||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
t_byte_source, file_info
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
@ -644,18 +653,21 @@ class MediaRepository(object):
|
|||
# Write to database
|
||||
if server_name:
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
server_name,
|
||||
media_id,
|
||||
file_id,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
t_method,
|
||||
t_len,
|
||||
)
|
||||
else:
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
"height": m_height,
|
||||
})
|
||||
defer.returnValue({"width": m_width, "height": m_height})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_old_remote_media(self, before_ts):
|
||||
|
@ -747,11 +759,12 @@ class MediaRepositoryResource(Resource):
|
|||
|
||||
self.putChild(b"upload", UploadResource(hs, media_repo))
|
||||
self.putChild(b"download", DownloadResource(hs, media_repo))
|
||||
self.putChild(b"thumbnail", ThumbnailResource(
|
||||
hs, media_repo, media_repo.media_storage,
|
||||
))
|
||||
self.putChild(
|
||||
b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
|
||||
)
|
||||
if hs.config.url_preview_enabled:
|
||||
self.putChild(b"preview_url", PreviewUrlResource(
|
||||
hs, media_repo, media_repo.media_storage,
|
||||
))
|
||||
self.putChild(
|
||||
b"preview_url",
|
||||
PreviewUrlResource(hs, media_repo, media_repo.media_storage),
|
||||
)
|
||||
self.putChild(b"config", MediaConfigResource(hs))
|
||||
|
|
|
@ -66,8 +66,7 @@ class MediaStorage(object):
|
|||
with self.store_into_file(file_info) as (f, fname, finish_cb):
|
||||
# Write to the main repository
|
||||
yield logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
_write_file_synchronously, source, f,
|
||||
self.hs.get_reactor(), _write_file_synchronously, source, f
|
||||
)
|
||||
yield finish_cb()
|
||||
|
||||
|
@ -179,7 +178,8 @@ class MediaStorage(object):
|
|||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(
|
||||
open(local_path, "wb"), self.hs.get_reactor())
|
||||
open(local_path, "wb"), self.hs.get_reactor()
|
||||
)
|
||||
yield res.write_to_consumer(consumer)
|
||||
yield consumer.wait()
|
||||
defer.returnValue(local_path)
|
||||
|
@ -217,10 +217,10 @@ class MediaStorage(object):
|
|||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method
|
||||
method=file_info.thumbnail_method,
|
||||
)
|
||||
return self.filepaths.remote_media_filepath_rel(
|
||||
file_info.server_name, file_info.file_id,
|
||||
file_info.server_name, file_info.file_id
|
||||
)
|
||||
|
||||
if file_info.thumbnail:
|
||||
|
@ -229,11 +229,9 @@ class MediaStorage(object):
|
|||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method
|
||||
method=file_info.thumbnail_method,
|
||||
)
|
||||
return self.filepaths.local_media_filepath_rel(
|
||||
file_info.file_id,
|
||||
)
|
||||
return self.filepaths.local_media_filepath_rel(file_info.file_id)
|
||||
|
||||
|
||||
def _write_file_synchronously(source, dest):
|
||||
|
@ -255,6 +253,7 @@ class FileResponder(Responder):
|
|||
open_file (file): A file like object to be streamed ot the client,
|
||||
is closed when finished streaming.
|
||||
"""
|
||||
|
||||
def __init__(self, open_file):
|
||||
self.open_file = open_file
|
||||
|
||||
|
|
|
@ -92,7 +92,7 @@ class PreviewUrlResource(Resource):
|
|||
)
|
||||
|
||||
self._cleaner_loop = self.clock.looping_call(
|
||||
self._start_expire_url_cache_data, 10 * 1000,
|
||||
self._start_expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
|
@ -121,16 +121,16 @@ class PreviewUrlResource(Resource):
|
|||
for attrib in entry:
|
||||
pattern = entry[attrib]
|
||||
value = getattr(url_tuple, attrib)
|
||||
logger.debug((
|
||||
"Matching attrib '%s' with value '%s' against"
|
||||
" pattern '%s'"
|
||||
) % (attrib, value, pattern))
|
||||
logger.debug(
|
||||
("Matching attrib '%s' with value '%s' against" " pattern '%s'")
|
||||
% (attrib, value, pattern)
|
||||
)
|
||||
|
||||
if value is None:
|
||||
match = False
|
||||
continue
|
||||
|
||||
if pattern.startswith('^'):
|
||||
if pattern.startswith("^"):
|
||||
if not re.match(pattern, getattr(url_tuple, attrib)):
|
||||
match = False
|
||||
continue
|
||||
|
@ -139,12 +139,9 @@ class PreviewUrlResource(Resource):
|
|||
match = False
|
||||
continue
|
||||
if match:
|
||||
logger.warn(
|
||||
"URL %s blocked by url_blacklist entry %s", url, entry
|
||||
)
|
||||
logger.warn("URL %s blocked by url_blacklist entry %s", url, entry)
|
||||
raise SynapseError(
|
||||
403, "URL blocked by url pattern blacklist entry",
|
||||
Codes.UNKNOWN
|
||||
403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
|
||||
)
|
||||
|
||||
# the in-memory cache:
|
||||
|
@ -156,14 +153,8 @@ class PreviewUrlResource(Resource):
|
|||
observable = self._cache.get(url)
|
||||
|
||||
if not observable:
|
||||
download = run_in_background(
|
||||
self._do_preview,
|
||||
url, requester.user, ts,
|
||||
)
|
||||
observable = ObservableDeferred(
|
||||
download,
|
||||
consumeErrors=True
|
||||
)
|
||||
download = run_in_background(self._do_preview, url, requester.user, ts)
|
||||
observable = ObservableDeferred(download, consumeErrors=True)
|
||||
self._cache[url] = observable
|
||||
else:
|
||||
logger.info("Returning cached response")
|
||||
|
@ -187,15 +178,15 @@ class PreviewUrlResource(Resource):
|
|||
# historical previews, if we have any)
|
||||
cache_result = yield self.store.get_url_cache(url, ts)
|
||||
if (
|
||||
cache_result and
|
||||
cache_result["expires_ts"] > ts and
|
||||
cache_result["response_code"] / 100 == 2
|
||||
cache_result
|
||||
and cache_result["expires_ts"] > ts
|
||||
and cache_result["response_code"] / 100 == 2
|
||||
):
|
||||
# It may be stored as text in the database, not as bytes (such as
|
||||
# PostgreSQL). If so, encode it back before handing it on.
|
||||
og = cache_result["og"]
|
||||
if isinstance(og, six.text_type):
|
||||
og = og.encode('utf8')
|
||||
og = og.encode("utf8")
|
||||
defer.returnValue(og)
|
||||
return
|
||||
|
||||
|
@ -203,33 +194,31 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if _is_media(media_info['media_type']):
|
||||
file_id = media_info['filesystem_id']
|
||||
if _is_media(media_info["media_type"]):
|
||||
file_id = media_info["filesystem_id"]
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, file_id, file_id, media_info["media_type"],
|
||||
url_cache=True,
|
||||
None, file_id, file_id, media_info["media_type"], url_cache=True
|
||||
)
|
||||
|
||||
og = {
|
||||
"og:description": media_info['download_name'],
|
||||
"og:image": "mxc://%s/%s" % (
|
||||
self.server_name, media_info['filesystem_id']
|
||||
),
|
||||
"og:image:type": media_info['media_type'],
|
||||
"matrix:image:size": media_info['media_length'],
|
||||
"og:description": media_info["download_name"],
|
||||
"og:image": "mxc://%s/%s"
|
||||
% (self.server_name, media_info["filesystem_id"]),
|
||||
"og:image:type": media_info["media_type"],
|
||||
"matrix:image:size": media_info["media_length"],
|
||||
}
|
||||
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
og["og:image:height"] = dims['height']
|
||||
og["og:image:width"] = dims["width"]
|
||||
og["og:image:height"] = dims["height"]
|
||||
else:
|
||||
logger.warn("Couldn't get dims for %s" % url)
|
||||
|
||||
# define our OG response for this media
|
||||
elif _is_html(media_info['media_type']):
|
||||
elif _is_html(media_info["media_type"]):
|
||||
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
|
||||
|
||||
with open(media_info['filename'], 'rb') as file:
|
||||
with open(media_info["filename"], "rb") as file:
|
||||
body = file.read()
|
||||
|
||||
encoding = None
|
||||
|
@ -242,45 +231,43 @@ class PreviewUrlResource(Resource):
|
|||
# If we find a match, it should take precedence over the
|
||||
# Content-Type header, so set it here.
|
||||
if match:
|
||||
encoding = match.group(1).decode('ascii')
|
||||
encoding = match.group(1).decode("ascii")
|
||||
|
||||
# If we don't find a match, we'll look at the HTTP Content-Type, and
|
||||
# if that doesn't exist, we'll fall back to UTF-8.
|
||||
if not encoding:
|
||||
match = _content_type_match.match(
|
||||
media_info['media_type']
|
||||
)
|
||||
match = _content_type_match.match(media_info["media_type"])
|
||||
encoding = match.group(1) if match else "utf-8"
|
||||
|
||||
og = decode_and_calc_og(body, media_info['uri'], encoding)
|
||||
og = decode_and_calc_og(body, media_info["uri"], encoding)
|
||||
|
||||
# pre-cache the image for posterity
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
if 'og:image' in og and og['og:image']:
|
||||
if "og:image" in og and og["og:image"]:
|
||||
image_info = yield self._download_url(
|
||||
_rebase_url(og['og:image'], media_info['uri']), user
|
||||
_rebase_url(og["og:image"], media_info["uri"]), user
|
||||
)
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
if _is_media(image_info["media_type"]):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
file_id = image_info['filesystem_id']
|
||||
file_id = image_info["filesystem_id"]
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, file_id, file_id, image_info["media_type"],
|
||||
url_cache=True,
|
||||
None, file_id, file_id, image_info["media_type"], url_cache=True
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
og["og:image:height"] = dims['height']
|
||||
og["og:image:width"] = dims["width"]
|
||||
og["og:image:height"] = dims["height"]
|
||||
else:
|
||||
logger.warn("Couldn't get dims for %s" % og["og:image"])
|
||||
|
||||
og["og:image"] = "mxc://%s/%s" % (
|
||||
self.server_name, image_info['filesystem_id']
|
||||
self.server_name,
|
||||
image_info["filesystem_id"],
|
||||
)
|
||||
og["og:image:type"] = image_info['media_type']
|
||||
og["matrix:image:size"] = image_info['media_length']
|
||||
og["og:image:type"] = image_info["media_type"]
|
||||
og["matrix:image:size"] = image_info["media_length"]
|
||||
else:
|
||||
del og["og:image"]
|
||||
else:
|
||||
|
@ -289,7 +276,7 @@ class PreviewUrlResource(Resource):
|
|||
|
||||
logger.debug("Calculated OG for %s as %s" % (url, og))
|
||||
|
||||
jsonog = json.dumps(og).encode('utf8')
|
||||
jsonog = json.dumps(og).encode("utf8")
|
||||
|
||||
# store OG in history-aware DB cache
|
||||
yield self.store.store_url_cache(
|
||||
|
@ -310,19 +297,15 @@ class PreviewUrlResource(Resource):
|
|||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
|
||||
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||
file_id = datetime.date.today().isoformat() + "_" + random_string(16)
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=file_id,
|
||||
url_cache=True,
|
||||
)
|
||||
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
logger.debug("Trying to get url '%s'" % url)
|
||||
length, headers, uri, code = yield self.client.get_file(
|
||||
url, output_stream=f, max_size=self.max_spider_size,
|
||||
url, output_stream=f, max_size=self.max_spider_size
|
||||
)
|
||||
except SynapseError:
|
||||
# Pass SynapseErrors through directly, so that the servlet
|
||||
|
@ -334,24 +317,25 @@ class PreviewUrlResource(Resource):
|
|||
# Note: This will also be the case if one of the resolved IP
|
||||
# addresses is blacklisted
|
||||
raise SynapseError(
|
||||
502, "DNS resolution failure during URL preview generation",
|
||||
Codes.UNKNOWN
|
||||
502,
|
||||
"DNS resolution failure during URL preview generation",
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
except Exception as e:
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
logger.warn("Error downloading %s: %r", url, e)
|
||||
|
||||
raise SynapseError(
|
||||
500, "Failed to download content: %s" % (
|
||||
traceback.format_exception_only(sys.exc_info()[0], e),
|
||||
),
|
||||
500,
|
||||
"Failed to download content: %s"
|
||||
% (traceback.format_exception_only(sys.exc_info()[0], e),),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
yield finish()
|
||||
|
||||
try:
|
||||
if b"Content-Type" in headers:
|
||||
media_type = headers[b"Content-Type"][0].decode('ascii')
|
||||
media_type = headers[b"Content-Type"][0].decode("ascii")
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
@ -375,24 +359,26 @@ class PreviewUrlResource(Resource):
|
|||
# therefore not expire it.
|
||||
raise
|
||||
|
||||
defer.returnValue({
|
||||
"media_type": media_type,
|
||||
"media_length": length,
|
||||
"download_name": download_name,
|
||||
"created_ts": time_now_ms,
|
||||
"filesystem_id": file_id,
|
||||
"filename": fname,
|
||||
"uri": uri,
|
||||
"response_code": code,
|
||||
# FIXME: we should calculate a proper expiration based on the
|
||||
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
||||
"expires": 60 * 60 * 1000,
|
||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
})
|
||||
defer.returnValue(
|
||||
{
|
||||
"media_type": media_type,
|
||||
"media_length": length,
|
||||
"download_name": download_name,
|
||||
"created_ts": time_now_ms,
|
||||
"filesystem_id": file_id,
|
||||
"filename": fname,
|
||||
"uri": uri,
|
||||
"response_code": code,
|
||||
# FIXME: we should calculate a proper expiration based on the
|
||||
# Cache-Control and Expire headers. But for now, assume 1 hour.
|
||||
"expires": 60 * 60 * 1000,
|
||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
}
|
||||
)
|
||||
|
||||
def _start_expire_url_cache_data(self):
|
||||
return run_as_background_process(
|
||||
"expire_url_cache_data", self._expire_url_cache_data,
|
||||
"expire_url_cache_data", self._expire_url_cache_data
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
@ -496,7 +482,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None):
|
|||
# blindly try decoding the body as utf-8, which seems to fix
|
||||
# the charset mismatches on https://google.com
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
|
||||
tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
|
||||
return og
|
||||
|
@ -523,8 +509,8 @@ def _calc_og(tree, media_uri):
|
|||
|
||||
og = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if 'content' in tag.attrib:
|
||||
og[tag.attrib['property']] = tag.attrib['content']
|
||||
if "content" in tag.attrib:
|
||||
og[tag.attrib["property"]] = tag.attrib["content"]
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
|
@ -535,39 +521,43 @@ def _calc_og(tree, media_uri):
|
|||
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if 'og:title' not in og:
|
||||
if "og:title" not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
if title and title[0].text is not None:
|
||||
og['og:title'] = title[0].text.strip()
|
||||
og["og:title"] = title[0].text.strip()
|
||||
else:
|
||||
og['og:title'] = None
|
||||
og["og:title"] = None
|
||||
|
||||
if 'og:image' not in og:
|
||||
if "og:image" not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
)
|
||||
if meta_image:
|
||||
og['og:image'] = _rebase_url(meta_image[0], media_uri)
|
||||
og["og:image"] = _rebase_url(meta_image[0], media_uri)
|
||||
else:
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(images, key=lambda i: (
|
||||
-1 * float(i.attrib['width']) * float(i.attrib['height'])
|
||||
))
|
||||
images = sorted(
|
||||
images,
|
||||
key=lambda i: (
|
||||
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
|
||||
),
|
||||
)
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
if images:
|
||||
og['og:image'] = images[0].attrib['src']
|
||||
og["og:image"] = images[0].attrib["src"]
|
||||
|
||||
if 'og:description' not in og:
|
||||
if "og:description" not in og:
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content")
|
||||
"/@content"
|
||||
)
|
||||
if meta_description:
|
||||
og['og:description'] = meta_description[0]
|
||||
og["og:description"] = meta_description[0]
|
||||
else:
|
||||
# grab any text nodes which are inside the <body/> tag...
|
||||
# unless they are within an HTML5 semantic markup tag...
|
||||
|
@ -588,18 +578,18 @@ def _calc_og(tree, media_uri):
|
|||
"script",
|
||||
"noscript",
|
||||
"style",
|
||||
etree.Comment
|
||||
etree.Comment,
|
||||
)
|
||||
|
||||
# Split all the text nodes into paragraphs (by splitting on new
|
||||
# lines)
|
||||
text_nodes = (
|
||||
re.sub(r'\s+', '\n', el).strip()
|
||||
re.sub(r"\s+", "\n", el).strip()
|
||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
og['og:description'] = summarize_paragraphs(text_nodes)
|
||||
og["og:description"] = summarize_paragraphs(text_nodes)
|
||||
else:
|
||||
og['og:description'] = summarize_paragraphs([og['og:description']])
|
||||
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
# as we only ever cared about its OG
|
||||
|
@ -636,7 +626,7 @@ def _iterate_over_text(tree, *tags_to_ignore):
|
|||
[child, child.tail] if child.tail else [child]
|
||||
for child in el.iterchildren()
|
||||
),
|
||||
elements
|
||||
elements,
|
||||
)
|
||||
|
||||
|
||||
|
@ -647,8 +637,8 @@ def _rebase_url(url, base):
|
|||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith('/'):
|
||||
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
|
||||
if not url[2].startswith("/"):
|
||||
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
|
||||
|
||||
|
@ -659,9 +649,8 @@ def _is_media(content_type):
|
|||
|
||||
def _is_html(content_type):
|
||||
content_type = content_type.lower()
|
||||
if (
|
||||
content_type.startswith("text/html") or
|
||||
content_type.startswith("application/xhtml")
|
||||
if content_type.startswith("text/html") or content_type.startswith(
|
||||
"application/xhtml"
|
||||
):
|
||||
return True
|
||||
|
||||
|
@ -671,19 +660,19 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
|||
# first paragraph and then word boundaries.
|
||||
# TODO: Respect sentences?
|
||||
|
||||
description = ''
|
||||
description = ""
|
||||
|
||||
# Keep adding paragraphs until we get to the MIN_SIZE.
|
||||
for text_node in text_nodes:
|
||||
if len(description) < min_size:
|
||||
text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
|
||||
description += text_node + '\n\n'
|
||||
text_node = re.sub(r"[\t \r\n]+", " ", text_node)
|
||||
description += text_node + "\n\n"
|
||||
else:
|
||||
break
|
||||
|
||||
description = description.strip()
|
||||
description = re.sub(r'[\t ]+', ' ', description)
|
||||
description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
|
||||
description = re.sub(r"[\t ]+", " ", description)
|
||||
description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
|
||||
|
||||
# If the concatenation of paragraphs to get above MIN_SIZE
|
||||
# took us over MAX_SIZE, then we need to truncate mid paragraph
|
||||
|
@ -715,5 +704,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
|||
|
||||
# We always add an ellipsis because at the very least
|
||||
# we chopped mid paragraph.
|
||||
description = new_desc.strip() + u"…"
|
||||
description = new_desc.strip() + "…"
|
||||
return description if description else None
|
||||
|
|
|
@ -32,6 +32,7 @@ class StorageProvider(object):
|
|||
"""A storage provider is a service that can store uploaded media and
|
||||
retrieve them.
|
||||
"""
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
"""Store the file described by file_info. The actual contents can be
|
||||
retrieved by reading the file in file_info.upload_path.
|
||||
|
@ -70,6 +71,7 @@ class StorageProviderWrapper(StorageProvider):
|
|||
uploaded, or todo the upload in the backgroud.
|
||||
store_remote (bool): Whether remote media should be uploaded
|
||||
"""
|
||||
|
||||
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
||||
self.backend = backend
|
||||
self.store_local = store_local
|
||||
|
@ -92,6 +94,7 @@ class StorageProviderWrapper(StorageProvider):
|
|||
return self.backend.store_file(path, file_info)
|
||||
except Exception:
|
||||
logger.exception("Error storing file")
|
||||
|
||||
run_in_background(store)
|
||||
return defer.succeed(None)
|
||||
|
||||
|
@ -123,8 +126,7 @@ class FileStorageProviderBackend(StorageProvider):
|
|||
os.makedirs(dirname)
|
||||
|
||||
return logcontext.defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
||||
)
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
|
|
|
@ -74,19 +74,18 @@ class ThumbnailResource(Resource):
|
|||
else:
|
||||
if self.dynamic_thumbnails:
|
||||
yield self._select_or_generate_remote_thumbnail(
|
||||
request, server_name, media_id,
|
||||
width, height, method, m_type
|
||||
request, server_name, media_id, width, height, method, m_type
|
||||
)
|
||||
else:
|
||||
yield self._respond_remote_thumbnail(
|
||||
request, server_name, media_id,
|
||||
width, height, method, m_type
|
||||
request, server_name, media_id, width, height, method, m_type
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||
method, m_type):
|
||||
def _respond_local_thumbnail(
|
||||
self, request, media_id, width, height, method, m_type
|
||||
):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
|
@ -105,7 +104,8 @@ class ThumbnailResource(Resource):
|
|||
)
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None, file_id=media_id,
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
|
@ -124,9 +124,15 @@ class ThumbnailResource(Resource):
|
|||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
|
||||
desired_height, desired_method,
|
||||
desired_type):
|
||||
def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
|
@ -146,7 +152,8 @@ class ThumbnailResource(Resource):
|
|||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=None, file_id=media_id,
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=info["thumbnail_width"],
|
||||
|
@ -167,7 +174,11 @@ class ThumbnailResource(Resource):
|
|||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_method, desired_type,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
url_cache=media_info["url_cache"],
|
||||
)
|
||||
|
||||
|
@ -178,13 +189,20 @@ class ThumbnailResource(Resource):
|
|||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
|
||||
desired_width, desired_height,
|
||||
desired_method, desired_type):
|
||||
def _select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request,
|
||||
server_name,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
file_id = media_info["filesystem_id"]
|
||||
|
@ -197,7 +215,8 @@ class ThumbnailResource(Resource):
|
|||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||
server_name=server_name,
|
||||
file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=info["thumbnail_width"],
|
||||
thumbnail_height=info["thumbnail_height"],
|
||||
|
@ -217,8 +236,13 @@ class ThumbnailResource(Resource):
|
|||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
|
||||
server_name, file_id, media_id, desired_width,
|
||||
desired_height, desired_method, desired_type
|
||||
server_name,
|
||||
file_id,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
)
|
||||
|
||||
if file_path:
|
||||
|
@ -228,15 +252,16 @@ class ThumbnailResource(Resource):
|
|||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
||||
height, method, m_type):
|
||||
def _respond_remote_thumbnail(
|
||||
self, request, server_name, media_id, width, height, method, m_type
|
||||
):
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
if thumbnail_infos:
|
||||
|
@ -244,7 +269,8 @@ class ThumbnailResource(Resource):
|
|||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
file_info = FileInfo(
|
||||
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||
server_name=server_name,
|
||||
file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
|
@ -261,8 +287,14 @@ class ThumbnailResource(Resource):
|
|||
logger.info("Failed to find any generated thumbnails")
|
||||
respond_404(request)
|
||||
|
||||
def _select_thumbnail(self, desired_width, desired_height, desired_method,
|
||||
desired_type, thumbnail_infos):
|
||||
def _select_thumbnail(
|
||||
self,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
thumbnail_infos,
|
||||
):
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
|
@ -280,15 +312,27 @@ class ThumbnailResource(Resource):
|
|||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
if t_w >= d_w or t_h >= d_h:
|
||||
info_list.append((
|
||||
aspect_quality, min_quality, size_quality, type_quality,
|
||||
length_quality, info
|
||||
))
|
||||
info_list.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
else:
|
||||
info_list2.append((
|
||||
aspect_quality, min_quality, size_quality, type_quality,
|
||||
length_quality, info
|
||||
))
|
||||
info_list2.append(
|
||||
(
|
||||
aspect_quality,
|
||||
min_quality,
|
||||
size_quality,
|
||||
type_quality,
|
||||
length_quality,
|
||||
info,
|
||||
)
|
||||
)
|
||||
if info_list:
|
||||
return min(info_list)[-1]
|
||||
else:
|
||||
|
@ -304,13 +348,11 @@ class ThumbnailResource(Resource):
|
|||
type_quality = desired_type != info["thumbnail_type"]
|
||||
length_quality = info["thumbnail_length"]
|
||||
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
|
||||
info_list.append((
|
||||
size_quality, type_quality, length_quality, info
|
||||
))
|
||||
info_list.append((size_quality, type_quality, length_quality, info))
|
||||
elif t_method == "scale":
|
||||
info_list2.append((
|
||||
size_quality, type_quality, length_quality, info
|
||||
))
|
||||
info_list2.append(
|
||||
(size_quality, type_quality, length_quality, info)
|
||||
)
|
||||
if info_list:
|
||||
return min(info_list)[-1]
|
||||
else:
|
||||
|
|
|
@ -28,16 +28,13 @@ EXIF_TRANSPOSE_MAPPINGS = {
|
|||
5: Image.TRANSPOSE,
|
||||
6: Image.ROTATE_270,
|
||||
7: Image.TRANSVERSE,
|
||||
8: Image.ROTATE_90
|
||||
8: Image.ROTATE_90,
|
||||
}
|
||||
|
||||
|
||||
class Thumbnailer(object):
|
||||
|
||||
FORMATS = {
|
||||
"image/jpeg": "JPEG",
|
||||
"image/png": "PNG",
|
||||
}
|
||||
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
|
||||
|
||||
def __init__(self, input_path):
|
||||
self.image = Image.open(input_path)
|
||||
|
@ -110,17 +107,13 @@ class Thumbnailer(object):
|
|||
"""
|
||||
if width * self.height > height * self.width:
|
||||
scaled_height = (width * self.height) // self.width
|
||||
scaled_image = self.image.resize(
|
||||
(width, scaled_height), Image.ANTIALIAS
|
||||
)
|
||||
scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS)
|
||||
crop_top = (scaled_height - height) // 2
|
||||
crop_bottom = height + crop_top
|
||||
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
|
||||
else:
|
||||
scaled_width = (height * self.width) // self.height
|
||||
scaled_image = self.image.resize(
|
||||
(scaled_width, height), Image.ANTIALIAS
|
||||
)
|
||||
scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS)
|
||||
crop_left = (scaled_width - width) // 2
|
||||
crop_right = width + crop_left
|
||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
|
|
|
@ -55,48 +55,36 @@ class UploadResource(Resource):
|
|||
requester = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader(b"Content-Length").decode('ascii')
|
||||
content_length = request.getHeader(b"Content-Length").decode("ascii")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
raise SynapseError(msg="Request must specify a Content-Length", code=400)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
raise SynapseError(msg="Upload request body is too large", code=413)
|
||||
|
||||
upload_name = parse_string(request, b"filename", encoding=None)
|
||||
if upload_name:
|
||||
try:
|
||||
upload_name = upload_name.decode('utf8')
|
||||
upload_name = upload_name.decode("utf8")
|
||||
except UnicodeDecodeError:
|
||||
raise SynapseError(
|
||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
|
||||
code=400,
|
||||
msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
|
||||
)
|
||||
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader(b"Content-Type"):
|
||||
media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
|
||||
media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
|
||||
else:
|
||||
raise SynapseError(
|
||||
msg="Upload request missing 'Content-Type'",
|
||||
code=400,
|
||||
)
|
||||
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
|
||||
|
||||
# if headers.hasHeader(b"Content-Disposition"):
|
||||
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = yield self.media_repo.create_content(
|
||||
media_type, upload_name, request.content,
|
||||
content_length, requester.user
|
||||
media_type, upload_name, request.content, content_length, requester.user
|
||||
)
|
||||
|
||||
logger.info("Uploaded content with URI %r", content_uri)
|
||||
|
||||
respond_with_json(
|
||||
request, 200, {"content_uri": content_uri}, send_cors=True
|
||||
)
|
||||
respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True)
|
||||
|
|
|
@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource):
|
|||
|
||||
def render_GET(self, request):
|
||||
metadata_xml = saml2.metadata.create_metadata_string(
|
||||
configfile=None, config=self.sp_config,
|
||||
configfile=None, config=self.sp_config
|
||||
)
|
||||
request.setHeader(b"Content-Type", b"text/xml; charset=utf-8")
|
||||
return metadata_xml
|
||||
|
|
|
@ -44,18 +44,16 @@ class SAML2ResponseResource(Resource):
|
|||
|
||||
@wrap_html_request_handler
|
||||
def _async_render_POST(self, request):
|
||||
resp_bytes = parse_string(request, 'SAMLResponse', required=True)
|
||||
relay_state = parse_string(request, 'RelayState', required=True)
|
||||
resp_bytes = parse_string(request, "SAMLResponse", required=True)
|
||||
relay_state = parse_string(request, "RelayState", required=True)
|
||||
|
||||
try:
|
||||
saml2_auth = self._saml_client.parse_authn_request_response(
|
||||
resp_bytes, saml2.BINDING_HTTP_POST,
|
||||
resp_bytes, saml2.BINDING_HTTP_POST
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Exception parsing SAML2 response", exc_info=1)
|
||||
raise CodeMessageException(
|
||||
400, "Unable to parse SAML2 response: %s" % (e,),
|
||||
)
|
||||
raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,))
|
||||
|
||||
if saml2_auth.not_signed:
|
||||
raise CodeMessageException(400, "SAML2 response was not signed")
|
||||
|
@ -67,6 +65,5 @@ class SAML2ResponseResource(Resource):
|
|||
|
||||
displayName = saml2_auth.ava.get("displayName", [None])[0]
|
||||
return self._sso_auth_handler.on_successful_auth(
|
||||
username, request, relay_state,
|
||||
user_display_name=displayName,
|
||||
username, request, relay_state, user_display_name=displayName
|
||||
)
|
||||
|
|
|
@ -29,6 +29,7 @@ class WellKnownBuilder(object):
|
|||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self._config = hs.config
|
||||
|
||||
|
@ -37,15 +38,11 @@ class WellKnownBuilder(object):
|
|||
if self._config.public_baseurl is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"m.homeserver": {
|
||||
"base_url": self._config.public_baseurl,
|
||||
},
|
||||
}
|
||||
result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
|
||||
|
||||
if self._config.default_identity_server:
|
||||
result["m.identity_server"] = {
|
||||
"base_url": self._config.default_identity_server,
|
||||
"base_url": self._config.default_identity_server
|
||||
}
|
||||
|
||||
return result
|
||||
|
@ -66,7 +63,7 @@ class WellKnownResource(Resource):
|
|||
if not r:
|
||||
request.setResponseCode(404)
|
||||
request.setHeader(b"Content-Type", b"text/plain")
|
||||
return b'.well-known not available'
|
||||
return b".well-known not available"
|
||||
|
||||
logger.debug("returning: %s", r)
|
||||
request.setHeader(b"Content-Type", b"application/json")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue