daemon: Store and restore client info of our children.

This commit is contained in:
Damir Jelić 2019-04-11 16:59:37 +02:00
parent 700510aa36
commit 1bbf38e240
3 changed files with 66 additions and 12 deletions

View File

@ -20,13 +20,7 @@ from nio import GroupEncryptionError, LoginResponse
from pantalaimon.client import PanClient from pantalaimon.client import PanClient
from pantalaimon.log import logger from pantalaimon.log import logger
from pantalaimon.store import PanStore from pantalaimon.store import ClientInfo, PanStore
@attr.s
class Client:
user_id = attr.ib(type=str)
access_token = attr.ib(type=str)
@attr.s @attr.s
@ -50,6 +44,8 @@ class ProxyDaemon:
self.store = PanStore(self.data_dir) self.store = PanStore(self.data_dir)
accounts = self.store.get_users() accounts = self.store.get_users()
self.client_info = self.store.load_clients()
for user_id, device_id in accounts: for user_id, device_id in accounts:
token = self.store.load_access_token(user_id, device_id) token = self.store.load_access_token(user_id, device_id)
@ -159,8 +155,9 @@ class ProxyDaemon:
return user return user
async def start_pan_client(self, access_token, user, user_id, password): async def start_pan_client(self, access_token, user, user_id, password):
client = Client(user_id, access_token) client = ClientInfo(user_id, access_token)
self.client_info[access_token] = client self.client_info[access_token] = client
self.store.save_client(client)
if user_id in self.pan_clients: if user_id in self.pan_clients:
logger.info(f"Background sync client already exists for {user_id}," logger.info(f"Background sync client already exists for {user_id},"

View File

@ -1,9 +1,9 @@
import os import os
from typing import List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import attr import attr
from nio.store import Accounts, use_database from nio.store import Accounts, use_database
from peewee import (DoesNotExist, ForeignKeyField, Model, SqliteDatabase, from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
TextField) TextField)
@ -17,13 +17,27 @@ class AccessTokens(Model):
) )
class Clients(Model):
user_id = TextField()
token = TextField()
class Meta:
constraints = [SQL("UNIQUE(user_id,token)")]
@attr.s
class ClientInfo:
user_id = attr.ib(type=str)
access_token = attr.ib(type=str)
@attr.s @attr.s
class PanStore: class PanStore:
store_path = attr.ib(type=str) store_path = attr.ib(type=str)
database_name = attr.ib(type=str, default="pan.db") database_name = attr.ib(type=str, default="pan.db")
database = attr.ib(type=SqliteDatabase, init=False) database = attr.ib(type=SqliteDatabase, init=False)
database_path = attr.ib(type=str, init=False) database_path = attr.ib(type=str, init=False)
models = [Accounts, AccessTokens] models = [Accounts, AccessTokens, Clients]
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.database_path = os.path.join( self.database_path = os.path.join(
@ -93,3 +107,24 @@ class PanStore:
return account.access_token[0].token return account.access_token[0].token
except IndexError: except IndexError:
return None return None
@use_database
def save_client(self, client):
# type: (ClientInfo) -> None
Clients.replace(
user_id=client.user_id,
token=client.access_token
).execute()
@use_database
def load_clients(self):
# type: () -> Dict[str, ClientInfo]
clients = dict()
query = Clients.select()
for c in query:
client = ClientInfo(c.user_id, c.token)
clients[c.token] = client
return clients

View File

@ -10,7 +10,7 @@ from faker.providers import BaseProvider
from nio.crypto import OlmAccount from nio.crypto import OlmAccount
from nio.store import SqliteStore from nio.store import SqliteStore
from pantalaimon.store import PanStore from pantalaimon.store import ClientInfo, PanStore
faker = Faker() faker = Faker()
@ -25,6 +25,9 @@ class Provider(BaseProvider):
def access_token(self): def access_token(self):
return "MDA" + "".join(choices(digits + ascii_letters, k=272)) return "MDA" + "".join(choices(digits + ascii_letters, k=272))
def client(self):
return ClientInfo(faker.mx_id(), faker.access_token())
faker.add_provider(Provider) faker.add_provider(Provider)
@ -34,6 +37,11 @@ def access_token():
return faker.access_token() return faker.access_token()
@pytest.fixture
def client():
return faker.client()
@pytest.fixture @pytest.fixture
def tempdir(): def tempdir():
newpath = tempfile.mkdtemp() newpath = tempfile.mkdtemp()
@ -73,3 +81,17 @@ class TestClass(object):
token = panstore.load_access_token(user_id, device_id) token = panstore.load_access_token(user_id, device_id)
access_token == token access_token == token
def test_child_clinets_storing(self, panstore, client):
clients = panstore.load_clients()
assert not clients
panstore.save_client(client)
clients = panstore.load_clients()
assert clients
client2 = faker.client()
panstore.save_client(client2)
clients = panstore.load_clients()
assert len(clients) == 2