daemon: Override the download endpoint and decrypt the files if necessary.

This commit is contained in:
Damir Jelić 2020-02-20 13:16:54 +01:00
parent a8ba24339f
commit 4dac44cfd7
2 changed files with 75 additions and 1 deletions

View File

@ -16,6 +16,7 @@ import asyncio
import json
import os
import urllib.parse
import concurrent.futures
from json import JSONDecodeError
from typing import Any, Dict
@ -26,7 +27,15 @@ from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
from jsonschema import ValidationError
from multidict import CIMultiDict
from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError
from nio import (
Api,
EncryptionError,
LoginResponse,
OlmTrustError,
SendRetryError,
DownloadResponse,
)
from nio.crypto import decrypt_attachment
from pantalaimon.client import (
SEARCH_TERMS_SCHEMA,
@ -993,6 +1002,64 @@ class ProxyDaemon:
return web.json_response(result, headers=CORS_HEADERS, status=200)
async def download(self, request):
server_name = request.match_info["server_name"]
media_id = request.match_info["media_id"]
file_name = request.match_info.get("file_name")
try:
media_info = self.media_info[(server_name, media_id)]
except KeyError:
media_info = self.store.load_media(self.name, server_name, media_id)
if not media_info:
logger.info(f"No media info found for {server_name}/{media_id}")
return await self.forward_to_web(request)
self.media_info[(server_name, media_id)] = media_info
try:
key = media_info.key["k"]
hash = media_info.hashes["sha256"]
except KeyError:
logger.warn(
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
)
return await self.forward_to_web(request)
if not self.pan_clients:
return await self.forward_to_web(request)
client = next(iter(self.pan_clients.values()))
try:
response = await client.download(server_name, media_id, file_name)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
if not isinstance(response, DownloadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
logger.info(f"Decrypting media {server_name}/{media_id}")
loop = asyncio.get_running_loop()
with concurrent.futures.ProcessPoolExecutor() as pool:
decrypted_file = await loop.run_in_executor(
pool, decrypt_attachment, response.body, key, hash, media_info.iv
)
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=decrypted_file,
)
async def well_known(self, _):
"""Intercept well-known requests

View File

@ -74,6 +74,13 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
web.get("/.well-known/matrix/client", proxy.well_known),
web.post("/_matrix/client/r0/search", proxy.search),
web.options("/_matrix/client/r0/search", proxy.search_opts),
web.get(
"/_matrix/media/r0/download/{server_name}/{media_id}", proxy.download
),
web.get(
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
proxy.download,
),
]
)
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)