diff --git a/pantalaimon/client.py b/pantalaimon/client.py index b2d2a72..01b3295 100644 --- a/pantalaimon/client.py +++ b/pantalaimon/client.py @@ -14,6 +14,7 @@ import asyncio import os +import time from collections import defaultdict from pprint import pformat from typing import Any, Dict, Optional @@ -177,6 +178,7 @@ class PanClient(AsyncClient): self.pan_store = pan_store self.pan_conf = pan_conf self.media_info = media_info + self.last_sync_request_ts = 0 if INDEXING_ENABLED: logger.info("Indexing enabled.") @@ -199,6 +201,7 @@ class PanClient(AsyncClient): self.send_semaphores = defaultdict(asyncio.Semaphore) self.send_decision_queues = dict() # type: asyncio.Queue self.last_sync_token = None + self.last_sync_task = None self.history_fetcher_task = None self.history_fetch_queue = asyncio.Queue() @@ -523,6 +526,22 @@ class PanClient(AsyncClient): ) await self.send_update_device(device) + def ensure_sync_running(self, loop_sleep_time=100): + self.last_sync_request_ts = int(time.time()) + if self.task is None: + self.start_loop(loop_sleep_time) + + async def can_stop_sync(self): + try: + while True: + await asyncio.sleep(self.pan_conf.sync_stop_after) + if time.time() - self.last_sync_request_ts > self.pan_conf.sync_stop_after: + await self.loop_stop() + break + except (asyncio.CancelledError, KeyboardInterrupt): + return + + def start_loop(self, loop_sleep_time=100): """Start a loop that runs forever and keeps on syncing with the server. @@ -541,6 +560,7 @@ class PanClient(AsyncClient): sync_filter = {"room": {"state": {"lazy_load_members": True}}} next_batch = self.pan_store.load_token(self.server_name, self.user_id) self.last_sync_token = next_batch + self.last_sync_request_ts = int(time.time()) # We don't store any room state so initial sync needs to be with the # full_state parameter. Subsequent ones are normal. @@ -555,6 +575,10 @@ class PanClient(AsyncClient): ) self.task = task + if self.pan_conf.sync_stop_after > 0: + self.last_sync_task = loop.create_task(self.can_stop_sync()) + + return task async def start_sas(self, message, device): @@ -774,7 +798,7 @@ class PanClient(AsyncClient): async def loop_stop(self): """Stop the client loop.""" - logger.info("Stopping the sync loop") + logger.info(f"Stopping the sync loop for {self.user_id}") if self.task and not self.task.done(): self.task.cancel() @@ -786,6 +810,16 @@ class PanClient(AsyncClient): self.task = None + if self.last_sync_task and not self.last_sync_task.done(): + self.last_sync_task.cancel() + + try: + await self.last_sync_task + except KeyboardInterrupt: + pass + + self.last_sync_task = None + if self.history_fetcher_task and not self.history_fetcher_task.done(): self.history_fetcher_task.cancel() diff --git a/pantalaimon/config.py b/pantalaimon/config.py index 2a01714..47ca53d 100644 --- a/pantalaimon/config.py +++ b/pantalaimon/config.py @@ -39,6 +39,8 @@ class PanConfigParser(configparser.ConfigParser): "IndexingBatchSize": "100", "HistoryFetchDelay": "3000", "DebugEncryption": "False", + "SyncOnStartup": "True", + "StopSyncingTimeout": "0" }, converters={ "address": parse_address, @@ -121,6 +123,10 @@ class ServerConfig: the room history. history_fetch_delay (int): The delay between room history fetching requests in seconds. + sync_on_startup (bool): Begin syncing all accounts registered with + pantalaimon on startup. + sync_stop_after (int): The number of seconds to wait since the + client has requested a /sync, before stopping a sync. """ name = attr.ib(type=str) @@ -137,6 +143,8 @@ class ServerConfig: index_encrypted_only = attr.ib(type=bool, default=True) indexing_batch_size = attr.ib(type=int, default=100) history_fetch_delay = attr.ib(type=int, default=3) + sync_on_startup = attr.ib(type=bool, default=True) + sync_stop_after = attr.ib(type=int, default=0) @attr.s @@ -204,6 +212,9 @@ class PanConfig: indexing_batch_size = section.getint("IndexingBatchSize") + sync_on_startup = section.getboolean("SyncOnStartup") + sync_stop_after = section.getint("SyncStopAfter") + if not 1 < indexing_batch_size <= 1000: raise PanConfigError( "The indexing batch size needs to be " @@ -243,6 +254,8 @@ class PanConfig: index_encrypted_only, indexing_batch_size, history_fetch_delay / 1000, + sync_on_startup, + sync_stop_after, ) self.servers[section_name] = server_conf diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 55b67a2..e98512b 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -15,12 +15,12 @@ import asyncio import json import os +import time import urllib.parse import concurrent.futures from json import JSONDecodeError from typing import Any, Dict from uuid import uuid4 - import aiohttp import attr import keyring @@ -159,8 +159,8 @@ class ProxyDaemon: ) loop.create_task(pan_client.send_update_devices(pan_client.device_store)) - - pan_client.start_loop() + if self.conf.sync_on_startup: + pan_client.start_loop() async def _find_client(self, access_token): client_info = self.client_info.get(access_token, None) @@ -736,6 +736,8 @@ class ProxyDaemon: if not client: return self._unknown_token + client.ensure_sync_running() + sync_filter = request.query.get("filter", None) query = CIMultiDict(request.query)