daemon: Override the download endpoint and decrypt the files if necessary.

This commit is contained in:
Damir Jelić 2020-02-20 13:16:54 +01:00
parent a8ba24339f
commit 4dac44cfd7
2 changed files with 75 additions and 1 deletions

View File

@ -16,6 +16,7 @@ import asyncio
import json import json
import os import os
import urllib.parse import urllib.parse
import concurrent.futures
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Dict from typing import Any, Dict
@ -26,7 +27,15 @@ from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
from jsonschema import ValidationError from jsonschema import ValidationError
from multidict import CIMultiDict from multidict import CIMultiDict
from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError from nio import (
Api,
EncryptionError,
LoginResponse,
OlmTrustError,
SendRetryError,
DownloadResponse,
)
from nio.crypto import decrypt_attachment
from pantalaimon.client import ( from pantalaimon.client import (
SEARCH_TERMS_SCHEMA, SEARCH_TERMS_SCHEMA,
@ -993,6 +1002,64 @@ class ProxyDaemon:
return web.json_response(result, headers=CORS_HEADERS, status=200) return web.json_response(result, headers=CORS_HEADERS, status=200)
async def download(self, request):
server_name = request.match_info["server_name"]
media_id = request.match_info["media_id"]
file_name = request.match_info.get("file_name")
try:
media_info = self.media_info[(server_name, media_id)]
except KeyError:
media_info = self.store.load_media(self.name, server_name, media_id)
if not media_info:
logger.info(f"No media info found for {server_name}/{media_id}")
return await self.forward_to_web(request)
self.media_info[(server_name, media_id)] = media_info
try:
key = media_info.key["k"]
hash = media_info.hashes["sha256"]
except KeyError:
logger.warn(
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
)
return await self.forward_to_web(request)
if not self.pan_clients:
return await self.forward_to_web(request)
client = next(iter(self.pan_clients.values()))
try:
response = await client.download(server_name, media_id, file_name)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
if not isinstance(response, DownloadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
logger.info(f"Decrypting media {server_name}/{media_id}")
loop = asyncio.get_running_loop()
with concurrent.futures.ProcessPoolExecutor() as pool:
decrypted_file = await loop.run_in_executor(
pool, decrypt_attachment, response.body, key, hash, media_info.iv
)
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=decrypted_file,
)
async def well_known(self, _): async def well_known(self, _):
"""Intercept well-known requests """Intercept well-known requests

View File

@ -74,6 +74,13 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
web.get("/.well-known/matrix/client", proxy.well_known), web.get("/.well-known/matrix/client", proxy.well_known),
web.post("/_matrix/client/r0/search", proxy.search), web.post("/_matrix/client/r0/search", proxy.search),
web.options("/_matrix/client/r0/search", proxy.search_opts), web.options("/_matrix/client/r0/search", proxy.search_opts),
web.get(
"/_matrix/media/r0/download/{server_name}/{media_id}", proxy.download
),
web.get(
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
proxy.download,
),
] ]
) )
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)