Convert the remaining media repo code to async / await. (#7947)

This commit is contained in:
Patrick Cloke 2020-07-27 14:40:11 -04:00 committed by GitHub
parent 8553f46498
commit 68626ff8e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 107 deletions

1
changelog.d/7947.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -17,7 +17,9 @@
import logging import logging
import os import os
import urllib import urllib
from typing import Awaitable
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up. held can be cleaned up.
""" """
def write_to_consumer(self, consumer): def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer """Stream response into consumer
Args: Args:
consumer (IConsumer) consumer: The consumer to stream into.
Returns: Returns:
Deferred: Resolves once the response has finished being written Resolves once the response has finished being written
""" """
pass pass

View File

@ -18,10 +18,11 @@ import errno
import logging import logging
import os import os
import shutil import shutil
from typing import Dict, Tuple from typing import IO, Dict, Optional, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import ( from synapse.api.errors import (
@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import ( from ._base import (
FileInfo, FileInfo,
Responder,
get_filename_from_headers, get_filename_from_headers,
respond_404, respond_404,
respond_with_responder, respond_with_responder,
@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id) self.recently_accessed_locals.add(media_id)
async def create_content( async def create_content(
self, media_type, upload_name, content, content_length, auth_user self,
): media_type: str,
upload_name: str,
content: IO,
content_length: int,
auth_user: str,
) -> str:
"""Store uploaded content for a local user and return the mxc URL """Store uploaded content for a local user and return the mxc URL
Args: Args:
media_type(str): The content type of the file media_type: The content type of the file
upload_name(str): The name of the file upload_name: The name of the file
content: A file like object that is the content to store content: A file like object that is the content to store
content_length(int): The length of the content content_length: The length of the content
auth_user(str): The user_id of the uploader auth_user: The user_id of the uploader
Returns: Returns:
Deferred[str]: The mxc url of the stored content The mxc url of the stored content
""" """
media_id = random_string(24) media_id = random_string(24)
@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id) return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(self, request, media_id, name): async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
"""Responds to reqests for local media, if exists, or returns 404. """Responds to reqests for local media, if exists, or returns 404.
Args: Args:
request(twisted.web.http.Request) request: The incoming request.
media_id (str): The media ID of the content. (This is the same as media_id: The media ID of the content. (This is the same as
the file_id for local content.) the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
Returns: Returns:
Deferred: Resolves once a response has successfully been written Resolves once a response has successfully been written to request
to request
""" """
media_info = await self.store.get_local_media(media_id) media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]: if not media_info or media_info["quarantined_by"]:
@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name request, responder, media_type, media_length, upload_name
) )
async def get_remote_media(self, request, server_name, media_id, name): async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.
Args: Args:
request(twisted.web.http.Request) request: The incoming request.
server_name (str): Remote server_name where the media originated. server_name: Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the remote server).
remote server). name: Optional name that, if specified, will be used as
name (str|None): Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
Returns: Returns:
Deferred: Resolves once a response has successfully been written Resolves once a response has successfully been written to request
to request
""" """
if ( if (
self.federation_domain_whitelist is not None self.federation_domain_whitelist is not None
@ -245,17 +253,16 @@ class MediaRepository(object):
else: else:
respond_404(request) respond_404(request)
async def get_remote_media_info(self, server_name, media_id): async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
Args: Args:
server_name (str): Remote server_name where the media originated. server_name: Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the remote server).
remote server).
Returns: Returns:
Deferred[dict]: The media_info of the file The media info of the file
""" """
if ( if (
self.federation_domain_whitelist is not None self.federation_domain_whitelist is not None
@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info return media_info
async def _get_remote_media_impl(self, server_name, media_id): async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -288,7 +297,7 @@ class MediaRepository(object):
remote server). remote server).
Returns: Returns:
Deferred[(Responder, media_info)] A tuple of responder and the media info of the file.
""" """
media_info = await self.store.get_cached_remote_media(server_name, media_id) media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
return responder, media_info return responder, media_info
async def _download_remote_file(self, server_name, media_id, file_id): async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
Args: Args:
server_name (str): Originating server server_name: Originating server
media_id (str): The media ID of the content (as defined by the media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is remote server). This is different than the file_id, which is
locally generated. locally generated.
file_id (str): Local file ID file_id: Local file ID
Returns: Returns:
Deferred[MediaInfo] The media info of the file.
""" """
file_info = FileInfo(server_name=server_name, file_id=file_id) file_info = FileInfo(server_name=server_name, file_id=file_id)
@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path return output_path
async def _generate_thumbnails( async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False self,
): server_name: Optional[str],
media_id: str,
file_id: str,
media_type: str,
url_cache: bool = False,
) -> Optional[dict]:
"""Generate and store thumbnails for an image. """Generate and store thumbnails for an image.
Args: Args:
server_name (str|None): The server name if remote media, else None if local server_name: The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as media_id: The media ID of the content. (This is the same as
the file_id for local content) the file_id for local content)
file_id (str): Local file ID file_id: Local file ID
media_type (str): The content type of the file media_type: The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache, url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer used exclusively by the url previewer
Returns: Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image Dict with "width" and "height" keys of original image or None if the
media cannot be thumbnailed.
""" """
requirements = self._get_thumbnail_requirements(media_type) requirements = self._get_thumbnail_requirements(media_type)
if not requirements: if not requirements:
return return None
input_path = await self.media_storage.ensure_media_is_in_local_cache( input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache) FileInfo(server_name, file_id, url_cache=url_cache)
@ -584,7 +601,7 @@ class MediaRepository(object):
m_height, m_height,
self.max_image_pixels, self.max_image_pixels,
) )
return return None
if thumbnailer.transpose_method is not None: if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread( m_width, m_height = await defer_to_thread(

View File

@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import inspect import inspect
import logging import logging
import os import os
import shutil import shutil
from typing import Optional from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
from synapse.server import HomeServer
from .storage_provider import StorageProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,20 +39,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources. """Responsible for storing/fetching files from local sources.
Args: Args:
hs (synapse.server.Homeserver) hs
local_media_directory (str): Base path where we store media on disk local_media_directory: Base path where we store media on disk
filepaths (MediaFilePaths) filepaths
storage_providers ([StorageProvider]): List of StorageProvider that are storage_providers: List of StorageProvider that are used to fetch and store files.
used to fetch and store files.
""" """
def __init__(self, hs, local_media_directory, filepaths, storage_providers): def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs self.hs = hs
self.local_media_directory = local_media_directory self.local_media_directory = local_media_directory
self.filepaths = filepaths self.filepaths = filepaths
self.storage_providers = storage_providers self.storage_providers = storage_providers
async def store_file(self, source, file_info: FileInfo) -> str: async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other """Write `source` to the on disk media store, and also any other
configured storage providers configured storage providers
@ -69,7 +79,7 @@ class MediaStorage(object):
return fname return fname
@contextlib.contextmanager @contextlib.contextmanager
def store_into_file(self, file_info): def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as """Context manager used to get a file like object to write into, as
described by file_info. described by file_info.
@ -85,7 +95,7 @@ class MediaStorage(object):
error. error.
Args: Args:
file_info (FileInfo): Info about the file to store file_info: Info about the file to store
Example: Example:
@ -143,9 +153,9 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb")) return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers: for provider in self.storage_providers:
res = provider.fetch(path, file_info) res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable, but guard against # Fetch is supposed to return an Awaitable[Responder], but guard
# improper implementations. # against improper implementations.
if inspect.isawaitable(res): if inspect.isawaitable(res):
res = await res res = await res
if res: if res:
@ -174,9 +184,9 @@ class MediaStorage(object):
os.makedirs(dirname) os.makedirs(dirname)
for provider in self.storage_providers: for provider in self.storage_providers:
res = provider.fetch(path, file_info) res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable, but guard against # Fetch is supposed to return an Awaitable[Responder], but guard
# improper implementations. # against improper implementations.
if inspect.isawaitable(res): if inspect.isawaitable(res):
res = await res res = await res
if res: if res:
@ -190,17 +200,11 @@ class MediaStorage(object):
raise Exception("file could not be found") raise Exception("file could not be found")
def _file_info_to_path(self, file_info): def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path. """Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory. store files on local FS under the base media repository directory.
Args:
file_info (FileInfo)
Returns:
str
""" """
if file_info.url_cache: if file_info.url_cache:
if file_info.thumbnail: if file_info.thumbnail:

View File

@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url, user, ts): async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview """Check the db, and download the URL and build a preview
Args: Args:
url (str): url: The URL to preview.
user (str): user: The user requesting the preview.
ts (int): ts: The timestamp requested for the preview.
Returns: Returns:
Deferred[bytes]: json-encoded og data json-encoded og data
""" """
# check the URL cache in the DB (which will also provide us with # check the URL cache in the DB (which will also provide us with
# historical previews, if we have any) # historical previews, if we have any)

View File

@ -16,62 +16,62 @@
import logging import logging
import os import os
import shutil import shutil
from typing import Optional
from twisted.internet import defer
from synapse.config._base import Config from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background from synapse.logging.context import defer_to_thread, run_in_background
from ._base import FileInfo, Responder
from .media_storage import FileResponder from .media_storage import FileResponder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StorageProvider(object): class StorageProvider:
"""A storage provider is a service that can store uploaded media and """A storage provider is a service that can store uploaded media and
retrieve them. retrieve them.
""" """
def store_file(self, path, file_info): async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be """Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path. retrieved by reading the file in file_info.upload_path.
Args: Args:
path (str): Relative path of file in local cache path: Relative path of file in local cache
file_info (FileInfo) file_info: The metadata of the file.
Returns:
Deferred
""" """
pass
def fetch(self, path, file_info): async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it """Attempt to fetch the file described by file_info and stream it
into writer. into writer.
Args: Args:
path (str): Relative path of file in local cache path: Relative path of file in local cache
file_info (FileInfo) file_info: The metadata of the file.
Returns: Returns:
Deferred(Responder): Returns a Responder if the provider has the file, Returns a Responder if the provider has the file, otherwise returns None.
otherwise returns None.
""" """
pass
class StorageProviderWrapper(StorageProvider): class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options """Wraps a storage provider and provides various config options
Args: Args:
backend (StorageProvider) backend: The storage provider to wrap.
store_local (bool): Whether to store new local files or not. store_local: Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background. uploaded, or todo the upload in the background.
store_remote (bool): Whether remote media should be uploaded store_remote: Whether remote media should be uploaded
""" """
def __init__(self, backend, store_local, store_synchronous, store_remote): def __init__(
self,
backend: StorageProvider,
store_local: bool,
store_synchronous: bool,
store_remote: bool,
):
self.backend = backend self.backend = backend
self.store_local = store_local self.store_local = store_local
self.store_synchronous = store_synchronous self.store_synchronous = store_synchronous
@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self): def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,) return "StorageProviderWrapper[%s]" % (self.backend,)
def store_file(self, path, file_info): async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local: if not file_info.server_name and not self.store_local:
return defer.succeed(None) return None
if file_info.server_name and not self.store_remote: if file_info.server_name and not self.store_remote:
return defer.succeed(None) return None
if self.store_synchronous: if self.store_synchronous:
return self.backend.store_file(path, file_info) return await self.backend.store_file(path, file_info)
else: else:
# TODO: Handle errors. # TODO: Handle errors.
def store(): def store():
@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file") logger.exception("Error storing file")
run_in_background(store) run_in_background(store)
return defer.succeed(None) return None
def fetch(self, path, file_info): async def fetch(self, path, file_info):
return self.backend.fetch(path, file_info) return await self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider): class FileStorageProviderBackend(StorageProvider):
@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self): def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,) return "FileStorageProviderBackend[%s]" % (self.base_directory,)
def store_file(self, path, file_info): async def store_file(self, path, file_info):
"""See StorageProvider.store_file""" """See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path) primary_fname = os.path.join(self.cache_directory, path)
@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
return defer_to_thread( return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
) )
def fetch(self, path, file_info): async def fetch(self, path, file_info):
"""See StorageProvider.fetch""" """See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path) backup_fname = os.path.join(self.base_directory, path)