mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-24 06:11:16 -05:00
WIP: first commit
This commit is contained in:
parent
5b1e220f5e
commit
e5922da6ec
@ -35,6 +35,7 @@ from nio import (
|
|||||||
OlmTrustError,
|
OlmTrustError,
|
||||||
SendRetryError,
|
SendRetryError,
|
||||||
DownloadResponse,
|
DownloadResponse,
|
||||||
|
UploadResponse,
|
||||||
)
|
)
|
||||||
from nio.crypto import decrypt_attachment
|
from nio.crypto import decrypt_attachment
|
||||||
|
|
||||||
@ -851,9 +852,46 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
# The room isn't encrypted just forward the message.
|
# The room isn't encrypted just forward the message.
|
||||||
if not encrypt:
|
if not encrypt:
|
||||||
|
content_msgtype = ""
|
||||||
|
msgtype = request.match_info["event_type"]
|
||||||
|
if content_msgtype == "m.image" or content_msgtype == "m.video" \
|
||||||
|
or content_msgtype == "m.audio" or content_msgtype == "m.file":
|
||||||
|
content_uri = request.match_info["content"]["url"]
|
||||||
|
|
||||||
|
upload = self.store.load_upload(content_uri)
|
||||||
|
if upload is None:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
server_name = request.match_info["server_name"]
|
||||||
|
media_id = content_uri
|
||||||
|
file_name = request.match_info.get("file_name")
|
||||||
|
|
||||||
|
response, decrypted_file, error = self._load_media(server_name, media_id, file_name, request)
|
||||||
|
|
||||||
|
if response is None and decrypted_file is None and error is None:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
if error is ClientConnectionError:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
if error is KeyError:
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
if not isinstance(response, DownloadResponse):
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
decrypted_upload = client.upload(
|
||||||
|
data_provider=decrypted_file,
|
||||||
|
content_type=response.content_type,
|
||||||
|
filename=file_name,
|
||||||
|
encrypt=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, UploadResponse):
|
||||||
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
|
request.match_info["content"]["url"] = decrypted_upload.content_uri
|
||||||
|
|
||||||
return await self.forward_to_web(request, token=client.access_token)
|
return await self.forward_to_web(request, token=client.access_token)
|
||||||
|
|
||||||
msgtype = request.match_info["event_type"]
|
|
||||||
txnid = request.match_info.get("txnid", uuid4())
|
txnid = request.match_info.get("txnid", uuid4())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1039,11 +1077,36 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
return web.json_response(result, headers=CORS_HEADERS, status=200)
|
return web.json_response(result, headers=CORS_HEADERS, status=200)
|
||||||
|
|
||||||
async def download(self, request):
|
async def upload(self, request):
|
||||||
server_name = request.match_info["server_name"]
|
file_name = request.query.get("filename", "")
|
||||||
media_id = request.match_info["media_id"]
|
content_type = request.headers.get("Content-Type", "application/octet-stream")
|
||||||
file_name = request.match_info.get("file_name")
|
client = next(iter(self.pan_clients.values()))
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await client.upload(
|
||||||
|
data_provider=await request.read,
|
||||||
|
content_type=content_type,
|
||||||
|
filename=file_name,
|
||||||
|
encrypt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, UploadResponse):
|
||||||
|
return web.Response(
|
||||||
|
status=response.transport_response.status,
|
||||||
|
content_type=response.transport_response.content_type,
|
||||||
|
headers=CORS_HEADERS,
|
||||||
|
body=await response.transport_response.read(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.store.save_upload(response.content_uri)
|
||||||
|
|
||||||
|
except ClientConnectionError as e:
|
||||||
|
return web.Response(status=500, text=str(e))
|
||||||
|
except SendRetryError as e:
|
||||||
|
return web.Response(status=503, text=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def _load_media(self, server_name, media_id, file_name, request):
|
||||||
try:
|
try:
|
||||||
media_info = self.media_info[(server_name, media_id)]
|
media_info = self.media_info[(server_name, media_id)]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -1062,25 +1125,19 @@ class ProxyDaemon:
|
|||||||
logger.warn(
|
logger.warn(
|
||||||
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
|
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
|
||||||
)
|
)
|
||||||
return await self.forward_to_web(request)
|
return None, None, KeyError
|
||||||
|
|
||||||
if not self.pan_clients:
|
if not self.pan_clients:
|
||||||
return await self.forward_to_web(request)
|
return None, None, None
|
||||||
|
|
||||||
client = next(iter(self.pan_clients.values()))
|
client = next(iter(self.pan_clients.values()))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await client.download(server_name, media_id, file_name)
|
response = await client.download(server_name, media_id, file_name)
|
||||||
except ClientConnectionError as e:
|
except ClientConnectionError as e:
|
||||||
return web.Response(status=500, text=str(e))
|
return None, None, e
|
||||||
|
|
||||||
if not isinstance(response, DownloadResponse):
|
if not isinstance(response, DownloadResponse):
|
||||||
return web.Response(
|
return response, None, None
|
||||||
status=response.transport_response.status,
|
|
||||||
content_type=response.transport_response.content_type,
|
|
||||||
headers=CORS_HEADERS,
|
|
||||||
body=await response.transport_response.read(),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Decrypting media {server_name}/{media_id}")
|
logger.info(f"Decrypting media {server_name}/{media_id}")
|
||||||
|
|
||||||
@ -1090,6 +1147,31 @@ class ProxyDaemon:
|
|||||||
pool, decrypt_attachment, response.body, key, hash, media_info.iv
|
pool, decrypt_attachment, response.body, key, hash, media_info.iv
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return response, decrypted_file, None
|
||||||
|
|
||||||
|
|
||||||
|
async def download(self, request):
|
||||||
|
server_name = request.match_info["server_name"]
|
||||||
|
media_id = request.match_info["media_id"]
|
||||||
|
file_name = request.match_info.get("file_name")
|
||||||
|
|
||||||
|
response, decrypted_file, error = self._load_media(server_name, media_id, file_name, request)
|
||||||
|
|
||||||
|
if response is None and decrypted_file is None and error is None:
|
||||||
|
return await self.forward_to_web(request)
|
||||||
|
if error is ClientConnectionError:
|
||||||
|
return web.Response(status=500, text=str(error))
|
||||||
|
if error is KeyError:
|
||||||
|
return await self.forward_to_web(request)
|
||||||
|
|
||||||
|
if not isinstance(response, DownloadResponse):
|
||||||
|
return web.Response(
|
||||||
|
status=response.transport_response.status,
|
||||||
|
content_type=response.transport_response.content_type,
|
||||||
|
headers=CORS_HEADERS,
|
||||||
|
body=await response.transport_response.read(),
|
||||||
|
)
|
||||||
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=response.transport_response.status,
|
status=response.transport_response.status,
|
||||||
content_type=response.transport_response.content_type,
|
content_type=response.transport_response.content_type,
|
||||||
@ -1097,6 +1179,7 @@ class ProxyDaemon:
|
|||||||
body=decrypted_file,
|
body=decrypted_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def well_known(self, _):
|
async def well_known(self, _):
|
||||||
"""Intercept well-known requests
|
"""Intercept well-known requests
|
||||||
|
|
||||||
|
@ -93,6 +93,10 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
|
|||||||
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
|
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
|
||||||
proxy.download,
|
proxy.download,
|
||||||
),
|
),
|
||||||
|
web.post(
|
||||||
|
r"/_matrix/media/r0/upload",
|
||||||
|
proxy.upload,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
|
||||||
|
@ -48,6 +48,11 @@ class MediaInfo:
|
|||||||
hashes = attr.ib(type=dict)
|
hashes = attr.ib(type=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class UploadInfo:
|
||||||
|
content_uri = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
class DictField(TextField):
|
class DictField(TextField):
|
||||||
def python_value(self, value): # pragma: no cover
|
def python_value(self, value): # pragma: no cover
|
||||||
return json.loads(value)
|
return json.loads(value)
|
||||||
@ -112,6 +117,10 @@ class PanMediaInfo(Model):
|
|||||||
class Meta:
|
class Meta:
|
||||||
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
|
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
|
||||||
|
|
||||||
|
class PanUploadInfo(Model):
|
||||||
|
content_uri = TextField()
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(content_uri)")]
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class ClientInfo:
|
class ClientInfo:
|
||||||
@ -162,6 +171,23 @@ class PanStore:
|
|||||||
except DoesNotExist:
|
except DoesNotExist:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def save_upload(self, content_uri):
|
||||||
|
PanUploadInfo.insert(
|
||||||
|
content_uri=content_uri,
|
||||||
|
).on_conflict_ignore().execute()
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def load_upload(self, content_uri):
|
||||||
|
u = PanUploadInfo.get_or_none(
|
||||||
|
PanUploadInfo.content_uri == content_uri,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not u:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return UploadInfo(u.content_uri)
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def save_media(self, server, media):
|
def save_media(self, server, media):
|
||||||
server = Servers.get(name=server)
|
server = Servers.get(name=server)
|
||||||
|
Loading…
Reference in New Issue
Block a user