Support MSC3916 by adding a federation /thumbnail endpoint and authenticated _matrix/client/v1/media/thumbnail endpoint (#17388)

[MSC3916](https://github.com/matrix-org/matrix-spec-proposals/pull/3916)
added the endpoints `_matrix/federation/v1/media/thumbnail` and the
authenticated `_matrix/client/v1/media/thumbnail`.

This PR implements those endpoints, along with stabilizing
`_matrix/client/v1/media/config` and
`_matrix/client/v1/media/preview_url`.

Complement tests are at
https://github.com/matrix-org/complement/pull/728
This commit is contained in:
Shay 2024-07-08 02:11:20 -07:00 committed by GitHub
parent 20de685a4b
commit cf69f8d59b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 585 additions and 131 deletions

View File

@ -0,0 +1,3 @@
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md)
by adding `_matrix/client/v1/media/thumbnail`, `_matrix/federation/v1/media/thumbnail` endpoints and stabilizing the
remaining `_matrix/client/v1/media` endpoints.

View File

@ -437,10 +437,6 @@ class ExperimentalConfig(Config):
"msc3823_account_suspension", False "msc3823_account_suspension", False
) )
self.msc3916_authenticated_media_enabled = experimental.get(
"msc3916_authenticated_media_enabled", False
)
# MSC4151: Report room API (Client-Server API) # MSC4151: Report room API (Client-Server API)
self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False) self.msc4151_enabled: bool = experimental.get("msc4151_enabled", False)

View File

@ -33,6 +33,7 @@ from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES, FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationMediaDownloadServlet, FederationMediaDownloadServlet,
FederationMediaThumbnailServlet,
FederationUnstableClientKeysClaimServlet, FederationUnstableClientKeysClaimServlet,
) )
from synapse.http.server import HttpServer, JsonResource from synapse.http.server import HttpServer, JsonResource
@ -316,7 +317,10 @@ def register_servlets(
): ):
continue continue
if servletclass == FederationMediaDownloadServlet: if (
servletclass == FederationMediaDownloadServlet
or servletclass == FederationMediaThumbnailServlet
):
if not hs.config.server.enable_media_repo: if not hs.config.server.enable_media_repo:
continue continue

View File

@ -363,6 +363,8 @@ class BaseFederationServlet:
if ( if (
func.__self__.__class__.__name__ # type: ignore func.__self__.__class__.__name__ # type: ignore
== "FederationMediaDownloadServlet" == "FederationMediaDownloadServlet"
or func.__self__.__class__.__name__ # type: ignore
== "FederationMediaThumbnailServlet"
): ):
response = await func( response = await func(
origin, content, request, *args, **kwargs origin, content, request, *args, **kwargs
@ -375,6 +377,8 @@ class BaseFederationServlet:
if ( if (
func.__self__.__class__.__name__ # type: ignore func.__self__.__class__.__name__ # type: ignore
== "FederationMediaDownloadServlet" == "FederationMediaDownloadServlet"
or func.__self__.__class__.__name__ # type: ignore
== "FederationMediaThumbnailServlet"
): ):
response = await func( response = await func(
origin, content, request, *args, **kwargs origin, content, request, *args, **kwargs

View File

@ -46,11 +46,13 @@ from synapse.http.servlet import (
parse_boolean_from_args, parse_boolean_from_args,
parse_integer, parse_integer,
parse_integer_from_args, parse_integer_from_args,
parse_string,
parse_string_from_args, parse_string_from_args,
parse_strings_from_args, parse_strings_from_args,
) )
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS from synapse.media._base import DEFAULT_MAX_TIMEOUT_MS, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import SYNAPSE_VERSION from synapse.util import SYNAPSE_VERSION
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
@ -826,6 +828,59 @@ class FederationMediaDownloadServlet(BaseFederationServerServlet):
) )
class FederationMediaThumbnailServlet(BaseFederationServerServlet):
"""
Implementation of new federation media `/thumbnail` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media
item. This endpoint only returns local media.
"""
PATH = "/media/thumbnail/(?P<media_id>[^/]*)"
RATELIMIT = True
def __init__(
self,
hs: "HomeServer",
ratelimiter: FederationRateLimiter,
authenticator: Authenticator,
server_name: str,
):
super().__init__(hs, authenticator, ratelimiter, server_name)
self.media_repo = self.hs.get_media_repository()
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.thumbnail_provider = ThumbnailProvider(
hs, self.media_repo, self.media_repo.media_storage
)
async def on_GET(
self,
origin: Optional[str],
content: Literal[None],
request: SynapseRequest,
media_id: str,
) -> None:
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self.dynamic_thumbnails:
await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms, True
)
else:
await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms, True
)
self.media_repo.mark_recently_accessed(None, media_id)
FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = ( FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationSendServlet, FederationSendServlet,
FederationEventServlet, FederationEventServlet,
@ -858,4 +913,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationMakeKnockServlet, FederationMakeKnockServlet,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationMediaDownloadServlet, FederationMediaDownloadServlet,
FederationMediaThumbnailServlet,
) )

View File

@ -542,7 +542,12 @@ class MediaRepository:
respond_404(request) respond_404(request)
async def get_remote_media_info( async def get_remote_media_info(
self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str self,
server_name: str,
media_id: str,
max_timeout_ms: int,
ip_address: str,
use_federation: bool,
) -> RemoteMedia: ) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
@ -553,6 +558,8 @@ class MediaRepository:
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
ip_address: IP address of the requester ip_address: IP address of the requester
use_federation: if a download is necessary, whether to request the remote file
over the federation `/download` endpoint
Returns: Returns:
The media info of the file The media info of the file
@ -573,7 +580,7 @@ class MediaRepository:
max_timeout_ms, max_timeout_ms,
self.download_ratelimiter, self.download_ratelimiter,
ip_address, ip_address,
False, use_federation,
) )
# Ensure we actually use the responder so that it releases resources # Ensure we actually use the responder so that it releases resources

View File

@ -36,9 +36,11 @@ from synapse.media._base import (
ThumbnailInfo, ThumbnailInfo,
respond_404, respond_404,
respond_with_file, respond_with_file,
respond_with_multipart_responder,
respond_with_responder, respond_with_responder,
) )
from synapse.media.media_storage import MediaStorage from synapse.media.media_storage import FileResponder, MediaStorage
from synapse.storage.databases.main.media_repository import LocalMedia
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository from synapse.media.media_repository import MediaRepository
@ -271,6 +273,7 @@ class ThumbnailProvider:
method: str, method: str,
m_type: str, m_type: str,
max_timeout_ms: int, max_timeout_ms: int,
for_federation: bool,
) -> None: ) -> None:
media_info = await self.media_repo.get_local_media_info( media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms request, media_id, max_timeout_ms
@ -290,6 +293,8 @@ class ThumbnailProvider:
media_id, media_id,
url_cache=bool(media_info.url_cache), url_cache=bool(media_info.url_cache),
server_name=None, server_name=None,
for_federation=for_federation,
media_info=media_info,
) )
async def select_or_generate_local_thumbnail( async def select_or_generate_local_thumbnail(
@ -301,6 +306,7 @@ class ThumbnailProvider:
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
max_timeout_ms: int, max_timeout_ms: int,
for_federation: bool,
) -> None: ) -> None:
media_info = await self.media_repo.get_local_media_info( media_info = await self.media_repo.get_local_media_info(
request, media_id, max_timeout_ms request, media_id, max_timeout_ms
@ -326,10 +332,16 @@ class ThumbnailProvider:
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
if responder: if responder:
await respond_with_responder( if for_federation:
request, responder, info.type, info.length await respond_with_multipart_responder(
) self.hs.get_clock(), request, responder, media_info
return )
return
else:
await respond_with_responder(
request, responder, info.type, info.length
)
return
logger.debug("We don't have a thumbnail of that size. Generating") logger.debug("We don't have a thumbnail of that size. Generating")
@ -344,7 +356,15 @@ class ThumbnailProvider:
) )
if file_path: if file_path:
await respond_with_file(request, desired_type, file_path) if for_federation:
await respond_with_multipart_responder(
self.hs.get_clock(),
request,
FileResponder(open(file_path, "rb")),
media_info,
)
else:
await respond_with_file(request, desired_type, file_path)
else: else:
logger.warning("Failed to generate thumbnail") logger.warning("Failed to generate thumbnail")
raise SynapseError(400, "Failed to generate thumbnail.") raise SynapseError(400, "Failed to generate thumbnail.")
@ -360,9 +380,10 @@ class ThumbnailProvider:
desired_type: str, desired_type: str,
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str, ip_address: str,
use_federation: bool,
) -> None: ) -> None:
media_info = await self.media_repo.get_remote_media_info( media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms, ip_address server_name, media_id, max_timeout_ms, ip_address, use_federation
) )
if not media_info: if not media_info:
respond_404(request) respond_404(request)
@ -424,12 +445,13 @@ class ThumbnailProvider:
m_type: str, m_type: str,
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str, ip_address: str,
use_federation: bool,
) -> None: ) -> None:
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of # We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails. # downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info( media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms, ip_address server_name, media_id, max_timeout_ms, ip_address, use_federation
) )
if not media_info: if not media_info:
return return
@ -448,6 +470,7 @@ class ThumbnailProvider:
media_info.filesystem_id, media_info.filesystem_id,
url_cache=False, url_cache=False,
server_name=server_name, server_name=server_name,
for_federation=False,
) )
async def _select_and_respond_with_thumbnail( async def _select_and_respond_with_thumbnail(
@ -461,7 +484,9 @@ class ThumbnailProvider:
media_id: str, media_id: str,
file_id: str, file_id: str,
url_cache: bool, url_cache: bool,
for_federation: bool,
server_name: Optional[str] = None, server_name: Optional[str] = None,
media_info: Optional[LocalMedia] = None,
) -> None: ) -> None:
""" """
Respond to a request with an appropriate thumbnail from the previously generated thumbnails. Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
@ -476,6 +501,8 @@ class ThumbnailProvider:
file_id: The ID of the media that a thumbnail is being requested for. file_id: The ID of the media that a thumbnail is being requested for.
url_cache: True if this is from a URL cache. url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail. server_name: The server name, if this is a remote thumbnail.
for_federation: whether the request is from the federation /thumbnail request
media_info: metadata about the media being requested.
""" """
logger.debug( logger.debug(
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s", "_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
@ -511,13 +538,20 @@ class ThumbnailProvider:
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
if responder: if responder:
await respond_with_responder( if for_federation:
request, assert media_info is not None
responder, await respond_with_multipart_responder(
file_info.thumbnail.type, self.hs.get_clock(), request, responder, media_info
file_info.thumbnail.length, )
) return
return else:
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
return
# If we can't find the thumbnail we regenerate it. This can happen # If we can't find the thumbnail we regenerate it. This can happen
# if e.g. we've deleted the thumbnails but still have the original # if e.g. we've deleted the thumbnails but still have the original
@ -558,12 +592,18 @@ class ThumbnailProvider:
) )
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder( if for_federation:
request, assert media_info is not None
responder, await respond_with_multipart_responder(
file_info.thumbnail.type, self.hs.get_clock(), request, responder, media_info
file_info.thumbnail.length, )
) else:
await respond_with_responder(
request,
responder,
file_info.thumbnail.type,
file_info.thumbnail.length,
)
else: else:
# This might be because: # This might be because:
# 1. We can't create thumbnails for the given media (corrupted or # 1. We can't create thumbnails for the given media (corrupted or

View File

@ -47,7 +47,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnstablePreviewURLServlet(RestServlet): class PreviewURLServlet(RestServlet):
""" """
Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API Same as `GET /_matrix/media/r0/preview_url`, this endpoint provides a generic preview API
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
@ -65,9 +65,7 @@ class UnstablePreviewURLServlet(RestServlet):
* Matrix cannot be used to distribute the metadata between homeservers. * Matrix cannot be used to distribute the metadata between homeservers.
""" """
PATTERNS = [ PATTERNS = [re.compile(r"^/_matrix/client/v1/media/preview_url$")]
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/preview_url$")
]
def __init__( def __init__(
self, self,
@ -95,10 +93,8 @@ class UnstablePreviewURLServlet(RestServlet):
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)
class UnstableMediaConfigResource(RestServlet): class MediaConfigResource(RestServlet):
PATTERNS = [ PATTERNS = [re.compile(r"^/_matrix/client/v1/media/config$")]
re.compile(r"^/_matrix/client/unstable/org.matrix.msc3916/media/config$")
]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -112,10 +108,10 @@ class UnstableMediaConfigResource(RestServlet):
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
class UnstableThumbnailResource(RestServlet): class ThumbnailResource(RestServlet):
PATTERNS = [ PATTERNS = [
re.compile( re.compile(
"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$" "/_matrix/client/v1/media/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
) )
] ]
@ -159,11 +155,25 @@ class UnstableThumbnailResource(RestServlet):
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
await self.thumbnailer.select_or_generate_local_thumbnail( await self.thumbnailer.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms request,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
False,
) )
else: else:
await self.thumbnailer.respond_local_thumbnail( await self.thumbnailer.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms request,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
False,
) )
self.media_repo.mark_recently_accessed(None, media_id) self.media_repo.mark_recently_accessed(None, media_id)
else: else:
@ -191,6 +201,7 @@ class UnstableThumbnailResource(RestServlet):
m_type, m_type,
max_timeout_ms, max_timeout_ms,
ip_address, ip_address,
True,
) )
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)
@ -260,11 +271,9 @@ class DownloadResource(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
media_repo = hs.get_media_repository() media_repo = hs.get_media_repository()
if hs.config.media.url_preview_enabled: if hs.config.media.url_preview_enabled:
UnstablePreviewURLServlet(hs, media_repo, media_repo.media_storage).register( PreviewURLServlet(hs, media_repo, media_repo.media_storage).register(
http_server http_server
) )
UnstableMediaConfigResource(hs).register(http_server) MediaConfigResource(hs).register(http_server)
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register( ThumbnailResource(hs, media_repo, media_repo.media_storage).register(http_server)
http_server
)
DownloadResource(hs, media_repo).register(http_server) DownloadResource(hs, media_repo).register(http_server)

View File

@ -88,11 +88,25 @@ class ThumbnailResource(RestServlet):
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails: if self.dynamic_thumbnails:
await self.thumbnail_provider.select_or_generate_local_thumbnail( await self.thumbnail_provider.select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms request,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
False,
) )
else: else:
await self.thumbnail_provider.respond_local_thumbnail( await self.thumbnail_provider.respond_local_thumbnail(
request, media_id, width, height, method, m_type, max_timeout_ms request,
media_id,
width,
height,
method,
m_type,
max_timeout_ms,
False,
) )
self.media_repo.mark_recently_accessed(None, media_id) self.media_repo.mark_recently_accessed(None, media_id)
else: else:
@ -120,5 +134,6 @@ class ThumbnailResource(RestServlet):
m_type, m_type,
max_timeout_ms, max_timeout_ms,
ip_address, ip_address,
False,
) )
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)

View File

@ -35,6 +35,7 @@ from synapse.types import UserID
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.media.test_media_storage import small_png
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
@ -146,3 +147,112 @@ class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
# check that the png file exists and matches what was uploaded # check that the png file exists and matches what was uploaded
found_file = any(SMALL_PNG in field for field in stripped_bytes) found_file = any(SMALL_PNG in field for field in stripped_bytes)
self.assertTrue(found_file) self.assertTrue(found_file)
class FederationThumbnailTest(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
self.primary_base_path = os.path.join(self.test_dir, "primary")
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs.config.media.media_store_path = self.primary_base_path
storage_providers = [
StorageProviderWrapper(
FileStorageProviderBackend(hs, self.secondary_base_path),
store_local=True,
store_remote=False,
store_synchronous=True,
)
]
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
hs, self.primary_base_path, self.filepaths, storage_providers
)
self.media_repo = hs.get_media_repository()
def test_thumbnail_download_scaled(self) -> None:
content = io.BytesIO(small_png.data)
content_uri = self.get_success(
self.media_repo.create_content(
"image/png",
"test_png_thumbnail",
content,
67,
UserID.from_string("@user_id:whatever.org"),
)
)
# test with an image file
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=scale",
)
self.pump()
self.assertEqual(200, channel.code)
content_type = channel.headers.getRawHeaders("content-type")
assert content_type is not None
assert "multipart/mixed" in content_type[0]
assert "boundary" in content_type[0]
# extract boundary
boundary = content_type[0].split("boundary=")[1]
# split on boundary and check that json field and expected value exist
body = channel.result.get("body")
assert body is not None
stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
found_json = any(
b"\r\nContent-Type: application/json\r\n\r\n{}" in field
for field in stripped_bytes
)
self.assertTrue(found_json)
# check that the png file exists and matches the expected scaled bytes
found_file = any(small_png.expected_scaled in field for field in stripped_bytes)
self.assertTrue(found_file)
def test_thumbnail_download_cropped(self) -> None:
content = io.BytesIO(small_png.data)
content_uri = self.get_success(
self.media_repo.create_content(
"image/png",
"test_png_thumbnail",
content,
67,
UserID.from_string("@user_id:whatever.org"),
)
)
# test with an image file
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/media/thumbnail/{content_uri.media_id}?width=32&height=32&method=crop",
)
self.pump()
self.assertEqual(200, channel.code)
content_type = channel.headers.getRawHeaders("content-type")
assert content_type is not None
assert "multipart/mixed" in content_type[0]
assert "boundary" in content_type[0]
# extract boundary
boundary = content_type[0].split("boundary=")[1]
# split on boundary and check that json field and expected value exist
body = channel.result.get("body")
assert body is not None
stripped_bytes = body.split(b"\r\n" + b"--" + boundary.encode("utf-8"))
found_json = any(
b"\r\nContent-Type: application/json\r\n\r\n{}" in field
for field in stripped_bytes
)
self.assertTrue(found_json)
# check that the png file exists and matches the expected cropped bytes
found_file = any(
small_png.expected_cropped in field for field in stripped_bytes
)
self.assertTrue(found_file)

View File

@ -18,7 +18,6 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
import itertools
import os import os
import shutil import shutil
import tempfile import tempfile
@ -227,19 +226,15 @@ test_images = [
empty_file, empty_file,
SVG, SVG,
] ]
urls = [ input_values = [(x,) for x in test_images]
"_matrix/media/r0/thumbnail",
"_matrix/client/unstable/org.matrix.msc3916/media/thumbnail",
]
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls)) @parameterized_class(("test_image",), input_values)
class MediaRepoTests(unittest.HomeserverTestCase): class MediaRepoTests(unittest.HomeserverTestCase):
servlets = [media.register_servlets] servlets = [media.register_servlets]
test_image: ClassVar[TestImage] test_image: ClassVar[TestImage]
hijack_auth = True hijack_auth = True
user_id = "@test:user" user_id = "@test:user"
url: ClassVar[str]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[ self.fetches: List[
@ -304,7 +299,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path}, "config": {"directory": self.storage_path},
} }
config["media_storage_providers"] = [provider_config] config["media_storage_providers"] = [provider_config]
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
hs = self.setup_test_homeserver(config=config, federation_http_client=client) hs = self.setup_test_homeserver(config=config, federation_http_client=client)
@ -509,7 +503,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=scale" params = "?width=32&height=32&method=scale"
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/{self.url}/{self.media_id}{params}", f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -537,7 +531,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/{self.url}/{self.media_id}{params}", f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -573,7 +567,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/{self.url}/{self.media_id}{params}", f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -608,7 +602,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body, channel.json_body,
{ {
"errcode": "M_UNKNOWN", "errcode": "M_UNKNOWN",
"error": f"Cannot find any thumbnails for the requested media ('/{self.url}/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
}, },
) )
else: else:
@ -618,7 +612,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body, channel.json_body,
{ {
"errcode": "M_NOT_FOUND", "errcode": "M_NOT_FOUND",
"error": f"Not found '/{self.url}/example.com/12345'", "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
}, },
) )

View File

@ -23,12 +23,15 @@ import io
import json import json
import os import os
import re import re
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type import shutil
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
from urllib import parse from urllib import parse
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
from parameterized import parameterized_class from parameterized import parameterized, parameterized_class
from PIL import Image as Image
from typing_extensions import ClassVar
from twisted.internet import defer from twisted.internet import defer
from twisted.internet._resolver import HostResolution from twisted.internet._resolver import HostResolution
@ -40,7 +43,6 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
@ -48,7 +50,8 @@ from synapse.config.oembed import OEmbedEndpointConfig
from synapse.http.client import MultipartResponse from synapse.http.client import MultipartResponse
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo from synapse.media._base import FileInfo, ThumbnailInfo
from synapse.media.thumbnailer import ThumbnailProvider
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, media from synapse.rest.client import login, media
@ -76,7 +79,7 @@ except ImportError:
lxml = None # type: ignore[assignment] lxml = None # type: ignore[assignment]
class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase): class MediaDomainBlockingTests(unittest.HomeserverTestCase):
remote_media_id = "doesnotmatter" remote_media_id = "doesnotmatter"
remote_server_name = "evil.com" remote_server_name = "evil.com"
servlets = [ servlets = [
@ -144,7 +147,6 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
# Should result in a 404. # Should result in a 404.
"prevent_media_downloads_from": ["evil.com"], "prevent_media_downloads_from": ["evil.com"],
"dynamic_thumbnails": True, "dynamic_thumbnails": True,
"experimental_features": {"msc3916_authenticated_media_enabled": True},
} }
) )
def test_cannot_download_blocked_media_thumbnail(self) -> None: def test_cannot_download_blocked_media_thumbnail(self) -> None:
@ -153,7 +155,7 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
""" """
response = self.make_request( response = self.make_request(
"GET", "GET",
f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100", f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
shorthand=False, shorthand=False,
content={"width": 100, "height": 100}, content={"width": 100, "height": 100},
access_token=self.tok, access_token=self.tok,
@ -166,7 +168,6 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
# This proves we haven't broken anything. # This proves we haven't broken anything.
"prevent_media_downloads_from": ["not-listed.com"], "prevent_media_downloads_from": ["not-listed.com"],
"dynamic_thumbnails": True, "dynamic_thumbnails": True,
"experimental_features": {"msc3916_authenticated_media_enabled": True},
} }
) )
def test_remote_media_thumbnail_normally_unblocked(self) -> None: def test_remote_media_thumbnail_normally_unblocked(self) -> None:
@ -175,14 +176,14 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
""" """
response = self.make_request( response = self.make_request(
"GET", "GET",
f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100", f"/_matrix/client/v1/media/thumbnail/evil.com/{self.remote_media_id}?width=100&height=100",
shorthand=False, shorthand=False,
access_token=self.tok, access_token=self.tok,
) )
self.assertEqual(response.code, 200) self.assertEqual(response.code, 200)
class UnstableURLPreviewTests(unittest.HomeserverTestCase): class URLPreviewTests(unittest.HomeserverTestCase):
if not lxml: if not lxml:
skip = "url preview feature requires lxml" skip = "url preview feature requires lxml"
@ -198,7 +199,6 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config() config = self.default_config()
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
config["url_preview_enabled"] = True config["url_preview_enabled"] = True
config["max_spider_size"] = 9999999 config["max_spider_size"] = 9999999
config["url_preview_ip_range_blacklist"] = ( config["url_preview_ip_range_blacklist"] = (
@ -284,18 +284,6 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver() # type: ignore[assignment] self.reactor.nameResolver = Resolver() # type: ignore[assignment]
def create_resource_dict(self) -> Dict[str, Resource]:
"""Create a resource tree for the test server
A resource tree is a mapping from path to twisted.web.resource.
The default implementation creates a JsonResource and calls each function in
`servlets` to register servlets against it.
"""
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def _assert_small_png(self, json_body: JsonDict) -> None: def _assert_small_png(self, json_body: JsonDict) -> None:
"""Assert properties from the SMALL_PNG test image.""" """Assert properties from the SMALL_PNG test image."""
self.assertTrue(json_body["og:image"].startswith("mxc://")) self.assertTrue(json_body["og:image"].startswith("mxc://"))
@ -309,7 +297,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -334,7 +322,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
# Check the cache returns the correct response # Check the cache returns the correct response
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
) )
@ -352,7 +340,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
# Check the database cache returns the correct response # Check the database cache returns the correct response
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
) )
@ -375,7 +363,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -405,7 +393,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -441,7 +429,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -482,7 +470,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -517,7 +505,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -550,7 +538,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -580,7 +568,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
) )
@ -603,7 +591,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
) )
@ -622,7 +610,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
""" """
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://192.168.1.1", "/_matrix/client/v1/media/preview_url?url=http://192.168.1.1",
shorthand=False, shorthand=False,
) )
@ -640,7 +628,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
""" """
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://1.1.1.2", "/_matrix/client/v1/media/preview_url?url=http://1.1.1.2",
shorthand=False, shorthand=False,
) )
@ -659,7 +647,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -696,7 +684,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 502) self.assertEqual(channel.code, 502)
@ -718,7 +706,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
) )
@ -741,7 +729,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
) )
@ -760,7 +748,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
""" """
channel = self.make_request( channel = self.make_request(
"OPTIONS", "OPTIONS",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
) )
self.assertEqual(channel.code, 204) self.assertEqual(channel.code, 204)
@ -774,7 +762,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
# Build and make a request to the server # Build and make a request to the server
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://example.com", "/_matrix/client/v1/media/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -827,7 +815,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -877,7 +865,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -919,7 +907,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -959,7 +947,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1000,7 +988,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?{query_params}", f"/_matrix/client/v1/media/preview_url?{query_params}",
shorthand=False, shorthand=False,
) )
self.pump() self.pump()
@ -1021,7 +1009,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://matrix.org", "/_matrix/client/v1/media/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1058,7 +1046,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345", "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1118,7 +1106,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345", "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1167,7 +1155,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.hulu.com/watch/12345", "/_matrix/client/v1/media/preview_url?url=http://www.hulu.com/watch/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1212,7 +1200,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345", "/_matrix/client/v1/media/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1241,7 +1229,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1333,7 +1321,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", "/_matrix/client/v1/media/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1374,7 +1362,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=http://cdn.twitter.com/matrixdotorg", "/_matrix/client/v1/media/preview_url?url=http://cdn.twitter.com/matrixdotorg",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1416,7 +1404,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
# Check fetching # Check fetching
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/_matrix/media/v3/download/{host}/{media_id}", f"/_matrix/client/v1/media/download/{host}/{media_id}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1429,7 +1417,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/_matrix/media/v3/download/{host}/{media_id}", f"/_matrix/client/v1/download/{host}/{media_id}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1464,7 +1452,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
# Check fetching # Check fetching
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1482,7 +1470,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"/_matrix/client/unstable/org.matrix.msc3916/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", f"/_matrix/client/v1/media/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1532,8 +1520,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=" "/_matrix/client/v1/media/preview_url?url=" + bad_url,
+ bad_url,
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1542,8 +1529,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=" "/_matrix/client/v1/media/preview_url?url=" + good_url,
+ good_url,
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1575,8 +1561,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/preview_url?url=" "/_matrix/client/v1/media/preview_url?url=" + bad_url,
+ bad_url,
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1584,7 +1569,7 @@ class UnstableURLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 403, channel.result) self.assertEqual(channel.code, 403, channel.result)
class UnstableMediaConfigTest(unittest.HomeserverTestCase): class MediaConfigTest(unittest.HomeserverTestCase):
servlets = [ servlets = [
media.register_servlets, media.register_servlets,
admin.register_servlets, admin.register_servlets,
@ -1595,7 +1580,6 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
self, reactor: ThreadedMemoryReactorClock, clock: Clock self, reactor: ThreadedMemoryReactorClock, clock: Clock
) -> HomeServer: ) -> HomeServer:
config = self.default_config() config = self.default_config()
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
self.storage_path = self.mktemp() self.storage_path = self.mktemp()
self.media_store_path = self.mktemp() self.media_store_path = self.mktemp()
@ -1622,7 +1606,7 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
def test_media_config(self) -> None: def test_media_config(self) -> None:
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"/_matrix/client/unstable/org.matrix.msc3916/media/config", "/_matrix/client/v1/media/config",
shorthand=False, shorthand=False,
access_token=self.tok, access_token=self.tok,
) )
@ -1899,7 +1883,7 @@ input_values = [(x,) for x in test_images]
@parameterized_class(("test_image",), input_values) @parameterized_class(("test_image",), input_values)
class DownloadTestCase(unittest.HomeserverTestCase): class DownloadAndThumbnailTestCase(unittest.HomeserverTestCase):
test_image: ClassVar[TestImage] test_image: ClassVar[TestImage]
servlets = [ servlets = [
media.register_servlets, media.register_servlets,
@ -2005,7 +1989,6 @@ class DownloadTestCase(unittest.HomeserverTestCase):
"config": {"directory": self.storage_path}, "config": {"directory": self.storage_path},
} }
config["media_storage_providers"] = [provider_config] config["media_storage_providers"] = [provider_config]
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
hs = self.setup_test_homeserver(config=config, federation_http_client=client) hs = self.setup_test_homeserver(config=config, federation_http_client=client)
@ -2164,7 +2147,7 @@ class DownloadTestCase(unittest.HomeserverTestCase):
def test_unknown_federation_endpoint(self) -> None: def test_unknown_federation_endpoint(self) -> None:
""" """
Test that if the downloadd request to remote federation endpoint returns a 404 Test that if the download request to remote federation endpoint returns a 404
we fall back to the _matrix/media endpoint we fall back to the _matrix/media endpoint
""" """
channel = self.make_request( channel = self.make_request(
@ -2210,3 +2193,236 @@ class DownloadTestCase(unittest.HomeserverTestCase):
self.pump() self.pump()
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 200)
def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop",
self.test_image.expected_cropped,
expected_found=self.test_image.expected_found,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale",
self.test_image.expected_scaled,
expected_found=self.test_image.expected_found,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
self._test_thumbnail(
"invalid",
None,
expected_found=False,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
def test_no_thumbnail_crop(self) -> None:
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
self._test_thumbnail(
"crop",
None,
expected_found=False,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
def test_no_thumbnail_scale(self) -> None:
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail(
"scale",
None,
expected_found=False,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
self._test_thumbnail(
"scale",
self.test_image.expected_scaled,
expected_found=self.test_image.expected_found,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
if not self.test_image.expected_found:
return
# Fetching again should work, without re-requesting the image from the
# remote.
params = "?width=32&height=32&method=scale"
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()
self.assertEqual(channel.code, 200)
if self.test_image.expected_scaled:
self.assertEqual(
channel.result["body"],
self.test_image.expected_scaled,
channel.result["body"],
)
# Deleting the thumbnail on disk then re-requesting it should work as
# Synapse should regenerate missing thumbnails.
info = self.get_success(
self.store.get_cached_remote_media(self.remote, self.media_id)
)
assert info is not None
file_id = info.filesystem_id
thumbnail_dir = self.media_repo.filepaths.remote_media_thumbnail_dir(
self.remote, file_id
)
shutil.rmtree(thumbnail_dir, ignore_errors=True)
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()
self.assertEqual(channel.code, 200)
if self.test_image.expected_scaled:
self.assertEqual(
channel.result["body"],
self.test_image.expected_scaled,
channel.result["body"],
)
def _test_thumbnail(
self,
method: str,
expected_body: Optional[bytes],
expected_found: bool,
unable_to_thumbnail: bool = False,
) -> None:
"""Test the given thumbnailing method works as expected.
Args:
method: The thumbnailing method to use (crop, scale).
expected_body: The expected bytes from thumbnailing, or None if
test should just check for a valid image.
expected_found: True if the file should exist on the server, or False if
a 404/400 is expected.
unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
False if the thumbnailing should succeed or a normal 404 is expected.
"""
params = "?width=32&height=32&method=" + method
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/thumbnail/{self.remote}/{self.media_id}{params}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
b"Content-Type": [self.test_image.content_type],
}
self.fetches[0][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
if expected_found:
self.assertEqual(channel.code, 200)
self.assertEqual(
channel.headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
[b"cross-origin"],
)
if expected_body is not None:
self.assertEqual(
channel.result["body"], expected_body, channel.result["body"]
)
else:
# ensure that the result is at least some valid image
Image.open(io.BytesIO(channel.result["body"]))
elif unable_to_thumbnail:
# A 400 with a JSON body.
self.assertEqual(channel.code, 400)
self.assertEqual(
channel.json_body,
{
"errcode": "M_UNKNOWN",
"error": "Cannot find any thumbnails for the requested media ('/_matrix/client/v1/media/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
},
)
else:
# A 404 with a JSON body.
self.assertEqual(channel.code, 404)
self.assertEqual(
channel.json_body,
{
"errcode": "M_NOT_FOUND",
"error": "Not found '/_matrix/client/v1/media/thumbnail/example.com/12345'",
},
)
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
def test_same_quality(self, method: str, desired_size: int) -> None:
"""Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen."""
content_type = self.test_image.content_type.decode()
media_repo = self.hs.get_media_repository()
thumbnail_provider = ThumbnailProvider(
self.hs, media_repo, media_repo.media_storage
)
self.assertIsNotNone(
thumbnail_provider._select_thumbnail(
desired_width=desired_size,
desired_height=desired_size,
desired_method=method,
desired_type=content_type,
# Provide two identical thumbnails which are guaranteed to have the same
# quality rating.
thumbnail_infos=[
ThumbnailInfo(
width=32,
height=32,
method=method,
type=content_type,
length=256,
),
ThumbnailInfo(
width=32,
height=32,
method=method,
type=content_type,
length=256,
),
],
file_id=f"image{self.test_image.extension.decode()}",
url_cache=False,
server_name=None,
)
)