diff --git a/pantalaimon/client.py b/pantalaimon/client.py index f2f6895..84e6121 100644 --- a/pantalaimon/client.py +++ b/pantalaimon/client.py @@ -14,10 +14,10 @@ import asyncio import os +import time from collections import defaultdict from pprint import pformat from urllib.parse import urlparse - from aiohttp.client_exceptions import ClientConnectionError from jsonschema import Draft4Validator, FormatChecker, validators from playhouse.sqliteq import SqliteQueueDatabase @@ -49,10 +49,9 @@ from nio import ( ) from nio.crypto import Sas from nio.store import SqliteStore - from pantalaimon.index import INDEXING_ENABLED from pantalaimon.log import logger -from pantalaimon.store import FetchTask, MediaInfo +from pantalaimon.store import FetchTask, MediaInfo, PanSqliteStore from pantalaimon.thread_messages import ( DaemonResponse, InviteSasSignal, @@ -161,7 +160,7 @@ class PanClient(AsyncClient): media_info=None, ): config = config or AsyncClientConfig( - store=store_class or SqliteStore, store_name="pan.db" + store=store_class or PanSqliteStore, store_name="pan.db" ) super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy) @@ -176,6 +175,7 @@ class PanClient(AsyncClient): self.pan_store = pan_store self.pan_conf = pan_conf self.media_info = media_info + self.last_sync_request_ts = 0 if INDEXING_ENABLED: logger.info("Indexing enabled.") @@ -198,6 +198,7 @@ class PanClient(AsyncClient): self.send_semaphores = defaultdict(asyncio.Semaphore) self.send_decision_queues = dict() # type: asyncio.Queue self.last_sync_token = None + self.last_sync_task = None self.history_fetcher_task = None self.history_fetch_queue = asyncio.Queue() @@ -526,6 +527,22 @@ class PanClient(AsyncClient): ) await self.send_update_device(device) + def ensure_sync_running(self, loop_sleep_time=100): + self.last_sync_request_ts = int(time.time()) + if self.task is None: + self.start_loop(loop_sleep_time) + + async def can_stop_sync(self): + try: + while True: + await asyncio.sleep(self.pan_conf.sync_stop_after) + if time.time() - self.last_sync_request_ts > self.pan_conf.sync_stop_after: + await self.loop_stop() + break + except (asyncio.CancelledError, KeyboardInterrupt): + return + + def start_loop(self, loop_sleep_time=100): """Start a loop that runs forever and keeps on syncing with the server. @@ -543,6 +560,8 @@ class PanClient(AsyncClient): timeout = 30000 sync_filter = {"room": {"state": {"lazy_load_members": True}}} next_batch = self.pan_store.load_token(self.server_name, self.user_id) + self.last_sync_token = next_batch + self.last_sync_request_ts = int(time.time()) # We don't store any room state so initial sync needs to be with the # full_state parameter. Subsequent ones are normal. @@ -558,6 +577,10 @@ class PanClient(AsyncClient): ) self.task = task + # if self.pan_conf.sync_stop_after > 0: + # self.last_sync_task = loop.create_task(self.can_stop_sync()) + + return task async def start_sas(self, message, device): @@ -776,7 +799,7 @@ class PanClient(AsyncClient): async def loop_stop(self): """Stop the client loop.""" - logger.info("Stopping the sync loop") + logger.info(f"Stopping the sync loop for {self.user_id}") if self.task and not self.task.done(): self.task.cancel() @@ -788,6 +811,16 @@ class PanClient(AsyncClient): self.task = None + if self.last_sync_task and not self.last_sync_task.done(): + self.last_sync_task.cancel() + + try: + await self.last_sync_task + except KeyboardInterrupt: + pass + + self.last_sync_task = None + if self.history_fetcher_task and not self.history_fetcher_task.done(): self.history_fetcher_task.cancel() diff --git a/pantalaimon/config.py b/pantalaimon/config.py index a5e59d1..15871c6 100644 --- a/pantalaimon/config.py +++ b/pantalaimon/config.py @@ -39,6 +39,8 @@ class PanConfigParser(configparser.ConfigParser): "IndexingBatchSize": "100", "HistoryFetchDelay": "3000", "DebugEncryption": "False", + "SyncOnStartup": "False", + "StopSyncingTimeout": "600", "DropOldKeys": "False", }, converters={ @@ -122,6 +124,12 @@ class ServerConfig: the room history. history_fetch_delay (int): The delay between room history fetching requests in seconds. + sync_on_startup (bool): Begin syncing all accounts registered with + pantalaimon on startup. + sync_stop_after (int): The number of seconds to wait since the + client has requested a /sync, before stopping a sync. + store_forgetful (bool): Enable or disable discarding of previous sessions + from the store. drop_old_keys (bool): Should Pantalaimon only keep the most recent decryption key around. """ @@ -140,6 +148,9 @@ class ServerConfig: index_encrypted_only = attr.ib(type=bool, default=True) indexing_batch_size = attr.ib(type=int, default=100) history_fetch_delay = attr.ib(type=int, default=3) + sync_on_startup = attr.ib(type=bool, default=False) + sync_stop_after = attr.ib(type=int, default=600) + store_forgetful = attr.ib(type=bool, default=True) drop_old_keys = attr.ib(type=bool, default=False) @@ -204,9 +215,12 @@ class PanConfig: proxy = section.geturl("Proxy") search_requests = section.getboolean("SearchRequests") index_encrypted_only = section.getboolean("IndexEncryptedOnly") - + store_forgetful = config["Default"].getboolean("StoreForgetful") indexing_batch_size = section.getint("IndexingBatchSize") + sync_on_startup = False #section.getboolean("SyncOnStartup") + sync_stop_after = 600 #section.getint("SyncStopAfter") + if not 1 < indexing_batch_size <= 1000: raise PanConfigError( "The indexing batch size needs to be " @@ -247,6 +261,9 @@ class PanConfig: index_encrypted_only, indexing_batch_size, history_fetch_delay / 1000, + sync_on_startup, + sync_stop_after, + store_forgetful, drop_old_keys, ) diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 6d47b36..3b73130 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -15,6 +15,7 @@ import asyncio import json import os +import time import urllib.parse import concurrent.futures from io import BufferedReader, BytesIO @@ -22,7 +23,6 @@ from json import JSONDecodeError from typing import Any, Dict from urllib.parse import urlparse from uuid import uuid4 - import aiohttp import attr import keyring @@ -163,6 +163,7 @@ class ProxyDaemon: pan_client.user_id = user_id pan_client.access_token = token pan_client.load_store() + pan_client.store.forgetful = self.conf.store_forgetful self.pan_clients[user_id] = pan_client loop.create_task( @@ -172,8 +173,8 @@ class ProxyDaemon: ) loop.create_task(pan_client.send_update_devices(pan_client.device_store)) - - pan_client.start_loop() + if self.conf.sync_on_startup: + pan_client.start_loop() async def _find_client(self, access_token): client_info = self.client_info.get(access_token, None) @@ -756,6 +757,8 @@ class ProxyDaemon: if not client: return self._unknown_token + client.ensure_sync_running() + sync_filter = request.query.get("filter", None) query = CIMultiDict(request.query) diff --git a/pantalaimon/store.py b/pantalaimon/store.py index 0dfe045..105ed32 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -27,6 +27,9 @@ from nio.store import ( DeviceTrustState, use_database, use_database_atomic, + SqliteStore, + MegolmInboundSessions, + ForwardedChains ) from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField from cachetools import LRUCache @@ -454,6 +457,22 @@ class PanStore: return store +class PanSqliteStore(SqliteStore): + forgetful = False + + @use_database + def save_inbound_group_session(self, session): + """Save the provided Megolm inbound group session to the database. + Args: + session (InboundGroupSession): The session to save. + """ + # Delete previous sessions + if self.forgetful: + MegolmInboundSessions.delete().where( + (MegolmInboundSessions.sender_key == session.sender_key) | + (MegolmInboundSessions.room_id == session.room_id) + ).execute() + super().save_inbound_group_session(session) class KeyDroppingSqliteStore(SqliteStore): @use_database