diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 55b67a2..3522db5 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -35,6 +35,7 @@ from nio import ( OlmTrustError, SendRetryError, DownloadResponse, + UploadResponse, ) from nio.crypto import decrypt_attachment @@ -851,9 +852,46 @@ class ProxyDaemon: # The room isn't encrypted just forward the message. if not encrypt: + content_msgtype = "" + msgtype = request.match_info["event_type"] + if content_msgtype == "m.image" or content_msgtype == "m.video" \ + or content_msgtype == "m.audio" or content_msgtype == "m.file": + content_uri = request.match_info["content"]["url"] + + upload = self.store.load_upload(content_uri) + if upload is None: + return await self.forward_to_web(request, token=client.access_token) + + server_name = request.match_info["server_name"] + media_id = content_uri + file_name = request.match_info.get("file_name") + + response, decrypted_file, error = self._load_media(server_name, media_id, file_name, request) + + if response is None and decrypted_file is None and error is None: + return await self.forward_to_web(request, token=client.access_token) + if error is ClientConnectionError: + return await self.forward_to_web(request, token=client.access_token) + if error is KeyError: + return await self.forward_to_web(request, token=client.access_token) + + if not isinstance(response, DownloadResponse): + return await self.forward_to_web(request, token=client.access_token) + + decrypted_upload = client.upload( + data_provider=decrypted_file, + content_type=response.content_type, + filename=file_name, + encrypt=False, + ) + + if not isinstance(response, UploadResponse): + return await self.forward_to_web(request, token=client.access_token) + + request.match_info["content"]["url"] = decrypted_upload.content_uri + return await self.forward_to_web(request, token=client.access_token) - msgtype = request.match_info["event_type"] txnid = request.match_info.get("txnid", uuid4()) try: @@ -1039,11 +1077,36 @@ class ProxyDaemon: return web.json_response(result, headers=CORS_HEADERS, status=200) - async def download(self, request): - server_name = request.match_info["server_name"] - media_id = request.match_info["media_id"] - file_name = request.match_info.get("file_name") + async def upload(self, request): + file_name = request.query.get("filename", "") + content_type = request.headers.get("Content-Type", "application/octet-stream") + client = next(iter(self.pan_clients.values())) + try: + response = await client.upload( + data_provider=await request.read, + content_type=content_type, + filename=file_name, + encrypt=True, + ) + + if not isinstance(response, UploadResponse): + return web.Response( + status=response.transport_response.status, + content_type=response.transport_response.content_type, + headers=CORS_HEADERS, + body=await response.transport_response.read(), + ) + + self.store.save_upload(response.content_uri) + + except ClientConnectionError as e: + return web.Response(status=500, text=str(e)) + except SendRetryError as e: + return web.Response(status=503, text=str(e)) + + + def _load_media(self, server_name, media_id, file_name, request): try: media_info = self.media_info[(server_name, media_id)] except KeyError: @@ -1062,25 +1125,19 @@ class ProxyDaemon: logger.warn( f"Media info for {server_name}/{media_id} doesn't contain a key or hash." ) - return await self.forward_to_web(request) - + return None, None, KeyError if not self.pan_clients: - return await self.forward_to_web(request) + return None, None, None client = next(iter(self.pan_clients.values())) try: response = await client.download(server_name, media_id, file_name) except ClientConnectionError as e: - return web.Response(status=500, text=str(e)) + return None, None, e if not isinstance(response, DownloadResponse): - return web.Response( - status=response.transport_response.status, - content_type=response.transport_response.content_type, - headers=CORS_HEADERS, - body=await response.transport_response.read(), - ) + return response, None, None logger.info(f"Decrypting media {server_name}/{media_id}") @@ -1090,6 +1147,31 @@ class ProxyDaemon: pool, decrypt_attachment, response.body, key, hash, media_info.iv ) + return response, decrypted_file, None + + + async def download(self, request): + server_name = request.match_info["server_name"] + media_id = request.match_info["media_id"] + file_name = request.match_info.get("file_name") + + response, decrypted_file, error = self._load_media(server_name, media_id, file_name, request) + + if response is None and decrypted_file is None and error is None: + return await self.forward_to_web(request) + if error is ClientConnectionError: + return web.Response(status=500, text=str(error)) + if error is KeyError: + return await self.forward_to_web(request) + + if not isinstance(response, DownloadResponse): + return web.Response( + status=response.transport_response.status, + content_type=response.transport_response.content_type, + headers=CORS_HEADERS, + body=await response.transport_response.read(), + ) + return web.Response( status=response.transport_response.status, content_type=response.transport_response.content_type, @@ -1097,6 +1179,7 @@ class ProxyDaemon: body=decrypted_file, ) + async def well_known(self, _): """Intercept well-known requests diff --git a/pantalaimon/main.py b/pantalaimon/main.py index 12311ca..b75ae21 100644 --- a/pantalaimon/main.py +++ b/pantalaimon/main.py @@ -93,6 +93,10 @@ async def init(data_dir, server_conf, send_queue, recv_queue): "/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}", proxy.download, ), + web.post( + r"/_matrix/media/r0/upload", + proxy.upload, + ), ] ) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router) diff --git a/pantalaimon/store.py b/pantalaimon/store.py index b73cb0f..24e447d 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -48,6 +48,11 @@ class MediaInfo: hashes = attr.ib(type=dict) +@attr.s +class UploadInfo: + content_uri = attr.ib(type=str) + + class DictField(TextField): def python_value(self, value): # pragma: no cover return json.loads(value) @@ -112,6 +117,10 @@ class PanMediaInfo(Model): class Meta: constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")] +class PanUploadInfo(Model): + content_uri = TextField() + class Meta: + constraints = [SQL("UNIQUE(content_uri)")] @attr.s class ClientInfo: @@ -162,6 +171,23 @@ class PanStore: except DoesNotExist: return None + @use_database + def save_upload(self, content_uri): + PanUploadInfo.insert( + content_uri=content_uri, + ).on_conflict_ignore().execute() + + @use_database + def load_upload(self, content_uri): + u = PanUploadInfo.get_or_none( + PanUploadInfo.content_uri == content_uri, + ) + + if not u: + return None + + return UploadInfo(u.content_uri) + @use_database def save_media(self, server, media): server = Servers.get(name=server)