pantalaimon/pantalaimon/daemon.py

640 lines
18 KiB
Python
Executable File

#!/usr/bin/env python3
import asyncio
import json
import os
import sys
from ipaddress import ip_address
from json import JSONDecodeError
from urllib.parse import urlparse
import aiohttp
import attr
import click
import keyring
import logbook
from aiohttp import ClientSession, web
from aiohttp.client_exceptions import ContentTypeError
from appdirs import user_data_dir
from logbook import StderrHandler
from multidict import CIMultiDict
from nio import EncryptionError, GroupEncryptionError, LoginResponse
from pantalaimon.client import PanClient
from pantalaimon.log import logger
from pantalaimon.store import ClientInfo, PanStore
@attr.s
class ProxyDaemon:
homeserver = attr.ib()
data_dir = attr.ib()
proxy = attr.ib(default=None)
ssl = attr.ib(default=None)
store = attr.ib(type=PanStore, init=False)
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
client_info = attr.ib(
init=False,
default=attr.Factory(dict),
type=dict
)
default_session = attr.ib(init=False, default=None)
database_name = "pan.db"
def __attrs_post_init__(self):
self.homeserver_url = self.homeserver.geturl()
self.hostname = self.homeserver.hostname
self.store = PanStore(self.data_dir)
accounts = self.store.load_users(self.hostname)
self.client_info = self.store.load_clients(self.hostname)
for user_id, device_id in accounts:
token = keyring.get_password(
"pantalaimon",
f"{user_id}-{device_id}-token"
)
if not token:
logger.warn(f"Not restoring client for {user_id} {device_id}, "
f"missing access token.")
continue
logger.info(f"Restoring client for {user_id} {device_id}")
pan_client = PanClient(
self.homeserver_url,
user_id,
device_id,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
)
pan_client.user_id = user_id
pan_client.access_token = token
pan_client.load_store()
self.pan_clients[user_id] = pan_client
pan_client.start_loop()
def get_access_token(self, request):
# type: (aiohttp.web.BaseRequest) -> str
"""Extract the access token from the request.
This method extracts the access token either from the query string or
from the Authorization header of the request.
Returns the access token if it was found.
"""
access_token = request.query.get("access_token", "")
if not access_token:
access_token = request.headers.get(
"Authorization",
""
).strip("Bearer ")
return access_token
async def forward_request(
self,
request, # type: aiohttp.web.BaseRequest
params=None, # type: CIMultiDict
session=None # type: aiohttp.ClientSession
):
# type: (...) -> aiohttp.ClientResponse
"""Forward the given request to our configured homeserver.
Args:
request (aiohttp.BaseRequest): The request that should be
forwarded.
session (aiohttp.ClientSession): The client session that should be
used to forward the request.
"""
if not session:
if not self.default_session:
self.default_session = ClientSession()
session = self.default_session
assert session
path = request.path
method = request.method
headers = CIMultiDict(request.headers)
headers.pop("Host", None)
params = params or request.query
data = await request.text()
return await session.request(
method,
self.homeserver_url + path,
data=data,
params=params,
headers=headers,
proxy=self.proxy,
ssl=self.ssl
)
async def router(self, request):
"""Catchall request router."""
resp = await self.forward_request(request)
return(
await self.to_web_response(resp)
)
def _get_login_user(self, body):
identifier = body.get("identifier", None)
if identifier:
user = identifier.get("user", None)
if not user:
user = body.get("user", "")
else:
user = body.get("user", "")
return user
async def start_pan_client(self, access_token, user, user_id, password):
client = ClientInfo(user_id, access_token)
self.client_info[access_token] = client
self.store.save_client(self.hostname, client)
self.store.save_server_user(self.hostname, user_id)
if user_id in self.pan_clients:
logger.info(f"Background sync client already exists for {user_id},"
f" not starting new one")
return
pan_client = PanClient(
self.homeserver_url,
user,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
)
response = await pan_client.login(password, "pantalaimon")
if not isinstance(response, LoginResponse):
await pan_client.close()
return
logger.info(f"Succesfully started new background sync client for "
f"{user_id}")
self.pan_clients[user_id] = pan_client
keyring.set_password(
"pantalaimon",
f"{user_id}-{pan_client.device_id}-token",
pan_client.access_token
)
pan_client.start_loop()
async def login(self, request):
try:
body = await request.json()
except (JSONDecodeError, ContentTypeError):
# After a long debugging session the culprit ended up being aiohttp
# and a similar bug to
# https://github.com/aio-libs/aiohttp/issues/2277 but in the server
# part of aiohttp. The bug is fixed in the latest master of
# aiohttp.
# Return 500 here for now since quaternion doesn't work otherwise.
# After aiohttp 4.0 gets replace this with a 400 M_NOT_JSON
# response.
return web.Response(
status=500,
text=json.dumps({
"errcode": "M_NOT_JSON",
"error": "Request did not contain valid JSON."
})
)
user = self._get_login_user(body)
password = body.get("password", "")
logger.info(f"New user logging in: {user}")
response = await self.forward_request(request)
try:
json_response = await response.json()
except (JSONDecodeError, ContentTypeError):
json_response = None
pass
if response.status == 200 and json_response:
user_id = json_response.get("user_id", None)
access_token = json_response.get("access_token", None)
if user_id and access_token:
logger.info(f"User: {user} succesfully logged in, starting "
f"a background sync client.")
await self.start_pan_client(access_token, user, user_id,
password)
return web.Response(
status=response.status,
text=await response.text()
)
@property
def _missing_token(self):
return web.Response(
status=401,
text=json.dumps({
"errcode": "M_MISSING_TOKEN",
"error": "Missing access token."
})
)
@property
def _unknown_token(self):
return web.Response(
status=401,
text=json.dumps({
"errcode": "M_UNKNOWN_TOKEN",
"error": "Unrecognised access token."
})
)
@property
def _not_json(self):
return web.Response(
status=400,
text=json.dumps({
"errcode": "M_NOT_JSON",
"error": "Request did not contain valid JSON."
})
)
async def sync(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
try:
client_info = self.client_info[access_token]
client = self.pan_clients[client_info.user_id]
except KeyError:
return self._unknown_token
sync_filter = request.query.get("filter", None)
# timeout = request.query.get("timeout", None)
try:
sync_filter = json.loads(sync_filter)
except (JSONDecodeError, TypeError):
pass
if isinstance(sync_filter, int):
sync_filter = None
# TODO edit the sync filter to not filter encrypted messages
# TODO do the same with an uploaded filter
# room_filter = sync_filter.get("room", None)
# if room_filter:
# timeline_filter = room_filter.get("timeline", None)
# if timeline_filter:
# types_filter = timeline_filter.get("types", None)
query = CIMultiDict(request.query)
query.pop("filter", None)
response = await self.forward_request(request, query)
if response.status == 200:
json_response = await response.json()
json_response = client.decrypt_sync_body(json_response)
return web.Response(
status=response.status,
text=json.dumps(json_response)
)
else:
return web.Response(
status=response.status,
text=await response.text()
)
async def messages(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
try:
client_info = self.client_info[access_token]
client = self.pan_clients[client_info.user_id]
except KeyError:
return self._unknown_token
response = await self.forward_request(request)
if response.status == 200:
json_response = await response.json()
json_response = client.decrypt_messages_body(json_response)
return web.Response(
status=response.status,
text=json.dumps(json_response)
)
else:
return web.Response(
status=response.status,
text=await response.text()
)
async def to_web_response(self, response):
return web.Response(status=response.status, text=await response.text())
async def send_message(self, request):
access_token = self.get_access_token(request)
if not access_token:
return self._missing_token
try:
client_info = self.client_info[access_token]
client = self.pan_clients[client_info.user_id]
except KeyError:
return self._unknown_token
room_id = request.match_info["room_id"]
try:
encrypt = client.rooms[room_id].encrypted
except KeyError:
return await self.to_web_response(
await self.forward_request(request)
)
if not encrypt:
return await self.to_web_response(
await self.forward_request(request)
)
msgtype = request.match_info["event_type"]
txnid = request.match_info["txnid"]
try:
content = await request.json()
except (JSONDecodeError, ContentTypeError):
return self._not_json
try:
response = await client.room_send(room_id, msgtype, content, txnid)
except GroupEncryptionError:
await client.share_group_session(room_id)
response = await client.room_send(room_id, msgtype, content, txnid)
return web.Response(
status=response.transport_response.status,
text=await response.transport_response.text()
)
async def shutdown(self, app):
"""Shut the daemon down closing all the client sessions it has.
This method is called when we shut the whole app down
"""
for client in self.pan_clients.values():
await client.loop_stop()
await client.close()
if self.default_session:
await self.default_session.close()
self.default_session = None
async def init(homeserver, http_proxy, ssl):
"""Initialize the proxy and the http server."""
data_dir = user_data_dir("pantalaimon", "")
try:
os.makedirs(data_dir)
except OSError:
pass
proxy = ProxyDaemon(homeserver, data_dir, proxy=http_proxy, ssl=ssl)
app = web.Application()
app.add_routes([
web.post("/_matrix/client/r0/login", proxy.login),
web.get("/_matrix/client/r0/sync", proxy.sync),
web.get("/_matrix/client/r0/rooms/{room_id}/messages", proxy.messages),
web.put(
r"/_matrix/client/r0/rooms/{room_id}/send/{event_type}/{txnid}",
proxy.send_message
),
])
app.router.add_route("*", "/" + "{proxyPath:.*}", proxy.router)
app.on_shutdown.append(proxy.shutdown)
return proxy, app
class URL(click.ParamType):
name = 'url'
def convert(self, value, param, ctx):
try:
value = urlparse(value)
if value.scheme not in ('http', 'https'):
self.fail(f"Invalid URL scheme {value.scheme}. Only HTTP(s) "
"URLs are allowed")
value.port
except ValueError as e:
self.fail(f"Error parsing URL: {e}")
return value
class ipaddress(click.ParamType):
name = "ipaddress"
def convert(self, value, param, ctx):
try:
value = ip_address(value)
except ValueError as e:
self.fail(f"Error parsing ip address: {e}")
return value
@click.group(
help=("pantalaimon is a reverse proxy for matrix homeservers that "
"transparently encrypts and decrypts messages for clients that "
"connect to pantalaimon.\n\n"
"HOMESERVER - the homeserver that the daemon should connect to.")
)
def cli():
pass
def _find_device(user):
data_dir = user_data_dir("pantalaimon", "")
store = PanStore(data_dir)
accounts = store.load_all_users()
for user_id, device in accounts:
if user == user_id:
return device
click.echo("No such user/device combination found.")
sys.exit()
@cli.command(
"keys-import",
help="Import encryption keys into the pantalaimon store."
)
@click.argument("user", type=str)
@click.argument("infile", type=click.Path(exists=True))
@click.argument("passphrase", type=str)
def keys_import(user, infile, passphrase):
device = _find_device(user)
data_dir = user_data_dir("pantalaimon", "")
click.echo(f"Importing encryption keys for {user}, {device}...")
client = PanClient("", user, device, data_dir)
client.user_id = user
client.load_store()
try:
client.import_keys(infile, passphrase)
except (OSError, EncryptionError) as e:
click.echo(f"Error importing keys: {e}")
return
click.echo(
f"Succesfully imported encryption keys for {user}, {device}."
)
@cli.command(
"keys-export",
help="Export encryption keys from the pantalaimon store."
)
@click.argument("user", type=str)
@click.argument("outfile", type=click.Path())
@click.argument("passphrase", type=str)
def keys_export(user, outfile, passphrase):
device = _find_device(user)
data_dir = user_data_dir("pantalaimon", "")
click.echo(f"Exporting encryption keys for {user}, {device}...")
client = PanClient("", user, device, data_dir)
client.user_id = user
client.load_store()
try:
client.export_keys(outfile, passphrase)
except OSError as e:
click.echo(f"Error exporting keys: {e}")
return
click.echo(
f"Succesfully exported encryption keys for {user}, {device}."
)
@cli.command("list-users", help="List the user/device pairs of the daemon")
def list_users():
data_dir = user_data_dir("pantalaimon", "")
store = PanStore(data_dir)
accounts = store.load_all_users()
click.echo(f"Pantalaimon users:")
for user, device in accounts:
click.echo(f" {user} - {device}")
@cli.command(help=("Start the daemon"))
@click.option(
"--proxy",
type=URL(),
default=None,
help="A proxy that will be used to connect to the homeserver."
)
@click.option(
"-k",
"--ssl-insecure/--no-ssl-insecure",
default=False,
help="Disable SSL verification for the homeserver connection."
)
@click.option(
"-l",
"--listen-address",
type=ipaddress(),
default=ip_address("127.0.0.1"),
help=("The listening address for incoming client connections "
"(default: 127.0.0.1)")
)
@click.option(
"-p",
"--listen-port",
type=int,
default=8009,
help="The listening port for incoming client connections (default: 8009)"
)
@click.option("--log-level", type=click.Choice([
"error",
"warning",
"info",
"debug"
]), default="error")
@click.argument(
"homeserver",
type=URL(),
)
def start(
proxy,
ssl_insecure,
listen_address,
listen_port,
log_level,
homeserver
):
ssl = None if ssl_insecure is False else False
StderrHandler(level=log_level.upper()).push_application()
if log_level == "info":
logger.level = logbook.INFO
elif log_level == "warning":
logger.level = logbook.WARNING
elif log_level == "error":
logger.level = logbook.ERROR
elif log_level == "debug":
logger.level = logbook.DEBUG
loop = asyncio.get_event_loop()
proxy, app = loop.run_until_complete(init(
homeserver,
proxy.geturl() if proxy else None,
ssl
))
web.run_app(app, host=str(listen_address), port=listen_port)
if __name__ == "__main__":
cli()