From f0ea2ebd3d862fea30a6c3f6057918254b998fd7 Mon Sep 17 00:00:00 2001 From: Andrea Spacca Date: Wed, 6 Jan 2021 18:08:03 +0100 Subject: [PATCH] CR fixes --- pantalaimon/daemon.py | 51 +++++++++++++++++++++++++++++++------------ pantalaimon/store.py | 27 +++++++++++++++++++---- tests/store_test.py | 4 ++-- 3 files changed, 62 insertions(+), 20 deletions(-) diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 5e2010a..3a35a77 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -829,7 +829,7 @@ class ProxyDaemon: body=await response.read(), ) - async def _map_media_upload(self, content_key, content, request, client): + def _get_upload_and_media_info(self, content_key, content, request): content_uri = content[content_key] try: @@ -837,13 +837,11 @@ class ProxyDaemon: except KeyError: upload_info = self.store.load_upload(self.name, content_uri) if not upload_info: - logger.info(f"No upload info found for {self.name}/{content_uri}") - - return await self.forward_to_web(request, token=client.access_token) - - self.upload_info[content_uri] = upload_info + return None, None + self.upload_info[content_uri] = upload_info + content_uri = content[content_key] mxc = urlparse(content_uri) mxc_server = mxc.netloc.strip("/") mxc_path = mxc.path.strip("/") @@ -851,12 +849,19 @@ class ProxyDaemon: logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store") - media = MediaInfo(mxc_server, mxc_path, upload_info.key, upload_info.iv, upload_info.hashes) - self.media_info[(mxc_server, mxc_path)] = media - self.store.save_media(self.name, media) + media_info = MediaInfo(mxc_server, mxc_path, upload_info.key, upload_info.iv, upload_info.hashes) + self.media_info[(mxc_server, mxc_path)] = media_info + self.store.save_media(self.name, media_info) + return upload_info, media_info, file_name + + async def _map_media_upload(self, content_key, content, request, client): try: - response, decrypted_file = await self._load_media(mxc_server, mxc_path, file_name) + upload_info, media_info, file_name = self._get_upload_and_media_info(content_key, content, request) + if not upload_info: + return await self.forward_to_web(request, token=client.access_token) + + response, decrypted_file = await self._load_media(media_info.mcx_server, media_info.mxc_path, file_name) if response is None and decrypted_file is None: return await self.forward_to_web(request, token=client.access_token) @@ -868,11 +873,12 @@ class ProxyDaemon: if not isinstance(response, DownloadResponse): return await self.forward_to_web(request, token=client.access_token) - decrypted_upload, maybe_keys = await client.upload( + decrypted_upload, _ = await client.upload( data_provider=BufferedReader(BytesIO(decrypted_file)), content_type=response.content_type, filename=file_name, encrypt=False, + filesize=len(decrypted_file), ) if not isinstance(decrypted_upload, UploadResponse): @@ -909,7 +915,6 @@ class ProxyDaemon: msgtype = request.match_info["event_type"] - content = "" try: content = await request.json() except (JSONDecodeError, ContentTypeError): @@ -931,6 +936,23 @@ class ProxyDaemon: async def _send(ignore_unverified=False): try: + content_msgtype = content["msgtype"] + if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"] or msgtype == "m.room.avatar": + upload_info, media_info, file_name = self._get_upload_and_media_info("url", content, request) + if not upload_info: + response = await client.room_send( + room_id, msgtype, content, txnid, ignore_unverified + ) + + return web.Response( + status=response.transport_response.status, + content_type=response.transport_response.content_type, + headers=CORS_HEADERS, + body=await response.transport_response.read(), + ) + + content = media_info.to_content(file_name, content_msgtype, upload_info.mimetype), + response = await client.room_send( room_id, msgtype, content, txnid, ignore_unverified ) @@ -1119,6 +1141,7 @@ class ProxyDaemon: content_type=content_type, filename=file_name, encrypt=True, + filesize=len(body), ) if not isinstance(response, UploadResponse): @@ -1129,7 +1152,7 @@ class ProxyDaemon: body=await response.transport_response.read(), ) - self.store.save_upload(self.name, response.content_uri, maybe_keys) + self.store.save_upload(self.name, response.content_uri, maybe_keys, content_type) return web.Response( status=response.transport_response.status, @@ -1214,7 +1237,7 @@ class ProxyDaemon: file_name = request.match_info.get("file_name") try: - response, decrypted_file = await self._load_media(server_name, media_id, file_name, request) + response, decrypted_file = await self._load_media(server_name, media_id, file_name) if response is None and decrypted_file is None: return await self.forward_to_web(request) diff --git a/pantalaimon/store.py b/pantalaimon/store.py index fbde863..79e05d7 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -15,7 +15,7 @@ import json import os from collections import defaultdict -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import attr from nio.crypto import TrustState @@ -48,6 +48,23 @@ class MediaInfo: iv = attr.ib(type=str) hashes = attr.ib(type=dict) + def to_content(self, file_name: str, msgtype: str, mime_type: str) -> Dict[Any, Any]: + content = { + "body": file_name, + "file": { + "v": "v2", + "key": self.key, + "iv": self.iv, + "hashes": self.hashes, + "url": self.url, + "mimetype": mime_type, + } + } + + if msgtype: + content["msgtype"] = msgtype + + return content @attr.s class UploadInfo: @@ -55,6 +72,7 @@ class UploadInfo: key = attr.ib(type=dict) iv = attr.ib(type=str) hashes = attr.ib(type=dict) + mimetype = attr.ib(type=str) class DictField(TextField): @@ -186,7 +204,7 @@ class PanStore: return None @use_database - def save_upload(self, server, content_uri, upload): + def save_upload(self, server, content_uri, upload, mimetype): server = Servers.get(name=server) PanUploadInfo.insert( @@ -195,6 +213,7 @@ class PanStore: key=upload["key"], iv=upload["iv"], hashes=upload["hashes"], + mimetype=mimetype, ).on_conflict_ignore().execute() @use_database @@ -208,7 +227,7 @@ class PanStore: if i > MAX_LOADED_UPLOAD: break - upload = UploadInfo(u.content_uri, u.key, u.iv, u.hashes) + upload = UploadInfo(u.content_uri, u.key, u.iv, u.hashes, u.mimetype) upload_cache[u.content_uri] = upload return upload_cache @@ -221,7 +240,7 @@ class PanStore: if not u: return None - return UploadInfo(u.content_uri, u.key, u.iv, u.hashes) + return UploadInfo(u.content_uri, u.key, u.iv, u.hashes, u.mimetype) @use_database def save_media(self, server, media): diff --git a/tests/store_test.py b/tests/store_test.py index 981170f..b31d41a 100644 --- a/tests/store_test.py +++ b/tests/store_test.py @@ -187,9 +187,9 @@ class TestClass(object): assert not panstore.load_upload(server_name, event.url) - upload = UploadInfo(event.url, event.key, event.iv, event.hashes) + upload = UploadInfo(event.url, event.key, event.iv, event.hashes, event.mimetype) - panstore.save_upload(server_name, event.url, {"key": event.key, "iv": event.iv, "hashes": event.hashes}) + panstore.save_upload(server_name, event.url, {"key": event.key, "iv": event.iv, "hashes": event.hashes}, event.mimetype) upload_cache = panstore.load_upload(server_name)