daemon: Decouple the client sync from the daemon sync.

This commit is contained in:
Damir Jelić 2019-04-04 19:39:44 +02:00
parent f27eb836fe
commit 1378dca195
2 changed files with 214 additions and 70 deletions

View File

@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict
from pprint import pformat
@ -6,7 +7,10 @@ from nio import (
RoomEncryptedEvent,
MegolmEvent,
EncryptionError,
SyncResponse
SyncResponse,
KeysQueryResponse,
LocalProtocolError,
GroupEncryptionError
)
from pantalaimon.log import logger
@ -15,6 +19,96 @@ from pantalaimon.log import logger
class PantaClient(AsyncClient):
"""A wrapper class around a nio AsyncClient extending its functionality."""
def __init__(
self,
homeserver,
user="",
device_id="",
store_path="",
config=None,
ssl=None,
proxy=None
):
super().__init__(homeserver, user, device_id, store_path, config,
ssl, proxy)
self.loop_running = False
self.loop_stopped = asyncio.Event()
self.synced = asyncio.Event()
def verify_devices(self, changed_devices):
# Verify new devices automatically for now.
for user_id, device_dict in changed_devices.items():
for device in device_dict.values():
if device.deleted:
continue
logger.info("Automatically verifying device {} of "
"user {}".format(device.id, user_id))
self.verify_device(device)
async def loop(self):
"""Start a loop that runs forever and keeps on syncing with the server.
The loop can be stopped with the stop_loop() method.
"""
self.loop_running = True
self.loop_stopped.clear()
logger.info(f"Starting sync loop for {self.user_id}")
while self.loop_running:
if not self.logged_in:
# TODO login
pass
# TODO use user lazy loading here
response = await self.sync(30000)
if self.should_upload_keys:
await self.keys_upload()
if self.should_query_keys:
key_query_response = await self.keys_query()
if isinstance(key_query_response, KeysQueryResponse):
self.verify_devices(key_query_response.changed)
if not isinstance(response, SyncResponse):
# TODO error handling
pass
self.synced.set()
self.synced.clear()
logger.info("Stopping the sync loop")
self.loop_stopped.set()
async def loop_stop(self):
"""Stop the client loop.
Raises LocalProtocolError if the loop isn't running.
"""
if not self.loop_running:
LocalProtocolError("Loop is not running")
self.loop_running = False
await self.loop_stopped.wait()
async def encrypt(self, room_id, msgtype, content):
try:
return super().encrypt(
room_id,
msgtype,
content
)
except GroupEncryptionError:
await self.share_group_session(room_id)
return super().encrypt(
room_id,
msgtype,
content
)
def decrypt_sync_body(self, body):
# type: (Dict[Any, Any]) -> Dict[Any, Any]
"""Go through a json sync response and decrypt megolm encrypted events.
@ -34,7 +128,7 @@ class PantaClient(AsyncClient):
for event in room_dict["timeline"]["events"]:
if event["type"] != "m.room.encrypted":
logger.info("Event is not encrypted: "
"{}".format(pformat(event)))
"\n{}".format(pformat(event)))
continue
parsed_event = RoomEncryptedEvent.parse_event(event)
@ -42,7 +136,7 @@ class PantaClient(AsyncClient):
if not isinstance(parsed_event, MegolmEvent):
logger.warn("Encrypted event is not a megolm event:"
"{}".format(pformat(event)))
"\n{}".format(pformat(event)))
continue
try:

View File

@ -5,6 +5,7 @@ import asyncio
import aiohttp
import os
import json
import ssl
import click
from ipaddress import ip_address
@ -26,6 +27,12 @@ from pantalaimon.client import PantaClient
from pantalaimon.log import logger
@attr.s
class Client:
user_id = attr.ib(type=str)
access_token = attr.ib(type=str)
@attr.s
class ProxyDaemon:
homeserver = attr.ib()
@ -33,7 +40,12 @@ class ProxyDaemon:
proxy = attr.ib(default=None)
ssl = attr.ib(default=None)
client_sessions = attr.ib(init=False, default=attr.Factory(dict))
panta_clients = attr.ib(init=False, default=attr.Factory(dict))
client_info = attr.ib(
init=False,
default=attr.Factory(dict),
type=dict
)
default_session = attr.ib(init=False, default=None)
def get_access_token(self, request):
@ -55,7 +67,12 @@ class ProxyDaemon:
return access_token
async def forward_request(self, request, session):
async def forward_request(
self,
request,
params=None,
session=None
):
# type: (aiohttp.BaseRequest, aiohttp.ClientSession) -> str
"""Forward the given request to our configured homeserver.
@ -65,14 +82,21 @@ class ProxyDaemon:
session (aiohttp.ClientSession): The client session that should be
used to forward the request.
"""
if not session:
if not self.default_session:
self.default_session = ClientSession()
session = self.default_session
path = request.path
method = request.method
data = await request.text()
headers = CIMultiDict(request.headers)
headers.pop("Host", None)
params = request.query
params = params or request.query
data = await request.text()
return await session.request(
method,
@ -81,25 +105,16 @@ class ProxyDaemon:
params=params,
headers=headers,
proxy=self.proxy,
ssl=False
ssl=self.ssl
)
async def router(self, request):
"""Catchall request router."""
session = None
resp = await self.forward_request(request)
token = self.get_access_token(request)
client = self.client_sessions.get(token, None)
if client:
session = client.client_session
else:
if not self.default_session:
self.default_session = ClientSession()
session = self.default_session
resp = await self.forward_request(request, session)
return(web.Response(text=await resp.text()))
return(
await self.to_web_response(resp)
)
def _get_login_user(self, body):
identifier = body.get("identifier", None)
@ -114,6 +129,35 @@ class ProxyDaemon:
return user
async def start_panta_client(self, access_token, user, user_id, password):
client = Client(user_id, access_token)
self.client_info[access_token] = client
if user_id in self.panta_clients:
logger.info(f"Background sync client already exists for {user_id},"
f" not starting new one")
return
panta_client = PantaClient(
self.homeserver,
user,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
)
response = await panta_client.login(password, "pantalaimon")
if not isinstance(response, LoginResponse):
await panta_client.close()
return
logger.info(f"Succesfully started new background sync client for "
f"{user_id}")
self.panta_clients[user_id] = panta_client
loop = asyncio.get_event_loop()
loop.create_task(panta_client.loop())
async def login(self, request):
try:
@ -137,28 +181,30 @@ class ProxyDaemon:
user = self._get_login_user(body)
password = body.get("password", "")
device_id = body.get("device_id", "")
device_name = body.get("initial_device_display_name", "pantalaimon")
client = PantaClient(
self.homeserver,
user,
device_id,
store_path=self.data_dir,
ssl=self.ssl,
proxy=self.proxy
)
logger.info(f"New user logging in: {user}")
response = await client.login(password, device_name)
response = await self.forward_request(request)
if isinstance(response, LoginResponse):
self.client_sessions[response.access_token] = client
else:
await client.close()
try:
json_response = await response.json()
except JSONDecodeError:
json_response = None
pass
if response.status == 200 and json_response:
user_id = json_response.get("user_id", None)
access_token = json_response.get("access_token", None)
if user_id and access_token:
logger.info(f"User: {user} succesfully logged in, starting "
f"a background sync client.")
await self.start_panta_client(access_token, user, user_id,
password)
return web.Response(
status=response.transport_response.status,
text=await response.transport_response.text()
status=response.status,
text=await response.text()
)
@property
@ -198,7 +244,8 @@ class ProxyDaemon:
return self._missing_token
try:
client = self.client_sessions[access_token]
client_info = self.client_info[access_token]
client = self.panta_clients[client_info.user_id]
except KeyError:
return self._unknown_token
@ -223,39 +270,27 @@ class ProxyDaemon:
# if timeline_filter:
# types_filter = timeline_filter.get("types", None)
response = await client.sync(timeout, sync_filter)
query = CIMultiDict(request.query)
query.pop("filter", None)
response = await self.forward_request(request, query)
if response.status == 200:
json_response = await response.json()
json_response = client.decrypt_sync_body(json_response)
if not isinstance(response, SyncResponse):
return web.Response(
status=response.transport_response.status,
status=response.status,
text=json.dumps(json_response)
)
else:
return web.Response(
status=response.status,
text=await response.text()
)
if client.should_upload_keys:
await client.keys_upload()
if client.should_query_keys:
key_query_response = await client.keys_query()
# Verify new devices automatically for now.
if isinstance(key_query_response, KeysQueryResponse):
for user_id, device_dict in key_query_response.changed.items():
for device in device_dict.values():
if device.deleted:
continue
logger.info("Automatically verifying device {} of "
"user {}".format(device.id, user_id))
client.verify_device(device)
json_response = await response.transport_response.json()
decrypted_response = client.decrypt_sync_body(json_response)
return web.Response(
status=response.transport_response.status,
text=json.dumps(decrypted_response)
)
async def to_web_response(self, response):
return web.Response(status=response.status, text=await response.text())
async def send_message(self, request):
access_token = self.get_access_token(request)
@ -264,12 +299,26 @@ class ProxyDaemon:
return self._missing_token
try:
client = self.client_sessions[access_token]
client_info = self.client_info[access_token]
client = self.panta_clients[client_info.user_id]
except KeyError:
return self._unknown_token
msgtype = request.match_info["event_type"]
room_id = request.match_info["room_id"]
try:
encrypt = client.rooms[room_id].encrypted
except KeyError:
return await self.to_web_response(
await self.forward_request(request)
)
if not encrypt:
return await self.to_web_response(
await self.forward_request(request)
)
msgtype = request.match_info["event_type"]
txnid = request.match_info["txnid"]
try:
@ -293,7 +342,8 @@ class ProxyDaemon:
This method is called when we shut the whole app down
"""
for client in self.client_sessions.values():
for client in self.panta_clients.values():
await client.loop_stop()
await client.close()
if self.default_session: