This commit is contained in:
Andrea Spacca 2020-12-22 17:12:21 +01:00
parent ce0fa21f94
commit 1168bcf7ff
3 changed files with 120 additions and 66 deletions

View File

@ -17,6 +17,7 @@ import json
import os import os
import urllib.parse import urllib.parse
import concurrent.futures import concurrent.futures
from io import BufferedReader, BytesIO
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Dict from typing import Any, Dict
from urllib.parse import urlparse from urllib.parse import urlparse
@ -50,7 +51,7 @@ from pantalaimon.client import (
) )
from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError
from pantalaimon.log import logger from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore from pantalaimon.store import ClientInfo, PanStore, MediaInfo
from pantalaimon.thread_messages import ( from pantalaimon.thread_messages import (
AcceptSasMessage, AcceptSasMessage,
CancelSasMessage, CancelSasMessage,
@ -826,6 +827,52 @@ class ProxyDaemon:
body=await response.read(), body=await response.read(),
) )
def _map_media_upload(self, content, request, client):
content_uri = content["url"]
upload = self.store.load_upload(content_uri)
if upload is None:
return await self.forward_to_web(request, token=client.access_token)
mxc = urlparse(content_uri)
mxc_server = mxc.netloc.strip("/")
mxc_path = mxc.path.strip("/")
file_name = request.match_info.get("file_name")
logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store")
media = MediaInfo(mxc_server, mxc_path, upload.key, upload.iv, upload.hashes)
self.media_info[(mxc_server, mxc_path)] = media
client = next(iter(self.pan_clients.values()))
self.store.save_media(client.server_name, media)
try:
response, decrypted_file, error = await self._load_media(mxc_server, mxc_path, file_name, request)
if response is None and decrypted_file is None:
return await self.forward_to_web(request, token=client.access_token)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except KeyError:
return await self.forward_to_web(request, token=client.access_token)
if not isinstance(response, DownloadResponse):
return await self.forward_to_web(request, token=client.access_token)
decrypted_upload = await client.upload(
data_provider=BufferedReader(BytesIO(decrypted_file)),
content_type=response.content_type,
filename=file_name,
encrypt=False,
)
if not isinstance(decrypted_upload[0], UploadResponse):
raise ValueError
content["url"] = decrypted_upload[0].content_uri
return content
async def send_message(self, request): async def send_message(self, request):
access_token = self.get_access_token(request) access_token = self.get_access_token(request)
@ -851,56 +898,27 @@ class ProxyDaemon:
if request.match_info["event_type"] == "m.reaction": if request.match_info["event_type"] == "m.reaction":
encrypt = False encrypt = False
msgtype = request.match_info["event_type"]
# The room isn't encrypted just forward the message. # The room isn't encrypted just forward the message.
if not encrypt: if not encrypt:
content_msgtype = "" content = ""
msgtype = request.match_info["event_type"]
if content_msgtype == "m.image" or content_msgtype == "m.video" \
or content_msgtype == "m.audio" or content_msgtype == "m.file":
content_uri = request.match_info["content"]["url"]
upload = self.store.load_upload(content_uri)
if upload is None:
return await self.forward_to_web(request, token=client.access_token)
mxc = urlparse(content_uri)
server_name = mxc.netloc.strip("/")
media_id = mxc.path.strip("/")
file_name = request.match_info.get("file_name")
response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request)
if response is None and decrypted_file is None and error is None:
return await self.forward_to_web(request, token=client.access_token)
if error is ClientConnectionError:
return await self.forward_to_web(request, token=client.access_token)
if error is KeyError:
return await self.forward_to_web(request, token=client.access_token)
if not isinstance(response, DownloadResponse):
return await self.forward_to_web(request, token=client.access_token)
decrypted_upload = client.upload(
data_provider=decrypted_file,
content_type=response.content_type,
filename=file_name,
encrypt=False,
)
if not isinstance(response, UploadResponse):
return await self.forward_to_web(request, token=client.access_token)
request.match_info["content"]["url"] = decrypted_upload.content_uri
return await self.forward_to_web(request, token=client.access_token)
txnid = request.match_info.get("txnid", uuid4())
try: try:
content = await request.json() content = await request.json()
except (JSONDecodeError, ContentTypeError): except (JSONDecodeError, ContentTypeError):
return self._not_json return self._not_json
content_msgtype = content["msgtype"]
if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"]:
try:
content = self._map_media_upload(content, request, client)
except ValueError:
return await self.forward_to_web(request, token=client.access_token)
return await self.forward_to_web(request, token=client.access_token)
txnid = request.match_info.get("txnid", uuid4())
async def _send(ignore_unverified=False): async def _send(ignore_unverified=False):
try: try:
response = await client.room_send( response = await client.room_send(
@ -1084,30 +1102,37 @@ class ProxyDaemon:
content_type = request.headers.get("Content-Type", "application/octet-stream") content_type = request.headers.get("Content-Type", "application/octet-stream")
client = next(iter(self.pan_clients.values())) client = next(iter(self.pan_clients.values()))
body = await request.read()
try: try:
response = await client.upload( response = await client.upload(
data_provider=await request.read, data_provider=BufferedReader(BytesIO(body)),
content_type=content_type, content_type=content_type,
filename=file_name, filename=file_name,
encrypt=True, encrypt=True,
) )
if not isinstance(response, UploadResponse): if not isinstance(response[0], UploadResponse):
return web.Response( return web.Response(
status=response.transport_response.status, status=response[0].transport_response.status,
content_type=response.transport_response.content_type, content_type=response[0].transport_response.content_type,
headers=CORS_HEADERS, headers=CORS_HEADERS,
body=await response.transport_response.read(), body=await response[0].transport_response.read(),
) )
self.store.save_upload(response.content_uri) self.store.save_upload(response[0].content_uri, response[1])
return web.Response(
status=response[0].transport_response.status,
content_type=response[0].transport_response.content_type,
headers=CORS_HEADERS,
body=await response[0].transport_response.read(),
)
except ClientConnectionError as e: except ClientConnectionError as e:
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
except SendRetryError as e: except SendRetryError as e:
return web.Response(status=503, text=str(e)) return web.Response(status=503, text=str(e))
async def _load_media(self, server_name, media_id, file_name, request): async def _load_media(self, server_name, media_id, file_name, request):
try: try:
media_info = self.media_info[(server_name, media_id)] media_info = self.media_info[(server_name, media_id)]
@ -1116,30 +1141,30 @@ class ProxyDaemon:
if not media_info: if not media_info:
logger.info(f"No media info found for {server_name}/{media_id}") logger.info(f"No media info found for {server_name}/{media_id}")
return None, None, None return None, None
self.media_info[(server_name, media_id)] = media_info self.media_info[(server_name, media_id)] = media_info
try: try:
key = media_info.key["k"] key = media_info.key["k"]
hash = media_info.hashes["sha256"] hash = media_info.hashes["sha256"]
except KeyError: except KeyError as e:
logger.warn( logger.warn(
f"Media info for {server_name}/{media_id} doesn't contain a key or hash." f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
) )
return None, None, KeyError raise e
if not self.pan_clients: if not self.pan_clients:
return None, None, None return None, None
client = next(iter(self.pan_clients.values())) client = next(iter(self.pan_clients.values()))
try: try:
response = await client.download(server_name, media_id, file_name) response = await client.download(server_name, media_id, file_name)
except ClientConnectionError as e: except ClientConnectionError as e:
return None, None, e raise e
if not isinstance(response, DownloadResponse): if not isinstance(response, DownloadResponse):
return response, None, None return response, None
logger.info(f"Decrypting media {server_name}/{media_id}") logger.info(f"Decrypting media {server_name}/{media_id}")
@ -1149,7 +1174,7 @@ class ProxyDaemon:
pool, decrypt_attachment, response.body, key, hash, media_info.iv pool, decrypt_attachment, response.body, key, hash, media_info.iv
) )
return response, decrypted_file, None return response, decrypted_file
async def download(self, request): async def download(self, request):
@ -1157,13 +1182,14 @@ class ProxyDaemon:
media_id = request.match_info["media_id"] media_id = request.match_info["media_id"]
file_name = request.match_info.get("file_name") file_name = request.match_info.get("file_name")
try:
response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request) response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request)
if response is None and decrypted_file is None and error is None: if response is None and decrypted_file is None:
return await self.forward_to_web(request) return await self.forward_to_web(request)
if error is ClientConnectionError: except ClientConnectionError as e:
return web.Response(status=500, text=str(error)) return web.Response(status=500, text=str(e))
if error is KeyError: except KeyError:
return await self.forward_to_web(request) return await self.forward_to_web(request)
if not isinstance(response, DownloadResponse): if not isinstance(response, DownloadResponse):
@ -1181,7 +1207,6 @@ class ProxyDaemon:
body=decrypted_file, body=decrypted_file,
) )
async def well_known(self, _): async def well_known(self, _):
"""Intercept well-known requests """Intercept well-known requests

View File

@ -51,6 +51,9 @@ class MediaInfo:
@attr.s @attr.s
class UploadInfo: class UploadInfo:
content_uri = attr.ib(type=str) content_uri = attr.ib(type=str)
key = attr.ib(type=dict)
iv = attr.ib(type=str)
hashes = attr.ib(type=dict)
class DictField(TextField): class DictField(TextField):
@ -117,11 +120,17 @@ class PanMediaInfo(Model):
class Meta: class Meta:
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")] constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
class PanUploadInfo(Model): class PanUploadInfo(Model):
content_uri = TextField() content_uri = TextField()
key = DictField()
hashes = DictField()
iv = TextField()
class Meta: class Meta:
constraints = [SQL("UNIQUE(content_uri)")] constraints = [SQL("UNIQUE(content_uri)")]
@attr.s @attr.s
class ClientInfo: class ClientInfo:
user_id = attr.ib(type=str) user_id = attr.ib(type=str)
@ -144,6 +153,7 @@ class PanStore:
PanSyncTokens, PanSyncTokens,
PanFetcherTasks, PanFetcherTasks,
PanMediaInfo, PanMediaInfo,
PanUploadInfo,
] ]
def __attrs_post_init__(self): def __attrs_post_init__(self):
@ -172,9 +182,12 @@ class PanStore:
return None return None
@use_database @use_database
def save_upload(self, content_uri): def save_upload(self, content_uri, media):
PanUploadInfo.insert( PanUploadInfo.insert(
content_uri=content_uri, content_uri=content_uri,
key=media["key"],
iv=media["iv"],
hashes=media["hashes"],
).on_conflict_ignore().execute() ).on_conflict_ignore().execute()
@use_database @use_database
@ -186,7 +199,7 @@ class PanStore:
if not u: if not u:
return None return None
return UploadInfo(u.content_uri) return UploadInfo(u.content_uri, u.key, u.iv, u.hashes)
@use_database @use_database
def save_media(self, server, media): def save_media(self, server, media):

View File

@ -8,7 +8,7 @@ from nio import RoomMessage, RoomEncryptedMedia
from urllib.parse import urlparse from urllib.parse import urlparse
from conftest import faker from conftest import faker
from pantalaimon.index import INDEXING_ENABLED from pantalaimon.index import INDEXING_ENABLED
from pantalaimon.store import FetchTask, MediaInfo from pantalaimon.store import FetchTask, MediaInfo, UploadInfo
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost" TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
TEST_ROOM2 = "!testroom:localhost" TEST_ROOM2 = "!testroom:localhost"
@ -177,3 +177,19 @@ class TestClass(object):
media_info = media_cache[(mxc_server, mxc_path)] media_info = media_cache[(mxc_server, mxc_path)]
assert media_info == media assert media_info == media
assert media_info == panstore.load_media(server_name, mxc_server, mxc_path) assert media_info == panstore.load_media(server_name, mxc_server, mxc_path)
def test_upload_storage(self, panstore):
event = self.encrypted_media_event
assert not panstore.load_upload(event.url)
upload = UploadInfo(event.url, event.key, event.iv, event.hashes)
panstore.save_upload(event.url)
upload_cache = panstore.load_upload(event.url)
assert (event.url) in upload_cache
upload_info = upload_cache[(event.url)]
assert upload_info == upload
assert upload_info == panstore.load_upload(event.url)