From 68626ff8e98443d6dc470970274a853a93fceefa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 27 Jul 2020 14:40:11 -0400 Subject: [PATCH] Convert the remaining media repo code to async / await. (#7947) --- changelog.d/7947.misc | 1 + synapse/rest/media/v1/_base.py | 8 +- synapse/rest/media/v1/media_repository.py | 105 ++++++++++-------- synapse/rest/media/v1/media_storage.py | 52 +++++---- synapse/rest/media/v1/preview_url_resource.py | 10 +- synapse/rest/media/v1/storage_provider.py | 62 +++++------ 6 files changed, 131 insertions(+), 107 deletions(-) create mode 100644 changelog.d/7947.misc diff --git a/changelog.d/7947.misc b/changelog.d/7947.misc new file mode 100644 index 000000000..dfe4c0317 --- /dev/null +++ b/changelog.d/7947.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 9a847130c..20ddb9550 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -17,7 +17,9 @@ import logging import os import urllib +from typing import Awaitable +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -240,14 +242,14 @@ class Responder(object): held can be cleaned up. """ - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer Args: - consumer (IConsumer) + consumer: The consumer to stream into. Returns: - Deferred: Resolves once the response has finished being written + Resolves once the response has finished being written """ pass diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 45628c07b..6fb4039e9 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -18,10 +18,11 @@ import errno import logging import os import shutil -from typing import Dict, Tuple +from typing import IO, Dict, Optional, Tuple import twisted.internet.error import twisted.web.http +from twisted.web.http import Request from twisted.web.resource import Resource from synapse.api.errors import ( @@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string from ._base import ( FileInfo, + Responder, get_filename_from_headers, respond_404, respond_with_responder, @@ -135,19 +137,24 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) async def create_content( - self, media_type, upload_name, content, content_length, auth_user - ): + self, + media_type: str, + upload_name: str, + content: IO, + content_length: int, + auth_user: str, + ) -> str: """Store uploaded content for a local user and return the mxc URL Args: - media_type(str): The content type of the file - upload_name(str): The name of the file + media_type: The content type of the file + upload_name: The name of the file content: A file like object that is the content to store - content_length(int): The length of the content - auth_user(str): The user_id of the uploader + content_length: The length of the content + auth_user: The user_id of the uploader Returns: - Deferred[str]: The mxc url of the stored content + The mxc url of the stored content """ media_id = random_string(24) @@ -170,19 +177,20 @@ class MediaRepository(object): return "mxc://%s/%s" % (self.server_name, media_id) - async def get_local_media(self, request, media_id, name): + async def get_local_media( + self, request: Request, media_id: str, name: Optional[str] + ) -> None: """Responds to reqests for local media, if exists, or returns 404. Args: - request(twisted.web.http.Request) - media_id (str): The media ID of the content. (This is the same as + request: The incoming request. + media_id: The media ID of the content. (This is the same as the file_id for local content.) - name (str|None): Optional name that, if specified, will be used as + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: @@ -203,20 +211,20 @@ class MediaRepository(object): request, responder, media_type, media_length, upload_name ) - async def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media( + self, request: Request, server_name: str, media_id: str, name: Optional[str] + ) -> None: """Respond to requests for remote media. Args: - request(twisted.web.http.Request) - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). - name (str|None): Optional name that, if specified, will be used as + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None @@ -245,17 +253,16 @@ class MediaRepository(object): else: respond_404(request) - async def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: """Gets the media info associated with the remote file, downloading if necessary. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: - Deferred[dict]: The media_info of the file + The media info of the file """ if ( self.federation_domain_whitelist is not None @@ -278,7 +285,9 @@ class MediaRepository(object): return media_info - async def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -288,7 +297,7 @@ class MediaRepository(object): remote server). Returns: - Deferred[(Responder, media_info)] + A tuple of responder and the media info of the file. """ media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -319,19 +328,21 @@ class MediaRepository(object): responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file( + self, server_name: str, media_id: str, file_id: str + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: - server_name (str): Originating server - media_id (str): The media ID of the content (as defined by the + server_name: Originating server + media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id (str): Local file ID + file_id: Local file ID Returns: - Deferred[MediaInfo] + The media info of the file. """ file_info = FileInfo(server_name=server_name, file_id=file_id) @@ -549,25 +560,31 @@ class MediaRepository(object): return output_path async def _generate_thumbnails( - self, server_name, media_id, file_id, media_type, url_cache=False - ): + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: """Generate and store thumbnails for an image. Args: - server_name (str|None): The server name if remote media, else None if local - media_id (str): The media ID of the content. (This is the same as + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as the file_id for local content) - file_id (str): Local file ID - media_type (str): The content type of the file - url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: - Deferred[dict]: Dict with "width" and "height" keys of original image + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: - return + return None input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) @@ -584,7 +601,7 @@ class MediaRepository(object): m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = await defer_to_thread( diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py index 66bc1c336..858b6d300 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py @@ -12,13 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib import inspect import logging import os import shutil -from typing import Optional +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from twisted.protocols.basic import FileSender @@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + + from .storage_provider import StorageProvider logger = logging.getLogger(__name__) @@ -34,20 +39,25 @@ class MediaStorage(object): """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProvider"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - async def store_file(self, source, file_info: FileInfo) -> str: + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers @@ -69,7 +79,7 @@ class MediaStorage(object): return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. @@ -85,7 +95,7 @@ class MediaStorage(object): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: @@ -143,9 +153,9 @@ class MediaStorage(object): return FileResponder(open(local_path, "rb")) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): res = await res if res: @@ -174,9 +184,9 @@ class MediaStorage(object): os.makedirs(dirname) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. + res = provider.fetch(path, file_info) # type: Any + # Fetch is supposed to return an Awaitable[Responder], but guard + # against improper implementations. if inspect.isawaitable(res): res = await res if res: @@ -190,17 +200,11 @@ class MediaStorage(object): raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 13d1a6d2e..e12f65a20 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource): og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url, user, ts): + async def _do_preview(self, url: str, user: str, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: - url (str): - user (str): - ts (int): + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. Returns: - Deferred[bytes]: json-encoded og data + json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 858680be2..a33f56e80 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -16,62 +16,62 @@ import logging import os import shutil - -from twisted.internet import defer +from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) -class StorageProvider(object): +class StorageProvider: """A storage provider is a service that can store uploaded media and retrieve them. """ - def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) - - Returns: - Deferred + path: Relative path of file in local cache + file_info: The metadata of the file. """ - pass - def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) + path: Relative path of file in local cache + file_info: The metadata of the file. Returns: - Deferred(Responder): Returns a Responder if the provider has the file, - otherwise returns None. + Returns a Responder if the provider has the file, otherwise returns None. """ - pass class StorageProviderWrapper(StorageProvider): """Wraps a storage provider and provides various config options Args: - backend (StorageProvider) - store_local (bool): Whether to store new local files or not. - store_synchronous (bool): Whether to wait for file to be successfully + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully uploaded, or todo the upload in the background. - store_remote (bool): Whether remote media should be uploaded + store_remote: Whether remote media should be uploaded """ - def __init__(self, backend, store_local, store_synchronous, store_remote): + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous @@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider): def __str__(self): return "StorageProviderWrapper[%s]" % (self.backend,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: - return defer.succeed(None) + return None if file_info.server_name and not self.store_remote: - return defer.succeed(None) + return None if self.store_synchronous: - return self.backend.store_file(path, file_info) + return await self.backend.store_file(path, file_info) else: # TODO: Handle errors. def store(): @@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider): logger.exception("Error storing file") run_in_background(store) - return defer.succeed(None) + return None - def fetch(self, path, file_info): - return self.backend.fetch(path, file_info) + async def fetch(self, path, file_info): + return await self.backend.fetch(path, file_info) class FileStorageProviderBackend(StorageProvider): @@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return defer_to_thread( + return await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path)