From 1168bcf7ffb9900330d456306fcef89065a5215f Mon Sep 17 00:00:00 2001 From: Andrea Spacca Date: Tue, 22 Dec 2020 17:12:21 +0100 Subject: [PATCH] CR fixes --- pantalaimon/daemon.py | 151 ++++++++++++++++++++++++------------------ pantalaimon/store.py | 17 ++++- tests/store_test.py | 18 ++++- 3 files changed, 120 insertions(+), 66 deletions(-) diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 1e731c3..6130df7 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -17,6 +17,7 @@ import json import os import urllib.parse import concurrent.futures +from io import BufferedReader, BytesIO from json import JSONDecodeError from typing import Any, Dict from urllib.parse import urlparse @@ -50,7 +51,7 @@ from pantalaimon.client import ( ) from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError from pantalaimon.log import logger -from pantalaimon.store import ClientInfo, PanStore +from pantalaimon.store import ClientInfo, PanStore, MediaInfo from pantalaimon.thread_messages import ( AcceptSasMessage, CancelSasMessage, @@ -826,6 +827,52 @@ class ProxyDaemon: body=await response.read(), ) + def _map_media_upload(self, content, request, client): + content_uri = content["url"] + + upload = self.store.load_upload(content_uri) + if upload is None: + return await self.forward_to_web(request, token=client.access_token) + + mxc = urlparse(content_uri) + mxc_server = mxc.netloc.strip("/") + mxc_path = mxc.path.strip("/") + file_name = request.match_info.get("file_name") + + logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store") + + media = MediaInfo(mxc_server, mxc_path, upload.key, upload.iv, upload.hashes) + self.media_info[(mxc_server, mxc_path)] = media + client = next(iter(self.pan_clients.values())) + self.store.save_media(client.server_name, media) + + try: + response, decrypted_file, error = await self._load_media(mxc_server, mxc_path, file_name, request) + + if response is None and decrypted_file is None: + return await self.forward_to_web(request, token=client.access_token) + except ClientConnectionError as e: + return web.Response(status=500, text=str(e)) + except KeyError: + return await self.forward_to_web(request, token=client.access_token) + + if not isinstance(response, DownloadResponse): + return await self.forward_to_web(request, token=client.access_token) + + decrypted_upload = await client.upload( + data_provider=BufferedReader(BytesIO(decrypted_file)), + content_type=response.content_type, + filename=file_name, + encrypt=False, + ) + + if not isinstance(decrypted_upload[0], UploadResponse): + raise ValueError + + content["url"] = decrypted_upload[0].content_uri + + return content + async def send_message(self, request): access_token = self.get_access_token(request) @@ -851,56 +898,27 @@ class ProxyDaemon: if request.match_info["event_type"] == "m.reaction": encrypt = False + msgtype = request.match_info["event_type"] + # The room isn't encrypted just forward the message. if not encrypt: - content_msgtype = "" - msgtype = request.match_info["event_type"] - if content_msgtype == "m.image" or content_msgtype == "m.video" \ - or content_msgtype == "m.audio" or content_msgtype == "m.file": - content_uri = request.match_info["content"]["url"] + content = "" + try: + content = await request.json() + except (JSONDecodeError, ContentTypeError): + return self._not_json - upload = self.store.load_upload(content_uri) - if upload is None: + content_msgtype = content["msgtype"] + if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"]: + try: + content = self._map_media_upload(content, request, client) + except ValueError: return await self.forward_to_web(request, token=client.access_token) - mxc = urlparse(content_uri) - server_name = mxc.netloc.strip("/") - media_id = mxc.path.strip("/") - file_name = request.match_info.get("file_name") - - response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request) - - if response is None and decrypted_file is None and error is None: - return await self.forward_to_web(request, token=client.access_token) - if error is ClientConnectionError: - return await self.forward_to_web(request, token=client.access_token) - if error is KeyError: - return await self.forward_to_web(request, token=client.access_token) - - if not isinstance(response, DownloadResponse): - return await self.forward_to_web(request, token=client.access_token) - - decrypted_upload = client.upload( - data_provider=decrypted_file, - content_type=response.content_type, - filename=file_name, - encrypt=False, - ) - - if not isinstance(response, UploadResponse): - return await self.forward_to_web(request, token=client.access_token) - - request.match_info["content"]["url"] = decrypted_upload.content_uri - return await self.forward_to_web(request, token=client.access_token) txnid = request.match_info.get("txnid", uuid4()) - try: - content = await request.json() - except (JSONDecodeError, ContentTypeError): - return self._not_json - async def _send(ignore_unverified=False): try: response = await client.room_send( @@ -1084,30 +1102,37 @@ class ProxyDaemon: content_type = request.headers.get("Content-Type", "application/octet-stream") client = next(iter(self.pan_clients.values())) + body = await request.read() try: response = await client.upload( - data_provider=await request.read, + data_provider=BufferedReader(BytesIO(body)), content_type=content_type, filename=file_name, encrypt=True, ) - if not isinstance(response, UploadResponse): + if not isinstance(response[0], UploadResponse): return web.Response( - status=response.transport_response.status, - content_type=response.transport_response.content_type, + status=response[0].transport_response.status, + content_type=response[0].transport_response.content_type, headers=CORS_HEADERS, - body=await response.transport_response.read(), + body=await response[0].transport_response.read(), ) - self.store.save_upload(response.content_uri) + self.store.save_upload(response[0].content_uri, response[1]) + + return web.Response( + status=response[0].transport_response.status, + content_type=response[0].transport_response.content_type, + headers=CORS_HEADERS, + body=await response[0].transport_response.read(), + ) except ClientConnectionError as e: return web.Response(status=500, text=str(e)) except SendRetryError as e: return web.Response(status=503, text=str(e)) - async def _load_media(self, server_name, media_id, file_name, request): try: media_info = self.media_info[(server_name, media_id)] @@ -1116,30 +1141,30 @@ class ProxyDaemon: if not media_info: logger.info(f"No media info found for {server_name}/{media_id}") - return None, None, None + return None, None self.media_info[(server_name, media_id)] = media_info try: key = media_info.key["k"] hash = media_info.hashes["sha256"] - except KeyError: + except KeyError as e: logger.warn( f"Media info for {server_name}/{media_id} doesn't contain a key or hash." ) - return None, None, KeyError + raise e if not self.pan_clients: - return None, None, None + return None, None client = next(iter(self.pan_clients.values())) try: response = await client.download(server_name, media_id, file_name) except ClientConnectionError as e: - return None, None, e + raise e if not isinstance(response, DownloadResponse): - return response, None, None + return response, None logger.info(f"Decrypting media {server_name}/{media_id}") @@ -1149,7 +1174,7 @@ class ProxyDaemon: pool, decrypt_attachment, response.body, key, hash, media_info.iv ) - return response, decrypted_file, None + return response, decrypted_file async def download(self, request): @@ -1157,13 +1182,14 @@ class ProxyDaemon: media_id = request.match_info["media_id"] file_name = request.match_info.get("file_name") - response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request) + try: + response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request) - if response is None and decrypted_file is None and error is None: - return await self.forward_to_web(request) - if error is ClientConnectionError: - return web.Response(status=500, text=str(error)) - if error is KeyError: + if response is None and decrypted_file is None: + return await self.forward_to_web(request) + except ClientConnectionError as e: + return web.Response(status=500, text=str(e)) + except KeyError: return await self.forward_to_web(request) if not isinstance(response, DownloadResponse): @@ -1181,7 +1207,6 @@ class ProxyDaemon: body=decrypted_file, ) - async def well_known(self, _): """Intercept well-known requests diff --git a/pantalaimon/store.py b/pantalaimon/store.py index 24e447d..4e93f71 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -51,6 +51,9 @@ class MediaInfo: @attr.s class UploadInfo: content_uri = attr.ib(type=str) + key = attr.ib(type=dict) + iv = attr.ib(type=str) + hashes = attr.ib(type=dict) class DictField(TextField): @@ -117,11 +120,17 @@ class PanMediaInfo(Model): class Meta: constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")] + class PanUploadInfo(Model): content_uri = TextField() + key = DictField() + hashes = DictField() + iv = TextField() + class Meta: constraints = [SQL("UNIQUE(content_uri)")] + @attr.s class ClientInfo: user_id = attr.ib(type=str) @@ -144,6 +153,7 @@ class PanStore: PanSyncTokens, PanFetcherTasks, PanMediaInfo, + PanUploadInfo, ] def __attrs_post_init__(self): @@ -172,9 +182,12 @@ class PanStore: return None @use_database - def save_upload(self, content_uri): + def save_upload(self, content_uri, media): PanUploadInfo.insert( content_uri=content_uri, + key=media["key"], + iv=media["iv"], + hashes=media["hashes"], ).on_conflict_ignore().execute() @use_database @@ -186,7 +199,7 @@ class PanStore: if not u: return None - return UploadInfo(u.content_uri) + return UploadInfo(u.content_uri, u.key, u.iv, u.hashes) @use_database def save_media(self, server, media): diff --git a/tests/store_test.py b/tests/store_test.py index 2f31587..5ea5f96 100644 --- a/tests/store_test.py +++ b/tests/store_test.py @@ -8,7 +8,7 @@ from nio import RoomMessage, RoomEncryptedMedia from urllib.parse import urlparse from conftest import faker from pantalaimon.index import INDEXING_ENABLED -from pantalaimon.store import FetchTask, MediaInfo +from pantalaimon.store import FetchTask, MediaInfo, UploadInfo TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost" TEST_ROOM2 = "!testroom:localhost" @@ -177,3 +177,19 @@ class TestClass(object): media_info = media_cache[(mxc_server, mxc_path)] assert media_info == media assert media_info == panstore.load_media(server_name, mxc_server, mxc_path) + + def test_upload_storage(self, panstore): + event = self.encrypted_media_event + + assert not panstore.load_upload(event.url) + + upload = UploadInfo(event.url, event.key, event.iv, event.hashes) + + panstore.save_upload(event.url) + + upload_cache = panstore.load_upload(event.url) + + assert (event.url) in upload_cache + upload_info = upload_cache[(event.url)] + assert upload_info == upload + assert upload_info == panstore.load_upload(event.url)