diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index f8a8488..637e373 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -16,6 +16,7 @@ import asyncio import json import os import urllib.parse +import concurrent.futures from json import JSONDecodeError from typing import Any, Dict @@ -26,7 +27,15 @@ from aiohttp import ClientSession, web from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError from jsonschema import ValidationError from multidict import CIMultiDict -from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError +from nio import ( + Api, + EncryptionError, + LoginResponse, + OlmTrustError, + SendRetryError, + DownloadResponse, +) +from nio.crypto import decrypt_attachment from pantalaimon.client import ( SEARCH_TERMS_SCHEMA, @@ -993,6 +1002,64 @@ class ProxyDaemon: return web.json_response(result, headers=CORS_HEADERS, status=200) + async def download(self, request): + server_name = request.match_info["server_name"] + media_id = request.match_info["media_id"] + file_name = request.match_info.get("file_name") + + try: + media_info = self.media_info[(server_name, media_id)] + except KeyError: + media_info = self.store.load_media(self.name, server_name, media_id) + + if not media_info: + logger.info(f"No media info found for {server_name}/{media_id}") + return await self.forward_to_web(request) + + self.media_info[(server_name, media_id)] = media_info + + try: + key = media_info.key["k"] + hash = media_info.hashes["sha256"] + except KeyError: + logger.warn( + f"Media info for {server_name}/{media_id} doesn't contain a key or hash." + ) + return await self.forward_to_web(request) + + if not self.pan_clients: + return await self.forward_to_web(request) + + client = next(iter(self.pan_clients.values())) + + try: + response = await client.download(server_name, media_id, file_name) + except ClientConnectionError as e: + return web.Response(status=500, text=str(e)) + + if not isinstance(response, DownloadResponse): + return web.Response( + status=response.transport_response.status, + content_type=response.transport_response.content_type, + headers=CORS_HEADERS, + body=await response.transport_response.read(), + ) + + logger.info(f"Decrypting media {server_name}/{media_id}") + + loop = asyncio.get_running_loop() + with concurrent.futures.ProcessPoolExecutor() as pool: + decrypted_file = await loop.run_in_executor( + pool, decrypt_attachment, response.body, key, hash, media_info.iv + ) + + return web.Response( + status=response.transport_response.status, + content_type=response.transport_response.content_type, + headers=CORS_HEADERS, + body=decrypted_file, + ) + async def well_known(self, _): """Intercept well-known requests diff --git a/pantalaimon/main.py b/pantalaimon/main.py index 4b7e384..419ba8d 100644 --- a/pantalaimon/main.py +++ b/pantalaimon/main.py @@ -74,6 +74,13 @@ async def init(data_dir, server_conf, send_queue, recv_queue): web.get("/.well-known/matrix/client", proxy.well_known), web.post("/_matrix/client/r0/search", proxy.search), web.options("/_matrix/client/r0/search", proxy.search_opts), + web.get( + "/_matrix/media/r0/download/{server_name}/{media_id}", proxy.download + ), + web.get( + "/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}", + proxy.download, + ), ] ) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)