mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-03-10 16:30:06 -04:00
daemon: Override the download endpoint and decrypt the files if necessary.
This commit is contained in:
parent
a8ba24339f
commit
4dac44cfd7
@ -16,6 +16,7 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
import concurrent.futures
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Dict
|
||||
|
||||
@ -26,7 +27,15 @@ from aiohttp import ClientSession, web
|
||||
from aiohttp.client_exceptions import ClientConnectionError, ContentTypeError
|
||||
from jsonschema import ValidationError
|
||||
from multidict import CIMultiDict
|
||||
from nio import Api, EncryptionError, LoginResponse, OlmTrustError, SendRetryError
|
||||
from nio import (
|
||||
Api,
|
||||
EncryptionError,
|
||||
LoginResponse,
|
||||
OlmTrustError,
|
||||
SendRetryError,
|
||||
DownloadResponse,
|
||||
)
|
||||
from nio.crypto import decrypt_attachment
|
||||
|
||||
from pantalaimon.client import (
|
||||
SEARCH_TERMS_SCHEMA,
|
||||
@ -993,6 +1002,64 @@ class ProxyDaemon:
|
||||
|
||||
return web.json_response(result, headers=CORS_HEADERS, status=200)
|
||||
|
||||
async def download(self, request):
|
||||
server_name = request.match_info["server_name"]
|
||||
media_id = request.match_info["media_id"]
|
||||
file_name = request.match_info.get("file_name")
|
||||
|
||||
try:
|
||||
media_info = self.media_info[(server_name, media_id)]
|
||||
except KeyError:
|
||||
media_info = self.store.load_media(self.name, server_name, media_id)
|
||||
|
||||
if not media_info:
|
||||
logger.info(f"No media info found for {server_name}/{media_id}")
|
||||
return await self.forward_to_web(request)
|
||||
|
||||
self.media_info[(server_name, media_id)] = media_info
|
||||
|
||||
try:
|
||||
key = media_info.key["k"]
|
||||
hash = media_info.hashes["sha256"]
|
||||
except KeyError:
|
||||
logger.warn(
|
||||
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
|
||||
)
|
||||
return await self.forward_to_web(request)
|
||||
|
||||
if not self.pan_clients:
|
||||
return await self.forward_to_web(request)
|
||||
|
||||
client = next(iter(self.pan_clients.values()))
|
||||
|
||||
try:
|
||||
response = await client.download(server_name, media_id, file_name)
|
||||
except ClientConnectionError as e:
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
if not isinstance(response, DownloadResponse):
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
content_type=response.transport_response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=await response.transport_response.read(),
|
||||
)
|
||||
|
||||
logger.info(f"Decrypting media {server_name}/{media_id}")
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
with concurrent.futures.ProcessPoolExecutor() as pool:
|
||||
decrypted_file = await loop.run_in_executor(
|
||||
pool, decrypt_attachment, response.body, key, hash, media_info.iv
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
content_type=response.transport_response.content_type,
|
||||
headers=CORS_HEADERS,
|
||||
body=decrypted_file,
|
||||
)
|
||||
|
||||
async def well_known(self, _):
|
||||
"""Intercept well-known requests
|
||||
|
||||
|
@ -74,6 +74,13 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
|
||||
web.get("/.well-known/matrix/client", proxy.well_known),
|
||||
web.post("/_matrix/client/r0/search", proxy.search),
|
||||
web.options("/_matrix/client/r0/search", proxy.search_opts),
|
||||
web.get(
|
||||
"/_matrix/media/r0/download/{server_name}/{media_id}", proxy.download
|
||||
),
|
||||
web.get(
|
||||
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
|
||||
proxy.download,
|
||||
),
|
||||
]
|
||||
)
|
||||
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
||||
|
Loading…
x
Reference in New Issue
Block a user