diff --git a/pantalaimon/client.py b/pantalaimon/client.py index 52e09b7..1077e53 100644 --- a/pantalaimon/client.py +++ b/pantalaimon/client.py @@ -19,7 +19,6 @@ from functools import partial from pprint import pformat from typing import Any, Dict, Optional -import attr from aiohttp.client_exceptions import ClientConnectionError from jsonschema import Draft4Validator, FormatChecker, validators from nio import (AsyncClient, ClientConfig, EncryptionError, KeysQueryResponse, @@ -33,6 +32,7 @@ from nio.store import SqliteStore from pantalaimon.index import Index from pantalaimon.log import logger +from pantalaimon.store import FetchTask from pantalaimon.thread_messages import (DaemonResponse, InviteSasSignal, SasDoneSignal, ShowSasSignal, UpdateDevicesMessage) @@ -104,12 +104,6 @@ class InvalidOrderByError(Exception): pass -@attr.s -class FetchTask: - room_id = attr.ib(type=str) - token = attr.ib(type=str) - - class PanClient(AsyncClient): """A wrapper class around a nio AsyncClient extending its functionality.""" @@ -216,10 +210,23 @@ class PanClient(AsyncClient): message = UpdateDevicesMessage() await self.queue.put(message) + def delete_fetcher_task(self, task): + self.pan_store.delete_fetcher_task( + self.server_name, + self.user_id, + task + ) + async def fetcher_loop(self): + for t in self.pan_store.load_fetcher_tasks( + self.server_name, + self.user_id + ): + await self.history_fetch_queue.put(t) + while True: try: - await asyncio.sleep(5) + await asyncio.sleep(3) fetch_task = await self.history_fetch_queue.get() @@ -228,6 +235,7 @@ class PanClient(AsyncClient): except KeyError: # The room is missing from our client, we probably left the # room. + self.delete_fetcher_task(fetch_task) continue try: @@ -239,8 +247,9 @@ class PanClient(AsyncClient): except ClientConnectionError: self.history_fetch_queue.put(fetch_task) - # The chunk was empyt, we're at the start of the timeline. + # The chunk was empty, we're at the start of the timeline. if not response.chunk: + self.delete_fetcher_task(fetch_task) continue for event in response.chunk: @@ -259,10 +268,13 @@ class PanClient(AsyncClient): else: # There may be even more events to fetch, add a new task to # the queue. - await self.history_fetch_queue.put( - FetchTask(room.room_id, response.end) - ) - except asyncio.CancelledError: + task = FetchTask(room.room_id, response.end) + self.pan_store.save_fetcher_task(self.server_name, + self.user_id, task) + await self.history_fetch_queue.put(task) + + self.delete_fetcher_task(fetch_task) + except (asyncio.CancelledError, KeyboardInterrupt): return async def sync_tasks(self, response): @@ -290,9 +302,11 @@ class PanClient(AsyncClient): "room for history fetching.".format( self.rooms[room_id].display_name )) - await self.history_fetch_queue.put( - FetchTask(room_id, room.timeline.prev_batch) - ) + task = FetchTask(room_id, room.timeline.prev_batch) + self.pan_store.save_fetcher_task(self.server_name, + self.user_id, task) + + await self.history_fetch_queue.put(task) async def keys_query_cb(self, response): await self.send_update_devcies() @@ -593,6 +607,8 @@ class PanClient(AsyncClient): await self.history_fetcher_task self.history_fetcher_task = None + self.history_fetch_queue = asyncio.Queue() + def pan_decrypt_event( self, event_dict, diff --git a/pantalaimon/store.py b/pantalaimon/store.py index 7cb536d..332d00f 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -26,6 +26,12 @@ from peewee import (SQL, DateTimeField, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField) +@attr.s +class FetchTask: + room_id = attr.ib(type=str) + token = attr.ib(type=str) + + class DictField(TextField): def python_value(self, value): # pragma: no cover return json.loads(value) @@ -109,6 +115,18 @@ class PanSyncTokens(Model): constraints = [SQL("UNIQUE(user_id)")] +class PanFetcherTasks(Model): + user = ForeignKeyField( + model=ServerUsers, + column_name="user_id", + backref="fetcher_tasks") + room_id = TextField() + token = TextField() + + class Meta: + constraints = [SQL("UNIQUE(user_id, room_id, token)")] + + @attr.s class ClientInfo: user_id = attr.ib(type=str) @@ -131,7 +149,8 @@ class PanStore: Profile, Event, UserMessages, - PanSyncTokens + PanSyncTokens, + PanFetcherTasks ] def __attrs_post_init__(self): @@ -165,6 +184,38 @@ class PanStore: except DoesNotExist: return None + @use_database + def save_fetcher_task(self, server, pan_user, task): + server = Servers.get(name=server) + user = ServerUsers.get(server=server, user_id=pan_user) + + PanFetcherTasks.replace( + user=user, + room_id=task.room_id, + token=task.token + ).execute() + + def load_fetcher_tasks(self, server, pan_user): + server = Servers.get(name=server) + user = ServerUsers.get(server=server, user_id=pan_user) + + tasks = [] + + for t in user.fetcher_tasks: + tasks.append(FetchTask(t.room_id, t.token)) + + return tasks + + def delete_fetcher_task(self, server, pan_user, task): + server = Servers.get(name=server) + user = ServerUsers.get(server=server, user_id=pan_user) + + PanFetcherTasks.delete().where( + PanFetcherTasks.user == user, + PanFetcherTasks.room_id == task.room_id, + PanFetcherTasks.token == task.token + ).execute() + @use_database def save_token(self, server, pan_user, token): # type: (str, str, str) -> None diff --git a/tests/store_test.py b/tests/store_test.py index ee9bbcb..9a5d150 100644 --- a/tests/store_test.py +++ b/tests/store_test.py @@ -4,6 +4,7 @@ from nio import RoomMessage from conftest import faker from pantalaimon.index import Index +from pantalaimon.store import FetchTask TEST_ROOM = "!SVkFJHzfwvuaIEawgC:localhost" TEST_ROOM2 = "!testroom:localhost" @@ -125,4 +126,28 @@ class TestClass(object): assert not panstore.load_token("example", user) panstore.save_token("example", user, "abc123") - assert "abc123" == panstore.load_token("example", user) + assert panstore.load_token("example", user) == "abc123" + + def test_fetcher_tasks(self, panstore_with_users): + panstore = panstore_with_users + accounts = panstore.load_all_users() + user, _ = accounts[0] + + task = FetchTask(TEST_ROOM, "abc1234") + task2 = FetchTask(TEST_ROOM2, "abc1234") + + assert not panstore.load_fetcher_tasks("example", user) + + panstore.save_fetcher_task("example", user, task) + panstore.save_fetcher_task("example", user, task2) + + tasks = panstore.load_fetcher_tasks("example", user) + + assert task in tasks + assert task2 in tasks + + panstore.delete_fetcher_task("example", user, task) + tasks = panstore.load_fetcher_tasks("example", user) + + assert task not in tasks + assert task2 in tasks