main: Route ui messages to the correct daemon.

This commit is contained in:
Damir Jelić 2019-05-08 15:47:16 +02:00
parent 2946acbd6c
commit cc7b257345
2 changed files with 86 additions and 74 deletions

View File

@ -43,7 +43,6 @@ class ProxyDaemon:
store = attr.ib(type=PanStore, init=False)
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
queue_task = attr.ib(init=False)
client_info = attr.ib(
init=False,
default=attr.Factory(dict),
@ -89,9 +88,6 @@ class ProxyDaemon:
pan_client.start_loop()
loop = asyncio.get_event_loop()
self.queue_task = loop.create_task(self.queue_loop())
async def _verify_device(self, client, device):
ret = client.verify_device(device)
@ -123,72 +119,53 @@ class ProxyDaemon:
message = InfoMessage(string)
await self.queue.put(message)
async def queue_loop(self):
while True:
message = await self.recv_queue.get()
logger.debug(f"Daemon got message {message}")
async def receive_message(self, message):
client = self.pan_clients.get(message.pan_user)
if isinstance(
message,
(DeviceVerifyMessage, DeviceUnverifyMessage,
DeviceConfirmSasMessage)
):
client = self.pan_clients.get(message.pan_user, None)
if isinstance(
message,
(DeviceVerifyMessage, DeviceUnverifyMessage,
DeviceConfirmSasMessage)
):
if not client:
msg = f"No pan client found for {message.pan_user}."
logger.warn(msg)
self.send_info(msg)
return
device = client.device_store[message.user_id].get(
message.device_id,
None
)
device = client.device_store[message.user_id].get(
message.device_id,
None
)
if not device:
msg = (f"No device found for {message.user_id} and "
f"{message.device_id}")
await self.send_info(msg)
logger.info(msg)
return
if not device:
msg = (f"No device found for {message.user_id} and "
f"{message.device_id}")
await self.send_info(msg)
logger.info(msg)
return
if isinstance(message, DeviceVerifyMessage):
await self._verify_device(client, device)
elif isinstance(message, DeviceUnverifyMessage):
await self._unverify_device(client, device)
elif isinstance(message, DeviceConfirmSasMessage):
await client.confirm_sas(message)
if isinstance(message, DeviceVerifyMessage):
await self._verify_device(client, device)
elif isinstance(message, DeviceUnverifyMessage):
await self._unverify_device(client, device)
elif isinstance(message, DeviceConfirmSasMessage):
await client.confirm_sas(message)
elif isinstance(message, ExportKeysMessage):
path = os.path.abspath(message.file_path)
logger.info(f"Exporting keys to {path}")
elif isinstance(message, ExportKeysMessage):
client = self.pan_clients.get(message.pan_user, None)
try:
client.export_keys(path, message.passphrase)
except OSError as e:
logger.warn(f"Error exporting keys for {client.user_id} to"
f" {path} {e}")
if not client:
return
elif isinstance(message, ImportKeysMessage):
path = os.path.abspath(message.file_path)
logger.info(f"Importing keys from {path}")
path = os.path.abspath(message.file_path)
logger.info(f"Exporting keys to {path}")
try:
client.export_keys(path, message.passphrase)
except OSError as e:
logger.warn(f"Error exporting keys for {client.user_id} to"
f" {path} {e}")
elif isinstance(message, ImportKeysMessage):
client = self.pan_clients.get(message.pan_user, None)
if not client:
return
path = os.path.abspath(message.file_path)
logger.info(f"Importing keys from {path}")
try:
client.import_keys(path, message.passphrase)
except (OSError, EncryptionError) as e:
logger.warn(f"Error importing keys for {client.user_id} "
f"from {path} {e}")
try:
client.import_keys(path, message.passphrase)
except (OSError, EncryptionError) as e:
logger.warn(f"Error importing keys for {client.user_id} "
f"from {path} {e}")
def get_access_token(self, request):
# type: (aiohttp.web.BaseRequest) -> str
@ -670,5 +647,3 @@ class ProxyDaemon:
if self.default_session:
await self.default_session.close()
self.default_session = None
self.queue_task.cancel()

View File

@ -5,12 +5,14 @@ import os
import click
import janus
from typing import Optional
from appdirs import user_data_dir, user_config_dir
from logbook import StderrHandler
from aiohttp import web
from pantalaimon.ui import GlibT
from pantalaimon.ui import GlibT, InfoMessage
from pantalaimon.daemon import ProxyDaemon
from pantalaimon.config import PanConfig, PanConfigError, parse_log_level
from pantalaimon.log import logger
@ -67,6 +69,34 @@ async def init(data_dir, server_conf, send_queue, recv_queue):
return proxy, runner, site
async def message_router(receive_queue, send_queue, proxies):
"""Find the recipient of a message and forward it to the right proxy."""
def find_proxy_by_user(user):
# type: (str) -> Optional[ProxyDaemon]
for proxy in proxies:
if user in proxy.pan_clients:
return proxy
return None
async def send_info(string):
message = InfoMessage(string)
await send_queue.put(message)
while True:
message = await receive_queue.get()
logger.debug(f"Router got message {message}")
proxy = find_proxy_by_user(message.pan_user)
if not proxy:
msg = f"No pan client found for {message.pan_user}."
logger.warn(msg)
send_info(msg)
await proxy.receive_message(message)
@click.command(
help=("pantalaimon is a reverse proxy for matrix homeservers that "
"transparently encrypts and decrypts messages for clients that "
@ -113,18 +143,19 @@ def main(
ui_queue = janus.Queue(loop=loop)
servers = []
proxies = []
for server_conf in pan_conf.servers.values():
servers.append(
loop.run_until_complete(
init(
data_dir,
server_conf,
pan_queue.async_q,
ui_queue.async_q
)
proxy, runner, site = loop.run_until_complete(
init(
data_dir,
server_conf,
pan_queue.async_q,
ui_queue.async_q
)
)
servers.append((proxy, runner, site))
proxies.append(proxy)
glib_thread = GlibT(pan_queue.sync_q, ui_queue.sync_q, data_dir)
@ -137,12 +168,16 @@ def main(
glib_thread.stop()
await fut
message_router_task = loop.create_task(
message_router(ui_queue.async_q, pan_queue.async_q, proxies)
)
home = os.path.expanduser("~")
os.chdir(home)
try:
for proxy, _, site in servers:
click.echo(f"======== Starting daemon for homserver "
click.echo(f"======== Starting daemon for homeserver "
f"{proxy.name} on {site.name} ========")
loop.run_until_complete(site.start())
@ -153,6 +188,8 @@ def main(
loop.run_until_complete(runner.cleanup())
loop.run_until_complete(wait_for_glib(glib_thread, glib_fut))
message_router_task.cancel()
loop.run_until_complete(asyncio.wait({message_router_task}))
loop.close()