Convert the remaining media repo code to async / await. (#7947)

This commit is contained in:
Patrick Cloke 2020-07-27 14:40:11 -04:00 committed by GitHub
parent 8553f46498
commit 68626ff8e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 107 deletions

View file

@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
from typing import Dict, Tuple
from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
self,
media_type: str,
upload_name: str,
content: IO,
content_length: int,
auth_user: str,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
media_type(str): The content type of the file
upload_name(str): The name of the file
media_type: The content type of the file
upload_name: The name of the file
content: A file like object that is the content to store
content_length(int): The length of the content
auth_user(str): The user_id of the uploader
content_length: The length of the content
auth_user: The user_id of the uploader
Returns:
Deferred[str]: The mxc url of the stored content
The mxc url of the stored content
"""
media_id = random_string(24)
@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(self, request, media_id, name):
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
request: The incoming request.
media_id: The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
async def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
) -> None:
"""Respond to requests for remote media.
Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
request: The incoming request.
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@ -245,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
async def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
Returns:
Deferred[dict]: The media_info of the file
The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info
async def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@ -288,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
Deferred[(Responder, media_info)]
A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
async def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
file_id: Local file ID
Returns:
Deferred[MediaInfo]
The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
self,
server_name: Optional[str],
media_id: str,
file_id: str,
media_type: str,
url_cache: bool = False,
) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
server_name (str|None): The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as
server_name: The server name if remote media, else None if local
media_id: The media ID of the content. (This is the same as
the file_id for local content)
file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
file_id: Local file ID
media_type: The content type of the file
url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
Dict with "width" and "height" keys of original image or None if the
media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@ -584,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
return
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(