Refactor media modules. (#15146)

* Removes the `v1` directory from `test.rest.media.v1`.
* Moves the non-REST code from `synapse.rest.media.v1` to `synapse.media`.
* Flatten the `v1` directory from `synapse.rest.media`,  but leave compatiblity
  with 3rd party media repositories and spam checkers.
This commit is contained in:
Patrick Cloke 2023-02-27 08:26:05 -05:00 committed by GitHub
parent 3f2ef205e2
commit 4fc8875876
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 1190 additions and 1123 deletions

View file

@ -22,11 +22,10 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_boolean
from synapse.http.site import SynapseRequest
from ._base import parse_media_id, respond_404
from synapse.media._base import parse_media_id, respond_404
if TYPE_CHECKING:
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)

View file

@ -0,0 +1,93 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from synapse.config._base import ConfigError
from synapse.http.server import UnrecognizedRequestResource
from .config_resource import MediaConfigResource
from .download_resource import DownloadResource
from .preview_url_resource import PreviewUrlResource
from .thumbnail_resource import ThumbnailResource
from .upload_resource import UploadResource
if TYPE_CHECKING:
from synapse.server import HomeServer
class MediaRepositoryResource(UnrecognizedRequestResource):
"""File uploading and downloading.
Uploads are POSTed to a resource which returns a token which is used to GET
the download::
=> POST /_matrix/media/r0/upload HTTP/1.1
Content-Type: <media-type>
Content-Length: <content-length>
<media>
<= HTTP/1.1 200 OK
Content-Type: application/json
{ "content_uri": "mxc://<server-name>/<media-id>" }
=> GET /_matrix/media/r0/download/<server-name>/<media-id> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: <media-type>
Content-Disposition: attachment;filename=<upload-filename>
<media>
Clients can get thumbnails by supplying a desired width and height and
thumbnailing method::
=> GET /_matrix/media/r0/thumbnail/<server_name>
/<media-id>?width=<w>&height=<h>&method=<m> HTTP/1.1
<= HTTP/1.1 200 OK
Content-Type: image/jpeg or image/png
<thumbnail>
The thumbnail methods are "crop" and "scale". "scale" tries to return an
image where either the width or the height is smaller than the requested
size. The client should then scale and letterbox the image if it needs to
fit within a given rectangle. "crop" tries to return an image where the
width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit
within a given rectangle.
"""
def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.media.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
super().__init__()
media_repo = hs.get_media_repository()
self.putChild(b"upload", UploadResource(hs, media_repo))
self.putChild(b"download", DownloadResource(hs, media_repo))
self.putChild(
b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
)
if hs.config.media.url_preview_enabled:
self.putChild(
b"preview_url",
PreviewUrlResource(hs, media_repo, media_repo.media_storage),
)
self.putChild(b"config", MediaConfigResource(hs))

View file

@ -40,21 +40,19 @@ from synapse.http.server import (
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.media._base import FileInfo, get_filename_from_headers
from synapse.media.media_storage import MediaStorage
from synapse.media.oembed import OEmbedProvider
from synapse.media.preview_html import decode_body, parse_html_to_open_graph
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.rest.media.v1.preview_html import decode_body, parse_html_to_open_graph
from synapse.types import JsonDict, UserID
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.stringutils import random_string
from ._base import FileInfo
if TYPE_CHECKING:
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)

View file

@ -27,9 +27,7 @@ from synapse.http.server import (
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
from synapse.media._base import (
FileInfo,
ThumbnailInfo,
parse_media_id,
@ -37,9 +35,10 @@ from ._base import (
respond_with_file,
respond_with_responder,
)
from synapse.media.media_storage import MediaStorage
if TYPE_CHECKING:
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)

View file

@ -20,10 +20,10 @@ from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_bytes_from_args
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException
from synapse.media.media_storage import SpamMediaException
if TYPE_CHECKING:
from synapse.rest.media.v1.media_repository import MediaRepository
from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer
logger = logging.getLogger(__name__)

View file

@ -1,5 +1,4 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,468 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import urllib
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
import attr
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii, parse_and_validate_server_name
logger = logging.getLogger(__name__)
# list all text content types that will have the charset default to UTF-8 when
# none is given
TEXT_CONTENT_TYPES = [
"text/css",
"text/csv",
"text/html",
"text/calendar",
"text/plain",
"text/javascript",
"application/json",
"application/ld+json",
"application/rtf",
"image/svg+xml",
"text/xml",
]
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
"""Parses the server name, media ID and optional file name from the request URI
Also performs some rough validation on the server name.
Args:
request: The `Request`.
Returns:
A tuple containing the parsed server name, media ID and optional file name.
Raises:
SynapseError(404): if parsing or validation fail for any reason
"""
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath: List[bytes] = request.postpath # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name_bytes, media_id_bytes = postpath[:2]
server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8")
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
file_name = None
if len(postpath) > 2:
try:
file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except Exception:
raise SynapseError(
404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN
)
def respond_404(request: SynapseRequest) -> None:
respond_with_json(
request,
404,
cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND),
send_cors=True,
)
async def respond_with_file(
request: SynapseRequest,
media_type: str,
file_path: str,
file_size: Optional[int] = None,
upload_name: Optional[str] = None,
) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
await make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
respond_404(request)
def add_file_headers(
request: Request,
media_type: str,
file_size: Optional[int],
upload_name: Optional[str],
) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
request
media_type: The media/content type.
file_size: Size in bytes of the media, if known.
upload_name: The name of the requested file, if any.
"""
def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
# Default to a UTF-8 charset for text content types.
# ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
if media_type.lower() in TEXT_CONTENT_TYPES:
content_type = media_type + "; charset=UTF-8"
else:
content_type = media_type
request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
if upload_name:
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
#
# `filename` is defined to be a `value`, which is defined by RFC2616
# section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
# is (essentially) a single US-ASCII word, and a `quoted-string` is a
# US-ASCII string surrounded by double-quotes, using backslash as an
# escape character. Note that %-encoding is *not* permitted.
#
# `filename*` is defined to be an `ext-value`, which is defined in
# RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
# where `value-chars` is essentially a %-encoded string in the given charset.
#
# [1]: https://tools.ietf.org/html/rfc6266#section-4.1
# [2]: https://tools.ietf.org/html/rfc2616#section-3.6
# [3]: https://tools.ietf.org/html/rfc5987#section-3.2.1
# We avoid the quoted-string version of `filename`, because (a) synapse didn't
# correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
# may as well just do the filename* version.
if _can_encode_filename_as_token(upload_name):
disposition = "inline; filename=%s" % (upload_name,)
else:
disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,))
# Tell web crawlers to not index, archive, or follow links in media. This
# should help to prevent things in the media repo from showing up in web
# search results.
request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
_FILENAME_SEPARATOR_CHARS = {
"(",
")",
"<",
">",
"@",
",",
";",
":",
"\\",
'"',
"/",
"[",
"]",
"?",
"=",
"{",
"}",
}
def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
# token = 1*<any CHAR except CTLs or separators>
#
# separators = "(" | ")" | "<" | ">" | "@"
# | "," | ";" | ":" | "\" | <">
# | "/" | "[" | "]" | "?" | "="
# | "{" | "}" | SP | HT
#
# CHAR = <any US-ASCII character (octets 0 - 127)>
#
# CTL = <any US-ASCII control character
# (octets 0 - 31) and DEL (127)>
#
if ord(c) >= 127 or ord(c) <= 32 or c in _FILENAME_SEPARATOR_CHARS:
return False
return True
async def respond_with_responder(
request: SynapseRequest,
responder: "Optional[Responder]",
media_type: str,
file_size: Optional[int],
upload_name: Optional[str] = None,
) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
request
responder
media_type: The media/content type.
file_size: Size in bytes of the media. If not known it should be None
upload_name: The name of the requested file, if any.
"""
if not responder:
respond_404(request)
return
# If we have a responder we *must* use it as a context manager.
with responder:
if request._disconnected:
logger.warning(
"Not sending response to request %s, already disconnected.", request
)
return
logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name)
try:
await responder.write_to_consumer(request)
except Exception as e:
# The majority of the time this will be due to the client having gone
# away. Unfortunately, Twisted simply throws a generic exception at us
# in that case.
logger.warning("Failed to write to consumer: %s %s", type(e), e)
# Unregister the producer, if it has one, so Twisted doesn't complain
if request.producer:
request.unregisterProducer()
finish_request(request)
class Responder(ABC):
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
@abstractmethod
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
consumer: The consumer to stream into.
Returns:
Resolves once the response has finished being written
"""
raise NotImplementedError()
def __enter__(self) -> None: # noqa: B027
pass
def __exit__( # noqa: B027
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
pass
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThumbnailInfo:
"""Details about a generated thumbnail."""
width: int
height: int
method: str
# Content type of thumbnail, e.g. image/png
type: str
# The size of the media file, in bytes.
length: Optional[int] = None
@attr.s(slots=True, frozen=True, auto_attribs=True)
class FileInfo:
"""Details about a requested/uploaded file."""
# The server name where the media originated from, or None if local.
server_name: Optional[str]
# The local ID of the file. For local files this is the same as the media_id
file_id: str
# If the file is for the url preview cache
url_cache: bool = False
# Whether the file is a thumbnail or not.
thumbnail: Optional[ThumbnailInfo] = None
# The below properties exist to maintain compatibility with third-party modules.
@property
def thumbnail_width(self) -> Optional[int]:
if not self.thumbnail:
return None
return self.thumbnail.width
@property
def thumbnail_height(self) -> Optional[int]:
if not self.thumbnail:
return None
return self.thumbnail.height
@property
def thumbnail_method(self) -> Optional[str]:
if not self.thumbnail:
return None
return self.thumbnail.method
@property
def thumbnail_type(self) -> Optional[str]:
if not self.thumbnail:
return None
return self.thumbnail.type
@property
def thumbnail_length(self) -> Optional[int]:
if not self.thumbnail:
return None
return self.thumbnail.length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
headers: The HTTP request headers.
Returns:
The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
return None
_, params = _parse_header(content_disposition[0])
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get(b"filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith(b"utf-8''"):
upload_name_utf8 = upload_name_utf8[7:]
# We have a filename*= section. This MUST be ASCII, and any UTF-8
# bytes are %-quoted.
try:
# Once it is decoded, we can then unquote the %-encoded
# parts strictly into a unicode string.
upload_name = urllib.parse.unquote(
upload_name_utf8.decode("ascii"), errors="strict"
)
except UnicodeDecodeError:
# Incorrect UTF-8.
pass
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get(b"filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii.decode("ascii")
# This may be None here, indicating we did not find a matching name.
return upload_name
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
line: header to be parsed
Returns:
The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
pdict = {}
for p in parts:
i = p.find(b"=")
if i >= 0:
name = p[:i].strip().lower()
value = p[i + 1 :].strip()
# strip double-quotes
if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
value = value[1:-1]
value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
pdict[name] = value
return key, pdict
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
s: header to be parsed
Returns:
The split input
"""
while s[:1] == b";":
s = s[1:]
# look for the next ;
end = s.find(b";")
# if there is an odd number of " marks between here and the next ;, skip to the
# next ; instead
while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
end = s.find(b";", end + 1)
if end < 0:
end = len(s)
f = s[:end]
yield f.strip()
s = s[end:]
# This exists purely for backwards compatibility with media providers and spam checkers.
from synapse.media._base import FileInfo, Responder # noqa: F401

View file

@ -1,410 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import os
import re
import string
from typing import Any, Callable, List, TypeVar, Union, cast
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
F = TypeVar("F", bound=Callable[..., str])
def _wrap_in_base_path(func: F) -> F:
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@functools.wraps(func)
def _wrapped(self: "MediaFilePaths", *args: Any, **kwargs: Any) -> str:
path = func(self, *args, **kwargs)
return os.path.join(self.base_path, path)
return cast(F, _wrapped)
GetPathMethod = TypeVar(
"GetPathMethod", bound=Union[Callable[..., str], Callable[..., List[str]]]
)
def _wrap_with_jail_check(relative: bool) -> Callable[[GetPathMethod], GetPathMethod]:
"""Wraps a path-returning method to check that the returned path(s) do not escape
the media store directory.
The path-returning method may return either a single path, or a list of paths.
The check is not expected to ever fail, unless `func` is missing a call to
`_validate_path_component`, or `_validate_path_component` is buggy.
Args:
relative: A boolean indicating whether the wrapped method returns paths relative
to the media store directory.
Returns:
A method which will wrap a path-returning method, adding a check to ensure that
the returned path(s) lie within the media store directory. The check will raise
a `ValueError` if it fails.
"""
def _wrap_with_jail_check_inner(func: GetPathMethod) -> GetPathMethod:
@functools.wraps(func)
def _wrapped(
self: "MediaFilePaths", *args: Any, **kwargs: Any
) -> Union[str, List[str]]:
path_or_paths = func(self, *args, **kwargs)
if isinstance(path_or_paths, list):
paths_to_check = path_or_paths
else:
paths_to_check = [path_or_paths]
for path in paths_to_check:
# Construct the path that will ultimately be used.
# We cannot guess whether `path` is relative to the media store
# directory, since the media store directory may itself be a relative
# path.
if relative:
path = os.path.join(self.base_path, path)
normalized_path = os.path.normpath(path)
# Now that `normpath` has eliminated `../`s and `./`s from the path,
# `os.path.commonpath` can be used to check whether it lies within the
# media store directory.
if (
os.path.commonpath([normalized_path, self.normalized_base_path])
!= self.normalized_base_path
):
# The path resolves to outside the media store directory,
# or `self.base_path` is `.`, which is an unlikely configuration.
raise ValueError(f"Invalid media store path: {path!r}")
# Note that `os.path.normpath`/`abspath` has a subtle caveat:
# `a/b/c/../c` will normalize to `a/b/c`, but the former refers to a
# different path if `a/b/c` is a symlink. That is, the check above is
# not perfect and may allow a certain restricted subset of untrustworthy
# paths through. Since the check above is secondary to the main
# `_validate_path_component` checks, it's less important for it to be
# perfect.
#
# As an alternative, `os.path.realpath` will resolve symlinks, but
# proves problematic if there are symlinks inside the media store.
# eg. if `url_store/` is symlinked to elsewhere, its canonical path
# won't match that of the main media store directory.
return path_or_paths
return cast(GetPathMethod, _wrapped)
return _wrap_with_jail_check_inner
ALLOWED_CHARACTERS = set(
string.ascii_letters
+ string.digits
+ "_-"
+ ".[]:" # Domain names, IPv6 addresses and ports in server names
)
FORBIDDEN_NAMES = {
"",
os.path.curdir, # "." for the current platform
os.path.pardir, # ".." for the current platform
}
def _validate_path_component(name: str) -> str:
"""Checks that the given string can be safely used as a path component
Args:
name: The path component to check.
Returns:
The path component if valid.
Raises:
ValueError: If `name` cannot be safely used as a path component.
"""
if not ALLOWED_CHARACTERS.issuperset(name) or name in FORBIDDEN_NAMES:
raise ValueError(f"Invalid path component: {name!r}")
return name
class MediaFilePaths:
"""Describes where files are stored on disk.
Most of the functions have a `*_rel` variant which returns a file path that
is relative to the base media store path. This is mainly used when we want
to write to the backup media store (when one is configured)
"""
def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
self.normalized_base_path = os.path.normpath(self.base_path)
# Refuse to initialize if paths cannot be validated correctly for the current
# platform.
assert os.path.sep not in ALLOWED_CHARACTERS
assert os.path.altsep not in ALLOWED_CHARACTERS
# On Windows, paths have all sorts of weirdness which `_validate_path_component`
# does not consider. In any case, the remote media store can't work correctly
# for certain homeservers there, since ":"s aren't allowed in paths.
assert os.name == "posix"
@_wrap_with_jail_check(relative=True)
def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join(
"local_content",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
@_wrap_with_jail_check(relative=True)
def local_media_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"local_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
_validate_path_component(file_name),
)
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
@_wrap_with_jail_check(relative=False)
def local_media_thumbnail_dir(self, media_id: str) -> str:
"""
Retrieve the local store path of thumbnails of a given media_id
Args:
media_id: The media ID to query.
Returns:
Path of local_thumbnails from media_id
"""
return os.path.join(
self.base_path,
"local_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
@_wrap_with_jail_check(relative=True)
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content",
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
)
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
@_wrap_with_jail_check(relative=True)
def remote_media_thumbnail_rel(
self,
server_name: str,
file_id: str,
width: int,
height: int,
content_type: str,
method: str,
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
_validate_path_component(file_name),
)
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
# Legacy path that was used to store thumbnails previously.
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
@_wrap_with_jail_check(relative=True)
def remote_media_thumbnail_rel_legacy(
self, server_name: str, file_id: str, width: int, height: int, content_type: str
) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
"remote_thumbnail",
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
_validate_path_component(file_name),
)
@_wrap_with_jail_check(relative=False)
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
_validate_path_component(server_name),
_validate_path_component(file_id[0:2]),
_validate_path_component(file_id[2:4]),
_validate_path_component(file_id[4:]),
)
@_wrap_with_jail_check(relative=True)
def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
return os.path.join(
"url_cache",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
)
else:
return os.path.join(
"url_cache",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
@_wrap_with_jail_check(relative=False)
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path, "url_cache", _validate_path_component(media_id[:10])
)
]
else:
return [
os.path.join(
self.base_path,
"url_cache",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
),
os.path.join(
self.base_path, "url_cache", _validate_path_component(media_id[0:2])
),
]
@_wrap_with_jail_check(relative=True)
def url_cache_thumbnail_rel(
self, media_id: str, width: int, height: int, content_type: str, method: str
) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
_validate_path_component(file_name),
)
else:
return os.path.join(
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
_validate_path_component(file_name),
)
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
@_wrap_with_jail_check(relative=True)
def url_cache_thumbnail_directory_rel(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
)
else:
return os.path.join(
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
)
url_cache_thumbnail_directory = _wrap_in_base_path(
url_cache_thumbnail_directory_rel
)
@_wrap_with_jail_check(relative=False)
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
_validate_path_component(media_id[11:]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[:10]),
),
]
else:
return [
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
_validate_path_component(media_id[4:]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
_validate_path_component(media_id[2:4]),
),
os.path.join(
self.base_path,
"url_cache_thumbnails",
_validate_path_component(media_id[0:2]),
),
]

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,364 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import os
import shutil
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Awaitable,
BinaryIO,
Callable,
Generator,
Optional,
Sequence,
Tuple,
Type,
)
#
import attr
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
import synapse
from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
from synapse.rest.media.v1.storage_provider import StorageProvider
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class MediaStorage:
"""Responsible for storing/fetching files from local sources.
Args:
hs
local_media_directory: Base path where we store media on disk
filepaths
storage_providers: List of StorageProvider that are used to fetch and store files.
"""
def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.reactor = hs.get_reactor()
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
self.spam_checker = hs.get_spam_checker()
self.clock = hs.get_clock()
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
Args:
source: A file like object that should be written
file_info: Info about the file to store
Returns:
the file path written to in the primary media store
"""
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
await self.write_to_file(source, f)
await finish_cb()
return fname
async def write_to_file(self, source: IO, output: IO) -> None:
"""Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
def store_into_file(
self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]:
"""Context manager used to get a file like object to write into, as
described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
like object that can be written to, fname is the absolute path of file
on disk, and finish_cb is a function that returns an awaitable.
fname can be used to read the contents from after upload, e.g. to
generate thumbnails.
finish_cb must be called and waited on after the file has been
successfully been written to. Should not be called if there was an
error.
Args:
file_info: Info about the file to store
Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb):
# .. write into f ...
await finish_cb()
"""
path = self._file_info_to_path(file_info)
fname = os.path.join(self.local_media_directory, path)
dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True)
finished_called = [False]
try:
with open(fname, "wb") as f:
async def finish() -> None:
# Ensure that all writes have been flushed and close the
# file.
f.flush()
f.close()
spam_check = await self.spam_checker.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info
)
if spam_check != synapse.module_api.NOT_SPAM:
logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the
# try/except below. The media also won't be stored in
# the DB.
# We currently ignore any additional field returned by
# the spam-check API.
raise SpamMediaException(errcode=spam_check[0])
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception as e:
try:
os.remove(fname)
except Exception:
pass
raise e from None
if not finished_called:
raise Exception("Finished callback not called")
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache
and configured storage providers.
Args:
file_info
Returns:
Returns a Responder if the file was found, otherwise None.
"""
paths = [self._file_info_to_path(file_info)]
# fallback for remote thumbnails with no method in the filename
if file_info.thumbnail and file_info.server_name:
paths.append(
self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
)
)
for path in paths:
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
logger.debug("responding with local file %s", local_path)
return FileResponder(open(local_path, "rb"))
logger.debug("local file %s did not exist", local_path)
for provider in self.storage_providers:
for path in paths:
res: Any = await provider.fetch(path, file_info)
if res:
logger.debug("Streaming %s from %s", path, provider)
return res
logger.debug("%s not found on %s", path, provider)
return None
async def ensure_media_is_in_local_cache(self, file_info: FileInfo) -> str:
"""Ensures that the given file is in the local cache. Attempts to
download it from storage providers if it isn't.
Args:
file_info
Returns:
Full path to local file
"""
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
return local_path
# Fallback for paths without method names
# Should be removed in the future
if file_info.thumbnail and file_info.server_name:
legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
)
legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
if os.path.exists(legacy_local_path):
return legacy_local_path
dirname = os.path.dirname(local_path)
os.makedirs(dirname, exist_ok=True)
for provider in self.storage_providers:
res: Any = await provider.fetch(path, file_info)
if res:
with res:
consumer = BackgroundFileConsumer(
open(local_path, "wb"), self.reactor
)
await res.write_to_consumer(consumer)
await consumer.wait()
return local_path
raise NotFoundError()
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
"""
if file_info.url_cache:
if file_info.thumbnail:
return self.filepaths.url_cache_thumbnail_rel(
media_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
method=file_info.thumbnail.method,
)
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
if file_info.server_name:
if file_info.thumbnail:
return self.filepaths.remote_media_thumbnail_rel(
server_name=file_info.server_name,
file_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
method=file_info.thumbnail.method,
)
return self.filepaths.remote_media_filepath_rel(
file_info.server_name, file_info.file_id
)
if file_info.thumbnail:
return self.filepaths.local_media_thumbnail_rel(
media_id=file_info.file_id,
width=file_info.thumbnail.width,
height=file_info.thumbnail.height,
content_type=file_info.thumbnail.type,
method=file_info.thumbnail.method,
)
return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
Args:
source: A file like object that's to be written
dest: A file like object to be written to
"""
source.seek(0) # Ensure we read from the start of the file
shutil.copyfileobj(source, dest)
class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
def __init__(self, open_file: IO):
self.open_file = open_file
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
self.open_file.close()
class SpamMediaException(NotFoundError):
"""The media was blocked by a spam checker, so we simply 404 the request (in
the same way as if it was quarantined).
"""
@attr.s(slots=True, auto_attribs=True)
class ReadableFileWrapper:
"""Wrapper that allows reading a file in chunks, yielding to the reactor,
and writing to a callback.
This is simplified `FileSender` that takes an IO object rather than an
`IConsumer`.
"""
CHUNK_SIZE = 2**14
clock: Clock
path: str
async def write_chunks_to(self, callback: Callable[[bytes], object]) -> None:
"""Reads the file in chunks and calls the callback with each chunk."""
with open(self.path, "rb") as file:
while True:
chunk = file.read(self.CHUNK_SIZE)
if not chunk:
break
callback(chunk)
# We yield to the reactor by sleeping for 0 seconds.
await self.clock.sleep(0)
# This exists purely for backwards compatibility with spam checkers.
from synapse.media.media_storage import ReadableFileWrapper # noqa: F401

View file

@ -1,265 +0,0 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import logging
import urllib.parse
from typing import TYPE_CHECKING, List, Optional
import attr
from synapse.rest.media.v1.preview_html import parse_html_description
from synapse.types import JsonDict
from synapse.util import json_decoder
if TYPE_CHECKING:
from lxml import etree
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class OEmbedResult:
# The Open Graph result (converted from the oEmbed result).
open_graph_result: JsonDict
# The author_name of the oEmbed result
author_name: Optional[str]
# Number of milliseconds to cache the content, according to the oEmbed response.
#
# This will be None if no cache-age is provided in the oEmbed response (or
# if the oEmbed response cannot be turned into an Open Graph response).
cache_age: Optional[int]
class OEmbedProvider:
"""
A helper for accessing oEmbed content.
It can be used to check if a URL should be accessed via oEmbed and for
requesting/parsing oEmbed content.
"""
def __init__(self, hs: "HomeServer"):
self._oembed_patterns = {}
for oembed_endpoint in hs.config.oembed.oembed_patterns:
api_endpoint = oembed_endpoint.api_endpoint
# Only JSON is supported at the moment. This could be declared in
# the formats field. Otherwise, if the endpoint ends in .xml assume
# it doesn't support JSON.
if (
oembed_endpoint.formats is not None
and "json" not in oembed_endpoint.formats
) or api_endpoint.endswith(".xml"):
logger.info(
"Ignoring oEmbed endpoint due to not supporting JSON: %s",
api_endpoint,
)
continue
# Iterate through each URL pattern and point it to the endpoint.
for pattern in oembed_endpoint.url_patterns:
self._oembed_patterns[pattern] = api_endpoint
def get_oembed_url(self, url: str) -> Optional[str]:
"""
Check whether the URL should be downloaded as oEmbed content instead.
Args:
url: The URL to check.
Returns:
A URL to use instead or None if the original URL should be used.
"""
for url_pattern, endpoint in self._oembed_patterns.items():
if url_pattern.fullmatch(url):
# TODO Specify max height / width.
# Note that only the JSON format is supported, some endpoints want
# this in the URL, others want it as an argument.
endpoint = endpoint.replace("{format}", "json")
args = {"url": url, "format": "json"}
query_str = urllib.parse.urlencode(args, True)
return f"{endpoint}?{query_str}"
# No match.
return None
def autodiscover_from_html(self, tree: "etree.Element") -> Optional[str]:
"""
Search an HTML document for oEmbed autodiscovery information.
Args:
tree: The parsed HTML body.
Returns:
The URL to use for oEmbed information, or None if no URL was found.
"""
# Search for link elements with the proper rel and type attributes.
for tag in tree.xpath(
"//link[@rel='alternate'][@type='application/json+oembed']"
):
if "href" in tag.attrib:
return tag.attrib["href"]
# Some providers (e.g. Flickr) use alternative instead of alternate.
for tag in tree.xpath(
"//link[@rel='alternative'][@type='application/json+oembed']"
):
if "href" in tag.attrib:
return tag.attrib["href"]
return None
def parse_oembed_response(self, url: str, raw_body: bytes) -> OEmbedResult:
"""
Parse the oEmbed response into an Open Graph response.
Args:
url: The URL which is being previewed (not the one which was
requested).
raw_body: The oEmbed response as JSON encoded as bytes.
Returns:
json-encoded Open Graph data
"""
try:
# oEmbed responses *must* be UTF-8 according to the spec.
oembed = json_decoder.decode(raw_body.decode("utf-8"))
except ValueError:
return OEmbedResult({}, None, None)
# The version is a required string field, but not always provided,
# or sometimes provided as a float. Be lenient.
oembed_version = oembed.get("version", "1.0")
if oembed_version != "1.0" and oembed_version != 1:
return OEmbedResult({}, None, None)
# Attempt to parse the cache age, if possible.
try:
cache_age = int(oembed.get("cache_age")) * 1000
except (TypeError, ValueError):
# If the cache age cannot be parsed (e.g. wrong type or invalid
# string), ignore it.
cache_age = None
# The oEmbed response converted to Open Graph.
open_graph_response: JsonDict = {"og:url": url}
title = oembed.get("title")
if title and isinstance(title, str):
# A common WordPress plug-in seems to incorrectly escape entities
# in the oEmbed response.
open_graph_response["og:title"] = html.unescape(title)
author_name = oembed.get("author_name")
if not isinstance(author_name, str):
author_name = None
# Use the provider name and as the site.
provider_name = oembed.get("provider_name")
if provider_name and isinstance(provider_name, str):
open_graph_response["og:site_name"] = provider_name
# If a thumbnail exists, use it. Note that dimensions will be calculated later.
thumbnail_url = oembed.get("thumbnail_url")
if thumbnail_url and isinstance(thumbnail_url, str):
open_graph_response["og:image"] = thumbnail_url
# Process each type separately.
oembed_type = oembed.get("type")
if oembed_type == "rich":
html_str = oembed.get("html")
if isinstance(html_str, str):
calc_description_and_urls(open_graph_response, html_str)
elif oembed_type == "photo":
# If this is a photo, use the full image, not the thumbnail.
url = oembed.get("url")
if url and isinstance(url, str):
open_graph_response["og:image"] = url
elif oembed_type == "video":
open_graph_response["og:type"] = "video.other"
html_str = oembed.get("html")
if html_str and isinstance(html_str, str):
calc_description_and_urls(open_graph_response, oembed["html"])
for size in ("width", "height"):
val = oembed.get(size)
if type(val) is int:
open_graph_response[f"og:video:{size}"] = val
elif oembed_type == "link":
open_graph_response["og:type"] = "website"
else:
logger.warning("Unknown oEmbed type: %s", oembed_type)
return OEmbedResult(open_graph_response, author_name, cache_age)
def _fetch_urls(tree: "etree.Element", tag_name: str) -> List[str]:
results = []
for tag in tree.xpath("//*/" + tag_name):
if "src" in tag.attrib:
results.append(tag.attrib["src"])
return results
def calc_description_and_urls(open_graph_response: JsonDict, html_body: str) -> None:
"""
Calculate description for an HTML document.
This uses lxml to convert the HTML document into plaintext. If errors
occur during processing of the document, an empty response is returned.
Args:
open_graph_response: The current Open Graph summary. This is updated with additional fields.
html_body: The HTML document, as bytes.
Returns:
The summary
"""
# If there's no body, nothing useful is going to be found.
if not html_body:
return
from lxml import etree
# Create an HTML parser. If this fails, log and return no metadata.
parser = etree.HTMLParser(recover=True, encoding="utf-8")
# Attempt to parse the body. If this fails, log and return no metadata.
tree = etree.fromstring(html_body, parser)
# The data was successfully parsed, but no tree was found.
if tree is None:
return
# Attempt to find interesting URLs (images, videos, embeds).
if "og:image" not in open_graph_response:
image_urls = _fetch_urls(tree, "img")
if image_urls:
open_graph_response["og:image"] = image_urls[0]
video_urls = _fetch_urls(tree, "video") + _fetch_urls(tree, "embed")
if video_urls:
open_graph_response["og:video"] = video_urls[0]
description = parse_html_description(tree)
if description:
open_graph_response["og:description"] = description

View file

@ -1,501 +0,0 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import codecs
import logging
import re
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Union,
)
if TYPE_CHECKING:
from lxml import etree
logger = logging.getLogger(__name__)
_charset_match = re.compile(
rb'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9_-]+)"?', flags=re.I
)
_xml_encoding_match = re.compile(
rb'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9_-]+)"', flags=re.I
)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
# Certain elements aren't meant for display.
ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"}
def _normalise_encoding(encoding: str) -> Optional[str]:
"""Use the Python codec's name as the normalised entry."""
try:
return codecs.lookup(encoding).name
except LookupError:
return None
def _get_html_media_encodings(
body: bytes, content_type: Optional[str]
) -> Iterable[str]:
"""
Get potential encoding of the body based on the (presumably) HTML body or the content-type header.
The precedence used for finding a character encoding is:
1. <meta> tag with a charset declared.
2. The XML document's character encoding attribute.
3. The Content-Type header.
4. Fallback to utf-8.
5. Fallback to windows-1252.
This roughly follows the algorithm used by BeautifulSoup's bs4.dammit.EncodingDetector.
Args:
body: The HTML document, as bytes.
content_type: The Content-Type header.
Returns:
The character encoding of the body, as a string.
"""
# There's no point in returning an encoding more than once.
attempted_encodings: Set[str] = set()
# Limit searches to the first 1kb, since it ought to be at the top.
body_start = body[:1024]
# Check if it has an encoding set in a meta tag.
match = _charset_match.search(body_start)
if match:
encoding = _normalise_encoding(match.group(1).decode("ascii"))
if encoding:
attempted_encodings.add(encoding)
yield encoding
# TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
# Check if it has an XML document with an encoding.
match = _xml_encoding_match.match(body_start)
if match:
encoding = _normalise_encoding(match.group(1).decode("ascii"))
if encoding and encoding not in attempted_encodings:
attempted_encodings.add(encoding)
yield encoding
# Check the HTTP Content-Type header for a character set.
if content_type:
content_match = _content_type_match.match(content_type)
if content_match:
encoding = _normalise_encoding(content_match.group(1))
if encoding and encoding not in attempted_encodings:
attempted_encodings.add(encoding)
yield encoding
# Finally, fallback to UTF-8, then windows-1252.
for fallback in ("utf-8", "cp1252"):
if fallback not in attempted_encodings:
yield fallback
def decode_body(
body: bytes, uri: str, content_type: Optional[str] = None
) -> Optional["etree.Element"]:
"""
This uses lxml to parse the HTML document.
Args:
body: The HTML document, as bytes.
uri: The URI used to download the body.
content_type: The Content-Type header.
Returns:
The parsed HTML body, or None if an error occurred during processed.
"""
# If there's no body, nothing useful is going to be found.
if not body:
return None
# The idea here is that multiple encodings are tried until one works.
# Unfortunately the result is never used and then LXML will decode the string
# again with the found encoding.
for encoding in _get_html_media_encodings(body, content_type):
try:
body.decode(encoding)
except Exception:
pass
else:
break
else:
logger.warning("Unable to decode HTML body for %s", uri)
return None
from lxml import etree
# Create an HTML parser.
parser = etree.HTMLParser(recover=True, encoding=encoding)
# Attempt to parse the body. Returns None if the body was successfully
# parsed, but no tree was found.
return etree.fromstring(body, parser)
def _get_meta_tags(
tree: "etree.Element",
property: str,
prefix: str,
property_mapper: Optional[Callable[[str], Optional[str]]] = None,
) -> Dict[str, Optional[str]]:
"""
Search for meta tags prefixed with a particular string.
Args:
tree: The parsed HTML document.
property: The name of the property which contains the tag name, e.g.
"property" for Open Graph.
prefix: The prefix on the property to search for, e.g. "og" for Open Graph.
property_mapper: An optional callable to map the property to the Open Graph
form. Can return None for a key to ignore that key.
Returns:
A map of tag name to value.
"""
results: Dict[str, Optional[str]] = {}
for tag in tree.xpath(
f"//*/meta[starts-with(@{property}, '{prefix}:')][@content][not(@content='')]"
):
# if we've got more than 50 tags, someone is taking the piss
if len(results) >= 50:
logger.warning(
"Skipping parsing of Open Graph for page with too many '%s:' tags",
prefix,
)
return {}
key = tag.attrib[property]
if property_mapper:
key = property_mapper(key)
# None is a special value used to ignore a value.
if key is None:
continue
results[key] = tag.attrib["content"]
return results
def _map_twitter_to_open_graph(key: str) -> Optional[str]:
"""
Map a Twitter card property to the analogous Open Graph property.
Args:
key: The Twitter card property (starts with "twitter:").
Returns:
The Open Graph property (starts with "og:") or None to have this property
be ignored.
"""
# Twitter card properties with no analogous Open Graph property.
if key == "twitter:card" or key == "twitter:creator":
return None
if key == "twitter:site":
return "og:site_name"
# Otherwise, swap twitter to og.
return "og" + key[7:]
def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
"""
Parse the HTML document into an Open Graph response.
This uses lxml to search the HTML document for Open Graph data (or
synthesizes it from the document).
Args:
tree: The parsed HTML document.
Returns:
The Open Graph response as a dictionary.
"""
# Search for Open Graph (og:) meta tags, e.g.:
#
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = _get_meta_tags(tree, "property", "og")
# TODO: Search for properties specific to the different Open Graph types,
# such as article: meta tags, e.g.:
#
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
# Search for Twitter Card (twitter:) meta tags, e.g.:
#
# "twitter:site" : "@matrixdotorg"
# "twitter:creator" : "@matrixdotorg"
#
# Twitter cards tags also duplicate Open Graph tags.
#
# See https://developer.twitter.com/en/docs/twitter-for-websites/cards/guides/getting-started
twitter = _get_meta_tags(tree, "name", "twitter", _map_twitter_to_open_graph)
# Merge the Twitter values with the Open Graph values, but do not overwrite
# information from Open Graph tags.
for key, value in twitter.items():
if key not in og:
og[key] = value
if "og:title" not in og:
# Attempt to find a title from the title tag, or the biggest header on the page.
title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
if title:
og["og:title"] = title[0].strip()
else:
og["og:title"] = None
if "og:image" not in og:
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
)
# If a meta image is found, use it.
if meta_image:
og["og:image"] = meta_image[0]
else:
# Try to find images which are larger than 10px by 10px.
#
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(
images,
key=lambda i: (
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
),
)
# If no images were found, try to find *any* images.
if not images:
images = tree.xpath("//img[@src][1]")
if images:
og["og:image"] = images[0].attrib["src"]
# Finally, fallback to the favicon if nothing else.
else:
favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]")
if favicons:
og["og:image"] = favicons[0]
if "og:description" not in og:
# Check the first meta description tag for content.
meta_description = tree.xpath(
"//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
)
# If a meta description is found with content, use it.
if meta_description:
og["og:description"] = meta_description[0]
else:
og["og:description"] = parse_html_description(tree)
elif og["og:description"]:
# This must be a non-empty string at this point.
assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
return og
def parse_html_description(tree: "etree.Element") -> Optional[str]:
"""
Calculate a text description based on an HTML document.
Grabs any text nodes which are inside the <body/> tag, unless they are within
an HTML5 semantic markup tag (<header/>, <nav/>, <aside/>, <footer/>), or
if they are within a <script/>, <svg/> or <style/> tag, or if they are within
a tag whose content is usually only shown to old browsers
(<iframe/>, <video/>, <canvas/>, <picture/>).
This is a very very very coarse approximation to a plain text render of the page.
Args:
tree: The parsed HTML document.
Returns:
The plain text description, or None if one cannot be generated.
"""
# We don't just use XPATH here as that is slow on some machines.
from lxml import etree
TAGS_TO_REMOVE = {
"header",
"nav",
"aside",
"footer",
"script",
"noscript",
"style",
"svg",
"iframe",
"video",
"canvas",
"img",
"picture",
etree.Comment,
}
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
re.sub(r"\s+", "\n", el).strip()
for el in _iterate_over_text(tree.find("body"), TAGS_TO_REMOVE)
)
return summarize_paragraphs(text_nodes)
def _iterate_over_text(
tree: Optional["etree.Element"],
tags_to_ignore: Set[Union[str, "etree.Comment"]],
stack_limit: int = 1024,
) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
Args:
tree: The parent element to iterate. Can be None if there isn't one.
tags_to_ignore: Set of tags to ignore
stack_limit: Maximum stack size limit for depth-first traversal.
Nodes will be dropped if this limit is hit, which may truncate the
textual result.
Intended to limit the maximum working memory when generating a preview.
"""
if tree is None:
return
# This is a stack whose items are elements to iterate over *or* strings
# to be returned.
elements: List[Union[str, "etree.Element"]] = [tree]
while elements:
el = elements.pop()
if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
# If the element isn't meant for display, ignore it.
if el.get("role") in ARIA_ROLES_TO_IGNORE:
continue
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
yield el.text
# We add to the stack all the element's children, interspersed with
# each child's tail text (if it exists).
#
# We iterate in reverse order so that earlier pieces of text appear
# closer to the top of the stack.
for child in el.iterchildren(reversed=True):
if len(elements) > stack_limit:
# We've hit our limit for working memory
break
if child.tail:
# The tail text of a node is text that comes *after* the node,
# so we always include it even if we ignore the child node.
elements.append(child.tail)
elements.append(child)
def summarize_paragraphs(
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
) -> Optional[str]:
"""
Try to get a summary respecting first paragraph and then word boundaries.
Args:
text_nodes: The paragraphs to summarize.
min_size: The minimum number of words to include.
max_size: The maximum number of words to include.
Returns:
A summary of the text nodes, or None if that was not possible.
"""
# TODO: Respect sentences?
description = ""
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
text_node = re.sub(r"[\t \r\n]+", " ", text_node)
description += text_node + "\n\n"
else:
break
description = description.strip()
description = re.sub(r"[\t ]+", " ", description)
description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
if len(description) > max_size:
new_desc = ""
# This splits the paragraph into words, but keeping the
# (preceding) whitespace intact so we can easily concat
# words back together.
for match in re.finditer(r"\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
# MAX_SIZE.
if len(word) + len(new_desc) < max_size:
new_desc += word
else:
# At this point the next word *will* take us over
# MAX_SIZE, but we also want to ensure that its not
# a huge word. If it is add it anyway and we'll
# truncate later.
if len(new_desc) < min_size:
new_desc += word
break
# Double check that we're not over the limit
if len(new_desc) > max_size:
new_desc = new_desc[:max_size]
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
description = new_desc.strip() + ""
return description if description else None

View file

@ -1,4 +1,4 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -11,171 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import abc
import logging
import os
import shutil
from typing import TYPE_CHECKING, Callable, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from synapse.server import HomeServer
class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
@abc.abstractmethod
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
path: Relative path of file in local cache
file_info: The metadata of the file.
"""
@abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
path: Relative path of file in local cache
file_info: The metadata of the file.
Returns:
Returns a Responder if the provider has the file, otherwise returns None.
"""
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
backend: The storage provider to wrap.
store_local: Whether to store new local files or not.
store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
store_remote: Whether remote media should be uploaded
"""
def __init__(
self,
backend: StorageProvider,
store_local: bool,
store_synchronous: bool,
store_remote: bool,
):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
if file_info.server_name and not self.store_remote:
return None
if file_info.url_cache:
# The URL preview cache is short lived and not worth offloading or
# backing up.
return None
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store() -> None:
try:
return await maybe_awaitable(
self.backend.store_file(path, file_info)
)
except Exception:
logger.exception("Error storing file")
run_in_background(store)
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
if file_info.url_cache:
# Files in the URL preview cache definitely aren't stored here,
# so avoid any potentially slow I/O or network access.
return None
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
hs
config: The config returned by `parse_config`.
"""
def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media.media_store_path
self.base_directory = config
def __str__(self) -> str:
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
backup_fname = os.path.join(self.base_directory, path)
dirname = os.path.dirname(backup_fname)
os.makedirs(dirname, exist_ok=True)
# mypy needs help inferring the type of the second parameter, which is generic
shutil_copyfile: Callable[[str, str], str] = shutil.copyfile
await defer_to_thread(
self.hs.get_reactor(),
shutil_copyfile,
primary_fname,
backup_fname,
)
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
return None
@staticmethod
def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
The returned value is passed into the constructor.
In this case we only care about a single param, the directory, so let's
just pull that out.
"""
return Config.ensure_directory(config["directory"])
# This exists purely for backwards compatibility with media providers.
from synapse.media.storage_provider import StorageProvider # noqa: F401

View file

@ -1,221 +0,0 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from io import BytesIO
from types import TracebackType
from typing import Optional, Tuple, Type
from PIL import Image
logger = logging.getLogger(__name__)
EXIF_ORIENTATION_TAG = 0x0112
EXIF_TRANSPOSE_MAPPINGS = {
2: Image.FLIP_LEFT_RIGHT,
3: Image.ROTATE_180,
4: Image.FLIP_TOP_BOTTOM,
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
8: Image.ROTATE_90,
}
class ThumbnailError(Exception):
"""An error occurred generating a thumbnail."""
class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
@staticmethod
def set_limits(max_image_pixels: int) -> None:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
# Have we closed the image?
self._closed = False
try:
self.image = Image.open(input_path)
except OSError as e:
# If an error occurs opening the image, a thumbnail won't be able to
# be generated.
raise ThumbnailError from e
except Image.DecompressionBombError as e:
# If an image decompression bomb error occurs opening the image,
# then the image exceeds the pixel limit and a thumbnail won't
# be able to be generated.
raise ThumbnailError from e
self.width, self.height = self.image.size
self.transpose_method = None
try:
# We don't use ImageOps.exif_transpose since it crashes with big EXIF
#
# Ignore safety: Pillow seems to acknowledge that this method is
# "private, experimental, but generally widely used". Pillow 6
# includes a public getexif() method (no underscore) that we might
# consider using instead when we can bump that dependency.
#
# At the time of writing, Debian buster (currently oldstable)
# provides version 5.4.1. It's expected to EOL in mid-2022, see
# https://wiki.debian.org/DebianReleases#Production_Releases
image_exif = self.image._getexif() # type: ignore
if image_exif is not None:
image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
assert type(image_orientation) is int
self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
except Exception as e:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag
Returns:
A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
# Safety: `transpose` takes an int rather than e.g. an IntEnum.
# self.transpose_method is set above to be a value in
# EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
with self.image:
self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
self.width, self.height = self.image.size
self.transpose_method = None
# We don't need EXIF any more
self.image.info["exif"] = None
return self.image.size
def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
(w_in / h_in) = (w_out / h_out)
w_out = max(min(w_max, h_max * (w_in / h_in)), 1)
h_out = max(min(h_max, w_max * (h_in / w_in)), 1)
Args:
max_width: The largest possible width.
max_height: The largest possible height.
"""
if max_width * self.height < max_height * self.width:
return max_width, max((max_width * self.height) // self.width, 1)
else:
return max((max_height * self.width) // self.height, 1), max_height
def _resize(self, width: int, height: int) -> Image.Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful.
#
# If the image has transparency, use RGBA instead.
if self.image.mode in ["1", "L", "P"]:
if self.image.info.get("transparency", None) is not None:
with self.image:
self.image = self.image.convert("RGBA")
else:
with self.image:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions.
Returns:
The bytes of the encoded image ready to be written to disk
"""
with self._resize(width, height) as scaled:
return self._encode_image(scaled, output_type)
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
w_scaled = max(w_out, h_out * (w_in / h_in))
h_scaled = max(h_out, w_out * (h_in / w_in))
Args:
max_width: The largest possible width.
max_height: The largest possible height.
Returns:
The bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
scaled_width = width
scaled_height = (width * self.height) // self.width
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
crop = (0, crop_top, width, crop_bottom)
else:
scaled_width = (height * self.width) // self.height
scaled_height = height
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
crop = (crop_left, 0, crop_right, height)
with self._resize(scaled_width, scaled_height) as scaled_image:
with scaled_image.crop(crop) as cropped:
return self._encode_image(cropped, output_type)
def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":
output_image = output_image.convert("RGB")
output_image.save(output_bytes_io, fmt, quality=80)
return output_bytes_io
def close(self) -> None:
"""Closes the underlying image file.
Once closed no other functions can be called.
Can be called multiple times.
"""
if self._closed:
return
self._closed = True
# Since we run this on the finalizer then we need to handle `__init__`
# raising an exception before it can define `self.image`.
image = getattr(self, "image", None)
if image is None:
return
image.close()
def __enter__(self) -> "Thumbnailer":
"""Make `Thumbnailer` a context manager that calls `close` on
`__exit__`.
"""
return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close()
def __del__(self) -> None:
# Make sure we actually do close the image, rather than leak data.
self.close()