Support MSC3916 by adding _matrix/client/v1/media/download endpoint (#17365)

This commit is contained in:
Shay 2024-07-02 06:07:04 -07:00 committed by GitHub
parent b3b793786c
commit 8f890447b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1718 additions and 84 deletions

View File

@ -0,0 +1 @@
Support [MSC3916](https://github.com/matrix-org/matrix-spec-proposals/blob/rav/authentication-for-media/proposals/3916-authentication-for-media.md) by adding _matrix/client/v1/media/download endpoint.

View File

@ -117,7 +117,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
}, },
"media_repository": { "media_repository": {
"app": "synapse.app.generic_worker", "app": "synapse.app.generic_worker",
"listener_resources": ["media"], "listener_resources": ["media", "client"],
"endpoint_patterns": [ "endpoint_patterns": [
"^/_matrix/media/", "^/_matrix/media/",
"^/_synapse/admin/v1/purge_media_cache$", "^/_synapse/admin/v1/purge_media_cache$",
@ -125,6 +125,7 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"^/_synapse/admin/v1/user/.*/media.*$", "^/_synapse/admin/v1/user/.*/media.*$",
"^/_synapse/admin/v1/media/.*$", "^/_synapse/admin/v1/media/.*$",
"^/_synapse/admin/v1/quarantine_media/.*$", "^/_synapse/admin/v1/quarantine_media/.*$",
"^/_matrix/client/v1/media/.*$",
], ],
# The first configured media worker will run the media background jobs # The first configured media worker will run the media background jobs
"shared_extra_conf": { "shared_extra_conf": {

View File

@ -117,6 +117,19 @@ each upgrade are complete before moving on to the next upgrade, to avoid
stacking them up. You can monitor the currently running background updates with stacking them up. You can monitor the currently running background updates with
[the Admin API](usage/administration/admin_api/background_updates.html#status). [the Admin API](usage/administration/admin_api/background_updates.html#status).
# Upgrading to v1.111.0
## New worker endpoints for authenticated client media
[Media repository workers](./workers.md#synapseappmedia_repository) handling
Media APIs can now handle the following endpoint pattern:
```
^/_matrix/client/v1/media/.*$
```
Please update your reverse proxy configuration.
# Upgrading to v1.106.0 # Upgrading to v1.106.0
## Minimum supported Rust version ## Minimum supported Rust version

View File

@ -739,6 +739,7 @@ An example for a federation sender instance:
Handles the media repository. It can handle all endpoints starting with: Handles the media repository. It can handle all endpoints starting with:
/_matrix/media/ /_matrix/media/
/_matrix/client/v1/media/
... and the following regular expressions matching media-specific administration APIs: ... and the following regular expressions matching media-specific administration APIs:

View File

@ -96,3 +96,6 @@ ignore_missing_imports = True
# https://github.com/twisted/treq/pull/366 # https://github.com/twisted/treq/pull/366
[mypy-treq.*] [mypy-treq.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-multipart.*]
ignore_missing_imports = True

18
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]] [[package]]
name = "annotated-types" name = "annotated-types"
@ -2039,6 +2039,20 @@ files = [
[package.dependencies] [package.dependencies]
six = ">=1.5" six = ">=1.5"
[[package]]
name = "python-multipart"
version = "0.0.9"
description = "A streaming multipart parser for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
]
[package.extras]
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
[[package]] [[package]]
name = "pytz" name = "pytz"
version = "2022.7.1" version = "2022.7.1"
@ -3187,4 +3201,4 @@ user-search = ["pyicu"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8.0" python-versions = "^3.8.0"
content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea" content-hash = "e8d5806e10eb69bc06900fde18ea3df38f38490ab6baa73fe4a563dfb6abacba"

View File

@ -224,6 +224,8 @@ pydantic = ">=1.7.4, <3"
# needed. # needed.
setuptools_rust = ">=1.3" setuptools_rust = ">=1.3"
# This is used for parsing multipart responses
python-multipart = ">=0.0.9"
# Optional Dependencies # Optional Dependencies
# --------------------- # ---------------------

View File

@ -130,7 +130,8 @@ class Ratelimiter:
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited. burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set. Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action update: Whether to count this check as performing the action. If the action
cannot be performed, the user's action count is not incremented at all.
n_actions: The number of times the user wants to do this action. If the user n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented cannot do all of the actions, the user's action count is not incremented
at all. at all.

View File

@ -1871,6 +1871,52 @@ class FederationClient(FederationBase):
return filtered_statuses, filtered_failures return filtered_statuses, filtered_failures
async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise
logger.debug(
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
destination,
media_id,
)
return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
async def download_media( async def download_media(
self, self,
destination: str, destination: str,

View File

@ -824,7 +824,6 @@ class TransportLayerClient:
ip_address: str, ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}" path = f"/_matrix/media/r0/download/{destination}/{media_id}"
return await self.client.get_file( return await self.client.get_file(
destination, destination,
path, path,
@ -852,7 +851,6 @@ class TransportLayerClient:
ip_address: str, ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}" path = f"/_matrix/media/v3/download/{destination}/{media_id}"
return await self.client.get_file( return await self.client.get_file(
destination, destination,
path, path,
@ -873,6 +871,29 @@ class TransportLayerClient:
ip_address=ip_address, ip_address=ip_address,
) )
async def federation_download_media(
self,
destination: str,
media_id: str,
output_stream: BinaryIO,
max_size: int,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
path = f"/_matrix/federation/v1/media/download/{media_id}"
return await self.client.federation_get_file(
destination,
path,
output_stream=output_stream,
max_size=max_size,
args={
"timeout_ms": str(max_timeout_ms),
},
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
def _create_path(federation_prefix: str, path: str, *args: str) -> str: def _create_path(federation_prefix: str, path: str, *args: str) -> str:
""" """

View File

@ -32,8 +32,8 @@ from synapse.federation.transport.server._base import (
from synapse.federation.transport.server.federation import ( from synapse.federation.transport.server.federation import (
FEDERATION_SERVLET_CLASSES, FEDERATION_SERVLET_CLASSES,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationMediaDownloadServlet,
FederationUnstableClientKeysClaimServlet, FederationUnstableClientKeysClaimServlet,
FederationUnstableMediaDownloadServlet,
) )
from synapse.http.server import HttpServer, JsonResource from synapse.http.server import HttpServer, JsonResource
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -316,11 +316,8 @@ def register_servlets(
): ):
continue continue
if servletclass == FederationUnstableMediaDownloadServlet: if servletclass == FederationMediaDownloadServlet:
if ( if not hs.config.server.enable_media_repo:
not hs.config.server.enable_media_repo
or not hs.config.experimental.msc3916_authenticated_media_enabled
):
continue continue
servletclass( servletclass(

View File

@ -362,7 +362,7 @@ class BaseFederationServlet:
return None return None
if ( if (
func.__self__.__class__.__name__ # type: ignore func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet" == "FederationMediaDownloadServlet"
): ):
response = await func( response = await func(
origin, content, request, *args, **kwargs origin, content, request, *args, **kwargs
@ -374,7 +374,7 @@ class BaseFederationServlet:
else: else:
if ( if (
func.__self__.__class__.__name__ # type: ignore func.__self__.__class__.__name__ # type: ignore
== "FederationUnstableMediaDownloadServlet" == "FederationMediaDownloadServlet"
): ):
response = await func( response = await func(
origin, content, request, *args, **kwargs origin, content, request, *args, **kwargs

View File

@ -790,7 +790,7 @@ class FederationAccountStatusServlet(BaseFederationServerServlet):
return 200, {"account_statuses": statuses, "failures": failures} return 200, {"account_statuses": statuses, "failures": failures}
class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet): class FederationMediaDownloadServlet(BaseFederationServerServlet):
""" """
Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns Implementation of new federation media `/download` endpoint outlined in MSC3916. Returns
a multipart/mixed response consisting of a JSON object and the requested media a multipart/mixed response consisting of a JSON object and the requested media
@ -798,7 +798,6 @@ class FederationUnstableMediaDownloadServlet(BaseFederationServerServlet):
""" """
PATH = "/media/download/(?P<media_id>[^/]*)" PATH = "/media/download/(?P<media_id>[^/]*)"
PREFIX = FEDERATION_UNSTABLE_PREFIX + "/org.matrix.msc3916"
RATELIMIT = True RATELIMIT = True
def __init__( def __init__(
@ -858,5 +857,5 @@ FEDERATION_SERVLET_CLASSES: Tuple[Type[BaseFederationServlet], ...] = (
FederationV1SendKnockServlet, FederationV1SendKnockServlet,
FederationMakeKnockServlet, FederationMakeKnockServlet,
FederationAccountStatusServlet, FederationAccountStatusServlet,
FederationUnstableMediaDownloadServlet, FederationMediaDownloadServlet,
) )

View File

@ -35,6 +35,8 @@ from typing import (
Union, Union,
) )
import attr
import multipart
import treq import treq
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet from netaddr import AddrFormatError, IPAddress, IPSet
@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self._maybe_fail() self._maybe_fail()
@attr.s(auto_attribs=True, slots=True)
class MultipartResponse:
"""
A small class to hold parsed values of a multipart response.
"""
json: bytes = b"{}"
length: Optional[int] = None
content_type: Optional[bytes] = None
disposition: Optional[bytes] = None
url: Optional[bytes] = None
class _MultipartParserProtocol(protocol.Protocol):
"""
Protocol to read and parse a MSC3916 multipart/mixed response
"""
transport: Optional[ITCPTransport] = None
def __init__(
self,
stream: ByteWriteable,
deferred: defer.Deferred,
boundary: str,
max_length: Optional[int],
) -> None:
self.stream = stream
self.deferred = deferred
self.boundary = boundary
self.max_length = max_length
self.parser = None
self.multipart_response = MultipartResponse()
self.has_redirect = False
self.in_json = False
self.json_done = False
self.file_length = 0
self.total_length = 0
self.in_disposition = False
self.in_content_type = False
def dataReceived(self, incoming_data: bytes) -> None:
if self.deferred.called:
return
# we don't have a parser yet, instantiate it
if not self.parser:
def on_header_field(data: bytes, start: int, end: int) -> None:
if data[start:end] == b"Location":
self.has_redirect = True
if data[start:end] == b"Content-Disposition":
self.in_disposition = True
if data[start:end] == b"Content-Type":
self.in_content_type = True
def on_header_value(data: bytes, start: int, end: int) -> None:
# the first header should be content-type for application/json
if not self.in_json and not self.json_done:
assert data[start:end] == b"application/json"
self.in_json = True
elif self.has_redirect:
self.multipart_response.url = data[start:end]
elif self.in_content_type:
self.multipart_response.content_type = data[start:end]
self.in_content_type = False
elif self.in_disposition:
self.multipart_response.disposition = data[start:end]
self.in_disposition = False
def on_part_data(data: bytes, start: int, end: int) -> None:
# we've seen json header but haven't written the json data
if self.in_json and not self.json_done:
self.multipart_response.json = data[start:end]
self.json_done = True
# we have a redirect header rather than a file, and have already captured it
elif self.has_redirect:
return
# otherwise we are in the file part
else:
logger.info("Writing multipart file data to stream")
try:
self.stream.write(data[start:end])
except Exception as e:
logger.warning(
f"Exception encountered writing file data to stream: {e}"
)
self.deferred.errback()
self.file_length += end - start
callbacks = {
"on_header_field": on_header_field,
"on_header_value": on_header_value,
"on_part_data": on_part_data,
}
self.parser = multipart.MultipartParser(self.boundary, callbacks)
self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
try:
self.parser.write(incoming_data) # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Exception writing to multipart parser: {e}")
self.deferred.errback()
return
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
if reason.check(ResponseDone):
self.multipart_response.length = self.file_length
self.deferred.callback(self.multipart_response)
else:
self.deferred.errback(reason)
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
@ -1091,6 +1217,32 @@ def read_body_with_max_size(
return d return d
def read_multipart_response(
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
) -> "defer.Deferred[MultipartResponse]":
"""
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
the stream passed in and returning a deferred resolving to a MultipartResponse
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
boundary: the multipart/mixed boundary string
max_length: maximum allowable length of the response
"""
d: defer.Deferred[MultipartResponse] = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_length is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_length:
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d
response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
return d
def encode_query_args(args: Optional[QueryParams]) -> bytes: def encode_query_args(args: Optional[QueryParams]) -> bytes:
""" """
Encodes a map of query arguments to bytes which can be appended to a URL. Encodes a map of query arguments to bytes which can be appended to a URL.

View File

@ -75,9 +75,11 @@ from synapse.http.client import (
BlocklistingAgentWrapper, BlocklistingAgentWrapper,
BodyExceededMaxSize, BodyExceededMaxSize,
ByteWriteable, ByteWriteable,
SimpleHttpClient,
_make_scheduler, _make_scheduler,
encode_query_args, encode_query_args,
read_body_with_max_size, read_body_with_max_size,
read_multipart_response,
) )
from synapse.http.connectproxyclient import BearerProxyCredentials from synapse.http.connectproxyclient import BearerProxyCredentials
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
@ -466,6 +468,13 @@ class MatrixFederationHttpClient:
self._sleeper = AwakenableSleeper(self.reactor) self._sleeper = AwakenableSleeper(self.reactor)
self._simple_http_client = SimpleHttpClient(
hs,
ip_blocklist=hs.config.server.federation_ip_range_blocklist,
ip_allowlist=hs.config.server.federation_ip_range_allowlist,
use_proxy=True,
)
def wake_destination(self, destination: str) -> None: def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online.""" """Called when the remote server may have come back online."""
@ -1553,6 +1562,189 @@ class MatrixFederationHttpClient:
) )
return length, headers return length, headers
async def federation_get_file(
self,
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: str,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
"""GETs a file from a given homeserver over the federation /download endpoint
Args:
destination: The remote server to send the HTTP request to.
path: The HTTP path to GET.
output_stream: File to write the response body to.
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
requester IP
ip_address: IP address of the requester
max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
Resolves to an (int, dict, bytes) tuple of
the file length, a dict of the response headers, and the file json
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
NotRetryingDestination: If we are not yet ready to retry this
server.
FederationDeniedError: If this destination is not on our
federation whitelist
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
SynapseError: If the requested file exceeds ratelimits or the response from the
remote server is not a multipart response
AssertionError: if the resolved multipart response's length is None
"""
request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args
)
# check for a minimum balance of 1MiB in ratelimiter before initiating request
send_req, _ = await download_ratelimiter.can_do_action(
requester=None, key=ip_address, n_actions=1048576, update=False
)
if not send_req:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
response = await self._send_request(
request,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
headers = dict(response.headers.getAllRawHeaders())
expected_size = response.length
# if we don't get an expected length then use the max length
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
logger.debug(
f"File size unknown, assuming file is max allowable size: {max_size}"
)
read_body, _ = await download_ratelimiter.can_do_action(
requester=None,
key=ip_address,
n_actions=expected_size,
)
if not read_body:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
# this should be a multipart/mixed response with the boundary string in the header
try:
raw_content_type = headers.get(b"Content-Type")
assert raw_content_type is not None
content_type = raw_content_type[0].decode("UTF-8")
content_type_parts = content_type.split("boundary=")
boundary = content_type_parts[1]
except Exception:
msg = "Remote response is malformed: expected Content-Type of multipart/mixed with a boundary present."
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg)
try:
# add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >=
deferred = read_multipart_response(
response, output_stream, boundary, expected_size + 1
)
deferred.addTimeout(self.default_timeout_seconds, self.reactor)
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
except defer.TimeoutError as e:
logger.warning(
"{%s} [%s] Timed out reading response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except ResponseFailed as e:
logger.warning(
"{%s} [%s] Failed to read response - %s %s",
request.txn_id,
request.destination,
request.method,
request.uri.decode("ascii"),
)
raise RequestSendFailed(e, can_retry=True) from e
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
e,
)
raise
multipart_response = await make_deferred_yieldable(deferred)
if not multipart_response.url:
assert multipart_response.length is not None
length = multipart_response.length
headers[b"Content-Type"] = [multipart_response.content_type]
headers[b"Content-Disposition"] = [multipart_response.disposition]
# the response contained a redirect url to download the file from
else:
str_url = multipart_response.url.decode("utf-8")
logger.info(
"{%s} [%s] File download redirected, now downloading from: %s",
request.txn_id,
request.destination,
str_url,
)
length, headers, _, _ = await self._simple_http_client.get_file(
str_url, output_stream, expected_size
)
logger.info(
"{%s} [%s] Completed: %d %s [%d bytes] %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
length,
request.method,
request.uri.decode("ascii"),
)
return length, headers, multipart_response.json
def _flatten_response_never_received(e: BaseException) -> str: def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"): if hasattr(e, "reasons"):

View File

@ -221,6 +221,7 @@ def add_file_headers(
# select private. don't bother setting Expires as all our # select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control # clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
if file_size is not None: if file_size is not None:
request.setHeader(b"Content-Length", b"%d" % (file_size,)) request.setHeader(b"Content-Length", b"%d" % (file_size,))
@ -302,12 +303,37 @@ async def respond_with_multipart_responder(
) )
return return
if media_info.media_type.lower().split(";", 1)[0] in INLINE_CONTENT_TYPES:
disposition = "inline"
else:
disposition = "attachment"
def _quote(x: str) -> str:
return urllib.parse.quote(x.encode("utf-8"))
if media_info.upload_name:
if _can_encode_filename_as_token(media_info.upload_name):
disposition = "%s; filename=%s" % (
disposition,
media_info.upload_name,
)
else:
disposition = "%s; filename*=utf-8''%s" % (
disposition,
_quote(media_info.upload_name),
)
from synapse.media.media_storage import MultipartFileConsumer from synapse.media.media_storage import MultipartFileConsumer
# note that currently the json_object is just {}, this will change when linked media # note that currently the json_object is just {}, this will change when linked media
# is implemented # is implemented
multipart_consumer = MultipartFileConsumer( multipart_consumer = MultipartFileConsumer(
clock, request, media_info.media_type, {}, media_info.media_length clock,
request,
media_info.media_type,
{},
disposition,
media_info.media_length,
) )
logger.debug("Responding to media request with responder %s", responder) logger.debug("Responding to media request with responder %s", responder)

View File

@ -480,6 +480,7 @@ class MediaRepository:
name: Optional[str], name: Optional[str],
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str, ip_address: str,
use_federation_endpoint: bool,
) -> None: ) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.
@ -492,6 +493,8 @@ class MediaRepository:
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
ip_address: the IP address of the requester ip_address: the IP address of the requester
use_federation_endpoint: whether to request the remote media over the new
federation `/download` endpoint
Returns: Returns:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
@ -522,6 +525,7 @@ class MediaRepository:
max_timeout_ms, max_timeout_ms,
self.download_ratelimiter, self.download_ratelimiter,
ip_address, ip_address,
use_federation_endpoint,
) )
# We deliberately stream the file outside the lock # We deliberately stream the file outside the lock
@ -569,6 +573,7 @@ class MediaRepository:
max_timeout_ms, max_timeout_ms,
self.download_ratelimiter, self.download_ratelimiter,
ip_address, ip_address,
False,
) )
# Ensure we actually use the responder so that it releases resources # Ensure we actually use the responder so that it releases resources
@ -585,6 +590,7 @@ class MediaRepository:
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter, download_ratelimiter: Ratelimiter,
ip_address: str, ip_address: str,
use_federation_endpoint: bool,
) -> Tuple[Optional[Responder], RemoteMedia]: ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -598,6 +604,8 @@ class MediaRepository:
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP. requester IP.
ip_address: the IP address of the requester ip_address: the IP address of the requester
use_federation_endpoint: whether to request the remote media over the new federation
/download endpoint
Returns: Returns:
A tuple of responder and the media info of the file. A tuple of responder and the media info of the file.
@ -629,9 +637,23 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it. # Failed to find the file anywhere, lets download it.
try: try:
if not use_federation_endpoint:
media_info = await self._download_remote_file( media_info = await self._download_remote_file(
server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address server_name,
media_id,
max_timeout_ms,
download_ratelimiter,
ip_address,
) )
else:
media_info = await self._federation_download_remote_file(
server_name,
media_id,
max_timeout_ms,
download_ratelimiter,
ip_address,
)
except SynapseError: except SynapseError:
raise raise
except Exception as e: except Exception as e:
@ -775,6 +797,129 @@ class MediaRepository:
quarantined_by=None, quarantined_by=None,
) )
async def _federation_download_remote_file(
self,
server_name: str,
media_id: str,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> RemoteMedia:
"""Attempt to download the remote file from the given server name.
Uses the given file_id as the local id and downloads the file over the federation
v1 download endpoint
Args:
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP
ip_address: the IP address of the requester
Returns:
The media info of the file.
"""
file_id = random_string(24)
file_info = FileInfo(server_name=server_name, file_id=file_id)
async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
res = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
# if we had to fall back to the _matrix/media endpoint it will only return
# the headers and length, check the length of the tuple before unpacking
if len(res) == 3:
length, headers, json = res
else:
length, headers = res
except RequestSendFailed as e:
logger.warning(
"Request failed fetching remote media %s/%s: %r",
server_name,
media_id,
e,
)
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
logger.warning(
"HTTP error fetching remote media %s/%s: %s",
server_name,
media_id,
e.response,
)
if e.code == twisted.web.http.NOT_FOUND:
raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
logger.warning(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise
except NotRetryingDestination:
logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
logger.exception(
"Failed to fetch remote media %s/%s", server_name, media_id
)
raise SynapseError(502, "Failed to fetch remote media")
if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
# Multiple remote media download requests can race (when using
# multiple media repos), so this may throw a violation constraint
# exception. If it does we'll delete the newly downloaded file from
# disk (as we're in the ctx manager).
#
# However: we've already called `finish()` so we may have also
# written to the storage providers. This is preferable to the
# alternative where we call `finish()` *after* this, where we could
# end up having an entry in the DB but fail to write the files to
# the storage providers.
await self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=time_now_ms,
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
logger.debug("Stored remote media in file %r", fname)
return RemoteMedia(
media_origin=server_name,
media_id=media_id,
media_type=media_type,
media_length=length,
upload_name=upload_name,
created_ts=time_now_ms,
filesystem_id=file_id,
last_access_ts=time_now_ms,
quarantined_by=None,
)
def _get_thumbnail_requirements( def _get_thumbnail_requirements(
self, media_type: str self, media_type: str
) -> Tuple[ThumbnailRequirement, ...]: ) -> Tuple[ThumbnailRequirement, ...]:

View File

@ -401,13 +401,14 @@ class MultipartFileConsumer:
wrapped_consumer: interfaces.IConsumer, wrapped_consumer: interfaces.IConsumer,
file_content_type: str, file_content_type: str,
json_object: JsonDict, json_object: JsonDict,
content_length: Optional[int] = None, disposition: str,
content_length: Optional[int],
) -> None: ) -> None:
self.clock = clock self.clock = clock
self.wrapped_consumer = wrapped_consumer self.wrapped_consumer = wrapped_consumer
self.json_field = json_object self.json_field = json_object
self.json_field_written = False self.json_field_written = False
self.content_type_written = False self.file_headers_written = False
self.file_content_type = file_content_type self.file_content_type = file_content_type
self.boundary = uuid4().hex.encode("ascii") self.boundary = uuid4().hex.encode("ascii")
@ -420,6 +421,7 @@ class MultipartFileConsumer:
self.paused = False self.paused = False
self.length = content_length self.length = content_length
self.disposition = disposition
### IConsumer APIs ### ### IConsumer APIs ###
@ -488,11 +490,13 @@ class MultipartFileConsumer:
self.json_field_written = True self.json_field_written = True
# if we haven't written the content type yet, do so # if we haven't written the content type yet, do so
if not self.content_type_written: if not self.file_headers_written:
type = self.file_content_type.encode("utf-8") type = self.file_content_type.encode("utf-8")
content_type = Header(b"Content-Type", type) content_type = Header(b"Content-Type", type)
self.wrapped_consumer.write(bytes(content_type) + CRLF + CRLF) self.wrapped_consumer.write(bytes(content_type) + CRLF)
self.content_type_written = True disp_header = Header(b"Content-Disposition", self.disposition)
self.wrapped_consumer.write(bytes(disp_header) + CRLF + CRLF)
self.file_headers_written = True
self.wrapped_consumer.write(data) self.wrapped_consumer.write(data)
@ -506,7 +510,6 @@ class MultipartFileConsumer:
producing data for good. producing data for good.
""" """
assert self.producer is not None assert self.producer is not None
self.paused = True self.paused = True
self.producer.stopProducing() self.producer.stopProducing()
@ -518,7 +521,6 @@ class MultipartFileConsumer:
the time being, and to stop until C{resumeProducing()} is called. the time being, and to stop until C{resumeProducing()} is called.
""" """
assert self.producer is not None assert self.producer is not None
self.paused = True self.paused = True
if self.streaming: if self.streaming:
@ -549,7 +551,7 @@ class MultipartFileConsumer:
""" """
if not self.length: if not self.length:
return None return None
# calculate length of json field and content-type header # calculate length of json field and content-type, disposition headers
json_field = json.dumps(self.json_field) json_field = json.dumps(self.json_field)
json_bytes = json_field.encode("utf-8") json_bytes = json_field.encode("utf-8")
json_length = len(json_bytes) json_length = len(json_bytes)
@ -558,9 +560,13 @@ class MultipartFileConsumer:
content_type = Header(b"Content-Type", type) content_type = Header(b"Content-Type", type)
type_length = len(bytes(content_type)) type_length = len(bytes(content_type))
# 154 is the length of the elements that aren't variable, ie disp = self.disposition.encode("utf-8")
disp_header = Header(b"Content-Disposition", disp)
disp_length = len(bytes(disp_header))
# 156 is the length of the elements that aren't variable, ie
# CRLFs and boundary strings, etc # CRLFs and boundary strings, etc
self.length += json_length + type_length + 154 self.length += json_length + type_length + disp_length + 156
return self.length return self.length
@ -569,7 +575,6 @@ class MultipartFileConsumer:
async def _resumeProducingRepeatedly(self) -> None: async def _resumeProducingRepeatedly(self) -> None:
assert self.producer is not None assert self.producer is not None
assert not self.streaming assert not self.streaming
producer = cast("interfaces.IPullProducer", self.producer) producer = cast("interfaces.IPullProducer", self.producer)
self.paused = False self.paused = False

View File

@ -145,6 +145,10 @@ class ClientRestResource(JsonResource):
password_policy.register_servlets(hs, client_resource) password_policy.register_servlets(hs, client_resource)
knock.register_servlets(hs, client_resource) knock.register_servlets(hs, client_resource)
appservice_ping.register_servlets(hs, client_resource) appservice_ping.register_servlets(hs, client_resource)
if hs.config.server.enable_media_repo:
from synapse.rest.client import media
media.register_servlets(hs, client_resource)
# moving to /_synapse/admin # moving to /_synapse/admin
if is_main_process: if is_main_process:

View File

@ -22,6 +22,7 @@
import logging import logging
import re import re
from typing import Optional
from synapse.http.server import ( from synapse.http.server import (
HttpServer, HttpServer,
@ -194,14 +195,76 @@ class UnstableThumbnailResource(RestServlet):
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)
class DownloadResource(RestServlet):
PATTERNS = [
re.compile(
"/_matrix/client/v1/media/download/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)(/(?P<file_name>[^/]*))?$"
)
]
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self._is_mine_server_name = hs.is_mine_server_name
self.auth = hs.get_auth()
async def on_GET(
self,
request: SynapseRequest,
server_name: str,
media_id: str,
file_name: Optional[str] = None,
) -> None:
# Validate the server name, raising if invalid
parse_and_validate_server_name(server_name)
await self.auth.get_user_by_req(request)
set_cors_headers(request)
set_corp_headers(request)
request.setHeader(
b"Content-Security-Policy",
b"sandbox;"
b" default-src 'none';"
b" script-src 'none';"
b" plugin-types application/pdf;"
b" style-src 'unsafe-inline';"
b" media-src 'self';"
b" object-src 'self';",
)
# Limited non-standard form of CSP for IE11
request.setHeader(b"X-Content-Security-Policy", b"sandbox;")
request.setHeader(b"Referrer-Policy", b"no-referrer")
max_timeout_ms = parse_integer(
request, "timeout_ms", default=DEFAULT_MAX_TIMEOUT_MS
)
max_timeout_ms = min(max_timeout_ms, MAXIMUM_ALLOWED_MAX_TIMEOUT_MS)
if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(
request, media_id, file_name, max_timeout_ms
)
else:
ip_address = request.getClientAddress().host
await self.media_repo.get_remote_media(
request,
server_name,
media_id,
file_name,
max_timeout_ms,
ip_address,
True,
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3916_authenticated_media_enabled:
media_repo = hs.get_media_repository() media_repo = hs.get_media_repository()
if hs.config.media.url_preview_enabled: if hs.config.media.url_preview_enabled:
UnstablePreviewURLServlet( UnstablePreviewURLServlet(hs, media_repo, media_repo.media_storage).register(
hs, media_repo, media_repo.media_storage http_server
).register(http_server) )
UnstableMediaConfigResource(hs).register(http_server) UnstableMediaConfigResource(hs).register(http_server)
UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register( UnstableThumbnailResource(hs, media_repo, media_repo.media_storage).register(
http_server http_server
) )
DownloadResource(hs, media_repo).register(http_server)

View File

@ -105,4 +105,5 @@ class DownloadResource(RestServlet):
file_name, file_name,
max_timeout_ms, max_timeout_ms,
ip_address, ip_address,
False,
) )

View File

@ -36,10 +36,9 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase): class FederationMediaDownloadsTest(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs) super().prepare(reactor, clock, hs)
@ -65,9 +64,6 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
) )
self.media_repo = hs.get_media_repository() self.media_repo = hs.get_media_repository()
@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": True}}
)
def test_file_download(self) -> None: def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream") content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success( content_uri = self.get_success(
@ -82,7 +78,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# test with a text file # test with a text file
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
) )
self.pump() self.pump()
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
@ -106,7 +102,8 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# check that the text file and expected value exist # check that the text file and expected value exist
found_file = any( found_file = any(
"\r\nContent-Type: text/plain\r\n\r\nfile_to_stream" in field "\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream"
in field
for field in stripped for field in stripped
) )
self.assertTrue(found_file) self.assertTrue(found_file)
@ -124,7 +121,7 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# test with an image file # test with an image file
channel = self.make_signed_federation_request( channel = self.make_signed_federation_request(
"GET", "GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}", f"/_matrix/federation/v1/media/download/{content_uri.media_id}",
) )
self.pump() self.pump()
self.assertEqual(200, channel.code) self.assertEqual(200, channel.code)
@ -149,25 +146,3 @@ class FederationUnstableMediaDownloadsTest(unittest.FederatingHomeserverTestCase
# check that the png file exists and matches what was uploaded # check that the png file exists and matches what was uploaded
found_file = any(SMALL_PNG in field for field in stripped_bytes) found_file = any(SMALL_PNG in field for field in stripped_bytes)
self.assertTrue(found_file) self.assertTrue(found_file)
@override_config(
{"experimental_features": {"msc3916_authenticated_media_enabled": False}}
)
def test_disable_config(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.media_repo.create_content(
"text/plain",
"test_upload",
content,
46,
UserID.from_string("@user_id:whatever.org"),
)
)
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/unstable/org.matrix.msc3916/media/download/{content_uri.media_id}",
)
self.pump()
self.assertEqual(404, channel.code)
self.assertEqual(channel.json_body.get("errcode"), "M_UNRECOGNIZED")

View File

@ -37,18 +37,155 @@ from synapse.http.client import (
BlocklistingAgentWrapper, BlocklistingAgentWrapper,
BlocklistingReactorWrapper, BlocklistingReactorWrapper,
BodyExceededMaxSize, BodyExceededMaxSize,
MultipartResponse,
_DiscardBodyWithMaxSizeProtocol, _DiscardBodyWithMaxSizeProtocol,
_MultipartParserProtocol,
read_body_with_max_size, read_body_with_max_size,
read_multipart_response,
) )
from tests.server import FakeTransport, get_clock from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase from tests.unittest import TestCase
class ReadMultipartResponseTests(TestCase):
data1 = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_"
data2 = b"to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
redirect_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nLocation: https://cdn.example.org/ab/c1/2345.txt\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
def _build_multipart_response(
self, response_length: Union[int, str], max_length: int
) -> Tuple[
BytesIO,
"Deferred[MultipartResponse]",
_MultipartParserProtocol,
]:
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=response_length)
result = BytesIO()
boundary = "6067d4698f8d40a0a794ea7d7379d53a"
deferred = read_multipart_response(response, result, boundary, max_length)
# Fish the protocol out of the response.
protocol = response.deliverBody.call_args[0][0]
protocol.transport = Mock()
return result, deferred, protocol
def _assert_error(
self,
deferred: "Deferred[MultipartResponse]",
protocol: _MultipartParserProtocol,
) -> None:
"""Ensure that the expected error is received."""
assert isinstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
assert protocol.transport is not None
# type-ignore: presumably abortConnection has been replaced with a Mock.
protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
def _cleanup_error(self, deferred: "Deferred[MultipartResponse]") -> None:
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
def errback(f: Failure) -> None:
called[0] = True
deferred.addErrback(errback)
self.assertTrue(called[0])
def test_parse_file(self) -> None:
"""
Check that a multipart response containing a file is properly parsed
into the json/file parts, and the json and file are properly captured
"""
result, deferred, protocol = self._build_multipart_response(249, 250)
# Start sending data.
protocol.dataReceived(self.data1)
protocol.dataReceived(self.data2)
# Close the connection.
protocol.connectionLost(Failure(ResponseDone()))
multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
self.assertEqual(multipart_response.json, b"{}")
self.assertEqual(result.getvalue(), b"file_to_stream")
self.assertEqual(multipart_response.length, len(b"file_to_stream"))
self.assertEqual(multipart_response.content_type, b"text/plain")
self.assertEqual(
multipart_response.disposition, b"inline; filename=test_upload"
)
def test_parse_redirect(self) -> None:
"""
check that a multipart response containing a redirect is properly parsed and redirect url is
returned
"""
result, deferred, protocol = self._build_multipart_response(249, 250)
# Start sending data.
protocol.dataReceived(self.redirect_data)
# Close the connection.
protocol.connectionLost(Failure(ResponseDone()))
multipart_response: MultipartResponse = deferred.result # type: ignore[assignment]
self.assertEqual(multipart_response.json, b"{}")
self.assertEqual(result.getvalue(), b"")
self.assertEqual(
multipart_response.url, b"https://cdn.example.org/ab/c1/2345.txt"
)
def test_too_large(self) -> None:
"""A response which is too large raises an exception."""
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
# Start sending data.
protocol.dataReceived(self.data1)
self.assertEqual(result.getvalue(), b"file_")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
def test_additional_data(self) -> None:
"""A connection can receive data after being closed."""
result, deferred, protocol = self._build_multipart_response(UNKNOWN_LENGTH, 180)
# Start sending data.
protocol.dataReceived(self.data1)
self._assert_error(deferred, protocol)
# More data might have come in.
protocol.dataReceived(self.data2)
self.assertEqual(result.getvalue(), b"file_")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
def test_content_length(self) -> None:
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_multipart_response(250, 1)
# Deferred shouldn't be called yet.
self.assertFalse(deferred.called)
# Start sending data.
protocol.dataReceived(self.data1)
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
class ReadBodyWithMaxSizeTests(TestCase): class ReadBodyWithMaxSizeTests(TestCase):
def _build_response( def _build_response(self, length: Union[int, str] = UNKNOWN_LENGTH) -> Tuple[
self, length: Union[int, str] = UNKNOWN_LENGTH BytesIO,
) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]: "Deferred[int]",
_DiscardBodyWithMaxSizeProtocol,
]:
"""Start reading the body, returns the response, result and proto""" """Start reading the body, returns the response, result and proto"""
response = Mock(length=length) response = Mock(length=length)
result = BytesIO() result = BytesIO()

View File

@ -129,7 +129,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
@attr.s(auto_attribs=True, slots=True, frozen=True) @attr.s(auto_attribs=True, slots=True, frozen=True)
class _TestImage: class TestImage:
"""An image for testing thumbnailing with the expected results """An image for testing thumbnailing with the expected results
Attributes: Attributes:
@ -158,7 +158,7 @@ class _TestImage:
is_inline: bool = True is_inline: bool = True
small_png = _TestImage( small_png = TestImage(
SMALL_PNG, SMALL_PNG,
b"image/png", b"image/png",
b".png", b".png",
@ -175,7 +175,7 @@ small_png = _TestImage(
), ),
) )
small_png_with_transparency = _TestImage( small_png_with_transparency = TestImage(
unhexlify( unhexlify(
b"89504e470d0a1a0a0000000d49484452000000010000000101000" b"89504e470d0a1a0a0000000d49484452000000010000000101000"
b"00000376ef9240000000274524e5300010194fdae0000000a4944" b"00000376ef9240000000274524e5300010194fdae0000000a4944"
@ -188,7 +188,7 @@ small_png_with_transparency = _TestImage(
# different versions of Pillow. # different versions of Pillow.
) )
small_lossless_webp = _TestImage( small_lossless_webp = TestImage(
unhexlify( unhexlify(
b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700" b"524946461a000000574542505650384c0d0000002f0000001007" b"1011118888fe0700"
), ),
@ -196,7 +196,7 @@ small_lossless_webp = _TestImage(
b".webp", b".webp",
) )
empty_file = _TestImage( empty_file = TestImage(
b"", b"",
b"image/gif", b"image/gif",
b".gif", b".gif",
@ -204,7 +204,7 @@ empty_file = _TestImage(
unable_to_thumbnail=True, unable_to_thumbnail=True,
) )
SVG = _TestImage( SVG = TestImage(
b"""<?xml version="1.0"?> b"""<?xml version="1.0"?>
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" <!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"
"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">
@ -236,7 +236,7 @@ urls = [
@parameterized_class(("test_image", "url"), itertools.product(test_images, urls)) @parameterized_class(("test_image", "url"), itertools.product(test_images, urls))
class MediaRepoTests(unittest.HomeserverTestCase): class MediaRepoTests(unittest.HomeserverTestCase):
servlets = [media.register_servlets] servlets = [media.register_servlets]
test_image: ClassVar[_TestImage] test_image: ClassVar[TestImage]
hijack_auth = True hijack_auth = True
user_id = "@test:user" user_id = "@test:user"
url: ClassVar[str] url: ClassVar[str]

View File

@ -28,7 +28,7 @@ from twisted.web.http import HTTPChannel
from twisted.web.server import Request from twisted.web.server import Request
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login from synapse.rest.client import login, media
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
@ -255,6 +255,238 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
return sum(len(files) for _, _, files in os.walk(path)) return sum(len(files) for _, _, files in os.walk(path))
class AuthenticatedMediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks running multiple media repos work correctly using autheticated media paths"""
servlets = [
admin.register_servlets_for_client_rest_resource,
login.register_servlets,
media.register_servlets,
]
file_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: text/plain\r\nContent-Disposition: inline; filename=test_upload\r\n\r\nfile_to_stream\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self) -> dict:
conf = super().default_config()
conf["federation_custom_ca_list"] = [get_test_ca_cert_file()]
return conf
def make_worker_hs(
self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any
) -> HomeServer:
worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs)
# Force the media paths onto the replication resource.
worker_hs.get_media_repository_resource().register_servlets(
self._hs_to_site[worker_hs].resource, worker_hs
)
return worker_hs
def _get_media_req(
self, hs: HomeServer, target: str, media_id: str
) -> Tuple[FakeChannel, Request]:
"""Request some remote media from the given HS by calling the download
API.
This then triggers an outbound request from the HS to the target.
Returns:
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""
channel = make_request(
self.reactor,
self._hs_to_site[hs],
"GET",
f"/_matrix/client/v1/media/download/{target}/{media_id}",
shorthand=False,
access_token=self.access_token,
await_result=False,
)
self.pump()
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
# build the test server
server_factory = Factory.forProtocol(HTTPChannel)
# Request.finish expects the factory to have a 'log' method.
server_factory.log = _log_request
server_tls_protocol = wrap_server_factory_for_tls(
server_factory, self.reactor, sanlist=[b"DNS:example.com"]
).buildProtocol(None)
# now, tell the client protocol factory to build the client protocol (it will be a
# _WrappingProtocol, around a TLSMemoryBIOProtocol, around an
# HTTP11ClientProtocol) and wire the output of said protocol up to the server via
# a FakeTransport.
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(
FakeTransport(server_tls_protocol, self.reactor, client_protocol)
)
# tell the server tls protocol to send its stuff back to the client, too
server_tls_protocol.makeConnection(
FakeTransport(client_protocol, self.reactor, server_tls_protocol)
)
# fish the test server back out of the server-side TLS protocol.
http_server: HTTPChannel = server_tls_protocol.wrappedProtocol
# give the reactor a pump to get the TLS juices flowing.
self.reactor.pump((0.1,))
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(
request.path,
f"/_matrix/federation/v1/media/download/{media_id}".encode(),
)
self.assertEqual(
request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")]
)
return channel, request
def test_basic(self) -> None:
"""Test basic fetching of remote media from a single worker."""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
request.setResponseCode(200)
request.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request.write(self.file_data)
request.finish()
self.pump(0.1)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.result["body"], b"file_to_stream")
def test_download_simple_file_race(self) -> None:
"""Test that fetching remote media from two different processes at the
same time works.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_media()
# Make two requests without responding to the outbound media requests.
channel1, request1 = self._get_media_req(hs1, "example.com:443", "ABC123")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "ABC123")
# Respond to the first outbound media request and check that the client
# request is successful
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request1.write(self.file_data)
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
self.assertEqual(channel1.result["body"], b"file_to_stream")
# Now respond to the second with the same content.
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request2.write(self.file_data)
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], b"file_to_stream")
# We expect only one new file to have been persisted.
self.assertEqual(start_count + 1, self._count_remote_media())
def test_download_image_race(self) -> None:
"""Test that fetching remote *images* from two different processes at
the same time works.
This checks that races generating thumbnails are handled correctly.
"""
hs1 = self.make_worker_hs("synapse.app.generic_worker")
hs2 = self.make_worker_hs("synapse.app.generic_worker")
start_count = self._count_remote_thumbnails()
channel1, request1 = self._get_media_req(hs1, "example.com:443", "PIC1")
channel2, request2 = self._get_media_req(hs2, "example.com:443", "PIC1")
request1.setResponseCode(200)
request1.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
img_data = b"\r\n\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: application/json\r\n\r\n{}\r\n--6067d4698f8d40a0a794ea7d7379d53a\r\nContent-Type: image/png\r\nContent-Disposition: inline; filename=test_img\r\n\r\n"
request1.write(img_data)
request1.write(SMALL_PNG)
request1.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
request1.finish()
self.pump(0.1)
self.assertEqual(channel1.code, 200, channel1.result["body"])
self.assertEqual(channel1.result["body"], SMALL_PNG)
request2.setResponseCode(200)
request2.responseHeaders.setRawHeaders(
b"Content-Type",
["multipart/mixed; boundary=6067d4698f8d40a0a794ea7d7379d53a"],
)
request2.write(img_data)
request2.write(SMALL_PNG)
request2.write(b"\r\n--6067d4698f8d40a0a794ea7d7379d53a--\r\n\r\n")
request2.finish()
self.pump(0.1)
self.assertEqual(channel2.code, 200, channel2.result["body"])
self.assertEqual(channel2.result["body"], SMALL_PNG)
# We expect only three new thumbnails to have been persisted.
self.assertEqual(start_count + 3, self._count_remote_thumbnails())
def _count_remote_media(self) -> int:
"""Count the number of files in our remote media directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_content"
)
return sum(len(files) for _, _, files in os.walk(path))
def _count_remote_thumbnails(self) -> int:
"""Count the number of files in our remote thumbnails directory."""
path = os.path.join(
self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
)
return sum(len(files) for _, _, files in os.walk(path))
def _log_request(request: Request) -> None: def _log_request(request: Request) -> None:
"""Implements Factory.log, which is expected by Request.finish""" """Implements Factory.log, which is expected by Request.finish"""
logger.info("Completed request %s", request) logger.info("Completed request %s", request)

View File

@ -19,31 +19,54 @@
# #
# #
import base64 import base64
import io
import json import json
import os import os
import re import re
from typing import Any, Dict, Optional, Sequence, Tuple, Type from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
from parameterized import parameterized_class
from twisted.internet import defer
from twisted.internet._resolver import HostResolution from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IAddress, IResolutionReceiver from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.oembed import OEmbedEndpointConfig from synapse.config.oembed import OEmbedEndpointConfig
from synapse.http.client import MultipartResponse
from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable
from synapse.media._base import FileInfo from synapse.media._base import FileInfo
from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, media from synapse.rest.client import login, media
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest from tests import unittest
from tests.server import FakeTransport, ThreadedMemoryReactorClock from tests.media.test_media_storage import (
SVG,
TestImage,
empty_file,
small_lossless_webp,
small_png,
small_png_with_transparency,
)
from tests.server import FakeChannel, FakeTransport, ThreadedMemoryReactorClock
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
from tests.unittest import override_config from tests.unittest import override_config
@ -1607,3 +1630,583 @@ class UnstableMediaConfigTest(unittest.HomeserverTestCase):
self.assertEqual( self.assertEqual(
channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size channel.json_body["m.upload.size"], self.hs.config.media.max_upload_size
) )
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
# mock actually reading file body
def read_multipart_response_30MiB(*args: Any, **kwargs: Any) -> Deferred:
d: Deferred = defer.Deferred()
d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
return d
def read_multipart_response_50MiB(*args: Any, **kwargs: Any) -> Deferred:
d: Deferred = defer.Deferred()
d.callback(MultipartResponse(b"{}", 31457280, b"img/png", None))
return d
@patch(
"synapse.http.matrixfederationclient.read_multipart_response",
read_multipart_response_30MiB,
)
def test_download_ratelimit_default(self) -> None:
"""
Test remote media download ratelimiting against default configuration - 500MB bucket
and 87kb/second drain rate
"""
# mock out actually sending the request, returns a 30MiB response
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 31457280
resp.headers = Headers(
{"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
)
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abc",
shorthand=False,
access_token=self.tok,
)
assert channel.code == 200
# next 15 should go through
for i in range(15):
channel2 = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/remote.org/abc{i}",
shorthand=False,
access_token=self.tok,
)
assert channel2.code == 200
# 17th will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcd",
shorthand=False,
access_token=self.tok,
)
assert channel3.code == 429
# however, a request from a different IP will go through
channel4 = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcde",
shorthand=False,
client_ip="187.233.230.159",
access_token=self.tok,
)
assert channel4.code == 200
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
# 30MiB download is authorized - The last download was blocked at 503,316,480.
# The next download will be authorized when bucket hits 492,830,720
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
# needs to drain before another download will be authorized, that will take ~=
# 2 minutes (10,485,760/89,088/60)
self.reactor.pump([2.0 * 60.0])
# enough has drained and next request goes through
channel5 = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcdef",
shorthand=False,
access_token=self.tok,
)
assert channel5.code == 200
@override_config(
{
"remote_media_download_per_second": "50M",
"remote_media_download_burst_count": "50M",
}
)
@patch(
"synapse.http.matrixfederationclient.read_multipart_response",
read_multipart_response_50MiB,
)
def test_download_rate_limit_config(self) -> None:
"""
Test that download rate limit config options are correctly picked up and applied
"""
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 52428800
resp.headers = Headers(
{"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
)
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abc",
shorthand=False,
access_token=self.tok,
)
assert channel.code == 200
# immediate second request should fail
channel = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcd",
shorthand=False,
access_token=self.tok,
)
assert channel.code == 429
# advance half a second
self.reactor.pump([0.5])
# request still fails
channel = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcde",
shorthand=False,
access_token=self.tok,
)
assert channel.code == 429
# advance another half second
self.reactor.pump([0.5])
# enough has drained from bucket and request is successful
channel = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcdef",
shorthand=False,
access_token=self.tok,
)
assert channel.code == 200
@patch(
"synapse.http.matrixfederationclient.read_multipart_response",
read_multipart_response_30MiB,
)
def test_download_ratelimit_max_size_sub(self) -> None:
"""
Test that if no content-length is provided, the default max size is applied instead
"""
# mock out actually sending the request
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = UNKNOWN_LENGTH
resp.headers = Headers(
{"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]}
)
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# ten requests should go through using the max size (500MB/50MB)
for i in range(10):
channel2 = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/remote.org/abc{i}",
shorthand=False,
access_token=self.tok,
)
assert channel2.code == 200
# eleventh will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/client/v1/media/download/remote.org/abcd",
shorthand=False,
access_token=self.tok,
)
assert channel3.code == 429
def test_file_download(self) -> None:
content = io.BytesIO(b"file_to_stream")
content_uri = self.get_success(
self.repo.create_content(
"text/plain",
"test_upload",
content,
46,
UserID.from_string("@user_id:whatever.org"),
)
)
# test with a text file
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/test/{content_uri.media_id}",
shorthand=False,
access_token=self.tok,
)
self.pump()
self.assertEqual(200, channel.code)
test_images = [
small_png,
small_png_with_transparency,
small_lossless_webp,
empty_file,
SVG,
]
input_values = [(x,) for x in test_images]
@parameterized_class(("test_image",), input_values)
class DownloadTestCase(unittest.HomeserverTestCase):
test_image: ClassVar[TestImage]
servlets = [
media.register_servlets,
login.register_servlets,
admin.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[
Tuple[
"Deferred[Any]",
str,
str,
Optional[QueryParams],
]
] = []
def federation_get_file(
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: Any,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]], bytes]]":
"""A mock for MatrixFederationHttpClient.federation_get_file."""
def write_to(
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
data, response = r
output_stream.write(data)
return response
def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]] = (
Deferred()
)
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: Any,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
"""A mock for MatrixFederationHttpClient.get_file."""
def write_to(
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response
def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f
d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)
# Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.federation_get_file = federation_get_file
client.get_file = get_file
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config = self.default_config()
config["media_store_path"] = self.media_store_path
config["max_image_pixels"] = 2000000
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
config["experimental_features"] = {"msc3916_authenticated_media_enabled": True}
hs = self.setup_test_homeserver(config=config, federation_http_client=client)
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.media_repo = hs.get_media_repository()
self.remote = "example.com"
self.media_id = "12345"
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
def _req(
self, content_disposition: Optional[bytes], include_content_type: bool = True
) -> FakeChannel:
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()
# We've made one fetch, to example.com, using the federation media URL
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], "/_matrix/federation/v1/media/download/" + self.media_id
)
self.assertEqual(
self.fetches[0][3],
{"timeout_ms": "20000"},
)
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}
if include_content_type:
headers[b"Content-Type"] = [self.test_image.content_type]
if content_disposition:
headers[b"Content-Disposition"] = [content_disposition]
self.fetches[0][0].callback(
(self.test_image.data, (len(self.test_image.data), headers, b"{}"))
)
self.pump()
self.assertEqual(channel.code, 200)
return channel
def test_handle_missing_content_type(self) -> None:
channel = self._req(
b"attachment; filename=out" + self.test_image.extension,
include_content_type=False,
)
headers = channel.headers
self.assertEqual(channel.code, 200)
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
)
def test_disposition_filename_ascii(self) -> None:
"""
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
"""
channel = self._req(b"attachment; filename=out" + self.test_image.extension)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
)
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
[
(b"inline" if self.test_image.is_inline else b"attachment")
+ b"; filename=out"
+ self.test_image.extension
],
)
def test_disposition_filenamestar_utf8escaped(self) -> None:
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the
response.
"""
filename = parse.quote("\u2603".encode()).encode("ascii")
channel = self._req(
b"attachment; filename*=utf-8''" + filename + self.test_image.extension
)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
)
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
[
(b"inline" if self.test_image.is_inline else b"attachment")
+ b"; filename*=utf-8''"
+ filename
+ self.test_image.extension
],
)
def test_disposition_none(self) -> None:
"""
If there is no filename, Content-Disposition should only
be a disposition type.
"""
channel = self._req(None)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
)
self.assertEqual(
headers.getRawHeaders(b"Content-Disposition"),
[b"inline" if self.test_image.is_inline else b"attachment"],
)
def test_x_robots_tag_header(self) -> None:
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media.
"""
channel = self._req(b"attachment; filename=out" + self.test_image.extension)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"X-Robots-Tag"),
[b"noindex, nofollow, noarchive, noimageindex"],
)
def test_cross_origin_resource_policy_header(self) -> None:
"""
Test that the Cross-Origin-Resource-Policy header is set to "cross-origin"
allowing web clients to embed media from the downloads API.
"""
channel = self._req(b"attachment; filename=out" + self.test_image.extension)
headers = channel.headers
self.assertEqual(
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
[b"cross-origin"],
)
def test_unknown_federation_endpoint(self) -> None:
"""
Test that if the downloadd request to remote federation endpoint returns a 404
we fall back to the _matrix/media endpoint
"""
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()
# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
)
# The result which says the endpoint is unknown.
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
self.fetches[0][0].errback(
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
)
self.pump()
# There should now be another request to the _matrix/media/v3/download URL.
self.assertEqual(len(self.fetches), 2)
self.assertEqual(self.fetches[1][1], "example.com")
self.assertEqual(
self.fetches[1][2],
f"/_matrix/media/v3/download/example.com/{self.media_id}",
)
headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}
self.fetches[1][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)
self.pump()
self.assertEqual(channel.code, 200)