main: Route ui messages to the correct daemon.

This commit is contained in:
Damir Jelić 2019-05-08 15:47:16 +02:00
parent 2946acbd6c
commit cc7b257345
2 changed files with 86 additions and 74 deletions

View File

@ -43,7 +43,6 @@ class ProxyDaemon:
store = attr.ib(type=PanStore, init=False) store = attr.ib(type=PanStore, init=False)
homeserver_url = attr.ib(init=False, default=attr.Factory(dict)) homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict)) pan_clients = attr.ib(init=False, default=attr.Factory(dict))
queue_task = attr.ib(init=False)
client_info = attr.ib( client_info = attr.ib(
init=False, init=False,
default=attr.Factory(dict), default=attr.Factory(dict),
@ -89,9 +88,6 @@ class ProxyDaemon:
pan_client.start_loop() pan_client.start_loop()
loop = asyncio.get_event_loop()
self.queue_task = loop.create_task(self.queue_loop())
async def _verify_device(self, client, device): async def _verify_device(self, client, device):
ret = client.verify_device(device) ret = client.verify_device(device)
@ -123,72 +119,53 @@ class ProxyDaemon:
message = InfoMessage(string) message = InfoMessage(string)
await self.queue.put(message) await self.queue.put(message)
async def queue_loop(self): async def receive_message(self, message):
while True: client = self.pan_clients.get(message.pan_user)
message = await self.recv_queue.get()
logger.debug(f"Daemon got message {message}")
if isinstance( if isinstance(
message, message,
(DeviceVerifyMessage, DeviceUnverifyMessage, (DeviceVerifyMessage, DeviceUnverifyMessage,
DeviceConfirmSasMessage) DeviceConfirmSasMessage)
): ):
client = self.pan_clients.get(message.pan_user, None)
if not client: device = client.device_store[message.user_id].get(
msg = f"No pan client found for {message.pan_user}." message.device_id,
logger.warn(msg) None
self.send_info(msg) )
return
device = client.device_store[message.user_id].get( if not device:
message.device_id, msg = (f"No device found for {message.user_id} and "
None f"{message.device_id}")
) await self.send_info(msg)
logger.info(msg)
return
if not device: if isinstance(message, DeviceVerifyMessage):
msg = (f"No device found for {message.user_id} and " await self._verify_device(client, device)
f"{message.device_id}") elif isinstance(message, DeviceUnverifyMessage):
await self.send_info(msg) await self._unverify_device(client, device)
logger.info(msg) elif isinstance(message, DeviceConfirmSasMessage):
return await client.confirm_sas(message)
if isinstance(message, DeviceVerifyMessage): elif isinstance(message, ExportKeysMessage):
await self._verify_device(client, device) path = os.path.abspath(message.file_path)
elif isinstance(message, DeviceUnverifyMessage): logger.info(f"Exporting keys to {path}")
await self._unverify_device(client, device)
elif isinstance(message, DeviceConfirmSasMessage):
await client.confirm_sas(message)
elif isinstance(message, ExportKeysMessage): try:
client = self.pan_clients.get(message.pan_user, None) client.export_keys(path, message.passphrase)
except OSError as e:
logger.warn(f"Error exporting keys for {client.user_id} to"
f" {path} {e}")
if not client: elif isinstance(message, ImportKeysMessage):
return path = os.path.abspath(message.file_path)
logger.info(f"Importing keys from {path}")
path = os.path.abspath(message.file_path) try:
logger.info(f"Exporting keys to {path}") client.import_keys(path, message.passphrase)
except (OSError, EncryptionError) as e:
try: logger.warn(f"Error importing keys for {client.user_id} "
client.export_keys(path, message.passphrase) f"from {path} {e}")
except OSError as e:
logger.warn(f"Error exporting keys for {client.user_id} to"
f" {path} {e}")
elif isinstance(message, ImportKeysMessage):
client = self.pan_clients.get(message.pan_user, None)
if not client:
return
path = os.path.abspath(message.file_path)
logger.info(f"Importing keys from {path}")
try:
client.import_keys(path, message.passphrase)
except (OSError, EncryptionError) as e:
logger.warn(f"Error importing keys for {client.user_id} "
f"from {path} {e}")
def get_access_token(self, request): def get_access_token(self, request):
# type: (aiohttp.web.BaseRequest) -> str # type: (aiohttp.web.BaseRequest) -> str
@ -670,5 +647,3 @@ class ProxyDaemon:
if self.default_session: if self.default_session:
await self.default_session.close() await self.default_session.close()
self.default_session = None self.default_session = None
self.queue_task.cancel()

View File

@ -5,12 +5,14 @@ import os
import click import click
import janus import janus
from typing import Optional
from appdirs import user_data_dir, user_config_dir from appdirs import user_data_dir, user_config_dir
from logbook import StderrHandler from logbook import StderrHandler
from aiohttp import web from aiohttp import web
from pantalaimon.ui import GlibT from pantalaimon.ui import GlibT, InfoMessage
from pantalaimon.daemon import ProxyDaemon from pantalaimon.daemon import ProxyDaemon
from pantalaimon.config import PanConfig, PanConfigError, parse_log_level from pantalaimon.config import PanConfig, PanConfigError, parse_log_level
from pantalaimon.log import logger from pantalaimon.log import logger
@ -67,6 +69,34 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
return proxy, runner, site return proxy, runner, site
async def message_router(receive_queue, send_queue, proxies):
"""Find the recipient of a message and forward it to the right proxy."""
def find_proxy_by_user(user):
# type: (str) -> Optional[ProxyDaemon]
for proxy in proxies:
if user in proxy.pan_clients:
return proxy
return None
async def send_info(string):
message = InfoMessage(string)
await send_queue.put(message)
while True:
message = await receive_queue.get()
logger.debug(f"Router got message {message}")
proxy = find_proxy_by_user(message.pan_user)
if not proxy:
msg = f"No pan client found for {message.pan_user}."
logger.warn(msg)
send_info(msg)
await proxy.receive_message(message)
@click.command( @click.command(
help=("pantalaimon is a reverse proxy for matrix homeservers that " help=("pantalaimon is a reverse proxy for matrix homeservers that "
"transparently encrypts and decrypts messages for clients that " "transparently encrypts and decrypts messages for clients that "
@ -113,18 +143,19 @@ def main(
ui_queue = janus.Queue(loop=loop) ui_queue = janus.Queue(loop=loop)
servers = [] servers = []
proxies = []
for server_conf in pan_conf.servers.values(): for server_conf in pan_conf.servers.values():
servers.append( proxy, runner, site = loop.run_until_complete(
loop.run_until_complete( init(
init( data_dir,
data_dir, server_conf,
server_conf, pan_queue.async_q,
pan_queue.async_q, ui_queue.async_q
ui_queue.async_q
)
) )
) )
servers.append((proxy, runner, site))
proxies.append(proxy)
glib_thread = GlibT(pan_queue.sync_q, ui_queue.sync_q, data_dir) glib_thread = GlibT(pan_queue.sync_q, ui_queue.sync_q, data_dir)
@ -137,12 +168,16 @@ def main(
glib_thread.stop() glib_thread.stop()
await fut await fut
message_router_task = loop.create_task(
message_router(ui_queue.async_q, pan_queue.async_q, proxies)
)
home = os.path.expanduser("~") home = os.path.expanduser("~")
os.chdir(home) os.chdir(home)
try: try:
for proxy, _, site in servers: for proxy, _, site in servers:
click.echo(f"======== Starting daemon for homserver " click.echo(f"======== Starting daemon for homeserver "
f"{proxy.name} on {site.name} ========") f"{proxy.name} on {site.name} ========")
loop.run_until_complete(site.start()) loop.run_until_complete(site.start())
@ -153,6 +188,8 @@ def main(
loop.run_until_complete(runner.cleanup()) loop.run_until_complete(runner.cleanup())
loop.run_until_complete(wait_for_glib(glib_thread, glib_fut)) loop.run_until_complete(wait_for_glib(glib_thread, glib_fut))
message_router_task.cancel()
loop.run_until_complete(asyncio.wait({message_router_task}))
loop.close() loop.close()