From 1bbf38e2404f081f538210f7b8a9b3e6007cff9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Damir=20Jeli=C4=87?= Date: Thu, 11 Apr 2019 16:59:37 +0200 Subject: [PATCH] daemon: Store and restore client info of our children. --- pantalaimon/daemon.py | 13 +++++-------- pantalaimon/store.py | 41 ++++++++++++++++++++++++++++++++++++++--- tests/store_test.py | 24 +++++++++++++++++++++++- 3 files changed, 66 insertions(+), 12 deletions(-) diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 9481835..37dce79 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -20,13 +20,7 @@ from nio import GroupEncryptionError, LoginResponse from pantalaimon.client import PanClient from pantalaimon.log import logger -from pantalaimon.store import PanStore - - -@attr.s -class Client: - user_id = attr.ib(type=str) - access_token = attr.ib(type=str) +from pantalaimon.store import ClientInfo, PanStore @attr.s @@ -50,6 +44,8 @@ class ProxyDaemon: self.store = PanStore(self.data_dir) accounts = self.store.get_users() + self.client_info = self.store.load_clients() + for user_id, device_id in accounts: token = self.store.load_access_token(user_id, device_id) @@ -159,8 +155,9 @@ class ProxyDaemon: return user async def start_pan_client(self, access_token, user, user_id, password): - client = Client(user_id, access_token) + client = ClientInfo(user_id, access_token) self.client_info[access_token] = client + self.store.save_client(client) if user_id in self.pan_clients: logger.info(f"Background sync client already exists for {user_id}," diff --git a/pantalaimon/store.py b/pantalaimon/store.py index ecfd7e6..79b9e38 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -1,9 +1,9 @@ import os -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import attr from nio.store import Accounts, use_database -from peewee import (DoesNotExist, ForeignKeyField, Model, SqliteDatabase, +from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField) @@ -17,13 +17,27 @@ class AccessTokens(Model): ) +class Clients(Model): + user_id = TextField() + token = TextField() + + class Meta: + constraints = [SQL("UNIQUE(user_id,token)")] + + +@attr.s +class ClientInfo: + user_id = attr.ib(type=str) + access_token = attr.ib(type=str) + + @attr.s class PanStore: store_path = attr.ib(type=str) database_name = attr.ib(type=str, default="pan.db") database = attr.ib(type=SqliteDatabase, init=False) database_path = attr.ib(type=str, init=False) - models = [Accounts, AccessTokens] + models = [Accounts, AccessTokens, Clients] def __attrs_post_init__(self): self.database_path = os.path.join( @@ -93,3 +107,24 @@ class PanStore: return account.access_token[0].token except IndexError: return None + + @use_database + def save_client(self, client): + # type: (ClientInfo) -> None + Clients.replace( + user_id=client.user_id, + token=client.access_token + ).execute() + + @use_database + def load_clients(self): + # type: () -> Dict[str, ClientInfo] + clients = dict() + + query = Clients.select() + + for c in query: + client = ClientInfo(c.user_id, c.token) + clients[c.token] = client + + return clients diff --git a/tests/store_test.py b/tests/store_test.py index 0891d21..291ef13 100644 --- a/tests/store_test.py +++ b/tests/store_test.py @@ -10,7 +10,7 @@ from faker.providers import BaseProvider from nio.crypto import OlmAccount from nio.store import SqliteStore -from pantalaimon.store import PanStore +from pantalaimon.store import ClientInfo, PanStore faker = Faker() @@ -25,6 +25,9 @@ class Provider(BaseProvider): def access_token(self): return "MDA" + "".join(choices(digits + ascii_letters, k=272)) + def client(self): + return ClientInfo(faker.mx_id(), faker.access_token()) + faker.add_provider(Provider) @@ -34,6 +37,11 @@ def access_token(): return faker.access_token() +@pytest.fixture +def client(): + return faker.client() + + @pytest.fixture def tempdir(): newpath = tempfile.mkdtemp() @@ -73,3 +81,17 @@ class TestClass(object): token = panstore.load_access_token(user_id, device_id) access_token == token + + def test_child_clinets_storing(self, panstore, client): + clients = panstore.load_clients() + assert not clients + + panstore.save_client(client) + + clients = panstore.load_clients() + assert clients + + client2 = faker.client() + panstore.save_client(client2) + clients = panstore.load_clients() + assert len(clients) == 2