mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-02-11 20:38:49 -05:00
daemon: Store and restore client info of our children.
This commit is contained in:
parent
700510aa36
commit
1bbf38e240
@ -20,13 +20,7 @@ from nio import GroupEncryptionError, LoginResponse
|
|||||||
|
|
||||||
from pantalaimon.client import PanClient
|
from pantalaimon.client import PanClient
|
||||||
from pantalaimon.log import logger
|
from pantalaimon.log import logger
|
||||||
from pantalaimon.store import PanStore
|
from pantalaimon.store import ClientInfo, PanStore
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
class Client:
|
|
||||||
user_id = attr.ib(type=str)
|
|
||||||
access_token = attr.ib(type=str)
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
@ -50,6 +44,8 @@ class ProxyDaemon:
|
|||||||
self.store = PanStore(self.data_dir)
|
self.store = PanStore(self.data_dir)
|
||||||
accounts = self.store.get_users()
|
accounts = self.store.get_users()
|
||||||
|
|
||||||
|
self.client_info = self.store.load_clients()
|
||||||
|
|
||||||
for user_id, device_id in accounts:
|
for user_id, device_id in accounts:
|
||||||
token = self.store.load_access_token(user_id, device_id)
|
token = self.store.load_access_token(user_id, device_id)
|
||||||
|
|
||||||
@ -159,8 +155,9 @@ class ProxyDaemon:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
async def start_pan_client(self, access_token, user, user_id, password):
|
async def start_pan_client(self, access_token, user, user_id, password):
|
||||||
client = Client(user_id, access_token)
|
client = ClientInfo(user_id, access_token)
|
||||||
self.client_info[access_token] = client
|
self.client_info[access_token] = client
|
||||||
|
self.store.save_client(client)
|
||||||
|
|
||||||
if user_id in self.pan_clients:
|
if user_id in self.pan_clients:
|
||||||
logger.info(f"Background sync client already exists for {user_id},"
|
logger.info(f"Background sync client already exists for {user_id},"
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from nio.store import Accounts, use_database
|
from nio.store import Accounts, use_database
|
||||||
from peewee import (DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
||||||
TextField)
|
TextField)
|
||||||
|
|
||||||
|
|
||||||
@ -17,13 +17,27 @@ class AccessTokens(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Clients(Model):
|
||||||
|
user_id = TextField()
|
||||||
|
token = TextField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(user_id,token)")]
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
class ClientInfo:
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
access_token = attr.ib(type=str)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
class PanStore:
|
class PanStore:
|
||||||
store_path = attr.ib(type=str)
|
store_path = attr.ib(type=str)
|
||||||
database_name = attr.ib(type=str, default="pan.db")
|
database_name = attr.ib(type=str, default="pan.db")
|
||||||
database = attr.ib(type=SqliteDatabase, init=False)
|
database = attr.ib(type=SqliteDatabase, init=False)
|
||||||
database_path = attr.ib(type=str, init=False)
|
database_path = attr.ib(type=str, init=False)
|
||||||
models = [Accounts, AccessTokens]
|
models = [Accounts, AccessTokens, Clients]
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
self.database_path = os.path.join(
|
self.database_path = os.path.join(
|
||||||
@ -93,3 +107,24 @@ class PanStore:
|
|||||||
return account.access_token[0].token
|
return account.access_token[0].token
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def save_client(self, client):
|
||||||
|
# type: (ClientInfo) -> None
|
||||||
|
Clients.replace(
|
||||||
|
user_id=client.user_id,
|
||||||
|
token=client.access_token
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def load_clients(self):
|
||||||
|
# type: () -> Dict[str, ClientInfo]
|
||||||
|
clients = dict()
|
||||||
|
|
||||||
|
query = Clients.select()
|
||||||
|
|
||||||
|
for c in query:
|
||||||
|
client = ClientInfo(c.user_id, c.token)
|
||||||
|
clients[c.token] = client
|
||||||
|
|
||||||
|
return clients
|
||||||
|
@ -10,7 +10,7 @@ from faker.providers import BaseProvider
|
|||||||
from nio.crypto import OlmAccount
|
from nio.crypto import OlmAccount
|
||||||
from nio.store import SqliteStore
|
from nio.store import SqliteStore
|
||||||
|
|
||||||
from pantalaimon.store import PanStore
|
from pantalaimon.store import ClientInfo, PanStore
|
||||||
|
|
||||||
faker = Faker()
|
faker = Faker()
|
||||||
|
|
||||||
@ -25,6 +25,9 @@ class Provider(BaseProvider):
|
|||||||
def access_token(self):
|
def access_token(self):
|
||||||
return "MDA" + "".join(choices(digits + ascii_letters, k=272))
|
return "MDA" + "".join(choices(digits + ascii_letters, k=272))
|
||||||
|
|
||||||
|
def client(self):
|
||||||
|
return ClientInfo(faker.mx_id(), faker.access_token())
|
||||||
|
|
||||||
|
|
||||||
faker.add_provider(Provider)
|
faker.add_provider(Provider)
|
||||||
|
|
||||||
@ -34,6 +37,11 @@ def access_token():
|
|||||||
return faker.access_token()
|
return faker.access_token()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return faker.client()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tempdir():
|
def tempdir():
|
||||||
newpath = tempfile.mkdtemp()
|
newpath = tempfile.mkdtemp()
|
||||||
@ -73,3 +81,17 @@ class TestClass(object):
|
|||||||
|
|
||||||
token = panstore.load_access_token(user_id, device_id)
|
token = panstore.load_access_token(user_id, device_id)
|
||||||
access_token == token
|
access_token == token
|
||||||
|
|
||||||
|
def test_child_clinets_storing(self, panstore, client):
|
||||||
|
clients = panstore.load_clients()
|
||||||
|
assert not clients
|
||||||
|
|
||||||
|
panstore.save_client(client)
|
||||||
|
|
||||||
|
clients = panstore.load_clients()
|
||||||
|
assert clients
|
||||||
|
|
||||||
|
client2 = faker.client()
|
||||||
|
panstore.save_client(client2)
|
||||||
|
clients = panstore.load_clients()
|
||||||
|
assert len(clients) == 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user