WIP: first commit

This commit is contained in:
Andrea Spacca 2020-12-22 02:37:27 +01:00
parent 5b1e220f5e
commit e5922da6ec
3 changed files with 128 additions and 15 deletions

View File

@ -35,6 +35,7 @@ from nio import (
OlmTrustError, OlmTrustError,
SendRetryError, SendRetryError,
DownloadResponse, DownloadResponse,
UploadResponse,
) )
from nio.crypto import decrypt_attachment from nio.crypto import decrypt_attachment
@ -851,9 +852,46 @@ class ProxyDaemon:
# The room isn't encrypted just forward the message. # The room isn't encrypted just forward the message.
if not encrypt: if not encrypt:
content_msgtype = ""
msgtype = request.match_info["event_type"]
if content_msgtype == "m.image" or content_msgtype == "m.video" \
or content_msgtype == "m.audio" or content_msgtype == "m.file":
content_uri = request.match_info["content"]["url"]
upload = self.store.load_upload(content_uri)
if upload is None:
return await self.forward_to_web(request, token=client.access_token)
server_name = request.match_info["server_name"]
media_id = content_uri
file_name = request.match_info.get("file_name")
response, decrypted_file, error = self._load_media(server_name, media_id, file_name, request)
if response is None and decrypted_file is None and error is None:
return await self.forward_to_web(request, token=client.access_token)
if error is ClientConnectionError:
return await self.forward_to_web(request, token=client.access_token)
if error is KeyError:
return await self.forward_to_web(request, token=client.access_token)
if not isinstance(response, DownloadResponse):
return await self.forward_to_web(request, token=client.access_token)
decrypted_upload = client.upload(
data_provider=decrypted_file,
content_type=response.content_type,
filename=file_name,
encrypt=False,
)
if not isinstance(response, UploadResponse):
return await self.forward_to_web(request, token=client.access_token)
request.match_info["content"]["url"] = decrypted_upload.content_uri
return await self.forward_to_web(request, token=client.access_token) return await self.forward_to_web(request, token=client.access_token)
msgtype = request.match_info["event_type"]
txnid = request.match_info.get("txnid", uuid4()) txnid = request.match_info.get("txnid", uuid4())
try: try:
@ -1039,11 +1077,36 @@ class ProxyDaemon:
return web.json_response(result, headers=CORS_HEADERS, status=200) return web.json_response(result, headers=CORS_HEADERS, status=200)
async def download(self, request): async def upload(self, request):
server_name = request.match_info["server_name"] file_name = request.query.get("filename", "")
media_id = request.match_info["media_id"] content_type = request.headers.get("Content-Type", "application/octet-stream")
file_name = request.match_info.get("file_name") client = next(iter(self.pan_clients.values()))
try:
response = await client.upload(
data_provider=await request.read,
content_type=content_type,
filename=file_name,
encrypt=True,
)
if not isinstance(response, UploadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
self.store.save_upload(response.content_uri)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except SendRetryError as e:
return web.Response(status=503, text=str(e))
def _load_media(self, server_name, media_id, file_name, request):
try: try:
media_info = self.media_info[(server_name, media_id)] media_info = self.media_info[(server_name, media_id)]
except KeyError: except KeyError:
@ -1062,25 +1125,19 @@ class ProxyDaemon:
logger.warn( logger.warn(
f"Media info for {server_name}/{media_id} doesn't contain a key or hash." f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
) )
return await self.forward_to_web(request) return None, None, KeyError
if not self.pan_clients: if not self.pan_clients:
return await self.forward_to_web(request) return None, None, None
client = next(iter(self.pan_clients.values())) client = next(iter(self.pan_clients.values()))
try: try:
response = await client.download(server_name, media_id, file_name) response = await client.download(server_name, media_id, file_name)
except ClientConnectionError as e: except ClientConnectionError as e:
return web.Response(status=500, text=str(e)) return None, None, e
if not isinstance(response, DownloadResponse): if not isinstance(response, DownloadResponse):
return web.Response( return response, None, None
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
logger.info(f"Decrypting media {server_name}/{media_id}") logger.info(f"Decrypting media {server_name}/{media_id}")
@ -1090,6 +1147,31 @@ class ProxyDaemon:
pool, decrypt_attachment, response.body, key, hash, media_info.iv pool, decrypt_attachment, response.body, key, hash, media_info.iv
) )
return response, decrypted_file, None
async def download(self, request):
server_name = request.match_info["server_name"]
media_id = request.match_info["media_id"]
file_name = request.match_info.get("file_name")
response, decrypted_file, error = self._load_media(server_name, media_id, file_name, request)
if response is None and decrypted_file is None and error is None:
return await self.forward_to_web(request)
if error is ClientConnectionError:
return web.Response(status=500, text=str(error))
if error is KeyError:
return await self.forward_to_web(request)
if not isinstance(response, DownloadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
return web.Response( return web.Response(
status=response.transport_response.status, status=response.transport_response.status,
content_type=response.transport_response.content_type, content_type=response.transport_response.content_type,
@ -1097,6 +1179,7 @@ class ProxyDaemon:
body=decrypted_file, body=decrypted_file,
) )
async def well_known(self, _): async def well_known(self, _):
"""Intercept well-known requests """Intercept well-known requests

View File

@ -93,6 +93,10 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}", "/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
proxy.download, proxy.download,
), ),
web.post(
r"/_matrix/media/r0/upload",
proxy.upload,
),
] ]
) )
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router) app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)

View File

@ -48,6 +48,11 @@ class MediaInfo:
hashes = attr.ib(type=dict) hashes = attr.ib(type=dict)
@attr.s
class UploadInfo:
content_uri = attr.ib(type=str)
class DictField(TextField): class DictField(TextField):
def python_value(self, value): # pragma: no cover def python_value(self, value): # pragma: no cover
return json.loads(value) return json.loads(value)
@ -112,6 +117,10 @@ class PanMediaInfo(Model):
class Meta: class Meta:
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")] constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
class PanUploadInfo(Model):
content_uri = TextField()
class Meta:
constraints = [SQL("UNIQUE(content_uri)")]
@attr.s @attr.s
class ClientInfo: class ClientInfo:
@ -162,6 +171,23 @@ class PanStore:
except DoesNotExist: except DoesNotExist:
return None return None
@use_database
def save_upload(self, content_uri):
PanUploadInfo.insert(
content_uri=content_uri,
).on_conflict_ignore().execute()
@use_database
def load_upload(self, content_uri):
u = PanUploadInfo.get_or_none(
PanUploadInfo.content_uri == content_uri,
)
if not u:
return None
return UploadInfo(u.content_uri)
@use_database @use_database
def save_media(self, server, media): def save_media(self, server, media):
server = Servers.get(name=server) server = Servers.get(name=server)