Convert the remaining media repo code to async / await. (#7947)

This commit is contained in:
Patrick Cloke 2020-07-27 14:40:11 -04:00 committed by GitHub
parent 8553f46498
commit 68626ff8e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 107 deletions

1
changelog.d/7947.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -17,7 +17,9 @@
import logging
import os
import urllib
from typing import Awaitable
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up.
"""
def write_to_consumer(self, consumer):
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
consumer (IConsumer)
consumer: The consumer to stream into.
Returns:
Deferred: Resolves once the response has finished being written
Resolves once the response has finished being written
"""
pass

View File

@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
from typing import Dict, Tuple
from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
self,
media_type: str,
upload_name: str,
content: IO,
content_length: int,
auth_user: str,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
media_type(str): The content type of the file
upload_name(str): The name of the file
media_type: The content type of the file
upload_name: The name of the file
content: A file like object that is the content to store
content_length(int): The length of the content
auth_user(str): The user_id of the uploader
content_length: The length of the content
auth_user: The user_id of the uploader
Returns:
Deferred[str]: The mxc url of the stored content
The mxc url of the stored content
"""
media_id = random_string(24)
@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(self, request, media_id, name):
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
request: The incoming request.
media_id: The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
async def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
) -> None:
"""Respond to requests for remote media.
Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
request: The incoming request.
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@ -245,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
async def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
Returns:
Deferred[dict]: The media_info of the file
The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info
async def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@ -288,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
Deferred[(Responder, media_info)]
A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
async def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
file_id: Local file ID
Returns:
Deferred[MediaInfo]
The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
self,
server_name: Optional[str],
media_id: str,
file_id: str,
media_type: str,
url_cache: bool = False,
) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
server_name (str|None): The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as
server_name: The server name if remote media, else None if local
media_id: The media ID of the content. (This is the same as
the file_id for local content)
file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
file_id: Local file ID
media_type: The content type of the file
url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
Dict with "width" and "height" keys of original image or None if the
media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@ -584,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
return
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(

View File

@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import logging
import os
import shutil
from typing import Optional
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender
@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
from synapse.server import HomeServer
from .storage_provider import StorageProvider
logger = logging.getLogger(__name__)
@ -34,20 +39,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
hs (synapse.server.Homeserver)
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
hs
local_media_directory: Base path where we store media on disk
filepaths
storage_providers: List of StorageProvider that are used to fetch and store files.
"""
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
async def store_file(self, source, file_info: FileInfo) -> str:
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
@ -69,7 +79,7 @@ class MediaStorage(object):
return fname
@contextlib.contextmanager
def store_into_file(self, file_info):
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
@ -85,7 +95,7 @@ class MediaStorage(object):
error.
Args:
file_info (FileInfo): Info about the file to store
file_info: Info about the file to store
Example:
@ -143,9 +153,9 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable[Responder], but guard
# against improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
@ -174,9 +184,9 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable[Responder], but guard
# against improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
@ -190,17 +200,11 @@ class MediaStorage(object):
raise Exception("file could not be found")
def _file_info_to_path(self, file_info):
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
Args:
file_info (FileInfo)
Returns:
str
"""
if file_info.url_cache:
if file_info.thumbnail:

View File

@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url, user, ts):
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
url (str):
user (str):
ts (int):
url: The URL to preview.
user: The user requesting the preview.
ts: The timestamp requested for the preview.
Returns:
Deferred[bytes]: json-encoded og data
json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)

View File

@ -16,62 +16,62 @@
import logging
import os
import shutil
from twisted.internet import defer
from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
class StorageProvider(object):
class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
Returns:
Deferred
path: Relative path of file in local cache
file_info: The metadata of the file.
"""
pass
def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
path: Relative path of file in local cache
file_info: The metadata of the file.
Returns:
Deferred(Responder): Returns a Responder if the provider has the file,
otherwise returns None.
Returns a Responder if the provider has the file, otherwise returns None.
"""
pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
backend (StorageProvider)
store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
backend: The storage provider to wrap.
store_local: Whether to store new local files or not.
store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
store_remote (bool): Whether remote media should be uploaded
store_remote: Whether remote media should be uploaded
"""
def __init__(self, backend, store_local, store_synchronous, store_remote):
def __init__(
self,
backend: StorageProvider,
store_local: bool,
store_synchronous: bool,
store_remote: bool,
):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
def store_file(self, path, file_info):
async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
return defer.succeed(None)
return None
if file_info.server_name and not self.store_remote:
return defer.succeed(None)
return None
if self.store_synchronous:
return self.backend.store_file(path, file_info)
return await self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
def store():
@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
return defer.succeed(None)
return None
def fetch(self, path, file_info):
return self.backend.fetch(path, file_info)
async def fetch(self, path, file_info):
return await self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider):
@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
def store_file(self, path, file_info):
async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
return defer_to_thread(
return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
def fetch(self, path, file_info):
async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)