This commit is contained in:
Andrea Spacca 2021-01-06 18:08:03 +01:00
parent 3d1a807f7e
commit f0ea2ebd3d
3 changed files with 62 additions and 20 deletions

View File

@ -829,7 +829,7 @@ class ProxyDaemon:
body=await response.read(),
)
async def _map_media_upload(self, content_key, content, request, client):
def _get_upload_and_media_info(self, content_key, content, request):
content_uri = content[content_key]
try:
@ -837,13 +837,11 @@ class ProxyDaemon:
except KeyError:
upload_info = self.store.load_upload(self.name, content_uri)
if not upload_info:
logger.info(f"No upload info found for {self.name}/{content_uri}")
return await self.forward_to_web(request, token=client.access_token)
self.upload_info[content_uri] = upload_info
return None, None
self.upload_info[content_uri] = upload_info
content_uri = content[content_key]
mxc = urlparse(content_uri)
mxc_server = mxc.netloc.strip("/")
mxc_path = mxc.path.strip("/")
@ -851,12 +849,19 @@ class ProxyDaemon:
logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store")
media = MediaInfo(mxc_server, mxc_path, upload_info.key, upload_info.iv, upload_info.hashes)
self.media_info[(mxc_server, mxc_path)] = media
self.store.save_media(self.name, media)
media_info = MediaInfo(mxc_server, mxc_path, upload_info.key, upload_info.iv, upload_info.hashes)
self.media_info[(mxc_server, mxc_path)] = media_info
self.store.save_media(self.name, media_info)
return upload_info, media_info, file_name
async def _map_media_upload(self, content_key, content, request, client):
try:
response, decrypted_file = await self._load_media(mxc_server, mxc_path, file_name)
upload_info, media_info, file_name = self._get_upload_and_media_info(content_key, content, request)
if not upload_info:
return await self.forward_to_web(request, token=client.access_token)
response, decrypted_file = await self._load_media(media_info.mcx_server, media_info.mxc_path, file_name)
if response is None and decrypted_file is None:
return await self.forward_to_web(request, token=client.access_token)
@ -868,11 +873,12 @@ class ProxyDaemon:
if not isinstance(response, DownloadResponse):
return await self.forward_to_web(request, token=client.access_token)
decrypted_upload, maybe_keys = await client.upload(
decrypted_upload, _ = await client.upload(
data_provider=BufferedReader(BytesIO(decrypted_file)),
content_type=response.content_type,
filename=file_name,
encrypt=False,
filesize=len(decrypted_file),
)
if not isinstance(decrypted_upload, UploadResponse):
@ -909,7 +915,6 @@ class ProxyDaemon:
msgtype = request.match_info["event_type"]
content = ""
try:
content = await request.json()
except (JSONDecodeError, ContentTypeError):
@ -931,6 +936,23 @@ class ProxyDaemon:
async def _send(ignore_unverified=False):
try:
content_msgtype = content["msgtype"]
if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"] or msgtype == "m.room.avatar":
upload_info, media_info, file_name = self._get_upload_and_media_info("url", content, request)
if not upload_info:
response = await client.room_send(
room_id, msgtype, content, txnid, ignore_unverified
)
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
content = media_info.to_content(file_name, content_msgtype, upload_info.mimetype),
response = await client.room_send(
room_id, msgtype, content, txnid, ignore_unverified
)
@ -1119,6 +1141,7 @@ class ProxyDaemon:
content_type=content_type,
filename=file_name,
encrypt=True,
filesize=len(body),
)
if not isinstance(response, UploadResponse):
@ -1129,7 +1152,7 @@ class ProxyDaemon:
body=await response.transport_response.read(),
)
self.store.save_upload(self.name, response.content_uri, maybe_keys)
self.store.save_upload(self.name, response.content_uri, maybe_keys, content_type)
return web.Response(
status=response.transport_response.status,
@ -1214,7 +1237,7 @@ class ProxyDaemon:
file_name = request.match_info.get("file_name")
try:
response, decrypted_file = await self._load_media(server_name, media_id, file_name, request)
response, decrypted_file = await self._load_media(server_name, media_id, file_name)
if response is None and decrypted_file is None:
return await self.forward_to_web(request)

View File

@ -15,7 +15,7 @@
import json
import os
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import attr
from nio.crypto import TrustState
@ -48,6 +48,23 @@ class MediaInfo:
iv = attr.ib(type=str)
hashes = attr.ib(type=dict)
def to_content(self, file_name: str, msgtype: str, mime_type: str) -> Dict[Any, Any]:
content = {
"body": file_name,
"file": {
"v": "v2",
"key": self.key,
"iv": self.iv,
"hashes": self.hashes,
"url": self.url,
"mimetype": mime_type,
}
}
if msgtype:
content["msgtype"] = msgtype
return content
@attr.s
class UploadInfo:
@ -55,6 +72,7 @@ class UploadInfo:
key = attr.ib(type=dict)
iv = attr.ib(type=str)
hashes = attr.ib(type=dict)
mimetype = attr.ib(type=str)
class DictField(TextField):
@ -186,7 +204,7 @@ class PanStore:
return None
@use_database
def save_upload(self, server, content_uri, upload):
def save_upload(self, server, content_uri, upload, mimetype):
server = Servers.get(name=server)
PanUploadInfo.insert(
@ -195,6 +213,7 @@ class PanStore:
key=upload["key"],
iv=upload["iv"],
hashes=upload["hashes"],
mimetype=mimetype,
).on_conflict_ignore().execute()
@use_database
@ -208,7 +227,7 @@ class PanStore:
if i > MAX_LOADED_UPLOAD:
break
upload = UploadInfo(u.content_uri, u.key, u.iv, u.hashes)
upload = UploadInfo(u.content_uri, u.key, u.iv, u.hashes, u.mimetype)
upload_cache[u.content_uri] = upload
return upload_cache
@ -221,7 +240,7 @@ class PanStore:
if not u:
return None
return UploadInfo(u.content_uri, u.key, u.iv, u.hashes)
return UploadInfo(u.content_uri, u.key, u.iv, u.hashes, u.mimetype)
@use_database
def save_media(self, server, media):

View File

@ -187,9 +187,9 @@ class TestClass(object):
assert not panstore.load_upload(server_name, event.url)
upload = UploadInfo(event.url, event.key, event.iv, event.hashes)
upload = UploadInfo(event.url, event.key, event.iv, event.hashes, event.mimetype)
panstore.save_upload(server_name, event.url, {"key": event.key, "iv": event.iv, "hashes": event.hashes})
panstore.save_upload(server_name, event.url, {"key": event.key, "iv": event.iv, "hashes": event.hashes}, event.mimetype)
upload_cache = panstore.load_upload(server_name)