mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-23 13:51:15 -05:00
daemon: Decouple the client sync from the daemon sync.
This commit is contained in:
parent
f27eb836fe
commit
1378dca195
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from pprint import pformat
|
||||
|
||||
@ -6,7 +7,10 @@ from nio import (
|
||||
RoomEncryptedEvent,
|
||||
MegolmEvent,
|
||||
EncryptionError,
|
||||
SyncResponse
|
||||
SyncResponse,
|
||||
KeysQueryResponse,
|
||||
LocalProtocolError,
|
||||
GroupEncryptionError
|
||||
)
|
||||
|
||||
from pantalaimon.log import logger
|
||||
@ -15,6 +19,96 @@ from pantalaimon.log import logger
|
||||
class PantaClient(AsyncClient):
|
||||
"""A wrapper class around a nio AsyncClient extending its functionality."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
homeserver,
|
||||
user="",
|
||||
device_id="",
|
||||
store_path="",
|
||||
config=None,
|
||||
ssl=None,
|
||||
proxy=None
|
||||
):
|
||||
super().__init__(homeserver, user, device_id, store_path, config,
|
||||
ssl, proxy)
|
||||
|
||||
self.loop_running = False
|
||||
self.loop_stopped = asyncio.Event()
|
||||
self.synced = asyncio.Event()
|
||||
|
||||
def verify_devices(self, changed_devices):
|
||||
# Verify new devices automatically for now.
|
||||
for user_id, device_dict in changed_devices.items():
|
||||
for device in device_dict.values():
|
||||
if device.deleted:
|
||||
continue
|
||||
|
||||
logger.info("Automatically verifying device {} of "
|
||||
"user {}".format(device.id, user_id))
|
||||
self.verify_device(device)
|
||||
|
||||
async def loop(self):
|
||||
"""Start a loop that runs forever and keeps on syncing with the server.
|
||||
|
||||
The loop can be stopped with the stop_loop() method.
|
||||
"""
|
||||
self.loop_running = True
|
||||
self.loop_stopped.clear()
|
||||
|
||||
logger.info(f"Starting sync loop for {self.user_id}")
|
||||
|
||||
while self.loop_running:
|
||||
if not self.logged_in:
|
||||
# TODO login
|
||||
pass
|
||||
|
||||
# TODO use user lazy loading here
|
||||
response = await self.sync(30000)
|
||||
|
||||
if self.should_upload_keys:
|
||||
await self.keys_upload()
|
||||
|
||||
if self.should_query_keys:
|
||||
key_query_response = await self.keys_query()
|
||||
if isinstance(key_query_response, KeysQueryResponse):
|
||||
self.verify_devices(key_query_response.changed)
|
||||
|
||||
if not isinstance(response, SyncResponse):
|
||||
# TODO error handling
|
||||
pass
|
||||
|
||||
self.synced.set()
|
||||
self.synced.clear()
|
||||
|
||||
logger.info("Stopping the sync loop")
|
||||
self.loop_stopped.set()
|
||||
|
||||
async def loop_stop(self):
|
||||
"""Stop the client loop.
|
||||
|
||||
Raises LocalProtocolError if the loop isn't running.
|
||||
"""
|
||||
if not self.loop_running:
|
||||
LocalProtocolError("Loop is not running")
|
||||
|
||||
self.loop_running = False
|
||||
await self.loop_stopped.wait()
|
||||
|
||||
async def encrypt(self, room_id, msgtype, content):
|
||||
try:
|
||||
return super().encrypt(
|
||||
room_id,
|
||||
msgtype,
|
||||
content
|
||||
)
|
||||
except GroupEncryptionError:
|
||||
await self.share_group_session(room_id)
|
||||
return super().encrypt(
|
||||
room_id,
|
||||
msgtype,
|
||||
content
|
||||
)
|
||||
|
||||
def decrypt_sync_body(self, body):
|
||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||
"""Go through a json sync response and decrypt megolm encrypted events.
|
||||
@ -34,7 +128,7 @@ class PantaClient(AsyncClient):
|
||||
for event in room_dict["timeline"]["events"]:
|
||||
if event["type"] != "m.room.encrypted":
|
||||
logger.info("Event is not encrypted: "
|
||||
"{}".format(pformat(event)))
|
||||
"\n{}".format(pformat(event)))
|
||||
continue
|
||||
|
||||
parsed_event = RoomEncryptedEvent.parse_event(event)
|
||||
@ -42,7 +136,7 @@ class PantaClient(AsyncClient):
|
||||
|
||||
if not isinstance(parsed_event, MegolmEvent):
|
||||
logger.warn("Encrypted event is not a megolm event:"
|
||||
"{}".format(pformat(event)))
|
||||
"\n{}".format(pformat(event)))
|
||||
continue
|
||||
|
||||
try:
|
||||
|
@ -5,6 +5,7 @@ import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
import ssl
|
||||
|
||||
import click
|
||||
from ipaddress import ip_address
|
||||
@ -26,6 +27,12 @@ from pantalaimon.client import PantaClient
|
||||
from pantalaimon.log import logger
|
||||
|
||||
|
||||
@attr.s
|
||||
class Client:
|
||||
user_id = attr.ib(type=str)
|
||||
access_token = attr.ib(type=str)
|
||||
|
||||
|
||||
@attr.s
|
||||
class ProxyDaemon:
|
||||
homeserver = attr.ib()
|
||||
@ -33,7 +40,12 @@ class ProxyDaemon:
|
||||
proxy = attr.ib(default=None)
|
||||
ssl = attr.ib(default=None)
|
||||
|
||||
client_sessions = attr.ib(init=False, default=attr.Factory(dict))
|
||||
panta_clients = attr.ib(init=False, default=attr.Factory(dict))
|
||||
client_info = attr.ib(
|
||||
init=False,
|
||||
default=attr.Factory(dict),
|
||||
type=dict
|
||||
)
|
||||
default_session = attr.ib(init=False, default=None)
|
||||
|
||||
def get_access_token(self, request):
|
||||
@ -55,7 +67,12 @@ class ProxyDaemon:
|
||||
|
||||
return access_token
|
||||
|
||||
async def forward_request(self, request, session):
|
||||
async def forward_request(
|
||||
self,
|
||||
request,
|
||||
params=None,
|
||||
session=None
|
||||
):
|
||||
# type: (aiohttp.BaseRequest, aiohttp.ClientSession) -> str
|
||||
"""Forward the given request to our configured homeserver.
|
||||
|
||||
@ -65,14 +82,21 @@ class ProxyDaemon:
|
||||
session (aiohttp.ClientSession): The client session that should be
|
||||
used to forward the request.
|
||||
"""
|
||||
|
||||
if not session:
|
||||
if not self.default_session:
|
||||
self.default_session = ClientSession()
|
||||
session = self.default_session
|
||||
|
||||
path = request.path
|
||||
method = request.method
|
||||
data = await request.text()
|
||||
|
||||
headers = CIMultiDict(request.headers)
|
||||
headers.pop("Host", None)
|
||||
|
||||
params = request.query
|
||||
params = params or request.query
|
||||
|
||||
data = await request.text()
|
||||
|
||||
return await session.request(
|
||||
method,
|
||||
@ -81,25 +105,16 @@ class ProxyDaemon:
|
||||
params=params,
|
||||
headers=headers,
|
||||
proxy=self.proxy,
|
||||
ssl=False
|
||||
ssl=self.ssl
|
||||
)
|
||||
|
||||
async def router(self, request):
|
||||
"""Catchall request router."""
|
||||
session = None
|
||||
resp = await self.forward_request(request)
|
||||
|
||||
token = self.get_access_token(request)
|
||||
client = self.client_sessions.get(token, None)
|
||||
|
||||
if client:
|
||||
session = client.client_session
|
||||
else:
|
||||
if not self.default_session:
|
||||
self.default_session = ClientSession()
|
||||
session = self.default_session
|
||||
|
||||
resp = await self.forward_request(request, session)
|
||||
return(web.Response(text=await resp.text()))
|
||||
return(
|
||||
await self.to_web_response(resp)
|
||||
)
|
||||
|
||||
def _get_login_user(self, body):
|
||||
identifier = body.get("identifier", None)
|
||||
@ -114,6 +129,35 @@ class ProxyDaemon:
|
||||
|
||||
return user
|
||||
|
||||
async def start_panta_client(self, access_token, user, user_id, password):
|
||||
client = Client(user_id, access_token)
|
||||
self.client_info[access_token] = client
|
||||
|
||||
if user_id in self.panta_clients:
|
||||
logger.info(f"Background sync client already exists for {user_id},"
|
||||
f" not starting new one")
|
||||
return
|
||||
|
||||
panta_client = PantaClient(
|
||||
self.homeserver,
|
||||
user,
|
||||
store_path=self.data_dir,
|
||||
ssl=self.ssl,
|
||||
proxy=self.proxy
|
||||
)
|
||||
response = await panta_client.login(password, "pantalaimon")
|
||||
|
||||
if not isinstance(response, LoginResponse):
|
||||
await panta_client.close()
|
||||
return
|
||||
|
||||
logger.info(f"Succesfully started new background sync client for "
|
||||
f"{user_id}")
|
||||
|
||||
self.panta_clients[user_id] = panta_client
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(panta_client.loop())
|
||||
|
||||
async def login(self, request):
|
||||
try:
|
||||
@ -137,28 +181,30 @@ class ProxyDaemon:
|
||||
|
||||
user = self._get_login_user(body)
|
||||
password = body.get("password", "")
|
||||
device_id = body.get("device_id", "")
|
||||
device_name = body.get("initial_device_display_name", "pantalaimon")
|
||||
|
||||
client = PantaClient(
|
||||
self.homeserver,
|
||||
user,
|
||||
device_id,
|
||||
store_path=self.data_dir,
|
||||
ssl=self.ssl,
|
||||
proxy=self.proxy
|
||||
)
|
||||
logger.info(f"New user logging in: {user}")
|
||||
|
||||
response = await client.login(password, device_name)
|
||||
response = await self.forward_request(request)
|
||||
|
||||
if isinstance(response, LoginResponse):
|
||||
self.client_sessions[response.access_token] = client
|
||||
else:
|
||||
await client.close()
|
||||
try:
|
||||
json_response = await response.json()
|
||||
except JSONDecodeError:
|
||||
json_response = None
|
||||
pass
|
||||
|
||||
if response.status == 200 and json_response:
|
||||
user_id = json_response.get("user_id", None)
|
||||
access_token = json_response.get("access_token", None)
|
||||
|
||||
if user_id and access_token:
|
||||
logger.info(f"User: {user} succesfully logged in, starting "
|
||||
f"a background sync client.")
|
||||
await self.start_panta_client(access_token, user, user_id,
|
||||
password)
|
||||
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
text=await response.transport_response.text()
|
||||
status=response.status,
|
||||
text=await response.text()
|
||||
)
|
||||
|
||||
@property
|
||||
@ -198,7 +244,8 @@ class ProxyDaemon:
|
||||
return self._missing_token
|
||||
|
||||
try:
|
||||
client = self.client_sessions[access_token]
|
||||
client_info = self.client_info[access_token]
|
||||
client = self.panta_clients[client_info.user_id]
|
||||
except KeyError:
|
||||
return self._unknown_token
|
||||
|
||||
@ -223,39 +270,27 @@ class ProxyDaemon:
|
||||
# if timeline_filter:
|
||||
# types_filter = timeline_filter.get("types", None)
|
||||
|
||||
response = await client.sync(timeout, sync_filter)
|
||||
query = CIMultiDict(request.query)
|
||||
query.pop("filter", None)
|
||||
|
||||
response = await self.forward_request(request, query)
|
||||
|
||||
if response.status == 200:
|
||||
json_response = await response.json()
|
||||
json_response = client.decrypt_sync_body(json_response)
|
||||
|
||||
if not isinstance(response, SyncResponse):
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
status=response.status,
|
||||
text=json.dumps(json_response)
|
||||
)
|
||||
else:
|
||||
return web.Response(
|
||||
status=response.status,
|
||||
text=await response.text()
|
||||
)
|
||||
|
||||
if client.should_upload_keys:
|
||||
await client.keys_upload()
|
||||
|
||||
if client.should_query_keys:
|
||||
key_query_response = await client.keys_query()
|
||||
|
||||
# Verify new devices automatically for now.
|
||||
if isinstance(key_query_response, KeysQueryResponse):
|
||||
for user_id, device_dict in key_query_response.changed.items():
|
||||
for device in device_dict.values():
|
||||
if device.deleted:
|
||||
continue
|
||||
|
||||
logger.info("Automatically verifying device {} of "
|
||||
"user {}".format(device.id, user_id))
|
||||
client.verify_device(device)
|
||||
|
||||
json_response = await response.transport_response.json()
|
||||
|
||||
decrypted_response = client.decrypt_sync_body(json_response)
|
||||
|
||||
return web.Response(
|
||||
status=response.transport_response.status,
|
||||
text=json.dumps(decrypted_response)
|
||||
)
|
||||
async def to_web_response(self, response):
|
||||
return web.Response(status=response.status, text=await response.text())
|
||||
|
||||
async def send_message(self, request):
|
||||
access_token = self.get_access_token(request)
|
||||
@ -264,12 +299,26 @@ class ProxyDaemon:
|
||||
return self._missing_token
|
||||
|
||||
try:
|
||||
client = self.client_sessions[access_token]
|
||||
client_info = self.client_info[access_token]
|
||||
client = self.panta_clients[client_info.user_id]
|
||||
except KeyError:
|
||||
return self._unknown_token
|
||||
|
||||
msgtype = request.match_info["event_type"]
|
||||
room_id = request.match_info["room_id"]
|
||||
|
||||
try:
|
||||
encrypt = client.rooms[room_id].encrypted
|
||||
except KeyError:
|
||||
return await self.to_web_response(
|
||||
await self.forward_request(request)
|
||||
)
|
||||
|
||||
if not encrypt:
|
||||
return await self.to_web_response(
|
||||
await self.forward_request(request)
|
||||
)
|
||||
|
||||
msgtype = request.match_info["event_type"]
|
||||
txnid = request.match_info["txnid"]
|
||||
|
||||
try:
|
||||
@ -293,7 +342,8 @@ class ProxyDaemon:
|
||||
|
||||
This method is called when we shut the whole app down
|
||||
"""
|
||||
for client in self.client_sessions.values():
|
||||
for client in self.panta_clients.values():
|
||||
await client.loop_stop()
|
||||
await client.close()
|
||||
|
||||
if self.default_session:
|
||||
|
Loading…
Reference in New Issue
Block a user