mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2026-01-04 01:25:54 -05:00
Merge 7312e57d85 into 21fb28d090
This commit is contained in:
commit
f4cc85247a
4 changed files with 81 additions and 9 deletions
|
|
@ -14,10 +14,10 @@
|
|||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from pprint import pformat
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from aiohttp.client_exceptions import ClientConnectionError
|
||||
from jsonschema import Draft4Validator, FormatChecker, validators
|
||||
from playhouse.sqliteq import SqliteQueueDatabase
|
||||
|
|
@ -49,10 +49,9 @@ from nio import (
|
|||
)
|
||||
from nio.crypto import Sas
|
||||
from nio.store import SqliteStore
|
||||
|
||||
from pantalaimon.index import INDEXING_ENABLED
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import FetchTask, MediaInfo
|
||||
from pantalaimon.store import FetchTask, MediaInfo, PanSqliteStore
|
||||
from pantalaimon.thread_messages import (
|
||||
DaemonResponse,
|
||||
InviteSasSignal,
|
||||
|
|
@ -161,7 +160,7 @@ class PanClient(AsyncClient):
|
|||
media_info=None,
|
||||
):
|
||||
config = config or AsyncClientConfig(
|
||||
store=store_class or SqliteStore, store_name="pan.db"
|
||||
store=store_class or PanSqliteStore, store_name="pan.db"
|
||||
)
|
||||
super().__init__(homeserver, user_id, device_id, store_path, config, ssl, proxy)
|
||||
|
||||
|
|
@ -176,6 +175,7 @@ class PanClient(AsyncClient):
|
|||
self.pan_store = pan_store
|
||||
self.pan_conf = pan_conf
|
||||
self.media_info = media_info
|
||||
self.last_sync_request_ts = 0
|
||||
|
||||
if INDEXING_ENABLED:
|
||||
logger.info("Indexing enabled.")
|
||||
|
|
@ -198,6 +198,7 @@ class PanClient(AsyncClient):
|
|||
self.send_semaphores = defaultdict(asyncio.Semaphore)
|
||||
self.send_decision_queues = dict() # type: asyncio.Queue
|
||||
self.last_sync_token = None
|
||||
self.last_sync_task = None
|
||||
|
||||
self.history_fetcher_task = None
|
||||
self.history_fetch_queue = asyncio.Queue()
|
||||
|
|
@ -526,6 +527,22 @@ class PanClient(AsyncClient):
|
|||
)
|
||||
await self.send_update_device(device)
|
||||
|
||||
def ensure_sync_running(self, loop_sleep_time=100):
|
||||
self.last_sync_request_ts = int(time.time())
|
||||
if self.task is None:
|
||||
self.start_loop(loop_sleep_time)
|
||||
|
||||
async def can_stop_sync(self):
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(self.pan_conf.sync_stop_after)
|
||||
if time.time() - self.last_sync_request_ts > self.pan_conf.sync_stop_after:
|
||||
await self.loop_stop()
|
||||
break
|
||||
except (asyncio.CancelledError, KeyboardInterrupt):
|
||||
return
|
||||
|
||||
|
||||
def start_loop(self, loop_sleep_time=100):
|
||||
"""Start a loop that runs forever and keeps on syncing with the server.
|
||||
|
||||
|
|
@ -543,6 +560,8 @@ class PanClient(AsyncClient):
|
|||
timeout = 30000
|
||||
sync_filter = {"room": {"state": {"lazy_load_members": True}}}
|
||||
next_batch = self.pan_store.load_token(self.server_name, self.user_id)
|
||||
self.last_sync_token = next_batch
|
||||
self.last_sync_request_ts = int(time.time())
|
||||
|
||||
# We don't store any room state so initial sync needs to be with the
|
||||
# full_state parameter. Subsequent ones are normal.
|
||||
|
|
@ -558,6 +577,10 @@ class PanClient(AsyncClient):
|
|||
)
|
||||
self.task = task
|
||||
|
||||
# if self.pan_conf.sync_stop_after > 0:
|
||||
# self.last_sync_task = loop.create_task(self.can_stop_sync())
|
||||
|
||||
|
||||
return task
|
||||
|
||||
async def start_sas(self, message, device):
|
||||
|
|
@ -776,7 +799,7 @@ class PanClient(AsyncClient):
|
|||
|
||||
async def loop_stop(self):
|
||||
"""Stop the client loop."""
|
||||
logger.info("Stopping the sync loop")
|
||||
logger.info(f"Stopping the sync loop for {self.user_id}")
|
||||
|
||||
if self.task and not self.task.done():
|
||||
self.task.cancel()
|
||||
|
|
@ -788,6 +811,16 @@ class PanClient(AsyncClient):
|
|||
|
||||
self.task = None
|
||||
|
||||
if self.last_sync_task and not self.last_sync_task.done():
|
||||
self.last_sync_task.cancel()
|
||||
|
||||
try:
|
||||
await self.last_sync_task
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
self.last_sync_task = None
|
||||
|
||||
if self.history_fetcher_task and not self.history_fetcher_task.done():
|
||||
self.history_fetcher_task.cancel()
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class PanConfigParser(configparser.ConfigParser):
|
|||
"IndexingBatchSize": "100",
|
||||
"HistoryFetchDelay": "3000",
|
||||
"DebugEncryption": "False",
|
||||
"SyncOnStartup": "False",
|
||||
"StopSyncingTimeout": "600",
|
||||
"DropOldKeys": "False",
|
||||
},
|
||||
converters={
|
||||
|
|
@ -122,6 +124,12 @@ class ServerConfig:
|
|||
the room history.
|
||||
history_fetch_delay (int): The delay between room history fetching
|
||||
requests in seconds.
|
||||
sync_on_startup (bool): Begin syncing all accounts registered with
|
||||
pantalaimon on startup.
|
||||
sync_stop_after (int): The number of seconds to wait since the
|
||||
client has requested a /sync, before stopping a sync.
|
||||
store_forgetful (bool): Enable or disable discarding of previous sessions
|
||||
from the store.
|
||||
drop_old_keys (bool): Should Pantalaimon only keep the most recent
|
||||
decryption key around.
|
||||
"""
|
||||
|
|
@ -140,6 +148,9 @@ class ServerConfig:
|
|||
index_encrypted_only = attr.ib(type=bool, default=True)
|
||||
indexing_batch_size = attr.ib(type=int, default=100)
|
||||
history_fetch_delay = attr.ib(type=int, default=3)
|
||||
sync_on_startup = attr.ib(type=bool, default=False)
|
||||
sync_stop_after = attr.ib(type=int, default=600)
|
||||
store_forgetful = attr.ib(type=bool, default=True)
|
||||
drop_old_keys = attr.ib(type=bool, default=False)
|
||||
|
||||
|
||||
|
|
@ -204,9 +215,12 @@ class PanConfig:
|
|||
proxy = section.geturl("Proxy")
|
||||
search_requests = section.getboolean("SearchRequests")
|
||||
index_encrypted_only = section.getboolean("IndexEncryptedOnly")
|
||||
|
||||
store_forgetful = config["Default"].getboolean("StoreForgetful")
|
||||
indexing_batch_size = section.getint("IndexingBatchSize")
|
||||
|
||||
sync_on_startup = False #section.getboolean("SyncOnStartup")
|
||||
sync_stop_after = 600 #section.getint("SyncStopAfter")
|
||||
|
||||
if not 1 < indexing_batch_size <= 1000:
|
||||
raise PanConfigError(
|
||||
"The indexing batch size needs to be "
|
||||
|
|
@ -247,6 +261,9 @@ class PanConfig:
|
|||
index_encrypted_only,
|
||||
indexing_batch_size,
|
||||
history_fetch_delay / 1000,
|
||||
sync_on_startup,
|
||||
sync_stop_after,
|
||||
store_forgetful,
|
||||
drop_old_keys,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import urllib.parse
|
||||
import concurrent.futures
|
||||
from io import BufferedReader, BytesIO
|
||||
|
|
@ -22,7 +23,6 @@ from json import JSONDecodeError
|
|||
from typing import Any, Dict
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import aiohttp
|
||||
import attr
|
||||
import keyring
|
||||
|
|
@ -163,6 +163,7 @@ class ProxyDaemon:
|
|||
pan_client.user_id = user_id
|
||||
pan_client.access_token = token
|
||||
pan_client.load_store()
|
||||
pan_client.store.forgetful = self.conf.store_forgetful
|
||||
self.pan_clients[user_id] = pan_client
|
||||
|
||||
loop.create_task(
|
||||
|
|
@ -172,8 +173,8 @@ class ProxyDaemon:
|
|||
)
|
||||
|
||||
loop.create_task(pan_client.send_update_devices(pan_client.device_store))
|
||||
|
||||
pan_client.start_loop()
|
||||
if self.conf.sync_on_startup:
|
||||
pan_client.start_loop()
|
||||
|
||||
async def _find_client(self, access_token):
|
||||
client_info = self.client_info.get(access_token, None)
|
||||
|
|
@ -756,6 +757,8 @@ class ProxyDaemon:
|
|||
if not client:
|
||||
return self._unknown_token
|
||||
|
||||
client.ensure_sync_running()
|
||||
|
||||
sync_filter = request.query.get("filter", None)
|
||||
query = CIMultiDict(request.query)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,9 @@ from nio.store import (
|
|||
DeviceTrustState,
|
||||
use_database,
|
||||
use_database_atomic,
|
||||
SqliteStore,
|
||||
MegolmInboundSessions,
|
||||
ForwardedChains
|
||||
)
|
||||
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
|
||||
from cachetools import LRUCache
|
||||
|
|
@ -454,6 +457,22 @@ class PanStore:
|
|||
|
||||
return store
|
||||
|
||||
class PanSqliteStore(SqliteStore):
|
||||
forgetful = False
|
||||
|
||||
@use_database
|
||||
def save_inbound_group_session(self, session):
|
||||
"""Save the provided Megolm inbound group session to the database.
|
||||
Args:
|
||||
session (InboundGroupSession): The session to save.
|
||||
"""
|
||||
# Delete previous sessions
|
||||
if self.forgetful:
|
||||
MegolmInboundSessions.delete().where(
|
||||
(MegolmInboundSessions.sender_key == session.sender_key) |
|
||||
(MegolmInboundSessions.room_id == session.room_id)
|
||||
).execute()
|
||||
super().save_inbound_group_session(session)
|
||||
|
||||
class KeyDroppingSqliteStore(SqliteStore):
|
||||
@use_database
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue