client: Add/delete a continued/old fetcher task atomically.

This commit is contained in:
Damir Jelić 2019-07-04 14:54:27 +02:00
parent 98dfda9496
commit bf64f4cbae
2 changed files with 27 additions and 5 deletions

View File

@ -306,13 +306,14 @@ class PanClient(AsyncClient):
# There may be even more events to fetch, add a new task to # There may be even more events to fetch, add a new task to
# the queue. # the queue.
task = FetchTask(room.room_id, response.end) task = FetchTask(room.room_id, response.end)
self.pan_store.save_fetcher_task( self.pan_store.replace_fetcher_task(
self.server_name, self.user_id, task self.server_name, self.user_id, fetch_task, task
) )
await self.history_fetch_queue.put(task) await self.history_fetch_queue.put(task)
else:
await self.index.commit_events()
self.delete_fetcher_task(fetch_task)
await self.index.commit_events()
self.delete_fetcher_task(fetch_task)
except asyncio.CancelledError: except asyncio.CancelledError:
return return

View File

@ -19,7 +19,13 @@ from typing import List, Optional, Tuple
import attr import attr
from nio.crypto import TrustState from nio.crypto import TrustState
from nio.store import Accounts, DeviceKeys, DeviceTrustState, use_database from nio.store import (
Accounts,
DeviceKeys,
DeviceTrustState,
use_database,
use_database_atomic,
)
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField
@ -128,6 +134,21 @@ class PanStore:
except DoesNotExist: except DoesNotExist:
return None return None
@use_database_atomic
def replace_fetcher_task(self, server, pan_user, old_task, new_task):
server = Servers.get(name=server)
user = ServerUsers.get(server=server, user_id=pan_user)
PanFetcherTasks.delete().where(
PanFetcherTasks.user == user,
PanFetcherTasks.room_id == old_task.room_id,
PanFetcherTasks.token == old_task.token,
).execute()
PanFetcherTasks.replace(
user=user, room_id=new_task.room_id, token=new_task.token
).execute()
@use_database @use_database
def save_fetcher_task(self, server, pan_user, task): def save_fetcher_task(self, server, pan_user, task):
server = Servers.get(name=server) server = Servers.get(name=server)