Add supporting options for disallowing the sync on startup and stopping the sync after some time

This commit is contained in:
Will Hunt 2021-01-14 17:57:20 +00:00
parent 9c65c06075
commit 7446cfc084
3 changed files with 53 additions and 4 deletions

View File

@ -14,6 +14,7 @@
import asyncio
import os
import time
from collections import defaultdict
from pprint import pformat
from typing import Any, Dict, Optional
@ -177,6 +178,7 @@ class PanClient(AsyncClient):
self.pan_store = pan_store
self.pan_conf = pan_conf
self.media_info = media_info
self.last_sync_request_ts = 0
if INDEXING_ENABLED:
logger.info("Indexing enabled.")
@ -199,6 +201,7 @@ class PanClient(AsyncClient):
self.send_semaphores = defaultdict(asyncio.Semaphore)
self.send_decision_queues = dict() # type: asyncio.Queue
self.last_sync_token = None
self.last_sync_task = None
self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue()
@ -523,6 +526,22 @@ class PanClient(AsyncClient):
)
await self.send_update_device(device)
def ensure_sync_running(self, loop_sleep_time=100):
self.last_sync_request_ts = int(time.time())
if self.task is None:
self.start_loop(loop_sleep_time)
async def can_stop_sync(self):
try:
while True:
await asyncio.sleep(self.pan_conf.sync_stop_after)
if time.time() - self.last_sync_request_ts > self.pan_conf.sync_stop_after:
await self.loop_stop()
break
except (asyncio.CancelledError, KeyboardInterrupt):
return
def start_loop(self, loop_sleep_time=100):
"""Start a loop that runs forever and keeps on syncing with the server.
@ -541,6 +560,7 @@ class PanClient(AsyncClient):
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
self.last_sync_token = next_batch
self.last_sync_request_ts = int(time.time())
# We don't store any room state so initial sync needs to be with the
# full_state parameter. Subsequent ones are normal.
@ -555,6 +575,10 @@ class PanClient(AsyncClient):
)
self.task = task
if self.pan_conf.sync_stop_after > 0:
self.last_sync_task = loop.create_task(self.can_stop_sync())
return task
async def start_sas(self, message, device):
@ -774,7 +798,7 @@ class PanClient(AsyncClient):
async def loop_stop(self):
"""Stop the client loop."""
logger.info("Stopping the sync loop")
logger.info(f"Stopping the sync loop for {self.user_id}")
if self.task and not self.task.done():
self.task.cancel()
@ -786,6 +810,16 @@ class PanClient(AsyncClient):
self.task = None
if self.last_sync_task and not self.last_sync_task.done():
self.last_sync_task.cancel()
try:
await self.last_sync_task
except KeyboardInterrupt:
pass
self.last_sync_task = None
if self.history_fetcher_task and not self.history_fetcher_task.done():
self.history_fetcher_task.cancel()

View File

@ -39,6 +39,8 @@ class PanConfigParser(configparser.ConfigParser):
"IndexingBatchSize": "100",
"HistoryFetchDelay": "3000",
"DebugEncryption": "False",
"SyncOnStartup": "True",
"StopSyncingTimeout": "0"
},
converters={
"address": parse_address,
@ -121,6 +123,10 @@ class ServerConfig:
the room history.
history_fetch_delay (int): The delay between room history fetching
requests in seconds.
sync_on_startup (bool): Begin syncing all accounts registered with
pantalaimon on startup.
sync_stop_after (int): The number of seconds to wait since the
client has requested a /sync, before stopping a sync.
"""
name = attr.ib(type=str)
@ -137,6 +143,8 @@ class ServerConfig:
index_encrypted_only = attr.ib(type=bool, default=True)
indexing_batch_size = attr.ib(type=int, default=100)
history_fetch_delay = attr.ib(type=int, default=3)
sync_on_startup = attr.ib(type=bool, default=True)
sync_stop_after = attr.ib(type=int, default=0)
@attr.s
@ -204,6 +212,9 @@ class PanConfig:
indexing_batch_size = section.getint("IndexingBatchSize")
sync_on_startup = section.getboolean("SyncOnStartup")
sync_stop_after = section.getint("SyncStopAfter")
if not 1 < indexing_batch_size <= 1000:
raise PanConfigError(
"The indexing batch size needs to be "
@ -243,6 +254,8 @@ class PanConfig:
index_encrypted_only,
indexing_batch_size,
history_fetch_delay / 1000,
sync_on_startup,
sync_stop_after,
)
self.servers[section_name] = server_conf

View File

@ -15,12 +15,12 @@
import asyncio
import json
import os
import time
import urllib.parse
import concurrent.futures
from json import JSONDecodeError
from typing import Any, Dict
from uuid import uuid4
import aiohttp
import attr
import keyring
@ -159,8 +159,8 @@ class ProxyDaemon:
)
loop.create_task(pan_client.send_update_devices(pan_client.device_store))
pan_client.start_loop()
if self.conf.sync_on_startup:
pan_client.start_loop()
async def _find_client(self, access_token):
client_info = self.client_info.get(access_token, None)
@ -736,6 +736,8 @@ class ProxyDaemon:
if not client:
return self._unknown_token
client.ensure_sync_running()
sync_filter = request.query.get("filter", None)
query = CIMultiDict(request.query)