mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-02-25 09:01:27 -05:00
CR fixes
This commit is contained in:
parent
ce0fa21f94
commit
1168bcf7ff
@ -17,6 +17,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
from io import BufferedReader, BytesIO
|
||||||
from json import JSONDecodeError
|
from json import JSONDecodeError
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@ -50,7 +51,7 @@ from pantalaimon.client import (
|
|||||||
)
|
)
|
||||||
from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError
|
from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError
|
||||||
from pantalaimon.log import logger
|
from pantalaimon.log import logger
|
||||||
from pantalaimon.store import ClientInfo, PanStore
|
from pantalaimon.store import ClientInfo, PanStore, MediaInfo
|
||||||
from pantalaimon.thread_messages import (
|
from pantalaimon.thread_messages import (
|
||||||
AcceptSasMessage,
|
AcceptSasMessage,
|
||||||
CancelSasMessage,
|
CancelSasMessage,
|
||||||
@ -826,6 +827,52 @@ class ProxyDaemon:
|
|||||||
body=await response.read(),
|
body=await response.read(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _map_media_upload(self, content, request, client):
|
||||||
|
content_uri = content["url"]
|
||||||
|
|
||||||
|
upload = self.store.load_upload(content_uri)
|
||||||
|
if upload is None:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
mxc = urlparse(content_uri)
|
||||||
|
mxc_server = mxc.netloc.strip("/")
|
||||||
|
mxc_path = mxc.path.strip("/")
|
||||||
|
file_name = request.match_info.get("file_name")
|
||||||
|
|
||||||
|
logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store")
|
||||||
|
|
||||||
|
media = MediaInfo(mxc_server, mxc_path, upload.key, upload.iv, upload.hashes)
|
||||||
|
self.media_info[(mxc_server, mxc_path)] = media
|
||||||
|
client = next(iter(self.pan_clients.values()))
|
||||||
|
self.store.save_media(client.server_name, media)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response, decrypted_file, error = await self._load_media(mxc_server, mxc_path, file_name, request)
|
||||||
|
|
||||||
|
if response is None and decrypted_file is None:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
except ClientConnectionError as e:
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
|
except KeyError:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
if not isinstance(response, DownloadResponse):
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
decrypted_upload = await client.upload(
|
||||||
|
data_provider=BufferedReader(BytesIO(decrypted_file)),
|
||||||
|
content_type=response.content_type,
|
||||||
|
filename=file_name,
|
||||||
|
encrypt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(decrypted_upload[0], UploadResponse):
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
content["url"] = decrypted_upload[0].content_uri
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
async def send_message(self, request):
|
async def send_message(self, request):
|
||||||
access_token = self.get_access_token(request)
|
access_token = self.get_access_token(request)
|
||||||
|
|
||||||
@ -851,56 +898,27 @@ class ProxyDaemon:
|
|||||||
if request.match_info["event_type"] == "m.reaction":
|
if request.match_info["event_type"] == "m.reaction":
|
||||||
encrypt = False
|
encrypt = False
|
||||||
|
|
||||||
|
msgtype = request.match_info["event_type"]
|
||||||
|
|
||||||
# The room isn't encrypted just forward the message.
|
# The room isn't encrypted just forward the message.
|
||||||
if not encrypt:
|
if not encrypt:
|
||||||
content_msgtype = ""
|
content = ""
|
||||||
msgtype = request.match_info["event_type"]
|
|
||||||
if content_msgtype == "m.image" or content_msgtype == "m.video" \
|
|
||||||
or content_msgtype == "m.audio" or content_msgtype == "m.file":
|
|
||||||
content_uri = request.match_info["content"]["url"]
|
|
||||||
|
|
||||||
upload = self.store.load_upload(content_uri)
|
|
||||||
if upload is None:
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
|
|
||||||
mxc = urlparse(content_uri)
|
|
||||||
server_name = mxc.netloc.strip("/")
|
|
||||||
media_id = mxc.path.strip("/")
|
|
||||||
file_name = request.match_info.get("file_name")
|
|
||||||
|
|
||||||
response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request)
|
|
||||||
|
|
||||||
if response is None and decrypted_file is None and error is None:
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
if error is ClientConnectionError:
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
if error is KeyError:
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
|
|
||||||
if not isinstance(response, DownloadResponse):
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
|
|
||||||
decrypted_upload = client.upload(
|
|
||||||
data_provider=decrypted_file,
|
|
||||||
content_type=response.content_type,
|
|
||||||
filename=file_name,
|
|
||||||
encrypt=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(response, UploadResponse):
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
|
|
||||||
request.match_info["content"]["url"] = decrypted_upload.content_uri
|
|
||||||
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
|
||||||
|
|
||||||
txnid = request.match_info.get("txnid", uuid4())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = await request.json()
|
content = await request.json()
|
||||||
except (JSONDecodeError, ContentTypeError):
|
except (JSONDecodeError, ContentTypeError):
|
||||||
return self._not_json
|
return self._not_json
|
||||||
|
|
||||||
|
content_msgtype = content["msgtype"]
|
||||||
|
if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"]:
|
||||||
|
try:
|
||||||
|
content = self._map_media_upload(content, request, client)
|
||||||
|
except ValueError:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
txnid = request.match_info.get("txnid", uuid4())
|
||||||
|
|
||||||
async def _send(ignore_unverified=False):
|
async def _send(ignore_unverified=False):
|
||||||
try:
|
try:
|
||||||
response = await client.room_send(
|
response = await client.room_send(
|
||||||
@ -1084,30 +1102,37 @@ class ProxyDaemon:
|
|||||||
content_type = request.headers.get("Content-Type", "application/octet-stream")
|
content_type = request.headers.get("Content-Type", "application/octet-stream")
|
||||||
client = next(iter(self.pan_clients.values()))
|
client = next(iter(self.pan_clients.values()))
|
||||||
|
|
||||||
|
body = await request.read()
|
||||||
try:
|
try:
|
||||||
response = await client.upload(
|
response = await client.upload(
|
||||||
data_provider=await request.read,
|
data_provider=BufferedReader(BytesIO(body)),
|
||||||
content_type=content_type,
|
content_type=content_type,
|
||||||
filename=file_name,
|
filename=file_name,
|
||||||
encrypt=True,
|
encrypt=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(response, UploadResponse):
|
if not isinstance(response[0], UploadResponse):
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=response.transport_response.status,
|
status=response[0].transport_response.status,
|
||||||
content_type=response.transport_response.content_type,
|
content_type=response[0].transport_response.content_type,
|
||||||
headers=CORS_HEADERS,
|
headers=CORS_HEADERS,
|
||||||
body=await response.transport_response.read(),
|
body=await response[0].transport_response.read(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.save_upload(response.content_uri)
|
self.store.save_upload(response[0].content_uri, response[1])
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
status=response[0].transport_response.status,
|
||||||
|
content_type=response[0].transport_response.content_type,
|
||||||
|
headers=CORS_HEADERS,
|
||||||
|
body=await response[0].transport_response.read(),
|
||||||
|
)
|
||||||
|
|
||||||
except ClientConnectionError as e:
|
except ClientConnectionError as e:
|
||||||
return web.Response(status=500, text=str(e))
|
return web.Response(status=500, text=str(e))
|
||||||
except SendRetryError as e:
|
except SendRetryError as e:
|
||||||
return web.Response(status=503, text=str(e))
|
return web.Response(status=503, text=str(e))
|
||||||
|
|
||||||
|
|
||||||
async def _load_media(self, server_name, media_id, file_name, request):
|
async def _load_media(self, server_name, media_id, file_name, request):
|
||||||
try:
|
try:
|
||||||
media_info = self.media_info[(server_name, media_id)]
|
media_info = self.media_info[(server_name, media_id)]
|
||||||
@ -1116,30 +1141,30 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
if not media_info:
|
if not media_info:
|
||||||
logger.info(f"No media info found for {server_name}/{media_id}")
|
logger.info(f"No media info found for {server_name}/{media_id}")
|
||||||
return None, None, None
|
return None, None
|
||||||
|
|
||||||
self.media_info[(server_name, media_id)] = media_info
|
self.media_info[(server_name, media_id)] = media_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
key = media_info.key["k"]
|
key = media_info.key["k"]
|
||||||
hash = media_info.hashes["sha256"]
|
hash = media_info.hashes["sha256"]
|
||||||
except KeyError:
|
except KeyError as e:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
|
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
|
||||||
)
|
)
|
||||||
return None, None, KeyError
|
raise e
|
||||||
if not self.pan_clients:
|
if not self.pan_clients:
|
||||||
return None, None, None
|
return None, None
|
||||||
|
|
||||||
client = next(iter(self.pan_clients.values()))
|
client = next(iter(self.pan_clients.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.download(server_name, media_id, file_name)
|
response = await client.download(server_name, media_id, file_name)
|
||||||
except ClientConnectionError as e:
|
except ClientConnectionError as e:
|
||||||
return None, None, e
|
raise e
|
||||||
|
|
||||||
if not isinstance(response, DownloadResponse):
|
if not isinstance(response, DownloadResponse):
|
||||||
return response, None, None
|
return response, None
|
||||||
|
|
||||||
logger.info(f"Decrypting media {server_name}/{media_id}")
|
logger.info(f"Decrypting media {server_name}/{media_id}")
|
||||||
|
|
||||||
@ -1149,7 +1174,7 @@ class ProxyDaemon:
|
|||||||
pool, decrypt_attachment, response.body, key, hash, media_info.iv
|
pool, decrypt_attachment, response.body, key, hash, media_info.iv
|
||||||
)
|
)
|
||||||
|
|
||||||
return response, decrypted_file, None
|
return response, decrypted_file
|
||||||
|
|
||||||
|
|
||||||
async def download(self, request):
|
async def download(self, request):
|
||||||
@ -1157,13 +1182,14 @@ class ProxyDaemon:
|
|||||||
media_id = request.match_info["media_id"]
|
media_id = request.match_info["media_id"]
|
||||||
file_name = request.match_info.get("file_name")
|
file_name = request.match_info.get("file_name")
|
||||||
|
|
||||||
|
try:
|
||||||
response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request)
|
response, decrypted_file, error = await self._load_media(server_name, media_id, file_name, request)
|
||||||
|
|
||||||
if response is None and decrypted_file is None and error is None:
|
if response is None and decrypted_file is None:
|
||||||
return await self.forward_to_web(request)
|
return await self.forward_to_web(request)
|
||||||
if error is ClientConnectionError:
|
except ClientConnectionError as e:
|
||||||
return web.Response(status=500, text=str(error))
|
return web.Response(status=500, text=str(e))
|
||||||
if error is KeyError:
|
except KeyError:
|
||||||
return await self.forward_to_web(request)
|
return await self.forward_to_web(request)
|
||||||
|
|
||||||
if not isinstance(response, DownloadResponse):
|
if not isinstance(response, DownloadResponse):
|
||||||
@ -1181,7 +1207,6 @@ class ProxyDaemon:
|
|||||||
body=decrypted_file,
|
body=decrypted_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def well_known(self, _):
|
async def well_known(self, _):
|
||||||
"""Intercept well-known requests
|
"""Intercept well-known requests
|
||||||
|
|
||||||
|
@ -51,6 +51,9 @@ class MediaInfo:
|
|||||||
@attr.s
|
@attr.s
|
||||||
class UploadInfo:
|
class UploadInfo:
|
||||||
content_uri = attr.ib(type=str)
|
content_uri = attr.ib(type=str)
|
||||||
|
key = attr.ib(type=dict)
|
||||||
|
iv = attr.ib(type=str)
|
||||||
|
hashes = attr.ib(type=dict)
|
||||||
|
|
||||||
|
|
||||||
class DictField(TextField):
|
class DictField(TextField):
|
||||||
@ -117,11 +120,17 @@ class PanMediaInfo(Model):
|
|||||||
class Meta:
|
class Meta:
|
||||||
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
|
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
|
||||||
|
|
||||||
|
|
||||||
class PanUploadInfo(Model):
|
class PanUploadInfo(Model):
|
||||||
content_uri = TextField()
|
content_uri = TextField()
|
||||||
|
key = DictField()
|
||||||
|
hashes = DictField()
|
||||||
|
iv = TextField()
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
constraints = [SQL("UNIQUE(content_uri)")]
|
constraints = [SQL("UNIQUE(content_uri)")]
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class ClientInfo:
|
class ClientInfo:
|
||||||
user_id = attr.ib(type=str)
|
user_id = attr.ib(type=str)
|
||||||
@ -144,6 +153,7 @@ class PanStore:
|
|||||||
PanSyncTokens,
|
PanSyncTokens,
|
||||||
PanFetcherTasks,
|
PanFetcherTasks,
|
||||||
PanMediaInfo,
|
PanMediaInfo,
|
||||||
|
PanUploadInfo,
|
||||||
]
|
]
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
@ -172,9 +182,12 @@ class PanStore:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def save_upload(self, content_uri):
|
def save_upload(self, content_uri, media):
|
||||||
PanUploadInfo.insert(
|
PanUploadInfo.insert(
|
||||||
content_uri=content_uri,
|
content_uri=content_uri,
|
||||||
|
key=media["key"],
|
||||||
|
iv=media["iv"],
|
||||||
|
hashes=media["hashes"],
|
||||||
).on_conflict_ignore().execute()
|
).on_conflict_ignore().execute()
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
@ -186,7 +199,7 @@ class PanStore:
|
|||||||
if not u:
|
if not u:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return UploadInfo(u.content_uri)
|
return UploadInfo(u.content_uri, u.key, u.iv, u.hashes)
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def save_media(self, server, media):
|
def save_media(self, server, media):
|
||||||
|
@ -8,7 +8,7 @@ from nio import RoomMessage, RoomEncryptedMedia
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from conftest import faker
|
from conftest import faker
|
||||||
from pantalaimon.index import INDEXING_ENABLED
|
from pantalaimon.index import INDEXING_ENABLED
|
||||||
from pantalaimon.store import FetchTask, MediaInfo
|
from pantalaimon.store import FetchTask, MediaInfo, UploadInfo
|
||||||
|
|
||||||
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
|
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
|
||||||
TEST_ROOM2 = "!testroom:localhost"
|
TEST_ROOM2 = "!testroom:localhost"
|
||||||
@ -177,3 +177,19 @@ class TestClass(object):
|
|||||||
media_info = media_cache[(mxc_server, mxc_path)]
|
media_info = media_cache[(mxc_server, mxc_path)]
|
||||||
assert media_info == media
|
assert media_info == media
|
||||||
assert media_info == panstore.load_media(server_name, mxc_server, mxc_path)
|
assert media_info == panstore.load_media(server_name, mxc_server, mxc_path)
|
||||||
|
|
||||||
|
def test_upload_storage(self, panstore):
|
||||||
|
event = self.encrypted_media_event
|
||||||
|
|
||||||
|
assert not panstore.load_upload(event.url)
|
||||||
|
|
||||||
|
upload = UploadInfo(event.url, event.key, event.iv, event.hashes)
|
||||||
|
|
||||||
|
panstore.save_upload(event.url)
|
||||||
|
|
||||||
|
upload_cache = panstore.load_upload(event.url)
|
||||||
|
|
||||||
|
assert (event.url) in upload_cache
|
||||||
|
upload_info = upload_cache[(event.url)]
|
||||||
|
assert upload_info == upload
|
||||||
|
assert upload_info == panstore.load_upload(event.url)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user