mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-17 23:07:09 -05:00
Convert the remaining media repo code to async / await. (#7947)
This commit is contained in:
parent
8553f46498
commit
68626ff8e9
1
changelog.d/7947.misc
Normal file
1
changelog.d/7947.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -17,7 +17,9 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
|
from typing import Awaitable
|
||||||
|
|
||||||
|
from twisted.internet.interfaces import IConsumer
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||||
@ -240,14 +242,14 @@ class Responder(object):
|
|||||||
held can be cleaned up.
|
held can be cleaned up.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def write_to_consumer(self, consumer):
|
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
|
||||||
"""Stream response into consumer
|
"""Stream response into consumer
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
consumer (IConsumer)
|
consumer: The consumer to stream into.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Resolves once the response has finished being written
|
Resolves once the response has finished being written
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -18,10 +18,11 @@ import errno
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Dict, Tuple
|
from typing import IO, Dict, Optional, Tuple
|
||||||
|
|
||||||
import twisted.internet.error
|
import twisted.internet.error
|
||||||
import twisted.web.http
|
import twisted.web.http
|
||||||
|
from twisted.web.http import Request
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
|
|||||||
|
|
||||||
from ._base import (
|
from ._base import (
|
||||||
FileInfo,
|
FileInfo,
|
||||||
|
Responder,
|
||||||
get_filename_from_headers,
|
get_filename_from_headers,
|
||||||
respond_404,
|
respond_404,
|
||||||
respond_with_responder,
|
respond_with_responder,
|
||||||
@ -135,19 +137,24 @@ class MediaRepository(object):
|
|||||||
self.recently_accessed_locals.add(media_id)
|
self.recently_accessed_locals.add(media_id)
|
||||||
|
|
||||||
async def create_content(
|
async def create_content(
|
||||||
self, media_type, upload_name, content, content_length, auth_user
|
self,
|
||||||
):
|
media_type: str,
|
||||||
|
upload_name: str,
|
||||||
|
content: IO,
|
||||||
|
content_length: int,
|
||||||
|
auth_user: str,
|
||||||
|
) -> str:
|
||||||
"""Store uploaded content for a local user and return the mxc URL
|
"""Store uploaded content for a local user and return the mxc URL
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
media_type(str): The content type of the file
|
media_type: The content type of the file
|
||||||
upload_name(str): The name of the file
|
upload_name: The name of the file
|
||||||
content: A file like object that is the content to store
|
content: A file like object that is the content to store
|
||||||
content_length(int): The length of the content
|
content_length: The length of the content
|
||||||
auth_user(str): The user_id of the uploader
|
auth_user: The user_id of the uploader
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[str]: The mxc url of the stored content
|
The mxc url of the stored content
|
||||||
"""
|
"""
|
||||||
media_id = random_string(24)
|
media_id = random_string(24)
|
||||||
|
|
||||||
@ -170,19 +177,20 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
return "mxc://%s/%s" % (self.server_name, media_id)
|
return "mxc://%s/%s" % (self.server_name, media_id)
|
||||||
|
|
||||||
async def get_local_media(self, request, media_id, name):
|
async def get_local_media(
|
||||||
|
self, request: Request, media_id: str, name: Optional[str]
|
||||||
|
) -> None:
|
||||||
"""Responds to reqests for local media, if exists, or returns 404.
|
"""Responds to reqests for local media, if exists, or returns 404.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request(twisted.web.http.Request)
|
request: The incoming request.
|
||||||
media_id (str): The media ID of the content. (This is the same as
|
media_id: The media ID of the content. (This is the same as
|
||||||
the file_id for local content.)
|
the file_id for local content.)
|
||||||
name (str|None): Optional name that, if specified, will be used as
|
name: Optional name that, if specified, will be used as
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Resolves once a response has successfully been written
|
Resolves once a response has successfully been written to request
|
||||||
to request
|
|
||||||
"""
|
"""
|
||||||
media_info = await self.store.get_local_media(media_id)
|
media_info = await self.store.get_local_media(media_id)
|
||||||
if not media_info or media_info["quarantined_by"]:
|
if not media_info or media_info["quarantined_by"]:
|
||||||
@ -203,20 +211,20 @@ class MediaRepository(object):
|
|||||||
request, responder, media_type, media_length, upload_name
|
request, responder, media_type, media_length, upload_name
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_remote_media(self, request, server_name, media_id, name):
|
async def get_remote_media(
|
||||||
|
self, request: Request, server_name: str, media_id: str, name: Optional[str]
|
||||||
|
) -> None:
|
||||||
"""Respond to requests for remote media.
|
"""Respond to requests for remote media.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request(twisted.web.http.Request)
|
request: The incoming request.
|
||||||
server_name (str): Remote server_name where the media originated.
|
server_name: Remote server_name where the media originated.
|
||||||
media_id (str): The media ID of the content (as defined by the
|
media_id: The media ID of the content (as defined by the remote server).
|
||||||
remote server).
|
name: Optional name that, if specified, will be used as
|
||||||
name (str|None): Optional name that, if specified, will be used as
|
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred: Resolves once a response has successfully been written
|
Resolves once a response has successfully been written to request
|
||||||
to request
|
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
self.federation_domain_whitelist is not None
|
self.federation_domain_whitelist is not None
|
||||||
@ -245,17 +253,16 @@ class MediaRepository(object):
|
|||||||
else:
|
else:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
async def get_remote_media_info(self, server_name, media_id):
|
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
|
||||||
"""Gets the media info associated with the remote file, downloading
|
"""Gets the media info associated with the remote file, downloading
|
||||||
if necessary.
|
if necessary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str): Remote server_name where the media originated.
|
server_name: Remote server_name where the media originated.
|
||||||
media_id (str): The media ID of the content (as defined by the
|
media_id: The media ID of the content (as defined by the remote server).
|
||||||
remote server).
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: The media_info of the file
|
The media info of the file
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
self.federation_domain_whitelist is not None
|
self.federation_domain_whitelist is not None
|
||||||
@ -278,7 +285,9 @@ class MediaRepository(object):
|
|||||||
|
|
||||||
return media_info
|
return media_info
|
||||||
|
|
||||||
async def _get_remote_media_impl(self, server_name, media_id):
|
async def _get_remote_media_impl(
|
||||||
|
self, server_name: str, media_id: str
|
||||||
|
) -> Tuple[Optional[Responder], dict]:
|
||||||
"""Looks for media in local cache, if not there then attempt to
|
"""Looks for media in local cache, if not there then attempt to
|
||||||
download from remote server.
|
download from remote server.
|
||||||
|
|
||||||
@ -288,7 +297,7 @@ class MediaRepository(object):
|
|||||||
remote server).
|
remote server).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[(Responder, media_info)]
|
A tuple of responder and the media info of the file.
|
||||||
"""
|
"""
|
||||||
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
media_info = await self.store.get_cached_remote_media(server_name, media_id)
|
||||||
|
|
||||||
@ -319,19 +328,21 @@ class MediaRepository(object):
|
|||||||
responder = await self.media_storage.fetch_media(file_info)
|
responder = await self.media_storage.fetch_media(file_info)
|
||||||
return responder, media_info
|
return responder, media_info
|
||||||
|
|
||||||
async def _download_remote_file(self, server_name, media_id, file_id):
|
async def _download_remote_file(
|
||||||
|
self, server_name: str, media_id: str, file_id: str
|
||||||
|
) -> dict:
|
||||||
"""Attempt to download the remote file from the given server name,
|
"""Attempt to download the remote file from the given server name,
|
||||||
using the given file_id as the local id.
|
using the given file_id as the local id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str): Originating server
|
server_name: Originating server
|
||||||
media_id (str): The media ID of the content (as defined by the
|
media_id: The media ID of the content (as defined by the
|
||||||
remote server). This is different than the file_id, which is
|
remote server). This is different than the file_id, which is
|
||||||
locally generated.
|
locally generated.
|
||||||
file_id (str): Local file ID
|
file_id: Local file ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[MediaInfo]
|
The media info of the file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
file_info = FileInfo(server_name=server_name, file_id=file_id)
|
||||||
@ -549,25 +560,31 @@ class MediaRepository(object):
|
|||||||
return output_path
|
return output_path
|
||||||
|
|
||||||
async def _generate_thumbnails(
|
async def _generate_thumbnails(
|
||||||
self, server_name, media_id, file_id, media_type, url_cache=False
|
self,
|
||||||
):
|
server_name: Optional[str],
|
||||||
|
media_id: str,
|
||||||
|
file_id: str,
|
||||||
|
media_type: str,
|
||||||
|
url_cache: bool = False,
|
||||||
|
) -> Optional[dict]:
|
||||||
"""Generate and store thumbnails for an image.
|
"""Generate and store thumbnails for an image.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name (str|None): The server name if remote media, else None if local
|
server_name: The server name if remote media, else None if local
|
||||||
media_id (str): The media ID of the content. (This is the same as
|
media_id: The media ID of the content. (This is the same as
|
||||||
the file_id for local content)
|
the file_id for local content)
|
||||||
file_id (str): Local file ID
|
file_id: Local file ID
|
||||||
media_type (str): The content type of the file
|
media_type: The content type of the file
|
||||||
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
|
url_cache: If we are thumbnailing images downloaded for the URL cache,
|
||||||
used exclusively by the url previewer
|
used exclusively by the url previewer
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
Dict with "width" and "height" keys of original image or None if the
|
||||||
|
media cannot be thumbnailed.
|
||||||
"""
|
"""
|
||||||
requirements = self._get_thumbnail_requirements(media_type)
|
requirements = self._get_thumbnail_requirements(media_type)
|
||||||
if not requirements:
|
if not requirements:
|
||||||
return
|
return None
|
||||||
|
|
||||||
input_path = await self.media_storage.ensure_media_is_in_local_cache(
|
input_path = await self.media_storage.ensure_media_is_in_local_cache(
|
||||||
FileInfo(server_name, file_id, url_cache=url_cache)
|
FileInfo(server_name, file_id, url_cache=url_cache)
|
||||||
@ -584,7 +601,7 @@ class MediaRepository(object):
|
|||||||
m_height,
|
m_height,
|
||||||
self.max_image_pixels,
|
self.max_image_pixels,
|
||||||
)
|
)
|
||||||
return
|
return None
|
||||||
|
|
||||||
if thumbnailer.transpose_method is not None:
|
if thumbnailer.transpose_method is not None:
|
||||||
m_width, m_height = await defer_to_thread(
|
m_width, m_height = await defer_to_thread(
|
||||||
|
@ -12,13 +12,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Optional
|
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
||||||
|
|
||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
|||||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||||
|
|
||||||
from ._base import FileInfo, Responder
|
from ._base import FileInfo, Responder
|
||||||
|
from .filepath import MediaFilePaths
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
from .storage_provider import StorageProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -34,20 +39,25 @@ class MediaStorage(object):
|
|||||||
"""Responsible for storing/fetching files from local sources.
|
"""Responsible for storing/fetching files from local sources.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (synapse.server.Homeserver)
|
hs
|
||||||
local_media_directory (str): Base path where we store media on disk
|
local_media_directory: Base path where we store media on disk
|
||||||
filepaths (MediaFilePaths)
|
filepaths
|
||||||
storage_providers ([StorageProvider]): List of StorageProvider that are
|
storage_providers: List of StorageProvider that are used to fetch and store files.
|
||||||
used to fetch and store files.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
local_media_directory: str,
|
||||||
|
filepaths: MediaFilePaths,
|
||||||
|
storage_providers: Sequence["StorageProvider"],
|
||||||
|
):
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.local_media_directory = local_media_directory
|
self.local_media_directory = local_media_directory
|
||||||
self.filepaths = filepaths
|
self.filepaths = filepaths
|
||||||
self.storage_providers = storage_providers
|
self.storage_providers = storage_providers
|
||||||
|
|
||||||
async def store_file(self, source, file_info: FileInfo) -> str:
|
async def store_file(self, source: IO, file_info: FileInfo) -> str:
|
||||||
"""Write `source` to the on disk media store, and also any other
|
"""Write `source` to the on disk media store, and also any other
|
||||||
configured storage providers
|
configured storage providers
|
||||||
|
|
||||||
@ -69,7 +79,7 @@ class MediaStorage(object):
|
|||||||
return fname
|
return fname
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def store_into_file(self, file_info):
|
def store_into_file(self, file_info: FileInfo):
|
||||||
"""Context manager used to get a file like object to write into, as
|
"""Context manager used to get a file like object to write into, as
|
||||||
described by file_info.
|
described by file_info.
|
||||||
|
|
||||||
@ -85,7 +95,7 @@ class MediaStorage(object):
|
|||||||
error.
|
error.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_info (FileInfo): Info about the file to store
|
file_info: Info about the file to store
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -143,9 +153,9 @@ class MediaStorage(object):
|
|||||||
return FileResponder(open(local_path, "rb"))
|
return FileResponder(open(local_path, "rb"))
|
||||||
|
|
||||||
for provider in self.storage_providers:
|
for provider in self.storage_providers:
|
||||||
res = provider.fetch(path, file_info)
|
res = provider.fetch(path, file_info) # type: Any
|
||||||
# Fetch is supposed to return an Awaitable, but guard against
|
# Fetch is supposed to return an Awaitable[Responder], but guard
|
||||||
# improper implementations.
|
# against improper implementations.
|
||||||
if inspect.isawaitable(res):
|
if inspect.isawaitable(res):
|
||||||
res = await res
|
res = await res
|
||||||
if res:
|
if res:
|
||||||
@ -174,9 +184,9 @@ class MediaStorage(object):
|
|||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
for provider in self.storage_providers:
|
for provider in self.storage_providers:
|
||||||
res = provider.fetch(path, file_info)
|
res = provider.fetch(path, file_info) # type: Any
|
||||||
# Fetch is supposed to return an Awaitable, but guard against
|
# Fetch is supposed to return an Awaitable[Responder], but guard
|
||||||
# improper implementations.
|
# against improper implementations.
|
||||||
if inspect.isawaitable(res):
|
if inspect.isawaitable(res):
|
||||||
res = await res
|
res = await res
|
||||||
if res:
|
if res:
|
||||||
@ -190,17 +200,11 @@ class MediaStorage(object):
|
|||||||
|
|
||||||
raise Exception("file could not be found")
|
raise Exception("file could not be found")
|
||||||
|
|
||||||
def _file_info_to_path(self, file_info):
|
def _file_info_to_path(self, file_info: FileInfo) -> str:
|
||||||
"""Converts file_info into a relative path.
|
"""Converts file_info into a relative path.
|
||||||
|
|
||||||
The path is suitable for storing files under a directory, e.g. used to
|
The path is suitable for storing files under a directory, e.g. used to
|
||||||
store files on local FS under the base media repository directory.
|
store files on local FS under the base media repository directory.
|
||||||
|
|
||||||
Args:
|
|
||||||
file_info (FileInfo)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str
|
|
||||||
"""
|
"""
|
||||||
if file_info.url_cache:
|
if file_info.url_cache:
|
||||||
if file_info.thumbnail:
|
if file_info.thumbnail:
|
||||||
|
@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
|
|||||||
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
|
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
|
||||||
respond_with_json_bytes(request, 200, og, send_cors=True)
|
respond_with_json_bytes(request, 200, og, send_cors=True)
|
||||||
|
|
||||||
async def _do_preview(self, url, user, ts):
|
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
|
||||||
"""Check the db, and download the URL and build a preview
|
"""Check the db, and download the URL and build a preview
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url (str):
|
url: The URL to preview.
|
||||||
user (str):
|
user: The user requesting the preview.
|
||||||
ts (int):
|
ts: The timestamp requested for the preview.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[bytes]: json-encoded og data
|
json-encoded og data
|
||||||
"""
|
"""
|
||||||
# check the URL cache in the DB (which will also provide us with
|
# check the URL cache in the DB (which will also provide us with
|
||||||
# historical previews, if we have any)
|
# historical previews, if we have any)
|
||||||
|
@ -16,62 +16,62 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Optional
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.config._base import Config
|
from synapse.config._base import Config
|
||||||
from synapse.logging.context import defer_to_thread, run_in_background
|
from synapse.logging.context import defer_to_thread, run_in_background
|
||||||
|
|
||||||
|
from ._base import FileInfo, Responder
|
||||||
from .media_storage import FileResponder
|
from .media_storage import FileResponder
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class StorageProvider(object):
|
class StorageProvider:
|
||||||
"""A storage provider is a service that can store uploaded media and
|
"""A storage provider is a service that can store uploaded media and
|
||||||
retrieve them.
|
retrieve them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def store_file(self, path, file_info):
|
async def store_file(self, path: str, file_info: FileInfo):
|
||||||
"""Store the file described by file_info. The actual contents can be
|
"""Store the file described by file_info. The actual contents can be
|
||||||
retrieved by reading the file in file_info.upload_path.
|
retrieved by reading the file in file_info.upload_path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): Relative path of file in local cache
|
path: Relative path of file in local cache
|
||||||
file_info (FileInfo)
|
file_info: The metadata of the file.
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def fetch(self, path, file_info):
|
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||||
"""Attempt to fetch the file described by file_info and stream it
|
"""Attempt to fetch the file described by file_info and stream it
|
||||||
into writer.
|
into writer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): Relative path of file in local cache
|
path: Relative path of file in local cache
|
||||||
file_info (FileInfo)
|
file_info: The metadata of the file.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(Responder): Returns a Responder if the provider has the file,
|
Returns a Responder if the provider has the file, otherwise returns None.
|
||||||
otherwise returns None.
|
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class StorageProviderWrapper(StorageProvider):
|
class StorageProviderWrapper(StorageProvider):
|
||||||
"""Wraps a storage provider and provides various config options
|
"""Wraps a storage provider and provides various config options
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backend (StorageProvider)
|
backend: The storage provider to wrap.
|
||||||
store_local (bool): Whether to store new local files or not.
|
store_local: Whether to store new local files or not.
|
||||||
store_synchronous (bool): Whether to wait for file to be successfully
|
store_synchronous: Whether to wait for file to be successfully
|
||||||
uploaded, or todo the upload in the background.
|
uploaded, or todo the upload in the background.
|
||||||
store_remote (bool): Whether remote media should be uploaded
|
store_remote: Whether remote media should be uploaded
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
def __init__(
|
||||||
|
self,
|
||||||
|
backend: StorageProvider,
|
||||||
|
store_local: bool,
|
||||||
|
store_synchronous: bool,
|
||||||
|
store_remote: bool,
|
||||||
|
):
|
||||||
self.backend = backend
|
self.backend = backend
|
||||||
self.store_local = store_local
|
self.store_local = store_local
|
||||||
self.store_synchronous = store_synchronous
|
self.store_synchronous = store_synchronous
|
||||||
@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "StorageProviderWrapper[%s]" % (self.backend,)
|
return "StorageProviderWrapper[%s]" % (self.backend,)
|
||||||
|
|
||||||
def store_file(self, path, file_info):
|
async def store_file(self, path, file_info):
|
||||||
if not file_info.server_name and not self.store_local:
|
if not file_info.server_name and not self.store_local:
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
if file_info.server_name and not self.store_remote:
|
if file_info.server_name and not self.store_remote:
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
if self.store_synchronous:
|
if self.store_synchronous:
|
||||||
return self.backend.store_file(path, file_info)
|
return await self.backend.store_file(path, file_info)
|
||||||
else:
|
else:
|
||||||
# TODO: Handle errors.
|
# TODO: Handle errors.
|
||||||
def store():
|
def store():
|
||||||
@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
|
|||||||
logger.exception("Error storing file")
|
logger.exception("Error storing file")
|
||||||
|
|
||||||
run_in_background(store)
|
run_in_background(store)
|
||||||
return defer.succeed(None)
|
return None
|
||||||
|
|
||||||
def fetch(self, path, file_info):
|
async def fetch(self, path, file_info):
|
||||||
return self.backend.fetch(path, file_info)
|
return await self.backend.fetch(path, file_info)
|
||||||
|
|
||||||
|
|
||||||
class FileStorageProviderBackend(StorageProvider):
|
class FileStorageProviderBackend(StorageProvider):
|
||||||
@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
|
|||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
||||||
|
|
||||||
def store_file(self, path, file_info):
|
async def store_file(self, path, file_info):
|
||||||
"""See StorageProvider.store_file"""
|
"""See StorageProvider.store_file"""
|
||||||
|
|
||||||
primary_fname = os.path.join(self.cache_directory, path)
|
primary_fname = os.path.join(self.cache_directory, path)
|
||||||
@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
|
|||||||
if not os.path.exists(dirname):
|
if not os.path.exists(dirname):
|
||||||
os.makedirs(dirname)
|
os.makedirs(dirname)
|
||||||
|
|
||||||
return defer_to_thread(
|
return await defer_to_thread(
|
||||||
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetch(self, path, file_info):
|
async def fetch(self, path, file_info):
|
||||||
"""See StorageProvider.fetch"""
|
"""See StorageProvider.fetch"""
|
||||||
|
|
||||||
backup_fname = os.path.join(self.base_directory, path)
|
backup_fname = os.path.join(self.base_directory, path)
|
||||||
|
Loading…
Reference in New Issue
Block a user