mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-02-02 10:35:10 -05:00
daemon: Decouple the client sync from the daemon sync.
This commit is contained in:
parent
f27eb836fe
commit
1378dca195
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
@ -6,7 +7,10 @@ from nio import (
|
|||||||
RoomEncryptedEvent,
|
RoomEncryptedEvent,
|
||||||
MegolmEvent,
|
MegolmEvent,
|
||||||
EncryptionError,
|
EncryptionError,
|
||||||
SyncResponse
|
SyncResponse,
|
||||||
|
KeysQueryResponse,
|
||||||
|
LocalProtocolError,
|
||||||
|
GroupEncryptionError
|
||||||
)
|
)
|
||||||
|
|
||||||
from pantalaimon.log import logger
|
from pantalaimon.log import logger
|
||||||
@ -15,6 +19,96 @@ from pantalaimon.log import logger
|
|||||||
class PantaClient(AsyncClient):
|
class PantaClient(AsyncClient):
|
||||||
"""A wrapper class around a nio AsyncClient extending its functionality."""
|
"""A wrapper class around a nio AsyncClient extending its functionality."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
homeserver,
|
||||||
|
user="",
|
||||||
|
device_id="",
|
||||||
|
store_path="",
|
||||||
|
config=None,
|
||||||
|
ssl=None,
|
||||||
|
proxy=None
|
||||||
|
):
|
||||||
|
super().__init__(homeserver, user, device_id, store_path, config,
|
||||||
|
ssl, proxy)
|
||||||
|
|
||||||
|
self.loop_running = False
|
||||||
|
self.loop_stopped = asyncio.Event()
|
||||||
|
self.synced = asyncio.Event()
|
||||||
|
|
||||||
|
def verify_devices(self, changed_devices):
|
||||||
|
# Verify new devices automatically for now.
|
||||||
|
for user_id, device_dict in changed_devices.items():
|
||||||
|
for device in device_dict.values():
|
||||||
|
if device.deleted:
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info("Automatically verifying device {} of "
|
||||||
|
"user {}".format(device.id, user_id))
|
||||||
|
self.verify_device(device)
|
||||||
|
|
||||||
|
async def loop(self):
|
||||||
|
"""Start a loop that runs forever and keeps on syncing with the server.
|
||||||
|
|
||||||
|
The loop can be stopped with the stop_loop() method.
|
||||||
|
"""
|
||||||
|
self.loop_running = True
|
||||||
|
self.loop_stopped.clear()
|
||||||
|
|
||||||
|
logger.info(f"Starting sync loop for {self.user_id}")
|
||||||
|
|
||||||
|
while self.loop_running:
|
||||||
|
if not self.logged_in:
|
||||||
|
# TODO login
|
||||||
|
pass
|
||||||
|
|
||||||
|
# TODO use user lazy loading here
|
||||||
|
response = await self.sync(30000)
|
||||||
|
|
||||||
|
if self.should_upload_keys:
|
||||||
|
await self.keys_upload()
|
||||||
|
|
||||||
|
if self.should_query_keys:
|
||||||
|
key_query_response = await self.keys_query()
|
||||||
|
if isinstance(key_query_response, KeysQueryResponse):
|
||||||
|
self.verify_devices(key_query_response.changed)
|
||||||
|
|
||||||
|
if not isinstance(response, SyncResponse):
|
||||||
|
# TODO error handling
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.synced.set()
|
||||||
|
self.synced.clear()
|
||||||
|
|
||||||
|
logger.info("Stopping the sync loop")
|
||||||
|
self.loop_stopped.set()
|
||||||
|
|
||||||
|
async def loop_stop(self):
|
||||||
|
"""Stop the client loop.
|
||||||
|
|
||||||
|
Raises LocalProtocolError if the loop isn't running.
|
||||||
|
"""
|
||||||
|
if not self.loop_running:
|
||||||
|
LocalProtocolError("Loop is not running")
|
||||||
|
|
||||||
|
self.loop_running = False
|
||||||
|
await self.loop_stopped.wait()
|
||||||
|
|
||||||
|
async def encrypt(self, room_id, msgtype, content):
|
||||||
|
try:
|
||||||
|
return super().encrypt(
|
||||||
|
room_id,
|
||||||
|
msgtype,
|
||||||
|
content
|
||||||
|
)
|
||||||
|
except GroupEncryptionError:
|
||||||
|
await self.share_group_session(room_id)
|
||||||
|
return super().encrypt(
|
||||||
|
room_id,
|
||||||
|
msgtype,
|
||||||
|
content
|
||||||
|
)
|
||||||
|
|
||||||
def decrypt_sync_body(self, body):
|
def decrypt_sync_body(self, body):
|
||||||
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
# type: (Dict[Any, Any]) -> Dict[Any, Any]
|
||||||
"""Go through a json sync response and decrypt megolm encrypted events.
|
"""Go through a json sync response and decrypt megolm encrypted events.
|
||||||
@ -34,7 +128,7 @@ class PantaClient(AsyncClient):
|
|||||||
for event in room_dict["timeline"]["events"]:
|
for event in room_dict["timeline"]["events"]:
|
||||||
if event["type"] != "m.room.encrypted":
|
if event["type"] != "m.room.encrypted":
|
||||||
logger.info("Event is not encrypted: "
|
logger.info("Event is not encrypted: "
|
||||||
"{}".format(pformat(event)))
|
"\n{}".format(pformat(event)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parsed_event = RoomEncryptedEvent.parse_event(event)
|
parsed_event = RoomEncryptedEvent.parse_event(event)
|
||||||
@ -42,7 +136,7 @@ class PantaClient(AsyncClient):
|
|||||||
|
|
||||||
if not isinstance(parsed_event, MegolmEvent):
|
if not isinstance(parsed_event, MegolmEvent):
|
||||||
logger.warn("Encrypted event is not a megolm event:"
|
logger.warn("Encrypted event is not a megolm event:"
|
||||||
"{}".format(pformat(event)))
|
"\n{}".format(pformat(event)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -5,6 +5,7 @@ import asyncio
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import ssl
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from ipaddress import ip_address
|
from ipaddress import ip_address
|
||||||
@ -26,6 +27,12 @@ from pantalaimon.client import PantaClient
|
|||||||
from pantalaimon.log import logger
|
from pantalaimon.log import logger
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class Client:
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
access_token = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class ProxyDaemon:
|
class ProxyDaemon:
|
||||||
homeserver = attr.ib()
|
homeserver = attr.ib()
|
||||||
@ -33,7 +40,12 @@ class ProxyDaemon:
|
|||||||
proxy = attr.ib(default=None)
|
proxy = attr.ib(default=None)
|
||||||
ssl = attr.ib(default=None)
|
ssl = attr.ib(default=None)
|
||||||
|
|
||||||
client_sessions = attr.ib(init=False, default=attr.Factory(dict))
|
panta_clients = attr.ib(init=False, default=attr.Factory(dict))
|
||||||
|
client_info = attr.ib(
|
||||||
|
init=False,
|
||||||
|
default=attr.Factory(dict),
|
||||||
|
type=dict
|
||||||
|
)
|
||||||
default_session = attr.ib(init=False, default=None)
|
default_session = attr.ib(init=False, default=None)
|
||||||
|
|
||||||
def get_access_token(self, request):
|
def get_access_token(self, request):
|
||||||
@ -55,7 +67,12 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
async def forward_request(self, request, session):
|
async def forward_request(
|
||||||
|
self,
|
||||||
|
request,
|
||||||
|
params=None,
|
||||||
|
session=None
|
||||||
|
):
|
||||||
# type: (aiohttp.BaseRequest, aiohttp.ClientSession) -> str
|
# type: (aiohttp.BaseRequest, aiohttp.ClientSession) -> str
|
||||||
"""Forward the given request to our configured homeserver.
|
"""Forward the given request to our configured homeserver.
|
||||||
|
|
||||||
@ -65,14 +82,21 @@ class ProxyDaemon:
|
|||||||
session (aiohttp.ClientSession): The client session that should be
|
session (aiohttp.ClientSession): The client session that should be
|
||||||
used to forward the request.
|
used to forward the request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not session:
|
||||||
|
if not self.default_session:
|
||||||
|
self.default_session = ClientSession()
|
||||||
|
session = self.default_session
|
||||||
|
|
||||||
path = request.path
|
path = request.path
|
||||||
method = request.method
|
method = request.method
|
||||||
data = await request.text()
|
|
||||||
|
|
||||||
headers = CIMultiDict(request.headers)
|
headers = CIMultiDict(request.headers)
|
||||||
headers.pop("Host", None)
|
headers.pop("Host", None)
|
||||||
|
|
||||||
params = request.query
|
params = params or request.query
|
||||||
|
|
||||||
|
data = await request.text()
|
||||||
|
|
||||||
return await session.request(
|
return await session.request(
|
||||||
method,
|
method,
|
||||||
@ -81,25 +105,16 @@ class ProxyDaemon:
|
|||||||
params=params,
|
params=params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
proxy=self.proxy,
|
proxy=self.proxy,
|
||||||
ssl=False
|
ssl=self.ssl
|
||||||
)
|
)
|
||||||
|
|
||||||
async def router(self, request):
|
async def router(self, request):
|
||||||
"""Catchall request router."""
|
"""Catchall request router."""
|
||||||
session = None
|
resp = await self.forward_request(request)
|
||||||
|
|
||||||
token = self.get_access_token(request)
|
return(
|
||||||
client = self.client_sessions.get(token, None)
|
await self.to_web_response(resp)
|
||||||
|
)
|
||||||
if client:
|
|
||||||
session = client.client_session
|
|
||||||
else:
|
|
||||||
if not self.default_session:
|
|
||||||
self.default_session = ClientSession()
|
|
||||||
session = self.default_session
|
|
||||||
|
|
||||||
resp = await self.forward_request(request, session)
|
|
||||||
return(web.Response(text=await resp.text()))
|
|
||||||
|
|
||||||
def _get_login_user(self, body):
|
def _get_login_user(self, body):
|
||||||
identifier = body.get("identifier", None)
|
identifier = body.get("identifier", None)
|
||||||
@ -114,6 +129,35 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
async def start_panta_client(self, access_token, user, user_id, password):
|
||||||
|
client = Client(user_id, access_token)
|
||||||
|
self.client_info[access_token] = client
|
||||||
|
|
||||||
|
if user_id in self.panta_clients:
|
||||||
|
logger.info(f"Background sync client already exists for {user_id},"
|
||||||
|
f" not starting new one")
|
||||||
|
return
|
||||||
|
|
||||||
|
panta_client = PantaClient(
|
||||||
|
self.homeserver,
|
||||||
|
user,
|
||||||
|
store_path=self.data_dir,
|
||||||
|
ssl=self.ssl,
|
||||||
|
proxy=self.proxy
|
||||||
|
)
|
||||||
|
response = await panta_client.login(password, "pantalaimon")
|
||||||
|
|
||||||
|
if not isinstance(response, LoginResponse):
|
||||||
|
await panta_client.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Succesfully started new background sync client for "
|
||||||
|
f"{user_id}")
|
||||||
|
|
||||||
|
self.panta_clients[user_id] = panta_client
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.create_task(panta_client.loop())
|
||||||
|
|
||||||
async def login(self, request):
|
async def login(self, request):
|
||||||
try:
|
try:
|
||||||
@ -137,28 +181,30 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
user = self._get_login_user(body)
|
user = self._get_login_user(body)
|
||||||
password = body.get("password", "")
|
password = body.get("password", "")
|
||||||
device_id = body.get("device_id", "")
|
|
||||||
device_name = body.get("initial_device_display_name", "pantalaimon")
|
|
||||||
|
|
||||||
client = PantaClient(
|
logger.info(f"New user logging in: {user}")
|
||||||
self.homeserver,
|
|
||||||
user,
|
|
||||||
device_id,
|
|
||||||
store_path=self.data_dir,
|
|
||||||
ssl=self.ssl,
|
|
||||||
proxy=self.proxy
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.login(password, device_name)
|
response = await self.forward_request(request)
|
||||||
|
|
||||||
if isinstance(response, LoginResponse):
|
try:
|
||||||
self.client_sessions[response.access_token] = client
|
json_response = await response.json()
|
||||||
else:
|
except JSONDecodeError:
|
||||||
await client.close()
|
json_response = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
if response.status == 200 and json_response:
|
||||||
|
user_id = json_response.get("user_id", None)
|
||||||
|
access_token = json_response.get("access_token", None)
|
||||||
|
|
||||||
|
if user_id and access_token:
|
||||||
|
logger.info(f"User: {user} succesfully logged in, starting "
|
||||||
|
f"a background sync client.")
|
||||||
|
await self.start_panta_client(access_token, user, user_id,
|
||||||
|
password)
|
||||||
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=response.transport_response.status,
|
status=response.status,
|
||||||
text=await response.transport_response.text()
|
text=await response.text()
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -198,7 +244,8 @@ class ProxyDaemon:
|
|||||||
return self._missing_token
|
return self._missing_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = self.client_sessions[access_token]
|
client_info = self.client_info[access_token]
|
||||||
|
client = self.panta_clients[client_info.user_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return self._unknown_token
|
return self._unknown_token
|
||||||
|
|
||||||
@ -223,39 +270,27 @@ class ProxyDaemon:
|
|||||||
# if timeline_filter:
|
# if timeline_filter:
|
||||||
# types_filter = timeline_filter.get("types", None)
|
# types_filter = timeline_filter.get("types", None)
|
||||||
|
|
||||||
response = await client.sync(timeout, sync_filter)
|
query = CIMultiDict(request.query)
|
||||||
|
query.pop("filter", None)
|
||||||
|
|
||||||
|
response = await self.forward_request(request, query)
|
||||||
|
|
||||||
|
if response.status == 200:
|
||||||
|
json_response = await response.json()
|
||||||
|
json_response = client.decrypt_sync_body(json_response)
|
||||||
|
|
||||||
if not isinstance(response, SyncResponse):
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
status=response.transport_response.status,
|
status=response.status,
|
||||||
|
text=json.dumps(json_response)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return web.Response(
|
||||||
|
status=response.status,
|
||||||
text=await response.text()
|
text=await response.text()
|
||||||
)
|
)
|
||||||
|
|
||||||
if client.should_upload_keys:
|
async def to_web_response(self, response):
|
||||||
await client.keys_upload()
|
return web.Response(status=response.status, text=await response.text())
|
||||||
|
|
||||||
if client.should_query_keys:
|
|
||||||
key_query_response = await client.keys_query()
|
|
||||||
|
|
||||||
# Verify new devices automatically for now.
|
|
||||||
if isinstance(key_query_response, KeysQueryResponse):
|
|
||||||
for user_id, device_dict in key_query_response.changed.items():
|
|
||||||
for device in device_dict.values():
|
|
||||||
if device.deleted:
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info("Automatically verifying device {} of "
|
|
||||||
"user {}".format(device.id, user_id))
|
|
||||||
client.verify_device(device)
|
|
||||||
|
|
||||||
json_response = await response.transport_response.json()
|
|
||||||
|
|
||||||
decrypted_response = client.decrypt_sync_body(json_response)
|
|
||||||
|
|
||||||
return web.Response(
|
|
||||||
status=response.transport_response.status,
|
|
||||||
text=json.dumps(decrypted_response)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def send_message(self, request):
|
async def send_message(self, request):
|
||||||
access_token = self.get_access_token(request)
|
access_token = self.get_access_token(request)
|
||||||
@ -264,12 +299,26 @@ class ProxyDaemon:
|
|||||||
return self._missing_token
|
return self._missing_token
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = self.client_sessions[access_token]
|
client_info = self.client_info[access_token]
|
||||||
|
client = self.panta_clients[client_info.user_id]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return self._unknown_token
|
return self._unknown_token
|
||||||
|
|
||||||
msgtype = request.match_info["event_type"]
|
|
||||||
room_id = request.match_info["room_id"]
|
room_id = request.match_info["room_id"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
encrypt = client.rooms[room_id].encrypted
|
||||||
|
except KeyError:
|
||||||
|
return await self.to_web_response(
|
||||||
|
await self.forward_request(request)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not encrypt:
|
||||||
|
return await self.to_web_response(
|
||||||
|
await self.forward_request(request)
|
||||||
|
)
|
||||||
|
|
||||||
|
msgtype = request.match_info["event_type"]
|
||||||
txnid = request.match_info["txnid"]
|
txnid = request.match_info["txnid"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -293,7 +342,8 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
This method is called when we shut the whole app down
|
This method is called when we shut the whole app down
|
||||||
"""
|
"""
|
||||||
for client in self.client_sessions.values():
|
for client in self.panta_clients.values():
|
||||||
|
await client.loop_stop()
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
if self.default_session:
|
if self.default_session:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user