This commit is contained in:
Will Hunt 2024-12-05 18:34:19 +01:00 committed by GitHub
commit f4cc85247a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 81 additions and 9 deletions

View file

@ -14,10 +14,10 @@
import asyncio
import os
import time
from collections import defaultdict
from pprint import pformat
from urllib.parse import urlparse
from aiohttp.client_exceptions import ClientConnectionError
from jsonschema import Draft4Validator, FormatChecker, validators
from playhouse.sqliteq import SqliteQueueDatabase
@ -49,10 +49,9 @@ from nio import (
)
from nio.crypto import Sas
from nio.store import SqliteStore
from pantalaimon.index import INDEXING_ENABLED
from pantalaimon.log import logger
from pantalaimon.store import FetchTask, MediaInfo
from pantalaimon.store import FetchTask, MediaInfo, PanSqliteStore
from pantalaimon.thread_messages import (
DaemonResponse,
InviteSasSignal,
@ -161,7 +160,7 @@ class PanClient(AsyncClient):
media_info=None,
):
config = config or AsyncClientConfig(
store=store_class or SqliteStore, store_name="pan.db"
store=store_class or PanSqliteStore, store_name="pan.db"
)
super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
@ -176,6 +175,7 @@ class PanClient(AsyncClient):
self.pan_store = pan_store
self.pan_conf = pan_conf
self.media_info = media_info
self.last_sync_request_ts = 0
if INDEXING_ENABLED:
logger.info("Indexing enabled.")
@ -198,6 +198,7 @@ class PanClient(AsyncClient):
self.send_semaphores = defaultdict(asyncio.Semaphore)
self.send_decision_queues = dict() # type: asyncio.Queue
self.last_sync_token = None
self.last_sync_task = None
self.history_fetcher_task = None
self.history_fetch_queue = asyncio.Queue()
@ -526,6 +527,22 @@ class PanClient(AsyncClient):
)
await self.send_update_device(device)
def ensure_sync_running(self, loop_sleep_time=100):
self.last_sync_request_ts = int(time.time())
if self.task is None:
self.start_loop(loop_sleep_time)
async def can_stop_sync(self):
try:
while True:
await asyncio.sleep(self.pan_conf.sync_stop_after)
if time.time() - self.last_sync_request_ts > self.pan_conf.sync_stop_after:
await self.loop_stop()
break
except (asyncio.CancelledError, KeyboardInterrupt):
return
def start_loop(self, loop_sleep_time=100):
"""Start a loop that runs forever and keeps on syncing with the server.
@ -543,6 +560,8 @@ class PanClient(AsyncClient):
timeout = 30000
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
self.last_sync_token = next_batch
self.last_sync_request_ts = int(time.time())
# We don't store any room state so initial sync needs to be with the
# full_state parameter. Subsequent ones are normal.
@ -558,6 +577,10 @@ class PanClient(AsyncClient):
)
self.task = task
# if self.pan_conf.sync_stop_after > 0:
# self.last_sync_task = loop.create_task(self.can_stop_sync())
return task
async def start_sas(self, message, device):
@ -776,7 +799,7 @@ class PanClient(AsyncClient):
async def loop_stop(self):
"""Stop the client loop."""
logger.info("Stopping the sync loop")
logger.info(f"Stopping the sync loop for {self.user_id}")
if self.task and not self.task.done():
self.task.cancel()
@ -788,6 +811,16 @@ class PanClient(AsyncClient):
self.task = None
if self.last_sync_task and not self.last_sync_task.done():
self.last_sync_task.cancel()
try:
await self.last_sync_task
except KeyboardInterrupt:
pass
self.last_sync_task = None
if self.history_fetcher_task and not self.history_fetcher_task.done():
self.history_fetcher_task.cancel()

View file

@ -39,6 +39,8 @@ class PanConfigParser(configparser.ConfigParser):
"IndexingBatchSize": "100",
"HistoryFetchDelay": "3000",
"DebugEncryption": "False",
"SyncOnStartup": "False",
"StopSyncingTimeout": "600",
"DropOldKeys": "False",
},
converters={
@ -122,6 +124,12 @@ class ServerConfig:
the room history.
history_fetch_delay (int): The delay between room history fetching
requests in seconds.
sync_on_startup (bool): Begin syncing all accounts registered with
pantalaimon on startup.
sync_stop_after (int): The number of seconds to wait since the
client has requested a /sync, before stopping a sync.
store_forgetful (bool): Enable or disable discarding of previous sessions
from the store.
drop_old_keys (bool): Should Pantalaimon only keep the most recent
decryption key around.
"""
@ -140,6 +148,9 @@ class ServerConfig:
index_encrypted_only = attr.ib(type=bool, default=True)
indexing_batch_size = attr.ib(type=int, default=100)
history_fetch_delay = attr.ib(type=int, default=3)
sync_on_startup = attr.ib(type=bool, default=False)
sync_stop_after = attr.ib(type=int, default=600)
store_forgetful = attr.ib(type=bool, default=True)
drop_old_keys = attr.ib(type=bool, default=False)
@ -204,9 +215,12 @@ class PanConfig:
proxy = section.geturl("Proxy")
search_requests = section.getboolean("SearchRequests")
index_encrypted_only = section.getboolean("IndexEncryptedOnly")
store_forgetful = config["Default"].getboolean("StoreForgetful")
indexing_batch_size = section.getint("IndexingBatchSize")
sync_on_startup = False #section.getboolean("SyncOnStartup")
sync_stop_after = 600 #section.getint("SyncStopAfter")
if not 1 < indexing_batch_size <= 1000:
raise PanConfigError(
"The indexing batch size needs to be "
@ -247,6 +261,9 @@ class PanConfig:
index_encrypted_only,
indexing_batch_size,
history_fetch_delay / 1000,
sync_on_startup,
sync_stop_after,
store_forgetful,
drop_old_keys,
)

View file

@ -15,6 +15,7 @@
import asyncio
import json
import os
import time
import urllib.parse
import concurrent.futures
from io import BufferedReader, BytesIO
@ -22,7 +23,6 @@ from json import JSONDecodeError
from typing import Any, Dict
from urllib.parse import urlparse
from uuid import uuid4
import aiohttp
import attr
import keyring
@ -163,6 +163,7 @@ class ProxyDaemon:
pan_client.user_id = user_id
pan_client.access_token = token
pan_client.load_store()
pan_client.store.forgetful = self.conf.store_forgetful
self.pan_clients[user_id] = pan_client
loop.create_task(
@ -172,8 +173,8 @@ class ProxyDaemon:
)
loop.create_task(pan_client.send_update_devices(pan_client.device_store))
pan_client.start_loop()
if self.conf.sync_on_startup:
pan_client.start_loop()
async def _find_client(self, access_token):
client_info = self.client_info.get(access_token, None)
@ -756,6 +757,8 @@ class ProxyDaemon:
if not client:
return self._unknown_token
client.ensure_sync_running()
sync_filter = request.query.get("filter", None)
query = CIMultiDict(request.query)

View file

@ -27,6 +27,9 @@ from nio.store import (
DeviceTrustState,
use_database,
use_database_atomic,
SqliteStore,
MegolmInboundSessions,
ForwardedChains
)
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
from cachetools import LRUCache
@ -454,6 +457,22 @@ class PanStore:
return store
class PanSqliteStore(SqliteStore):
forgetful = False
@use_database
def save_inbound_group_session(self, session):
"""Save the provided Megolm inbound group session to the database.
Args:
session (InboundGroupSession): The session to save.
"""
# Delete previous sessions
if self.forgetful:
MegolmInboundSessions.delete().where(
(MegolmInboundSessions.sender_key == session.sender_key) |
(MegolmInboundSessions.room_id == session.room_id)
).execute()
super().save_inbound_group_session(session)
class KeyDroppingSqliteStore(SqliteStore):
@use_database