Merge branch 'develop' into jaywink/admin-forward-extremities

This commit is contained in:
Jason Robinson 2021-01-23 21:41:35 +02:00
commit 8965b6cfec
209 changed files with 7827 additions and 2430 deletions

View file

@ -15,6 +15,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, room_id: str):
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id: str):
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, server_name: str, media_id: str):
async def on_POST(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
class ProtectMediaByID(RestServlet):
"""Protect local media from being quarantined.
"""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
logging.info("Protecting local media by ID: %s", media_id)
# Quarantine this media id
await self.store.mark_local_media_as_safe(media_id)
return 200, {}
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id):
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
async def on_POST(self, request):
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_DELETE(self, request, server_name: str, media_id: str):
async def on_DELETE(
self, request: Request, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_POST(self, request, server_name: str):
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
def register_servlets_for_media_repo(hs, http_server):
def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)

View file

@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
if start < 0:
raise SynapseError(
400,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
400,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if len(users) >= limit:
if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@ -244,7 +259,7 @@ class UserRestServletV2(RestServlet):
if deactivate and not user["deactivated"]:
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
@ -486,12 +501,22 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
if not self.is_mine(UserID.from_string(target_user_id)):
raise SynapseError(400, "Can only deactivate local users")
if not await self.store.get_user_by_id(target_user_id):
raise NotFoundError("User not found")
async def on_POST(self, request, target_user_id):
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@ -501,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
UserID.from_string(target_user_id)
result = await self._deactivate_account_handler.deactivate_account(
target_user_id, erase
target_user_id, erase, requester, by_admin=True
)
if result:
id_server_unbind_result = "success"
@ -714,13 +737,6 @@ class UserMembershipRestServlet(RestServlet):
async def on_GET(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
if not self.is_mine(UserID.from_string(user_id)):
raise SynapseError(400, "Can only lookup local users")
user = await self.store.get_user_by_id(user_id)
if user is None:
raise NotFoundError("Unknown user")
room_ids = await self.store.get_rooms_for_user(user_id)
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
return 200, ret

View file

@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING:
import synapse.server
@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# provided.
if server:
raise e
else:
pass
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
)
try:
data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token
@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
)
try:
data = await handler.get_remote_public_room_list(
server,

View file

@ -20,9 +20,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@ -31,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@ -46,6 +44,10 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -189,11 +191,7 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
requester, request, body, "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@ -204,7 +202,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
user_id = requester.user.to_string()
@ -215,7 +215,6 @@ class PasswordRestServlet(RestServlet):
[[LoginType.EMAIL_IDENTITY]],
request,
body,
self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@ -227,7 +226,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
@ -260,7 +261,7 @@ class PasswordRestServlet(RestServlet):
password_hash = await self.auth_handler.hash(new_password)
elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
else:
# UI validation was skipped, but the request did not include a new
@ -304,19 +305,18 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
requester.user.to_string(), erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"deactivate your account",
requester, request, body, "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
requester.user.to_string(),
erase,
requester,
id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@ -695,11 +695,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"add a third-party identifier to your account",
requester, request, body, "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(

View file

@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
if self._is_worker:
raise Exception("Cannot handle PUT /account_data on worker")
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
max_id = await self.store.add_account_data_to_room(
await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):

View file

@ -19,7 +19,6 @@ from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http.server import respond_with_html
from synapse.http.servlet import RestServlet, parse_string
@ -46,22 +45,6 @@ class AuthRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
# SSO configuration.
self._cas_enabled = hs.config.cas_enabled
if self._cas_enabled:
self._cas_handler = hs.get_cas_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self._saml_enabled = hs.config.saml2_enabled
if self._saml_enabled:
self._saml_handler = hs.get_saml_handler()
self._oidc_enabled = hs.config.oidc_enabled
if self._oidc_enabled:
self._oidc_handler = hs.get_oidc_handler()
self._cas_server_url = hs.config.cas_server_url
self._cas_service_url = hs.config.cas_service_url
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template
@ -90,21 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider.
if self._cas_enabled:
sso_auth_provider = self._cas_handler # type: SsoIdentityProvider
elif self._saml_enabled:
sso_auth_provider = self._saml_handler
elif self._oidc_enabled:
sso_auth_provider = self._oidc_handler
else:
raise SynapseError(400, "Homeserver not configured for SSO.")
sso_redirect_url = await sso_auth_provider.handle_redirect_request(
request, None, session
)
html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
html = await self.auth_handler.start_sso_ui_auth(request, session)
else:
raise SynapseError(404, "Unknown auth stage type")
@ -128,7 +97,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
@ -144,7 +113,7 @@ class AuthRestServlet(RestServlet):
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
LoginType.TERMS, authdict, request.getClientIP()
)
if success:

View file

@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"remove device(s) from your account",
requester, request, body, "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"remove a device from your account",
requester, request, body, "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)

View file

@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
self.hs.get_ip_from_request(request),
"add a device signing key to your account",
requester, request, body, "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)

View file

@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@ -353,7 +354,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
ip = self.hs.get_ip_from_request(request)
ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@ -494,11 +495,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = await self.auth_handler.get_session_data(
session_id, "registered_user_id", None
session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
session_id, "password_hash", None
session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
# Ensure that the username is valid.
@ -513,11 +514,7 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
self._registration_flows,
request,
body,
self.hs.get_ip_from_request(request),
"register a new account",
self._registration_flows, request, body, "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@ -532,7 +529,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
e.session_id, "password_hash", password_hash
e.session_id,
UIAuthSessionDataConstants.PASSWORD_HASH,
password_hash,
)
raise
@ -633,7 +632,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
session_id, "registered_user_id", registered_user_id
session_id,
UIAuthSessionDataConstants.REGISTERED_USER_ID,
registered_user_id,
)
registered = True

View file

@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -17,10 +17,11 @@
import logging
import os
import urllib
from typing import Awaitable
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
]
def parse_media_id(request):
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
@ -69,7 +70,7 @@ def parse_media_id(request):
)
def respond_404(request):
def respond_404(request: Request) -> None:
respond_with_json(
request,
404,
@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None
):
request: Request,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
upload_name: Optional[str] = None,
) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name):
def add_file_headers(
request: Request,
media_type: str,
file_size: Optional[int],
upload_name: Optional[str],
) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request (twisted.web.http.Request)
media_type (str): The media/content type.
file_size (int): Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any.
request
media_type: The media/content type.
file_size: Size in bytes of the media, if known.
upload_name: The name of the requested file, if any.
"""
def _quote(x):
@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
request.setHeader(b"Content-Length", b"%d" % (file_size,))
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This
# should help to prevent things in the media repo from showing up in web
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
}
def _can_encode_filename_as_token(x):
def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None
):
request: Request,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
upload_name: Optional[str] = None,
) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request (twisted.web.http.Request)
responder (Responder|None)
media_type (str): The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any.
request
responder
media_type: The media/content type.
file_size: Size in bytes of the media. If not known it should be None
upload_name: The name of the requested file, if any.
"""
if request._disconnected:
logger.warning(
@ -285,6 +300,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@ -297,6 +313,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@ -306,24 +323,25 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers):
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
headers (dict[bytes, list[bytes]]): The HTTP request headers.
headers: The HTTP request headers.
Returns:
A Unicode string of the filename, or None.
The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
return
return None
_, params = _parse_header(content_disposition[0])
@ -356,17 +374,16 @@ def get_filename_from_headers(headers):
return upload_name
def _parse_header(line):
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
line (bytes): header to be parsed
line: header to be parsed
Returns:
Tuple[bytes, dict[bytes, bytes]]:
the main content-type, followed by the parameter dictionary
The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
@ -386,16 +403,16 @@ def _parse_header(line):
return key, pdict
def _parseparam(s):
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
s (bytes): header to be parsed
s: header to be parsed
Returns:
Iterable[bytes]: the split input
The split input
"""
while s[:1] == b";":
s = s[1:]

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,22 +15,29 @@
# limitations under the License.
#
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = synapse.http.servlet.parse_boolean(
request, "allow_remote", default=True
)
allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -16,11 +17,12 @@
import functools
import os
import re
from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
def _wrap_in_base_path(func):
def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured)
"""
def __init__(self, primary_base_path):
def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
def default_thumbnail_rel(
self, default_top_level, default_sub_type, width, height, content_type, method
):
self,
default_top_level: str,
default_sub_type: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id):
def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:],
)
def remote_media_filepath_rel(self, server_name, file_id):
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel(
self, server_name, file_id, width, height, content_type, method
):
self,
server_name: str,
file_id: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
def remote_media_thumbnail_rel_legacy(
self, server_name, file_id, width, height, content_type
self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@ -126,7 +142,7 @@ class MediaFilePaths:
file_name,
)
def remote_media_thumbnail_dir(self, server_name, file_id):
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:],
)
def url_cache_filepath_rel(self, media_id):
def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id):
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
def url_cache_thumbnail_directory(self, media_id):
def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import logging
import os
import shutil
from typing import IO, Dict, List, Optional, Tuple
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.client = hs.get_federation_http_client()
@ -73,16 +76,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed
)
async def _update_recently_accessed(self):
async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec()
)
def mark_recently_accessed(self, server_name, media_id):
def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed.
Args:
server_name (str|None): Origin server of media, or None if local
media_id (str): The media ID of the content
server_name: Origin server of media, or None if local
media_id: The media ID of the content
"""
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
def _generate_thumbnail(
self,
thumbnailer: Thumbnailer,
t_width: int,
t_height: int,
t_method: str,
t_type: str,
) -> Optional[BytesIO]:
m_width = thumbnailer.width
m_height = thumbnailer.height
@ -470,22 +480,20 @@ class MediaRepository:
m_height,
self.max_image_pixels,
)
return
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
else:
t_byte_source = None
return thumbnailer.scale(t_width, t_height, t_type)
return t_byte_source
return None
async def generate_local_exact_thumbnail(
self,
@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height}
async def delete_old_remote_media(self, before_ts):
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ import os
import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest):
def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
open_file (file): A file like object to be streamed ot the client,
open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
def __init__(self, open_file):
def __init__(self, open_file: IO):
self.open_file = open_file
def write_to_consumer(self, consumer):
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import errno
import fnmatch
@ -23,12 +23,13 @@ import re
import shutil
import sys
import traceback
from typing import Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
if TYPE_CHECKING:
from lxml import etree
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@ -119,7 +127,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.auth = hs.get_auth()
@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
async def _download_url(self, url: str, user):
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data
)
async def _expire_url_cache_data(self):
async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
# If there's no body, nothing useful is going to be found.
if not body:
return {}
@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
return og
def _calc_og(tree, media_uri):
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og["og:description"] = summarize_paragraphs(text_nodes)
else:
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
return og
def _iterate_over_text(tree, *tags_to_ignore):
def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
)
def _rebase_url(url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith("/"):
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
return urlparse.urlunparse(url)
def _rebase_url(url: str, base: str) -> str:
base_parts = list(urlparse.urlparse(base))
url_parts = list(urlparse.urlparse(url))
if not url_parts[0]: # fix up schema
url_parts[0] = base_parts[0] or "http"
if not url_parts[1]: # fix up hostname
url_parts[1] = base_parts[1]
if not url_parts[2].startswith("/"):
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url_parts)
def _is_media(content_type):
if content_type.lower().startswith("image/"):
return True
def _is_media(content_type: str) -> bool:
return content_type.lower().startswith("image/")
def _is_html(content_type):
def _is_html(content_type: str) -> bool:
content_type = content_type.lower()
if content_type.startswith("text/html") or content_type.startswith(
return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml"
):
return True
)
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?

View file

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
import os
import shutil
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
@ -27,13 +28,17 @@ from .media_storage import FileResponder
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class StorageProvider:
class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
async def store_file(self, path: str, file_info: FileInfo):
@abc.abstractmethod
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file.
"""
@abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def __str__(self):
def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
async def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.store_file(path, file_info))
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
return None
async def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
hs (HomeServer)
hs
config: The config returned by `parse_config`.
"""
def __init__(self, hs, config):
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
return await defer_to_thread(
await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
async def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
return None
@staticmethod
def parse_config(config):
def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -15,10 +16,14 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
@ -28,13 +33,22 @@ from ._base import (
respond_with_responder,
)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__()
self.store = hs.get_datastore()
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
async def _async_render_GET(self, request):
async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type
):
self,
request: Request,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
file_info = FileInfo(
server_name=None,
file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_id,
url_cache=media_info["url_cache"],
server_name=None,
)
async def _select_or_generate_local_thumbnail(
self,
request,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
):
request: Request,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
request,
server_name,
media_id,
desired_width,
desired_height,
desired_method,
desired_type,
):
request: Request,
server_name: str,
media_id: str,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails(
@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
):
self,
request: Request,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_info["filesystem_id"],
url_cache=None,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: Request,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
url_cache: Optional[str] = None,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos:
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
file_info = self._select_thumbnail(
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
)
file_info = FileInfo(
server_name=server_name,
file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
if not file_info:
logger.info("Couldn't find a thumbnail matching the desired inputs")
respond_404(request)
return
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
await respond_with_responder(
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
)
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
def _select_thumbnail(
self,
desired_width,
desired_height,
desired_method,
desired_type,
thumbnail_infos,
):
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
url_cache: Optional[str],
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width
d_h = desired_height
if desired_method.lower() == "crop":
if desired_method == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
# Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
continue
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
if t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
if crop_info_list:
return min(crop_info_list)[-1]
else:
return min(crop_info_list2)[-1]
else:
info_list = []
info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
if t_w >= d_w or t_h >= d_h:
crop_info_list.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
else:
crop_info_list2.append(
(
aspect_quality,
min_quality,
size_quality,
type_quality,
length_quality,
info,
)
)
if crop_info_list:
thumbnail_info = min(crop_info_list)[-1]
elif crop_info_list2:
thumbnail_info = min(crop_info_list2)[-1]
elif desired_method == "scale":
# Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
# Other thumbnails.
info_list2 = []
for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "scale":
continue
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
elif t_method == "scale":
else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
return min(info_list)[-1]
else:
return min(info_list2)[-1]
thumbnail_info = min(info_list)[-1]
elif info_list2:
thumbnail_info = min(info_list2)[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
thumbnail_length=thumbnail_info["thumbnail_length"],
)
# No matching thumbnail was found.
return None

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,6 +15,7 @@
# limitations under the License.
import logging
from io import BytesIO
from typing import Tuple
from PIL import Image
@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
def __init__(self, input_path: str):
try:
self.image = Image.open(input_path)
except OSError as e:
@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
def transpose(self):
def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag
Returns:
Tuple[int, int]: (width, height) containing the new image size in pixels.
A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method)
@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None
return self.image.size
def aspect(self, max_width, max_height):
def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
@ -91,7 +93,7 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
def _resize(self, width, height):
def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
@ -99,7 +101,7 @@ class Thumbnailer:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type):
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions.
Returns:
@ -108,7 +110,7 @@ class Thumbnailer:
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@ -136,7 +138,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
def _encode_image(self, output_image, output_type):
def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,18 +15,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request):
async def _async_render_POST(self, request: Request) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point

View file

@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource):
self._server_name = hs.hostname
async def _async_render_GET(self, request: SynapseRequest) -> None:
client_redirect_url = parse_string(request, "redirectUrl", required=True)
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding="utf-8"
)
idp = parse_string(request, "idp", required=False)
# if we need to pick an IdP, do so

View file

@ -34,10 +34,6 @@ class WellKnownBuilder:
self._config = hs.config
def get_well_known(self):
# if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None:
return None
result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
if self._config.default_identity_server: