Merge remote-tracking branch 'origin/master' into hs/sync-options

This commit is contained in:
Will Hunt 2021-01-14 18:45:45 +00:00
commit f44456a5db
10 changed files with 346 additions and 60 deletions

View File

@ -16,14 +16,14 @@ before_install:
matrix:
include:
- python: 3.6
env: TOXENV=py36
- python: 3.7
env: TOXENV=py37
- python: 3.7
- python: 3.8
env: TOXENV=py38
- python: 3.9
env: TOXENV=py39
- python: 3.9
env: TOXENV=coverage
install: pip install tox-travis PyGObject dbus-python aioresponses
install: pip install tox-travis aioresponses
script: tox
after_success:

View File

@ -128,7 +128,7 @@ Default location of the configuration file.
The following example shows a configured pantalaimon proxy with the name
.Em Clocktown ,
the homeserver URL is set to
.Em https://example.org ,
.Em https://localhost:8448 ,
the pantalaimon proxy is listening for client connections on the address
.Em localhost ,
and port

View File

@ -111,7 +111,7 @@ overridden using appropriate environment variables.
The following example shows a configured pantalaimon proxy with the name
*Clocktown*,
the homeserver URL is set to
*https://example.org*,
*https://localhost:8448*,
the pantalaimon proxy is listening for client connections on the address
*localhost*,
and port

View File

@ -18,8 +18,10 @@ import os
import time
import urllib.parse
import concurrent.futures
from io import BufferedReader, BytesIO
from json import JSONDecodeError
from typing import Any, Dict
from urllib.parse import urlparse
from uuid import uuid4
import aiohttp
import attr
@ -35,6 +37,7 @@ from nio import (
OlmTrustError,
SendRetryError,
DownloadResponse,
UploadResponse,
)
from nio.crypto import decrypt_attachment
@ -48,7 +51,7 @@ from pantalaimon.client import (
)
from pantalaimon.index import INDEXING_ENABLED, InvalidQueryError
from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore
from pantalaimon.store import ClientInfo, PanStore, MediaInfo
from pantalaimon.thread_messages import (
AcceptSasMessage,
CancelSasMessage,
@ -80,6 +83,11 @@ CORS_HEADERS = {
}
class NotDecryptedAvailableError(Exception):
"""Exception that signals that no decrypted upload is available"""
pass
@attr.s
class ProxyDaemon:
name = attr.ib()
@ -102,6 +110,7 @@ class ProxyDaemon:
client_info = attr.ib(init=False, default=attr.Factory(dict), type=dict)
default_session = attr.ib(init=False, default=None)
media_info = attr.ib(init=False, default=None)
upload_info = attr.ib(init=False, default=None)
database_name = "pan.db"
def __attrs_post_init__(self):
@ -112,6 +121,7 @@ class ProxyDaemon:
self.store = PanStore(self.data_dir)
accounts = self.store.load_users(self.name)
self.media_info = self.store.load_media(self.name)
self.upload_info = self.store.load_upload(self.name)
for user_id, device_id in accounts:
if self.conf.keyring:
@ -826,6 +836,60 @@ class ProxyDaemon:
body=await response.read(),
)
def _get_upload_and_media_info(self, content_key, content):
content_uri = content[content_key]
try:
upload_info = self.upload_info[content_uri]
except KeyError:
upload_info = self.store.load_upload(self.name, content_uri)
if not upload_info:
return None, None
self.upload_info[content_uri] = upload_info
content_uri = content[content_key]
mxc = urlparse(content_uri)
mxc_server = mxc.netloc.strip("/")
mxc_path = mxc.path.strip("/")
media_info = self.store.load_media(self.name, mxc_server, mxc_path)
if not media_info:
return None, None
self.media_info[(mxc_server, mxc_path)] = media_info
return upload_info, media_info
async def _map_decrypted_uri(self, content_key, content, request, client):
upload_info, media_info = self._get_upload_and_media_info(content_key, content)
if not upload_info or not media_info:
raise NotDecryptedAvailableError
response, decrypted_file = await self._load_decrypted_file(media_info.mxc_server, media_info.mxc_path,
upload_info.filename)
if response is None and decrypted_file is None:
raise NotDecryptedAvailableError
if not isinstance(response, DownloadResponse):
raise NotDecryptedAvailableError
decrypted_upload, _ = await client.upload(
data_provider=BufferedReader(BytesIO(decrypted_file)),
content_type=response.content_type,
filename=upload_info.filename,
encrypt=False,
filesize=len(decrypted_file),
)
if not isinstance(decrypted_upload, UploadResponse):
raise NotDecryptedAvailableError
content[content_key] = decrypted_upload.content_uri
return content
async def send_message(self, request):
access_token = self.get_access_token(request)
@ -851,23 +915,55 @@ class ProxyDaemon:
if request.match_info["event_type"] == "m.reaction":
encrypt = False
# The room isn't encrypted just forward the message.
if not encrypt:
return await self.forward_to_web(request, token=client.access_token)
msgtype = request.match_info["event_type"]
txnid = request.match_info.get("txnid", uuid4())
try:
content = await request.json()
except (JSONDecodeError, ContentTypeError):
return self._not_json
# The room isn't encrypted just forward the message.
if not encrypt:
content_msgtype = content["msgtype"]
if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"] or msgtype == "m.room.avatar":
try:
content = await self._map_decrypted_uri("url", content, request, client)
return await self.forward_to_web(request, data=json.dumps(content), token=client.access_token)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except (KeyError, NotDecryptedAvailableError):
return await self.forward_to_web(request, token=client.access_token)
return await self.forward_to_web(request, token=client.access_token)
txnid = request.match_info.get("txnid", uuid4())
async def _send(ignore_unverified=False):
try:
response = await client.room_send(
room_id, msgtype, content, txnid, ignore_unverified
)
content_msgtype = content["msgtype"]
if content_msgtype in ["m.image", "m.video", "m.audio", "m.file"] or msgtype == "m.room.avatar":
upload_info, media_info = self._get_upload_and_media_info("url", content)
if not upload_info or not media_info:
response = await client.room_send(
room_id, msgtype, content, txnid, ignore_unverified
)
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
media_content = media_info.to_content(content, upload_info.mimetype)
response = await client.room_send(
room_id, msgtype, media_content, txnid, ignore_unverified
)
else:
response = await client.room_send(
room_id, msgtype, content, txnid, ignore_unverified
)
return web.Response(
status=response.transport_response.status,
@ -1041,42 +1137,39 @@ class ProxyDaemon:
return web.json_response(result, headers=CORS_HEADERS, status=200)
async def download(self, request):
server_name = request.match_info["server_name"]
media_id = request.match_info["media_id"]
file_name = request.match_info.get("file_name")
try:
media_info = self.media_info[(server_name, media_id)]
except KeyError:
media_info = self.store.load_media(self.name, server_name, media_id)
if not media_info:
logger.info(f"No media info found for {server_name}/{media_id}")
return await self.forward_to_web(request)
self.media_info[(server_name, media_id)] = media_info
try:
key = media_info.key["k"]
hash = media_info.hashes["sha256"]
except KeyError:
logger.warn(
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
)
return await self.forward_to_web(request)
if not self.pan_clients:
return await self.forward_to_web(request)
async def upload(self, request):
file_name = request.query.get("filename", "")
content_type = request.headers.get("Content-Type", "application/octet-stream")
client = next(iter(self.pan_clients.values()))
body = await request.read()
try:
response = await client.download(server_name, media_id, file_name)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
response, maybe_keys = await client.upload(
data_provider=BufferedReader(BytesIO(body)),
content_type=content_type,
filename=file_name,
encrypt=True,
filesize=len(body),
)
if not isinstance(response, UploadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
self.store.save_upload(self.name, response.content_uri, file_name, content_type)
mxc = urlparse(response.content_uri)
mxc_server = mxc.netloc.strip("/")
mxc_path = mxc.path.strip("/")
logger.info(f"Adding media info for {mxc_server}/{mxc_path} to the store")
media_info = MediaInfo(mxc_server, mxc_path, maybe_keys["key"], maybe_keys["iv"], maybe_keys["hashes"])
self.store.save_media(self.name, media_info)
if not isinstance(response, DownloadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
@ -1084,6 +1177,44 @@ class ProxyDaemon:
body=await response.transport_response.read(),
)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except SendRetryError as e:
return web.Response(status=503, text=str(e))
async def _load_decrypted_file(self, server_name, media_id, file_name):
try:
media_info = self.media_info[(server_name, media_id)]
except KeyError:
media_info = self.store.load_media(self.name, server_name, media_id)
if not media_info:
logger.info(f"No media info found for {server_name}/{media_id}")
return None, None
self.media_info[(server_name, media_id)] = media_info
try:
key = media_info.key["k"]
hash = media_info.hashes["sha256"]
except KeyError as e:
logger.warn(
f"Media info for {server_name}/{media_id} doesn't contain a key or hash."
)
raise e
if not self.pan_clients:
return None, None
client = next(iter(self.pan_clients.values()))
try:
response = await client.download(server_name, media_id, file_name)
except ClientConnectionError as e:
raise e
if not isinstance(response, DownloadResponse):
return response, None
logger.info(f"Decrypting media {server_name}/{media_id}")
loop = asyncio.get_running_loop()
@ -1092,6 +1223,54 @@ class ProxyDaemon:
pool, decrypt_attachment, response.body, key, hash, media_info.iv
)
return response, decrypted_file
async def profile(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
client = await self._find_client(access_token)
if not client:
return self._unknown_token
try:
content = await request.json()
except (JSONDecodeError, ContentTypeError):
return self._not_json
try:
content = await self._map_decrypted_uri("avatar_url", content, request, client)
return await self.forward_to_web(request, data=json.dumps(content), token=client.access_token)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except (KeyError, NotDecryptedAvailableError):
return await self.forward_to_web(request, token=client.access_token)
async def download(self, request):
server_name = request.match_info["server_name"]
media_id = request.match_info["media_id"]
file_name = request.match_info.get("file_name")
try:
response, decrypted_file = await self._load_decrypted_file(server_name, media_id, file_name)
if response is None and decrypted_file is None:
return await self.forward_to_web(request)
except ClientConnectionError as e:
return web.Response(status=500, text=str(e))
except KeyError:
return await self.forward_to_web(request)
if not isinstance(response, DownloadResponse):
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,
headers=CORS_HEADERS,
body=await response.transport_response.read(),
)
return web.Response(
status=response.transport_response.status,
content_type=response.transport_response.content_type,

View File

@ -93,6 +93,15 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
"/_matrix/media/r0/download/{server_name}/{media_id}/{file_name}",
proxy.download,
),
web.post(
r"/_matrix/media/r0/upload",
proxy.upload,
),
web.put(
r"/_matrix/client/r0/profile/{userId}/avatar_url",
proxy.profile,
),
]
)
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)

View File

@ -15,7 +15,7 @@
import json
import os
from collections import defaultdict
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import attr
from nio.crypto import TrustState
@ -31,6 +31,7 @@ from cachetools import LRUCache
MAX_LOADED_MEDIA = 10000
MAX_LOADED_UPLOAD = 10000
@attr.s
@ -47,6 +48,25 @@ class MediaInfo:
iv = attr.ib(type=str)
hashes = attr.ib(type=dict)
def to_content(self, content: Dict, mime_type: str) -> Dict[Any, Any]:
content["file"] = {
"v": "v2",
"key": self.key,
"iv": self.iv,
"hashes": self.hashes,
"url": content["url"],
"mimetype": mime_type,
}
return content
@attr.s
class UploadInfo:
content_uri = attr.ib(type=str)
filename = attr.ib(type=str)
mimetype = attr.ib(type=str)
class DictField(TextField):
def python_value(self, value): # pragma: no cover
@ -113,6 +133,18 @@ class PanMediaInfo(Model):
constraints = [SQL("UNIQUE(server_id, mxc_server, mxc_path)")]
class PanUploadInfo(Model):
server = ForeignKeyField(
model=Servers, column_name="server_id", backref="upload", on_delete="CASCADE"
)
content_uri = TextField()
filename = TextField()
mimetype = TextField()
class Meta:
constraints = [SQL("UNIQUE(server_id, content_uri)")]
@attr.s
class ClientInfo:
user_id = attr.ib(type=str)
@ -135,6 +167,7 @@ class PanStore:
PanSyncTokens,
PanFetcherTasks,
PanMediaInfo,
PanUploadInfo,
]
def __attrs_post_init__(self):
@ -162,6 +195,43 @@ class PanStore:
except DoesNotExist:
return None
@use_database
def save_upload(self, server, content_uri, filename, mimetype):
server = Servers.get(name=server)
PanUploadInfo.insert(
server=server,
content_uri=content_uri,
filename=filename,
mimetype=mimetype,
).on_conflict_ignore().execute()
@use_database
def load_upload(self, server, content_uri=None):
server, _ = Servers.get_or_create(name=server)
if not content_uri:
upload_cache = LRUCache(maxsize=MAX_LOADED_UPLOAD)
for i, u in enumerate(server.upload):
if i > MAX_LOADED_UPLOAD:
break
upload = UploadInfo(u.content_uri, u.filename, u.mimetype)
upload_cache[u.content_uri] = upload
return upload_cache
else:
u = PanUploadInfo.get_or_none(
PanUploadInfo.server == server,
PanUploadInfo.content_uri == content_uri,
)
if not u:
return None
return UploadInfo(u.content_uri, u.filename, u.mimetype)
@use_database
def save_media(self, server, media):
server = Servers.get(name=server)
@ -226,6 +296,7 @@ class PanStore:
user=user, room_id=task.room_id, token=task.token
).execute()
@use_database
def load_fetcher_tasks(self, server, pan_user):
server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user)

View File

@ -30,6 +30,7 @@ if UI_ENABLED:
from gi.repository import GLib
from pydbus import SessionBus
from pydbus.generic import signal
from dbus.mainloop.glib import DBusGMainLoop
from nio import RoomKeyRequest, RoomKeyRequestCancellation
@ -447,6 +448,7 @@ if UI_ENABLED:
config = attr.ib()
loop = attr.ib(init=False)
dbus_loop = attr.ib(init=False)
store = attr.ib(init=False)
users = attr.ib(init=False)
devices = attr.ib(init=False)
@ -457,6 +459,7 @@ if UI_ENABLED:
def __attrs_post_init__(self):
self.loop = None
self.dbus_loop = None
id_counter = IdCounter()
@ -632,11 +635,12 @@ if UI_ENABLED:
return True
def run(self):
self.dbus_loop = DBusGMainLoop()
self.loop = GLib.MainLoop()
if self.config.notifications:
try:
notify2.init("pantalaimon", mainloop=self.loop)
notify2.init("pantalaimon", mainloop=self.dbus_loop)
self.notifications = True
except dbus.DBusException:
logger.error(
@ -646,6 +650,7 @@ if UI_ENABLED:
self.notifications = False
GLib.timeout_add(100, self.message_callback)
if not self.loop:
return

View File

@ -26,15 +26,15 @@ setup(
"logbook >= 1.5.3",
"peewee >= 3.13.1",
"janus >= 0.5",
"cachetools >= 3.0.0"
"prompt_toolkit>2<4",
"cachetools >= 3.0.0",
"prompt_toolkit > 2, < 4",
"typing;python_version<'3.5'",
"matrix-nio[e2e] >= 0.14, < 0.16"
],
extras_require={
"ui": [
"dbus-python >= 1.2, < 1.3",
"PyGObject >= 3.36, < 3.37",
"PyGObject >= 3.36, < 3.39",
"pydbus >= 0.6, < 0.7",
"notify2 >= 0.3, < 0.4",
]

View File

@ -8,7 +8,7 @@ from nio import RoomMessage, RoomEncryptedMedia
from urllib.parse import urlparse
from conftest import faker
from pantalaimon.index import INDEXING_ENABLED
from pantalaimon.store import FetchTask, MediaInfo
from pantalaimon.store import FetchTask, MediaInfo, UploadInfo
TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost"
TEST_ROOM2 = "!testroom:localhost"
@ -177,3 +177,25 @@ class TestClass(object):
media_info = media_cache[(mxc_server, mxc_path)]
assert media_info == media
assert media_info == panstore.load_media(server_name, mxc_server, mxc_path)
def test_upload_storage(self, panstore):
server_name = "test"
upload_cache = panstore.load_upload(server_name)
assert not upload_cache
filename = "orange_cat.jpg"
mimetype = "image/jpeg"
event = self.encrypted_media_event
assert not panstore.load_upload(server_name, event.url)
upload = UploadInfo(event.url, filename, mimetype)
panstore.save_upload(server_name, event.url, filename, mimetype)
upload_cache = panstore.load_upload(server_name)
assert (event.url) in upload_cache
upload_info = upload_cache[event.url]
assert upload_info == upload
assert upload_info == panstore.load_upload(server_name, event.url)

10
tox.ini
View File

@ -1,11 +1,11 @@
# content of: tox.ini , put in same dir as setup.py
[tox]
envlist = py36,py37,coverage
envlist = py38,py39,coverage
[testenv]
basepython =
py36: python3.6
py37: python3.7
py3: python3.7
py38: python3.8
py39: python3.9
py3: python3.9
deps = -rtest-requirements.txt
install_command = pip install {opts} {packages}
@ -15,7 +15,7 @@ commands = pytest
usedevelop = True
[testenv:coverage]
basepython = python3.7
basepython = python3.9
commands =
pytest --cov=pantalaimon --cov-report term-missing
coverage xml