Fix slipped logging context when media rejected (#17239)

When a module rejects a piece of media we end up trying to close the
same logging context twice.

Instead of fixing the existing code we refactor to use an async context
manager, which is easier to write correctly.
This commit is contained in:
Erik Johnston 2024-05-29 11:14:42 +01:00 committed by GitHub
parent ad179b0136
commit bb5a692946
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 56 additions and 92 deletions

1
changelog.d/17239.misc Normal file
View File

@ -0,0 +1 @@
Fix errors in logs about closing incorrect logging contexts when media gets rejected by a module.

View File

@ -650,7 +650,7 @@ class MediaRepository:
file_info = FileInfo(server_name=server_name, file_id=file_id) file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish): async with self.media_storage.store_into_file(file_info) as (f, fname):
try: try:
length, headers = await self.client.download_media( length, headers = await self.client.download_media(
server_name, server_name,
@ -693,8 +693,6 @@ class MediaRepository:
) )
raise SynapseError(502, "Failed to fetch remote media") raise SynapseError(502, "Failed to fetch remote media")
await finish()
if b"Content-Type" in headers: if b"Content-Type" in headers:
media_type = headers[b"Content-Type"][0].decode("ascii") media_type = headers[b"Content-Type"][0].decode("ascii")
else: else:
@ -1045,14 +1043,9 @@ class MediaRepository:
), ),
) )
with self.media_storage.store_into_file(file_info) as ( async with self.media_storage.store_into_file(file_info) as (f, fname):
f,
fname,
finish,
):
try: try:
await self.media_storage.write_to_file(t_byte_source, f) await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally: finally:
t_byte_source.close() t_byte_source.close()

View File

@ -27,10 +27,9 @@ from typing import (
IO, IO,
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Awaitable, AsyncIterator,
BinaryIO, BinaryIO,
Callable, Callable,
Generator,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
@ -97,11 +96,9 @@ class MediaStorage:
the file path written to in the primary media store the file path written to in the primary media store
""" """
with self.store_into_file(file_info) as (f, fname, finish_cb): async with self.store_into_file(file_info) as (f, fname):
# Write to the main media repository # Write to the main media repository
await self.write_to_file(source, f) await self.write_to_file(source, f)
# Write to the other storage providers
await finish_cb()
return fname return fname
@ -111,32 +108,27 @@ class MediaStorage:
await defer_to_thread(self.reactor, _write_file_synchronously, source, output) await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@trace_with_opname("MediaStorage.store_into_file") @trace_with_opname("MediaStorage.store_into_file")
@contextlib.contextmanager @contextlib.asynccontextmanager
def store_into_file( async def store_into_file(
self, file_info: FileInfo self, file_info: FileInfo
) -> Generator[Tuple[BinaryIO, str, Callable[[], Awaitable[None]]], None, None]: ) -> AsyncIterator[Tuple[BinaryIO, str]]:
"""Context manager used to get a file like object to write into, as """Async Context manager used to get a file like object to write into, as
described by file_info. described by file_info.
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file Actually yields a 2-tuple (file, fname,), where file is a file
like object that can be written to, fname is the absolute path of file like object that can be written to and fname is the absolute path of file
on disk, and finish_cb is a function that returns an awaitable. on disk.
fname can be used to read the contents from after upload, e.g. to fname can be used to read the contents from after upload, e.g. to
generate thumbnails. generate thumbnails.
finish_cb must be called and waited on after the file has been successfully been
written to. Should not be called if there was an error. Checks for spam and
stores the file into the configured storage providers.
Args: Args:
file_info: Info about the file to store file_info: Info about the file to store
Example: Example:
with media_storage.store_into_file(info) as (f, fname, finish_cb): async with media_storage.store_into_file(info) as (f, fname,):
# .. write into f ... # .. write into f ...
await finish_cb()
""" """
path = self._file_info_to_path(file_info) path = self._file_info_to_path(file_info)
@ -145,29 +137,30 @@ class MediaStorage:
dirname = os.path.dirname(fname) dirname = os.path.dirname(fname)
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
finished_called = [False]
main_media_repo_write_trace_scope = start_active_span( main_media_repo_write_trace_scope = start_active_span(
"writing to main media repo" "writing to main media repo"
) )
main_media_repo_write_trace_scope.__enter__() main_media_repo_write_trace_scope.__enter__()
with main_media_repo_write_trace_scope:
try: try:
with open(fname, "wb") as f: with open(fname, "wb") as f:
yield f, fname
async def finish() -> None: except Exception as e:
# When someone calls finish, we assume they are done writing to the main media repo try:
main_media_repo_write_trace_scope.__exit__(None, None, None) os.remove(fname)
except Exception:
pass
raise e from None
with start_active_span("writing to other storage providers"): with start_active_span("writing to other storage providers"):
# Ensure that all writes have been flushed and close the spam_check = (
# file. await self._spam_checker_module_callbacks.check_media_file_for_spam(
f.flush()
f.close()
spam_check = await self._spam_checker_module_callbacks.check_media_file_for_spam(
ReadableFileWrapper(self.clock, fname), file_info ReadableFileWrapper(self.clock, fname), file_info
) )
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
logger.info("Blocking media due to spam checker") logger.info("Blocking media due to spam checker")
# Note that we'll delete the stored media, due to the # Note that we'll delete the stored media, due to the
@ -181,27 +174,6 @@ class MediaStorage:
with start_active_span(str(provider)): with start_active_span(str(provider)):
await provider.store_file(path, file_info) await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
except Exception as e:
try:
main_media_repo_write_trace_scope.__exit__(
type(e), None, e.__traceback__
)
os.remove(fname)
except Exception:
pass
raise e from None
if not finished_called:
exc = Exception("Finished callback not called")
main_media_repo_write_trace_scope.__exit__(
type(exc), None, exc.__traceback__
)
raise exc
async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]: async def fetch_media(self, file_info: FileInfo) -> Optional[Responder]:
"""Attempts to fetch media described by file_info from the local cache """Attempts to fetch media described by file_info from the local cache
and configured storage providers. and configured storage providers.

View File

@ -592,7 +592,7 @@ class UrlPreviewer:
file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish): async with self.media_storage.store_into_file(file_info) as (f, fname):
if url.startswith("data:"): if url.startswith("data:"):
if not allow_data_urls: if not allow_data_urls:
raise SynapseError( raise SynapseError(
@ -603,8 +603,6 @@ class UrlPreviewer:
else: else:
download_result = await self._download_url(url, f) download_result = await self._download_url(url, f)
await finish()
try: try:
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()

View File

@ -93,13 +93,13 @@ class UnstableMediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404. # from a regular 404.
file_id = "abcdefg12345" file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f, media_storage = hs.get_media_repository().media_storage
fname,
finish, ctx = media_storage.store_into_file(file_info)
): (f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG) f.write(SMALL_PNG)
self.get_success(finish()) self.get_success(ctx.__aexit__(None, None, None))
self.get_success( self.get_success(
self.store.store_cached_remote_media( self.store.store_cached_remote_media(

View File

@ -44,13 +44,13 @@ class MediaDomainBlockingTests(unittest.HomeserverTestCase):
# from a regular 404. # from a regular 404.
file_id = "abcdefg12345" file_id = "abcdefg12345"
file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id) file_info = FileInfo(server_name=self.remote_server_name, file_id=file_id)
with hs.get_media_repository().media_storage.store_into_file(file_info) as (
f, media_storage = hs.get_media_repository().media_storage
fname,
finish, ctx = media_storage.store_into_file(file_info)
): (f, fname) = self.get_success(ctx.__aenter__())
f.write(SMALL_PNG) f.write(SMALL_PNG)
self.get_success(finish()) self.get_success(ctx.__aexit__(None, None, None))
self.get_success( self.get_success(
self.store.store_cached_remote_media( self.store.store_cached_remote_media(