Add type hints to media rest resources. (#9093)

This commit is contained in:
Patrick Cloke 2021-01-15 10:57:37 -05:00 committed by GitHub
parent 0dd2649c12
commit d34c6e1279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 286 additions and 165 deletions

1
changelog.d/9093.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to media repository.

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -17,10 +17,11 @@
import logging import logging
import os import os
import urllib import urllib
from typing import Awaitable from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
] ]
def parse_media_id(request): def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try: try:
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type. # clients that parse the URL to see content type.
@ -69,7 +70,7 @@ def parse_media_id(request):
) )
def respond_404(request): def respond_404(request: Request) -> None:
respond_with_json( respond_with_json(
request, request,
404, 404,
@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file( async def respond_with_file(
request, media_type, file_path, file_size=None, upload_name=None request: Request,
): media_type: str,
file_path: str,
file_size: Optional[int] = None,
upload_name: Optional[str] = None,
) -> None:
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request) respond_404(request)
def add_file_headers(request, media_type, file_size, upload_name): def add_file_headers(
request: Request,
media_type: str,
file_size: Optional[int],
upload_name: Optional[str],
) -> None:
"""Adds the correct response headers in preparation for responding with the """Adds the correct response headers in preparation for responding with the
media. media.
Args: Args:
request (twisted.web.http.Request) request
media_type (str): The media/content type. media_type: The media/content type.
file_size (int): Size in bytes of the media, if known. file_size: Size in bytes of the media, if known.
upload_name (str): The name of the requested file, if any. upload_name: The name of the requested file, if any.
""" """
def _quote(x): def _quote(x):
@ -153,6 +163,7 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our # select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control # clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,)) request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This # Tell web crawlers to not index, archive, or follow links in media. This
@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
} }
def _can_encode_filename_as_token(x): def _can_encode_filename_as_token(x: str) -> bool:
for c in x: for c in x:
# from RFC2616: # from RFC2616:
# #
@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder( async def respond_with_responder(
request, responder, media_type, file_size, upload_name=None request: Request,
): responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
upload_name: Optional[str] = None,
) -> None:
"""Responds to the request with given responder. If responder is None then """Responds to the request with given responder. If responder is None then
returns 404. returns 404.
Args: Args:
request (twisted.web.http.Request) request
responder (Responder|None) responder
media_type (str): The media/content type. media_type: The media/content type.
file_size (int|None): Size in bytes of the media. If not known it should be None file_size: Size in bytes of the media. If not known it should be None
upload_name (str|None): The name of the requested file, if any. upload_name: The name of the requested file, if any.
""" """
if request._disconnected: if request._disconnected:
logger.warning( logger.warning(
@ -308,22 +323,22 @@ class FileInfo:
self.thumbnail_type = thumbnail_type self.thumbnail_type = thumbnail_type
def get_filename_from_headers(headers): def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
""" """
Get the filename of the downloaded file by inspecting the Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header. Content-Disposition HTTP header.
Args: Args:
headers (dict[bytes, list[bytes]]): The HTTP request headers. headers: The HTTP request headers.
Returns: Returns:
A Unicode string of the filename, or None. The filename, or None.
""" """
content_disposition = headers.get(b"Content-Disposition", [b""]) content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out. # No header, bail out.
if not content_disposition[0]: if not content_disposition[0]:
return return None
_, params = _parse_header(content_disposition[0]) _, params = _parse_header(content_disposition[0])
@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
return upload_name return upload_name
def _parse_header(line): def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header. """Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings. Cargo-culted from `cgi`, but works on bytes rather than strings.
Args: Args:
line (bytes): header to be parsed line: header to be parsed
Returns: Returns:
Tuple[bytes, dict[bytes, bytes]]: The main content-type, followed by the parameter dictionary
the main content-type, followed by the parameter dictionary
""" """
parts = _parseparam(b";" + line) parts = _parseparam(b";" + line)
key = next(parts) key = next(parts)
@ -386,16 +400,16 @@ def _parse_header(line):
return key, pdict return key, pdict
def _parseparam(s): def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences """Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings. Cargo-culted from `cgi`, but works on bytes rather than strings.
Args: Args:
s (bytes): header to be parsed s: header to be parsed
Returns: Returns:
Iterable[bytes]: the split input The split input
""" """
while s[:1] == b";": while s[:1] == b";":
s = s[1:] s = s[1:]

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk> # Copyright 2018 Will Hunt <will@half-shot.uk>
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,22 +15,29 @@
# limitations under the License. # limitations under the License.
# #
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class MediaConfigResource(DirectServeJsonResource): class MediaConfigResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
config = hs.get_config() config = hs.get_config()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size} self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404 from ._base import parse_media_id, respond_404
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource): class DownloadResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request) set_cors_headers(request)
request.setHeader( request.setHeader(
b"Content-Security-Policy", b"Content-Security-Policy",
@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name: if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name) await self.media_repo.get_local_media(request, media_id, name)
else: else:
allow_remote = synapse.http.servlet.parse_boolean( allow_remote = parse_boolean(request, "allow_remote", default=True)
request, "allow_remote", default=True
)
if not allow_remote: if not allow_remote:
logger.info( logger.info(
"Rejecting request for remote media %s/%s due to allow_remote", "Rejecting request for remote media %s/%s due to allow_remote",

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -16,11 +17,12 @@
import functools import functools
import os import os
import re import re
from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
def _wrap_in_base_path(func): def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an """Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store absolute path based on the location of the primary media store
""" """
@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured) to write to the backup media store (when one is configured)
""" """
def __init__(self, primary_base_path): def __init__(self, primary_base_path: str):
self.base_path = primary_base_path self.base_path = primary_base_path
def default_thumbnail_rel( def default_thumbnail_rel(
self, default_top_level, default_sub_type, width, height, content_type, method self,
): default_top_level: str,
default_sub_type: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id): def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:]) return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method): def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:], media_id[4:],
) )
def remote_media_filepath_rel(self, server_name, file_id): def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join( return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
) )
@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel( def remote_media_thumbnail_rel(
self, server_name, file_id, width, height, content_type, method self,
): server_name: str,
file_id: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join( return os.path.join(
@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored # Should be removed after some time, when most of the thumbnails are stored
# using the new path. # using the new path.
def remote_media_thumbnail_rel_legacy( def remote_media_thumbnail_rel_legacy(
self, server_name, file_id, width, height, content_type self, server_name: str, file_id: str, width: int, height: int, content_type: str
): ):
top_level_type, sub_type = content_type.split("/") top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@ -126,7 +142,7 @@ class MediaFilePaths:
file_name, file_name,
) )
def remote_media_thumbnail_dir(self, server_name, file_id): def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join( return os.path.join(
self.base_path, self.base_path,
"remote_thumbnail", "remote_thumbnail",
@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:], file_id[4:],
) )
def url_cache_filepath_rel(self, media_id): def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id): def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file" "The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id): if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])] return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]), os.path.join(self.base_path, "url_cache", media_id[0:2]),
] ]
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method): def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
def url_cache_thumbnail_directory(self, media_id): def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf
@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:], media_id[4:],
) )
def url_cache_thumbnail_dirs_to_delete(self, media_id): def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails" "The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING> # Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf # E.g.: 2017-09-28-fsdRDt24DS234dsf

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import errno import errno
import logging import logging
import os import os
import shutil import shutil
from typing import IO, Dict, List, Optional, Tuple from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource from .upload_resource import UploadResource
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository: class MediaRepository:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = hs.get_federation_http_client() self.client = hs.get_federation_http_client()
@ -73,16 +76,16 @@ class MediaRepository:
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels self.max_image_pixels = hs.config.max_image_pixels
self.primary_base_path = hs.config.media_store_path self.primary_base_path = hs.config.media_store_path # type: str
self.filepaths = MediaFilePaths(self.primary_base_path) self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote") self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
self.recently_accessed_locals = set() self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed "update_recently_accessed_media", self._update_recently_accessed
) )
async def _update_recently_accessed(self): async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set() self.recently_accessed_remotes = set()
@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec() local_media, remote_media, self.clock.time_msec()
) )
def mark_recently_accessed(self, server_name, media_id): def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed. """Mark the given media as recently accessed.
Args: Args:
server_name (str|None): Origin server of media, or None if local server_name: Origin server of media, or None if local
media_id (str): The media ID of the content media_id: The media ID of the content
""" """
if server_name: if server_name:
self.recently_accessed_remotes.add((server_name, media_id)) self.recently_accessed_remotes.add((server_name, media_id))
@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type): def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ()) return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): def _generate_thumbnail(
self,
thumbnailer: Thumbnailer,
t_width: int,
t_height: int,
t_method: str,
t_type: str,
) -> Optional[BytesIO]:
m_width = thumbnailer.width m_width = thumbnailer.width
m_height = thumbnailer.height m_height = thumbnailer.height
@ -470,22 +480,20 @@ class MediaRepository:
m_height, m_height,
self.max_image_pixels, self.max_image_pixels,
) )
return return None
if thumbnailer.transpose_method is not None: if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose() m_width, m_height = thumbnailer.transpose()
if t_method == "crop": if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type) return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale": elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width) t_width = min(m_width, t_width)
t_height = min(m_height, t_height) t_height = min(m_height, t_height)
t_byte_source = thumbnailer.scale(t_width, t_height, t_type) return thumbnailer.scale(t_width, t_height, t_type)
else:
t_byte_source = None
return t_byte_source return None
async def generate_local_exact_thumbnail( async def generate_local_exact_thumbnail(
self, self,
@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height} return {"width": m_width, "height": m_height}
async def delete_old_remote_media(self, before_ts): async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts) old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0 deleted = 0
@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle. within a given rectangle.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here. # If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo: if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.") raise ConfigError("Synapse is not configured to use a media repo.")

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vecotr Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@ import os
import shutil import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id) return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest): def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called """Write `source` to the file like `dest` synchronously. Should be called
from a thread. from a thread.
@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request. """Wraps an open file that can be sent to a request.
Args: Args:
open_file (file): A file like object to be streamed ot the client, open_file: A file like object to be streamed ot the client,
is closed when finished streaming. is closed when finished streaming.
""" """
def __init__(self, open_file): def __init__(self, open_file: IO):
self.open_file = open_file self.open_file = open_file
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable( return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer) FileSender().beginFileTransfer(self.open_file, consumer)
) )

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import datetime import datetime
import errno import errno
import fnmatch import fnmatch
@ -23,12 +23,13 @@ import re
import shutil import shutil
import sys import sys
import traceback import traceback
from typing import Dict, Optional from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse from urllib import parse as urlparse
import attr import attr
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo from ._base import FileInfo
if TYPE_CHECKING:
from lxml import etree
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I) _charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@ -119,7 +127,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource): class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__() super().__init__()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000 self._start_expire_url_cache_data, 10 * 1000
) )
async def _async_render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET") request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e raise OEmbedError() from e
async def _download_url(self, url: str, user): async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice # TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a # we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot? # bot, so are we really a robot?
@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data "expire_url_cache_data", self._expire_url_cache_data
) )
async def _expire_url_cache_data(self): async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails. """Clean up expired url cache content, media and thumbnails.
""" """
# TODO: Delete from backup media store # TODO: Delete from backup media store
@ -676,7 +689,9 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache") logger.debug("No media removed from url cache")
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]: def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
# If there's no body, nothing useful is going to be found. # If there's no body, nothing useful is going to be found.
if not body: if not body:
return {} return {}
@ -697,7 +712,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
return og return og
def _calc_og(tree, media_uri): def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response. # suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them # if we see any image URLs in the OG response, then spider them
@ -801,7 +816,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE) for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
) )
og["og:description"] = summarize_paragraphs(text_nodes) og["og:description"] = summarize_paragraphs(text_nodes)
else: elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]]) og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling, # TODO: delete the url downloads to stop diskfilling,
@ -809,7 +826,9 @@ def _calc_og(tree, media_uri):
return og return og
def _iterate_over_text(tree, *tags_to_ignore): def _iterate_over_text(
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion, """Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags. skipping text nodes inside certain tags.
""" """
@ -843,32 +862,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
) )
def _rebase_url(url, base): def _rebase_url(url: str, base: str) -> str:
base = list(urlparse.urlparse(base)) base_parts = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url)) url_parts = list(urlparse.urlparse(url))
if not url[0]: # fix up schema if not url_parts[0]: # fix up schema
url[0] = base[0] or "http" url_parts[0] = base_parts[0] or "http"
if not url[1]: # fix up hostname if not url_parts[1]: # fix up hostname
url[1] = base[1] url_parts[1] = base_parts[1]
if not url[2].startswith("/"): if not url_parts[2].startswith("/"):
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2] url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
return urlparse.urlunparse(url) return urlparse.urlunparse(url_parts)
def _is_media(content_type): def _is_media(content_type: str) -> bool:
if content_type.lower().startswith("image/"): return content_type.lower().startswith("image/")
return True
def _is_html(content_type): def _is_html(content_type: str) -> bool:
content_type = content_type.lower() content_type = content_type.lower()
if content_type.startswith("text/html") or content_type.startswith( return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml" "application/xhtml"
): )
return True
def summarize_paragraphs(text_nodes, min_size=200, max_size=500): def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting # Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries. # first paragraph and then word boundaries.
# TODO: Respect sentences? # TODO: Respect sentences?

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,10 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import abc
import logging import logging
import os import os
import shutil import shutil
from typing import Optional from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.context import defer_to_thread, run_in_background
@ -27,13 +28,17 @@ from .media_storage import FileResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
class StorageProvider:
class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and """A storage provider is a service that can store uploaded media and
retrieve them. retrieve them.
""" """
async def store_file(self, path: str, file_info: FileInfo): @abc.abstractmethod
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be """Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path. retrieved by reading the file in file_info.upload_path.
@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file. file_info: The metadata of the file.
""" """
@abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it """Attempt to fetch the file described by file_info and stream it
into writer. into writer.
@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous self.store_synchronous = store_synchronous
self.store_remote = store_remote self.store_remote = store_remote
def __str__(self): def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,) return "StorageProviderWrapper[%s]" % (self.backend,)
async def store_file(self, path, file_info): async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local: if not file_info.server_name and not self.store_local:
return None return None
@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous: if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
return await maybe_awaitable(self.backend.store_file(path, file_info)) await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else: else:
# TODO: Handle errors. # TODO: Handle errors.
async def store(): async def store():
@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file") logger.exception("Error storing file")
run_in_background(store) run_in_background(store)
return None
async def fetch(self, path, file_info): async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard # store_file is supposed to return an Awaitable, but guard
# against improper implementations. # against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info)) return await maybe_awaitable(self.backend.fetch(path, file_info))
@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem. """A storage provider that stores files in a directory on a filesystem.
Args: Args:
hs (HomeServer) hs
config: The config returned by `parse_config`. config: The config returned by `parse_config`.
""" """
def __init__(self, hs, config): def __init__(self, hs: "HomeServer", config: str):
self.hs = hs self.hs = hs
self.cache_directory = hs.config.media_store_path self.cache_directory = hs.config.media_store_path
self.base_directory = config self.base_directory = config
@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self): def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,) return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path, file_info): async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file""" """See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path) primary_fname = os.path.join(self.cache_directory, path)
@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return await defer_to_thread( await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
) )
async def fetch(self, path, file_info): async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path) backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname): if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb")) return FileResponder(open(backup_fname, "rb"))
return None
@staticmethod @staticmethod
def parse_config(config): def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse """Called on startup to parse config supplied. This should parse
the config and raise if there is a problem. the config and raise if there is a problem.

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,10 +16,14 @@
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import ( from ._base import (
FileInfo, FileInfo,
@ -28,13 +33,22 @@ from ._base import (
respond_with_responder, respond_with_responder,
) )
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource): class ThumbnailResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo, media_storage): def __init__(
self,
hs: "HomeServer",
media_repo: "MediaRepository",
media_storage: MediaStorage,
):
super().__init__() super().__init__()
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname self.server_name = hs.hostname
async def _async_render_GET(self, request): async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request) set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail( async def _respond_local_thumbnail(
self, request, media_id, width, height, method, m_type self,
): request: Request,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info: if not media_info:
@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_local_thumbnail( async def _select_or_generate_local_thumbnail(
self, self,
request, request: Request,
media_id, media_id: str,
desired_width, desired_width: int,
desired_height, desired_height: int,
desired_method, desired_method: str,
desired_type, desired_type: str,
): ) -> None:
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info: if not media_info:
@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail( async def _select_or_generate_remote_thumbnail(
self, self,
request, request: Request,
server_name, server_name: str,
media_id, media_id: str,
desired_width, desired_width: int,
desired_height, desired_height: int,
desired_method, desired_method: str,
desired_type, desired_type: str,
): ) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id) media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails( thumbnail_infos = await self.store.get_remote_media_thumbnails(
@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.") raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail( async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type self,
): request: Request,
server_name: str,
media_id: str,
width: int,
height: int,
method: str,
m_type: str,
) -> None:
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of # We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails. # downloading the remote file and generating our own thumbnails.
@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
def _select_thumbnail( def _select_thumbnail(
self, self,
desired_width, desired_width: int,
desired_height, desired_height: int,
desired_method, desired_method: str,
desired_type, desired_type: str,
thumbnail_infos, thumbnail_infos,
): ) -> dict:
d_w = desired_width d_w = desired_width
d_h = desired_height d_h = desired_height

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from io import BytesIO from io import BytesIO
from typing import Tuple
from PIL import Image from PIL import Image
@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path): def __init__(self, input_path: str):
try: try:
self.image = Image.open(input_path) self.image = Image.open(input_path)
except OSError as e: except OSError as e:
@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF # A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e) logger.info("Error parsing image EXIF information: %s", e)
def transpose(self): def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag """Transpose the image using its EXIF Orientation tag
Returns: Returns:
Tuple[int, int]: (width, height) containing the new image size in pixels. A tuple containing the new image size in pixels as (width, height).
""" """
if self.transpose_method is not None: if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method) self.image = self.image.transpose(self.transpose_method)
@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None self.image.info["exif"] = None
return self.image.size return self.image.size
def aspect(self, max_width, max_height): def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which """Calculate the largest size that preserves aspect ratio which
fits within the given rectangle:: fits within the given rectangle::
@ -91,7 +93,7 @@ class Thumbnailer:
else: else:
return (max_height * self.width) // self.height, max_height return (max_height * self.width) // self.height, max_height
def _resize(self, width, height): def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB # 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which # otherwise they will be scaled using nearest neighbour which
# looks awful # looks awful
@ -99,7 +101,7 @@ class Thumbnailer:
self.image = self.image.convert("RGB") self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS) return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type): def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions. """Rescales the image to the given dimensions.
Returns: Returns:
@ -108,7 +110,7 @@ class Thumbnailer:
scaled = self._resize(width, height) scaled = self._resize(width, height)
return self._encode_image(scaled, output_type) return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type): def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving """Rescales and crops the image to the given dimensions preserving
aspect:: aspect::
(w_in / h_in) = (w_scaled / h_scaled) (w_in / h_in) = (w_scaled / h_scaled)
@ -136,7 +138,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height)) cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type) return self._encode_image(cropped, output_type)
def _encode_image(self, output_image, output_type): def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO() output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type] fmt = self.FORMATS[output_type]
if fmt == "JPEG": if fmt == "JPEG":

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,18 +15,25 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string from synapse.http.servlet import parse_string
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
from synapse.rest.media.v1.media_repository import MediaRepository
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource): class UploadResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs, media_repo): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request): async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True) respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request): async def _async_render_POST(self, request: Request) -> None:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point

View File

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
async def get_local_media_before( async def get_local_media_before(
self, before_ts: int, size_gt: int, keep_profiles: bool, self, before_ts: int, size_gt: int, keep_profiles: bool,
) -> Optional[List[str]]: ) -> List[str]:
# to find files that have never been accessed (last_access_ts IS NULL) # to find files that have never been accessed (last_access_ts IS NULL)
# compare with `created_ts` # compare with `created_ts`