Register media servlets via regex. (#16419)

This converts the media servlet URLs in the same way as
(most) of the rest of Synapse. This will give more flexibility
in the versions each endpoint exists under.
This commit is contained in:
Patrick Cloke 2023-10-06 07:22:55 -04:00 committed by GitHub
parent 5946074d69
commit 26b960b08b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 297 additions and 337 deletions

1
changelog.d/16419.misc Normal file
View File

@ -0,0 +1 @@
Update registration of media repository URLs.

View File

@ -266,7 +266,7 @@ class HttpServer(Protocol):
def register_paths( def register_paths(
self, self,
method: str, method: str,
path_patterns: Iterable[Pattern], path_patterns: Iterable[Pattern[str]],
callback: ServletCallback, callback: ServletCallback,
servlet_classname: str, servlet_classname: str,
) -> None: ) -> None:

View File

@ -26,11 +26,11 @@ from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web.server import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii, parse_and_validate_server_name from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -84,52 +84,12 @@ INLINE_CONTENT_TYPES = [
] ]
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
"""Parses the server name, media ID and optional file name from the request URI
Also performs some rough validation on the server name.
Args:
request: The `Request`.
Returns:
A tuple containing the parsed server name, media ID and optional file name.
Raises:
SynapseError(404): if parsing or validation fail for any reason
"""
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath: List[bytes] = request.postpath # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name_bytes, media_id_bytes = postpath[:2]
server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8")
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
file_name = None
if len(postpath) > 2:
try:
file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except Exception:
raise SynapseError(
404, "Invalid media id token %r" % (request.postpath,), Codes.UNKNOWN
)
def respond_404(request: SynapseRequest) -> None: def respond_404(request: SynapseRequest) -> None:
assert request.path is not None
respond_with_json( respond_with_json(
request, request,
404, 404,
cs_error("Not found %r" % (request.postpath,), code=Codes.NOT_FOUND), cs_error("Not found '%s'" % (request.path.decode(),), code=Codes.NOT_FOUND),
send_cors=True, send_cors=True,
) )

View File

@ -48,6 +48,7 @@ from synapse.media.filepath import MediaFilePaths
from synapse.media.media_storage import MediaStorage from synapse.media.media_storage import MediaStorage
from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.storage_provider import StorageProviderWrapper
from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.thumbnailer import Thumbnailer, ThumbnailError
from synapse.media.url_previewer import UrlPreviewer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID from synapse.types import UserID
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -114,7 +115,7 @@ class MediaRepository:
) )
storage_providers.append(provider) storage_providers.append(provider)
self.media_storage = MediaStorage( self.media_storage: MediaStorage = MediaStorage(
self.hs, self.primary_base_path, self.filepaths, storage_providers self.hs, self.primary_base_path, self.filepaths, storage_providers
) )
@ -142,6 +143,13 @@ class MediaRepository:
MEDIA_RETENTION_CHECK_PERIOD_MS, MEDIA_RETENTION_CHECK_PERIOD_MS,
) )
if hs.config.media.url_preview_enabled:
self.url_previewer: Optional[UrlPreviewer] = UrlPreviewer(
hs, self, self.media_storage
)
else:
self.url_previewer = None
def _start_update_recently_accessed(self) -> Deferred: def _start_update_recently_accessed(self) -> Deferred:
return run_as_background_process( return run_as_background_process(
"update_recently_accessed_media", self._update_recently_accessed "update_recently_accessed_media", self._update_recently_accessed

View File

@ -14,17 +14,19 @@
# limitations under the License. # limitations under the License.
# #
import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import respond_with_json
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
class MediaConfigResource(DirectServeJsonResource): class MediaConfigResource(RestServlet):
isLeaf = True PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/config$")]
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
@ -33,9 +35,6 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.media.max_upload_size} self.limits_dict = {"m.upload.size": config.media.max_upload_size}
async def _async_render_GET(self, request: SynapseRequest) -> None: async def on_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request) await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True) respond_with_json(request, 200, self.limits_dict, send_cors=True)
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)

View File

@ -13,16 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING import re
from typing import TYPE_CHECKING, Optional
from synapse.http.server import ( from synapse.http.server import set_corp_headers, set_cors_headers
DirectServeJsonResource, from synapse.http.servlet import RestServlet, parse_boolean
set_corp_headers,
set_cors_headers,
)
from synapse.http.servlet import parse_boolean
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media._base import parse_media_id, respond_404 from synapse.media._base import respond_404
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository from synapse.media.media_repository import MediaRepository
@ -31,15 +29,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource): class DownloadResource(RestServlet):
isLeaf = True PATTERNS = [
re.compile(
"/_matrix/media/(r0|v3|v1)/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
)
]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
self.media_repo = media_repo self.media_repo = media_repo
self._is_mine_server_name = hs.is_mine_server_name self._is_mine_server_name = hs.is_mine_server_name
async def _async_render_GET(self, request: SynapseRequest) -> None: async def on_GET(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
file_name: Optional[str] = None,
) -> None:
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
set_cors_headers(request) set_cors_headers(request)
set_corp_headers(request) set_corp_headers(request)
request.setHeader( request.setHeader(
@ -58,9 +69,8 @@ class DownloadResource(DirectServeJsonResource):
b"Referrer-Policy", b"Referrer-Policy",
b"no-referrer", b"no-referrer",
) )
server_name, media_id, name = parse_media_id(request)
if self._is_mine_server_name(server_name): if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, name) await self.media_repo.get_local_media(request, media_id, file_name)
else: else:
allow_remote = parse_boolean(request, "allow_remote", default=True) allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote: if not allow_remote:
@ -72,4 +82,6 @@ class DownloadResource(DirectServeJsonResource):
respond_404(request) respond_404(request)
return return
await self.media_repo.get_remote_media(request, server_name, media_id, name) await self.media_repo.get_remote_media(
request, server_name, media_id, file_name
)

View File

@ -15,7 +15,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.http.server import UnrecognizedRequestResource from synapse.http.server import HttpServer, JsonResource
from .config_resource import MediaConfigResource from .config_resource import MediaConfigResource
from .download_resource import DownloadResource from .download_resource import DownloadResource
@ -27,7 +27,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
class MediaRepositoryResource(UnrecognizedRequestResource): class MediaRepositoryResource(JsonResource):
"""File uploading and downloading. """File uploading and downloading.
Uploads are POSTed to a resource which returns a token which is used to GET Uploads are POSTed to a resource which returns a token which is used to GET
@ -70,6 +70,11 @@ class MediaRepositoryResource(UnrecognizedRequestResource):
width and height are close to the requested size and the aspect matches width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit the requested size. The client should scale the image if it needs to fit
within a given rectangle. within a given rectangle.
This gets mounted at various points under /_matrix/media, including:
* /_matrix/media/r0
* /_matrix/media/v1
* /_matrix/media/v3
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -77,17 +82,23 @@ class MediaRepositoryResource(UnrecognizedRequestResource):
if not hs.config.media.can_load_media_repo: if not hs.config.media.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.") raise ConfigError("Synapse is not configured to use a media repo.")
super().__init__() JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None:
media_repo = hs.get_media_repository() media_repo = hs.get_media_repository()
self.putChild(b"upload", UploadResource(hs, media_repo)) # Note that many of these should not exist as v1 endpoints, but empirically
self.putChild(b"download", DownloadResource(hs, media_repo)) # a lot of traffic still goes to them.
self.putChild(
b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) UploadResource(hs, media_repo).register(http_server)
DownloadResource(hs, media_repo).register(http_server)
ThumbnailResource(hs, media_repo, media_repo.media_storage).register(
http_server
) )
if hs.config.media.url_preview_enabled: if hs.config.media.url_preview_enabled:
self.putChild( PreviewUrlResource(hs, media_repo, media_repo.media_storage).register(
b"preview_url", http_server
PreviewUrlResource(hs, media_repo, media_repo.media_storage),
) )
self.putChild(b"config", MediaConfigResource(hs)) MediaConfigResource(hs).register(http_server)

View File

@ -13,24 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from synapse.http.server import ( from synapse.http.server import respond_with_json_bytes
DirectServeJsonResource, from synapse.http.servlet import RestServlet, parse_integer, parse_string
respond_with_json,
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media.media_storage import MediaStorage from synapse.media.media_storage import MediaStorage
from synapse.media.url_previewer import UrlPreviewer
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository from synapse.media.media_repository import MediaRepository
from synapse.server import HomeServer from synapse.server import HomeServer
class PreviewUrlResource(DirectServeJsonResource): class PreviewUrlResource(RestServlet):
""" """
The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API The `GET /_matrix/media/r0/preview_url` endpoint provides a generic preview API
for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix for URLs which outputs Open Graph (https://ogp.me/) responses (with some Matrix
@ -48,7 +44,7 @@ class PreviewUrlResource(DirectServeJsonResource):
* Matrix cannot be used to distribute the metadata between homeservers. * Matrix cannot be used to distribute the metadata between homeservers.
""" """
isLeaf = True PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/preview_url$")]
def __init__( def __init__(
self, self,
@ -62,14 +58,10 @@ class PreviewUrlResource(DirectServeJsonResource):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.media_repo = media_repo self.media_repo = media_repo
self.media_storage = media_storage self.media_storage = media_storage
assert self.media_repo.url_previewer is not None
self.url_previewer = self.media_repo.url_previewer
self._url_previewer = UrlPreviewer(hs, media_repo, media_storage) async def on_GET(self, request: SynapseRequest) -> None:
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url", required=True) url = parse_string(request, "url", required=True)
@ -77,5 +69,5 @@ class PreviewUrlResource(DirectServeJsonResource):
if ts is None: if ts is None:
ts = self.clock.time_msec() ts = self.clock.time_msec()
og = await self._url_previewer.preview(url, requester.user, ts) og = await self.url_previewer.preview(url, requester.user, ts)
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)

View File

@ -13,29 +13,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.http.server import ( from synapse.http.server import respond_with_json, set_corp_headers, set_cors_headers
DirectServeJsonResource, from synapse.http.servlet import RestServlet, parse_integer, parse_string
respond_with_json,
set_corp_headers,
set_cors_headers,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media._base import ( from synapse.media._base import (
FileInfo, FileInfo,
ThumbnailInfo, ThumbnailInfo,
parse_media_id,
respond_404, respond_404,
respond_with_file, respond_with_file,
respond_with_responder, respond_with_responder,
) )
from synapse.media.media_storage import MediaStorage from synapse.media.media_storage import MediaStorage
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.media.media_repository import MediaRepository from synapse.media.media_repository import MediaRepository
@ -44,8 +39,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource): class ThumbnailResource(RestServlet):
isLeaf = True PATTERNS = [
re.compile(
"/_matrix/media/(r0|v3|v1)/thumbnail/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
)
]
def __init__( def __init__(
self, self,
@ -60,12 +59,17 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_storage = media_storage self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self._is_mine_server_name = hs.is_mine_server_name self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
async def _async_render_GET(self, request: SynapseRequest) -> None: async def on_GET(
self, request: SynapseRequest, server_name: str, media_id: str
) -> None:
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
set_cors_headers(request) set_cors_headers(request)
set_corp_headers(request) set_corp_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True) width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True) height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")
@ -418,13 +422,14 @@ class ThumbnailResource(DirectServeJsonResource):
# `dynamic_thumbnails` is disabled. # `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails") logger.info("Failed to find any generated thumbnails")
assert request.path is not None
respond_with_json( respond_with_json(
request, request,
400, 400,
cs_error( cs_error(
"Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)" "Cannot find any thumbnails for the requested media ('%s'). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% ( % (
request.postpath, request.path.decode(),
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()), ", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
), ),
code=Codes.UNKNOWN, code=Codes.UNKNOWN,

View File

@ -14,11 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import IO, TYPE_CHECKING, Dict, List, Optional from typing import IO, TYPE_CHECKING, Dict, List, Optional
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import respond_with_json
from synapse.http.servlet import parse_bytes_from_args from synapse.http.servlet import RestServlet, parse_bytes_from_args
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.media.media_storage import SpamMediaException from synapse.media.media_storage import SpamMediaException
@ -29,8 +30,8 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource): class UploadResource(RestServlet):
isLeaf = True PATTERNS = [re.compile("/_matrix/media/(r0|v3|v1)/upload")]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"): def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__() super().__init__()
@ -43,10 +44,7 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.media.max_upload_size self.max_upload_size = hs.config.media.max_upload_size
self.clock = hs.get_clock() self.clock = hs.get_clock()
async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: async def on_POST(self, request: SynapseRequest) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
raw_content_length = request.getHeader("Content-Length") raw_content_length = request.getHeader("Content-Length")
if raw_content_length is None: if raw_content_length is None:

View File

@ -28,6 +28,7 @@ from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.api.errors import Codes from synapse.api.errors import Codes
from synapse.events import EventBase from synapse.events import EventBase
@ -41,12 +42,13 @@ from synapse.module_api import ModuleApi
from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login
from synapse.rest.media.thumbnail_resource import ThumbnailResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias from synapse.types import JsonDict, RoomAlias
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeChannel, FakeSite, make_request from tests.server import FakeChannel
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
from tests.utils import default_config from tests.utils import default_config
@ -288,22 +290,22 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
self.thumbnail_resource = media_resource.children[b"thumbnail"]
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.media_repo = hs.get_media_repository() self.media_repo = hs.get_media_repository()
self.media_id = "example.com/12345" self.media_id = "example.com/12345"
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def _req( def _req(
self, content_disposition: Optional[bytes], include_content_type: bool = True self, content_disposition: Optional[bytes], include_content_type: bool = True
) -> FakeChannel: ) -> FakeChannel:
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
"GET", "GET",
self.media_id, f"/_matrix/media/v3/download/{self.media_id}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -481,11 +483,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
# Fetching again should work, without re-requesting the image from the # Fetching again should work, without re-requesting the image from the
# remote. # remote.
params = "?width=32&height=32&method=scale" params = "?width=32&height=32&method=scale"
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -511,11 +511,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
) )
shutil.rmtree(thumbnail_dir, ignore_errors=True) shutil.rmtree(thumbnail_dir, ignore_errors=True)
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, f"/_matrix/media/v3/thumbnail/{self.media_id}{params}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -549,11 +547,9 @@ class MediaRepoTests(unittest.HomeserverTestCase):
""" """
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.thumbnail_resource, self.reactor),
"GET", "GET",
self.media_id + params, f"/_matrix/media/r0/thumbnail/{self.media_id}{params}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -590,7 +586,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body, channel.json_body,
{ {
"errcode": "M_UNKNOWN", "errcode": "M_UNKNOWN",
"error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
}, },
) )
else: else:
@ -600,7 +596,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.json_body, channel.json_body,
{ {
"errcode": "M_NOT_FOUND", "errcode": "M_NOT_FOUND",
"error": "Not found [b'example.com', b'12345']", "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'",
}, },
) )
@ -609,12 +605,17 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""Test that choosing between thumbnails with the same quality rating succeeds. """Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen.""" We are not particular about which thumbnail is chosen."""
media_repo = self.hs.get_media_repository()
thumbnail_resouce = ThumbnailResource(
self.hs, media_repo, media_repo.media_storage
)
self.assertIsNotNone( self.assertIsNotNone(
self.thumbnail_resource._select_thumbnail( thumbnail_resouce._select_thumbnail(
desired_width=desired_size, desired_width=desired_size,
desired_height=desired_size, desired_height=desired_size,
desired_method=method, desired_method=method,
desired_type=self.test_image.content_type, desired_type=self.test_image.content_type, # type: ignore[arg-type]
# Provide two identical thumbnails which are guaranteed to have the same # Provide two identical thumbnails which are guaranteed to have the same
# quality rating. # quality rating.
thumbnail_infos=[ thumbnail_infos=[
@ -636,7 +637,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
}, },
], ],
file_id=f"image{self.test_image.extension.decode()}", file_id=f"image{self.test_image.extension.decode()}",
url_cache=None, url_cache=False,
server_name=None, server_name=None,
) )
) )
@ -725,13 +726,13 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
self.user = self.register_user("user", "pass") self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass") self.tok = self.login("user", "pass")
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
load_legacy_spam_checkers(hs) load_legacy_spam_checkers(hs)
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def default_config(self) -> Dict[str, Any]: def default_config(self) -> Dict[str, Any]:
config = default_config("test") config = default_config("test")
@ -751,9 +752,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
def test_upload_innocent(self) -> None: def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed.""" """Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media( self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
def test_upload_ban(self) -> None: def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should """Attempt to upload some data that includes bytes "evil", which should
@ -762,9 +761,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase):
data = b"Some evil data" data = b"Some evil data"
self.helper.upload_media( self.helper.upload_media(data, tok=self.tok, expect_code=400)
self.upload_resource, data, tok=self.tok, expect_code=400
)
EVIL_DATA = b"Some evil data" EVIL_DATA = b"Some evil data"
@ -781,15 +778,15 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "pass") self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass") self.tok = self.login("user", "pass")
# Allow for uploading and downloading to/from the media repo
self.media_repo = hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"]
self.upload_resource = self.media_repo.children[b"upload"]
hs.get_module_api().register_spam_checker_callbacks( hs.get_module_api().register_spam_checker_callbacks(
check_media_file_for_spam=self.check_media_file_for_spam check_media_file_for_spam=self.check_media_file_for_spam
) )
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
async def check_media_file_for_spam( async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]:
@ -805,21 +802,16 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
def test_upload_innocent(self) -> None: def test_upload_innocent(self) -> None:
"""Attempt to upload some innocent data that should be allowed.""" """Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media( self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200)
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
def test_upload_ban(self) -> None: def test_upload_ban(self) -> None:
"""Attempt to upload some data that includes bytes "evil", which should """Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker. get rejected by the spam checker.
""" """
self.helper.upload_media( self.helper.upload_media(EVIL_DATA, tok=self.tok, expect_code=400)
self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400
)
self.helper.upload_media( self.helper.upload_media(
self.upload_resource,
EVIL_DATA_EXPERIMENT, EVIL_DATA_EXPERIMENT,
tok=self.tok, tok=self.tok,
expect_code=400, expect_code=400,

View File

@ -61,9 +61,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return self.setup_test_homeserver(config=config) return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo_resource = hs.get_media_repository_resource() media_repo = hs.get_media_repository()
preview_url = media_repo_resource.children[b"preview_url"] assert media_repo.url_previewer is not None
self.url_previewer = preview_url._url_previewer self.url_previewer = media_repo.url_previewer
def test_all_urls_allowed(self) -> None: def test_all_urls_allowed(self) -> None:
self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org")) self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org"))

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os import os
from typing import Optional, Tuple from typing import Any, Optional, Tuple
from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
from twisted.internet.protocol import Factory from twisted.internet.protocol import Factory
@ -29,7 +29,7 @@ from synapse.util import Clock
from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeSite, FakeTransport, make_request from tests.server import FakeChannel, FakeTransport, make_request
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,6 +56,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf return conf
def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
# Force the media paths onto the replication resource.
worker_hs.get_media_repository_resource().register_servlets(
self._hs_to_site[worker_hs].resource, worker_hs
)
return worker_hs
def _get_media_req( def _get_media_req(
self, hs: HomeServer, target: str, media_id: str self, hs: HomeServer, target: str, media_id: str
) -> Tuple[FakeChannel, Request]: ) -> Tuple[FakeChannel, Request]:
@ -68,12 +78,11 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
The channel for the *client* request and the *outbound* request for The channel for the *client* request and the *outbound* request for
the media which the caller should respond to. the media which the caller should respond to.
""" """
resource = hs.get_media_repository_resource().children[b"download"]
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource, self.reactor), self._hs_to_site[hs],
"GET", "GET",
f"/{target}/{media_id}", f"/_matrix/media/r0/download/{target}/{media_id}",
shorthand=False, shorthand=False,
access_token=self.access_token, access_token=self.access_token,
await_result=False, await_result=False,

View File

@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
import urllib.parse import urllib.parse
from typing import Dict
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -26,7 +28,6 @@ from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
@ -55,21 +56,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
room.register_servlets, room.register_servlets,
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def create_resource_dict(self) -> Dict[str, Resource]:
# Allow for uploading and downloading to/from the media repo resources = super().create_resource_dict()
self.media_repo = hs.get_media_repository_resource() resources["/_matrix/media"] = self.hs.get_media_repository_resource()
self.download_resource = self.media_repo.children[b"download"] return resources
self.upload_resource = self.media_repo.children[b"upload"]
def _ensure_quarantined( def _ensure_quarantined(
self, admin_user_tok: str, server_and_media_id: str self, admin_user_tok: str, server_and_media_id: str
) -> None: ) -> None:
"""Ensure a piece of media is quarantined when trying to access it.""" """Ensure a piece of media is quarantined when trying to access it."""
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_and_media_id, f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False, shorthand=False,
access_token=admin_user_tok, access_token=admin_user_tok,
) )
@ -117,20 +115,16 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user_tok = self.login("id_nonadmin", "pass") non_admin_user_tok = self.login("id_nonadmin", "pass")
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(SMALL_PNG, tok=admin_user_tok)
self.upload_resource, SMALL_PNG, tok=admin_user_tok
)
# Extract media ID from the response # Extract media ID from the response
server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://'
server_name, media_id = server_name_and_media_id.split("/") server_name, media_id = server_name_and_media_id.split("/")
# Attempt to access the media # Attempt to access the media
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_name_and_media_id, f"/_matrix/media/v3/download/{server_name_and_media_id}",
shorthand=False, shorthand=False,
access_token=non_admin_user_tok, access_token=non_admin_user_tok,
) )
@ -173,12 +167,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok) self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok)
# Upload some media # Upload some media
response_1 = self.helper.upload_media( response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
self.upload_resource, SMALL_PNG, tok=non_admin_user_tok response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
)
response_2 = self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract mxcs # Extract mxcs
mxc_1 = response_1["content_uri"] mxc_1 = response_1["content_uri"]
@ -227,12 +217,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user_tok = self.login("user_nonadmin", "pass") non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media # Upload some media
response_1 = self.helper.upload_media( response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
self.upload_resource, SMALL_PNG, tok=non_admin_user_tok response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
)
response_2 = self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract media IDs # Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:] server_and_media_id_1 = response_1["content_uri"][6:]
@ -265,12 +251,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
non_admin_user_tok = self.login("user_nonadmin", "pass") non_admin_user_tok = self.login("user_nonadmin", "pass")
# Upload some media # Upload some media
response_1 = self.helper.upload_media( response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
self.upload_resource, SMALL_PNG, tok=non_admin_user_tok response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok)
)
response_2 = self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=non_admin_user_tok
)
# Extract media IDs # Extract media IDs
server_and_media_id_1 = response_1["content_uri"][6:] server_and_media_id_1 = response_1["content_uri"][6:]
@ -304,11 +286,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1) self._ensure_quarantined(admin_user_tok, server_and_media_id_1)
# Attempt to access each piece of media # Attempt to access each piece of media
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
"GET", "GET",
server_and_media_id_2, f"/_matrix/media/v3/download/{server_and_media_id_2}",
shorthand=False, shorthand=False,
access_token=non_admin_user_tok, access_token=non_admin_user_tok,
) )

View File

@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
from typing import Dict
from parameterized import parameterized from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
@ -26,22 +28,27 @@ from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds
INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds
class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): class _AdminMediaTests(unittest.HomeserverTestCase):
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo, synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets, login.register_servlets,
] ]
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
class DeleteMediaByIDTestCase(_AdminMediaTests):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname self.server_name = hs.hostname
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -117,12 +124,8 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
Tests that delete a media is successfully Tests that delete a media is successfully
""" """
download_resource = self.media_repo.children[b"download"]
upload_resource = self.media_repo.children[b"upload"]
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(
upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=200, expect_code=200,
@ -134,11 +137,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertEqual(server_name, self.server_name) self.assertEqual(server_name, self.server_name)
# Attempt to access media # Attempt to access media
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False, shorthand=False,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
@ -173,11 +174,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
) )
# Attempt to access media # Attempt to access media
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False, shorthand=False,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
@ -194,7 +193,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(os.path.exists(local_path)) self.assertFalse(os.path.exists(local_path))
class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo, synapse.rest.admin.register_servlets_for_media_repo,
@ -529,11 +528,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
""" """
Create a media and return media_id and server_and_media_id Create a media and return media_id and server_and_media_id
""" """
upload_resource = self.media_repo.children[b"upload"]
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(
upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=200, expect_code=200,
@ -553,16 +549,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
""" """
Try to access a media and check the result Try to access a media and check the result
""" """
download_resource = self.media_repo.children[b"download"]
media_id = server_and_media_id.split("/")[1] media_id = server_and_media_id.split("/")[1]
local_path = self.filepaths.local_media_filepath(media_id) local_path = self.filepaths.local_media_filepath(media_id)
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False, shorthand=False,
access_token=self.admin_user_tok, access_token=self.admin_user_tok,
) )
@ -591,27 +583,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.assertFalse(os.path.exists(local_path)) self.assertFalse(os.path.exists(local_path))
class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): class QuarantineMediaByIDTestCase(_AdminMediaTests):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo = hs.get_media_repository_resource()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.server_name = hs.hostname self.server_name = hs.hostname
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
# Create media
upload_resource = media_repo.children[b"upload"]
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(
upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=200, expect_code=200,
@ -720,26 +701,16 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(media_info["quarantined_by"]) self.assertFalse(media_info["quarantined_by"])
class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): class ProtectMediaByIDTestCase(_AdminMediaTests):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo = hs.get_media_repository_resource() hs.get_media_repository_resource()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
# Create media
upload_resource = media_repo.children[b"upload"]
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(
upload_resource,
SMALL_PNG, SMALL_PNG,
tok=self.admin_user_tok, tok=self.admin_user_tok,
expect_code=200, expect_code=200,
@ -816,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase):
self.assertFalse(media_info["safe_from_quarantine"]) self.assertFalse(media_info["safe_from_quarantine"])
class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): class PurgeMediaCacheTestCase(_AdminMediaTests):
servlets = [ servlets = [
synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets,
synapse.rest.admin.register_servlets_for_media_repo, synapse.rest.admin.register_servlets_for_media_repo,

View File

@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.errors import Codes from synapse.api.errors import Codes
@ -34,8 +35,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
] ]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository_resource()
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
@ -44,6 +43,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
self.url = "/_synapse/admin/v1/statistics/users/media" self.url = "/_synapse/admin/v1/statistics/users/media"
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
def test_no_auth(self) -> None: def test_no_auth(self) -> None:
""" """
Try to list users without authentication. Try to list users without authentication.
@ -470,12 +474,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase):
user_token: Access token of the user user_token: Access token of the user
number_media: Number of media to be created for the user number_media: Number of media to be created for the user
""" """
upload_resource = self.media_repo.children[b"upload"]
for _ in range(number_media): for _ in range(number_media):
# Upload some media into the room # Upload some media into the room
self.helper.upload_media( self.helper.upload_media(SMALL_PNG, tok=user_token, expect_code=200)
upload_resource, SMALL_PNG, tok=user_token, expect_code=200
)
def _check_fields(self, content: List[JsonDict]) -> None: def _check_fields(self, content: List[JsonDict]) -> None:
"""Checks that all attributes are present in content """Checks that all attributes are present in content

View File

@ -17,12 +17,13 @@ import hmac
import os import os
import urllib.parse import urllib.parse
from binascii import unhexlify from binascii import unhexlify
from typing import List, Optional from typing import Dict, List, Optional
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized, parameterized_class from parameterized import parameterized, parameterized_class
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes
@ -45,7 +46,6 @@ from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
from tests.unittest import override_config from tests.unittest import override_config
@ -3421,7 +3421,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.media_repo = hs.get_media_repository_resource()
self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.filepaths = MediaFilePaths(hs.config.media.media_store_path)
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
@ -3432,6 +3431,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
self.other_user self.other_user
) )
def create_resource_dict(self) -> Dict[str, Resource]:
resources = super().create_resource_dict()
resources["/_matrix/media"] = self.hs.get_media_repository_resource()
return resources
@parameterized.expand(["GET", "DELETE"]) @parameterized.expand(["GET", "DELETE"])
def test_no_auth(self, method: str) -> None: def test_no_auth(self, method: str) -> None:
"""Try to list media of an user without authentication.""" """Try to list media of an user without authentication."""
@ -3907,12 +3911,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
Returns: Returns:
The ID of the newly created media. The ID of the newly created media.
""" """
upload_resource = self.media_repo.children[b"upload"]
download_resource = self.media_repo.children[b"download"]
# Upload some media into the room # Upload some media into the room
response = self.helper.upload_media( response = self.helper.upload_media(
upload_resource, image_data, user_token, filename, expect_code=200 image_data, user_token, filename, expect_code=200
) )
# Extract media ID from the response # Extract media ID from the response
@ -3920,11 +3921,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
media_id = server_and_media_id.split("/")[1] media_id = server_and_media_id.split("/")[1]
# Try to access a media and to create `last_access_ts` # Try to access a media and to create `last_access_ts`
channel = make_request( channel = self.make_request(
self.reactor,
FakeSite(download_resource, self.reactor),
"GET", "GET",
server_and_media_id, f"/_matrix/media/v3/download/{server_and_media_id}",
shorthand=False, shorthand=False,
access_token=user_token, access_token=user_token,
) )

View File

@ -37,7 +37,6 @@ import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
from synapse.api.constants import Membership from synapse.api.constants import Membership
@ -45,7 +44,7 @@ from synapse.api.errors import Codes
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request from tests.server import FakeChannel, make_request
from tests.test_utils.html_parsers import TestHtmlParser from tests.test_utils.html_parsers import TestHtmlParser
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
@ -558,7 +557,6 @@ class RestHelper:
def upload_media( def upload_media(
self, self,
resource: Resource,
image_data: bytes, image_data: bytes,
tok: str, tok: str,
filename: str = "test.png", filename: str = "test.png",
@ -576,7 +574,7 @@ class RestHelper:
path = "/_matrix/media/r0/upload?filename=%s" % (filename,) path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
channel = make_request( channel = make_request(
self.reactor, self.reactor,
FakeSite(resource, self.reactor), self.site,
"POST", "POST",
path, path,
content=image_data, content=image_data,

View File

@ -24,10 +24,10 @@ from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IAddress, IResolutionReceiver from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from twisted.web.resource import Resource
from synapse.config.oembed import OEmbedEndpointConfig from synapse.config.oembed import OEmbedEndpointConfig
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
from synapse.rest.media.media_repository_resource import MediaRepositoryResource
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
@ -117,8 +117,8 @@ class URLPreviewTests(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.media_repo = hs.get_media_repository() self.media_repo = hs.get_media_repository()
media_repo_resource = hs.get_media_repository_resource() assert self.media_repo.url_previewer is not None
self.preview_url = media_repo_resource.children[b"preview_url"] self.url_previewer = self.media_repo.url_previewer
self.lookups: Dict[str, Any] = {} self.lookups: Dict[str, Any] = {}
@ -143,8 +143,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.reactor.nameResolver = Resolver() # type: ignore[assignment] self.reactor.nameResolver = Resolver() # type: ignore[assignment]
def create_test_resource(self) -> MediaRepositoryResource: def create_resource_dict(self) -> Dict[str, Resource]:
return self.hs.get_media_repository_resource() """Create a resource tree for the test server
A resource tree is a mapping from path to twisted.web.resource.
The default implementation creates a JsonResource and calls each function in
`servlets` to register servlets against it.
"""
return {"/_matrix/media": self.hs.get_media_repository_resource()}
def _assert_small_png(self, json_body: JsonDict) -> None: def _assert_small_png(self, json_body: JsonDict) -> None:
"""Assert properties from the SMALL_PNG test image.""" """Assert properties from the SMALL_PNG test image."""
@ -159,7 +166,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -183,7 +190,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check the cache returns the correct response # Check the cache returns the correct response
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://matrix.org", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
) )
# Check the cache response has the same content # Check the cache response has the same content
@ -193,13 +202,15 @@ class URLPreviewTests(unittest.HomeserverTestCase):
) )
# Clear the in-memory cache # Clear the in-memory cache
self.assertIn("http://matrix.org", self.preview_url._url_previewer._cache) self.assertIn("http://matrix.org", self.url_previewer._cache)
self.preview_url._url_previewer._cache.pop("http://matrix.org") self.url_previewer._cache.pop("http://matrix.org")
self.assertNotIn("http://matrix.org", self.preview_url._url_previewer._cache) self.assertNotIn("http://matrix.org", self.url_previewer._cache)
# Check the database cache returns the correct response # Check the database cache returns the correct response
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://matrix.org", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False,
) )
# Check the cache response has the same content # Check the cache response has the same content
@ -221,7 +232,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -251,7 +262,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -287,7 +298,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -328,7 +339,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -363,7 +374,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -396,7 +407,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://example.com", "/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -425,7 +436,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
) )
# No requests made. # No requests made.
@ -446,7 +459,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
) )
self.assertEqual(channel.code, 502) self.assertEqual(channel.code, 502)
@ -463,7 +478,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blocked IP addresses, accessed directly, are not spidered. Blocked IP addresses, accessed directly, are not spidered.
""" """
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://192.168.1.1", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://192.168.1.1",
shorthand=False,
) )
# No requests made. # No requests made.
@ -479,7 +496,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
Blocked IP ranges, accessed directly, are not spidered. Blocked IP ranges, accessed directly, are not spidered.
""" """
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://1.1.1.2", shorthand=False "GET", "/_matrix/media/v3/preview_url?url=http://1.1.1.2", shorthand=False
) )
self.assertEqual(channel.code, 403) self.assertEqual(channel.code, 403)
@ -497,7 +514,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://example.com", "/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -533,7 +550,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
] ]
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
) )
self.assertEqual(channel.code, 502) self.assertEqual(channel.code, 502)
self.assertEqual( self.assertEqual(
@ -553,7 +572,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
] ]
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
) )
# No requests made. # No requests made.
@ -574,7 +595,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
channel = self.make_request( channel = self.make_request(
"GET", "preview_url?url=http://example.com", shorthand=False "GET",
"/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
) )
self.assertEqual(channel.code, 502) self.assertEqual(channel.code, 502)
@ -591,10 +614,11 @@ class URLPreviewTests(unittest.HomeserverTestCase):
OPTIONS returns the OPTIONS. OPTIONS returns the OPTIONS.
""" """
channel = self.make_request( channel = self.make_request(
"OPTIONS", "preview_url?url=http://example.com", shorthand=False "OPTIONS",
"/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False,
) )
self.assertEqual(channel.code, 200) self.assertEqual(channel.code, 204)
self.assertEqual(channel.json_body, {})
def test_accept_language_config_option(self) -> None: def test_accept_language_config_option(self) -> None:
""" """
@ -605,7 +629,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Build and make a request to the server # Build and make a request to the server
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://example.com", "/_matrix/media/v3/preview_url?url=http://example.com",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -658,7 +682,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -708,7 +732,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -750,7 +774,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -790,7 +814,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -831,7 +855,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"preview_url?{query_params}", f"/_matrix/media/v3/preview_url?{query_params}",
shorthand=False, shorthand=False,
) )
self.pump() self.pump()
@ -852,7 +876,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://matrix.org", "/_matrix/media/v3/preview_url?url=http://matrix.org",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -889,7 +913,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345", "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -949,7 +973,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345", "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -998,7 +1022,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://www.hulu.com/watch/12345", "/_matrix/media/v3/preview_url?url=http://www.hulu.com/watch/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1043,7 +1067,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://twitter.com/matrixdotorg/status/12345", "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1072,7 +1096,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1164,7 +1188,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1205,7 +1229,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=http://cdn.twitter.com/matrixdotorg", "/_matrix/media/v3/preview_url?url=http://cdn.twitter.com/matrixdotorg",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1247,7 +1271,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check fetching # Check fetching
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"download/{host}/{media_id}", f"/_matrix/media/v3/download/{host}/{media_id}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1260,7 +1284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"download/{host}/{media_id}", f"/_matrix/media/v3/download/{host}/{media_id}",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1295,7 +1319,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# Check fetching # Check fetching
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1313,7 +1337,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale",
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1343,7 +1367,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertTrue(os.path.isdir(thumbnail_dir)) self.assertTrue(os.path.isdir(thumbnail_dir))
self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1) self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1)
self.get_success(self.preview_url._url_previewer._expire_url_cache_data()) self.get_success(self.url_previewer._expire_url_cache_data())
for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs: for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs:
self.assertFalse( self.assertFalse(
@ -1363,7 +1387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=" + bad_url, "/_matrix/media/v3/preview_url?url=" + bad_url,
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1372,7 +1396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=" + good_url, "/_matrix/media/v3/preview_url?url=" + good_url,
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )
@ -1404,7 +1428,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel = self.make_request( channel = self.make_request(
"GET", "GET",
"preview_url?url=" + bad_url, "/_matrix/media/v3/preview_url?url=" + bad_url,
shorthand=False, shorthand=False,
await_result=False, await_result=False,
) )

View File

@ -60,7 +60,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource from synapse.http.server import JsonResource, OptionsResource
from synapse.http.site import SynapseRequest, SynapseSite from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import ( from synapse.logging.context import (
SENTINEL_CONTEXT, SENTINEL_CONTEXT,
@ -459,7 +459,7 @@ class HomeserverTestCase(TestCase):
The default calls `self.create_resource_dict` and builds the resultant dict The default calls `self.create_resource_dict` and builds the resultant dict
into a tree. into a tree.
""" """
root_resource = Resource() root_resource = OptionsResource()
create_resource_tree(self.create_resource_dict(), root_resource) create_resource_tree(self.create_resource_dict(), root_resource)
return root_resource return root_resource