mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-02-07 02:25:23 -05:00
daemon: Store and restore client info of our children.
This commit is contained in:
parent
700510aa36
commit
1bbf38e240
@ -20,13 +20,7 @@ from nio import GroupEncryptionError, LoginResponse
|
||||
|
||||
from pantalaimon.client import PanClient
|
||||
from pantalaimon.log import logger
|
||||
from pantalaimon.store import PanStore
|
||||
|
||||
|
||||
@attr.s
|
||||
class Client:
|
||||
user_id = attr.ib(type=str)
|
||||
access_token = attr.ib(type=str)
|
||||
from pantalaimon.store import ClientInfo, PanStore
|
||||
|
||||
|
||||
@attr.s
|
||||
@ -50,6 +44,8 @@ class ProxyDaemon:
|
||||
self.store = PanStore(self.data_dir)
|
||||
accounts = self.store.get_users()
|
||||
|
||||
self.client_info = self.store.load_clients()
|
||||
|
||||
for user_id, device_id in accounts:
|
||||
token = self.store.load_access_token(user_id, device_id)
|
||||
|
||||
@ -159,8 +155,9 @@ class ProxyDaemon:
|
||||
return user
|
||||
|
||||
async def start_pan_client(self, access_token, user, user_id, password):
|
||||
client = Client(user_id, access_token)
|
||||
client = ClientInfo(user_id, access_token)
|
||||
self.client_info[access_token] = client
|
||||
self.store.save_client(client)
|
||||
|
||||
if user_id in self.pan_clients:
|
||||
logger.info(f"Background sync client already exists for {user_id},"
|
||||
|
@ -1,9 +1,9 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from nio.store import Accounts, use_database
|
||||
from peewee import (DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||
TextField)
|
||||
|
||||
|
||||
@ -17,13 +17,27 @@ class AccessTokens(Model):
|
||||
)
|
||||
|
||||
|
||||
class Clients(Model):
|
||||
user_id = TextField()
|
||||
token = TextField()
|
||||
|
||||
class Meta:
|
||||
constraints = [SQL("UNIQUE(user_id,token)")]
|
||||
|
||||
|
||||
@attr.s
|
||||
class ClientInfo:
|
||||
user_id = attr.ib(type=str)
|
||||
access_token = attr.ib(type=str)
|
||||
|
||||
|
||||
@attr.s
|
||||
class PanStore:
|
||||
store_path = attr.ib(type=str)
|
||||
database_name = attr.ib(type=str, default="pan.db")
|
||||
database = attr.ib(type=SqliteDatabase, init=False)
|
||||
database_path = attr.ib(type=str, init=False)
|
||||
models = [Accounts, AccessTokens]
|
||||
models = [Accounts, AccessTokens, Clients]
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
self.database_path = os.path.join(
|
||||
@ -93,3 +107,24 @@ class PanStore:
|
||||
return account.access_token[0].token
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@use_database
|
||||
def save_client(self, client):
|
||||
# type: (ClientInfo) -> None
|
||||
Clients.replace(
|
||||
user_id=client.user_id,
|
||||
token=client.access_token
|
||||
).execute()
|
||||
|
||||
@use_database
|
||||
def load_clients(self):
|
||||
# type: () -> Dict[str, ClientInfo]
|
||||
clients = dict()
|
||||
|
||||
query = Clients.select()
|
||||
|
||||
for c in query:
|
||||
client = ClientInfo(c.user_id, c.token)
|
||||
clients[c.token] = client
|
||||
|
||||
return clients
|
||||
|
@ -10,7 +10,7 @@ from faker.providers import BaseProvider
|
||||
from nio.crypto import OlmAccount
|
||||
from nio.store import SqliteStore
|
||||
|
||||
from pantalaimon.store import PanStore
|
||||
from pantalaimon.store import ClientInfo, PanStore
|
||||
|
||||
faker = Faker()
|
||||
|
||||
@ -25,6 +25,9 @@ class Provider(BaseProvider):
|
||||
def access_token(self):
|
||||
return "MDA" + "".join(choices(digits + ascii_letters, k=272))
|
||||
|
||||
def client(self):
|
||||
return ClientInfo(faker.mx_id(), faker.access_token())
|
||||
|
||||
|
||||
faker.add_provider(Provider)
|
||||
|
||||
@ -34,6 +37,11 @@ def access_token():
|
||||
return faker.access_token()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return faker.client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tempdir():
|
||||
newpath = tempfile.mkdtemp()
|
||||
@ -73,3 +81,17 @@ class TestClass(object):
|
||||
|
||||
token = panstore.load_access_token(user_id, device_id)
|
||||
access_token == token
|
||||
|
||||
def test_child_clinets_storing(self, panstore, client):
|
||||
clients = panstore.load_clients()
|
||||
assert not clients
|
||||
|
||||
panstore.save_client(client)
|
||||
|
||||
clients = panstore.load_clients()
|
||||
assert clients
|
||||
|
||||
client2 = faker.client()
|
||||
panstore.save_client(client2)
|
||||
clients = panstore.load_clients()
|
||||
assert len(clients) == 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user