diff --git a/pantalaimon/daemon.py b/pantalaimon/daemon.py index 5113a0d..7233d42 100755 --- a/pantalaimon/daemon.py +++ b/pantalaimon/daemon.py @@ -32,6 +32,7 @@ class ProxyDaemon: ssl = attr.ib(default=None) store = attr.ib(type=PanStore, init=False) + homeserver_url = attr.ib(init=False, default=attr.Factory(dict)) pan_clients = attr.ib(init=False, default=attr.Factory(dict)) client_info = attr.ib( init=False, @@ -42,10 +43,12 @@ class ProxyDaemon: database_name = "pan.db" def __attrs_post_init__(self): + self.homeserver_url = self.homeserver.geturl() + self.hostname = self.homeserver.hostname self.store = PanStore(self.data_dir) - accounts = self.store.get_users() + accounts = self.store.load_users(self.hostname) - self.client_info = self.store.load_clients() + self.client_info = self.store.load_clients(self.hostname) for user_id, device_id in accounts: token = self.store.load_access_token(user_id, device_id) @@ -58,7 +61,7 @@ class ProxyDaemon: logger.info(f"Restoring client for {user_id} {device_id}") pan_client = PanClient( - self.homeserver, + self.homeserver_url, user_id, device_id, store_path=self.data_dir, @@ -126,7 +129,7 @@ class ProxyDaemon: return await session.request( method, - self.homeserver + path, + self.homeserver_url + path, data=data, params=params, headers=headers, @@ -158,7 +161,8 @@ class ProxyDaemon: async def start_pan_client(self, access_token, user, user_id, password): client = ClientInfo(user_id, access_token) self.client_info[access_token] = client - self.store.save_client(client) + self.store.save_client(self.hostname, client) + self.store.save_server_user(self.hostname, user_id) if user_id in self.pan_clients: logger.info(f"Background sync client already exists for {user_id}," @@ -166,7 +170,7 @@ class ProxyDaemon: return pan_client = PanClient( - self.homeserver, + self.homeserver_url, user, store_path=self.data_dir, ssl=self.ssl, @@ -479,7 +483,7 @@ def cli(): def _find_device(user): data_dir = user_data_dir("pantalaimon", "") store = PanStore(data_dir) - accounts = store.get_users() + accounts = store.load_all_users() for user_id, device in accounts: if user == user_id: @@ -549,7 +553,7 @@ def keys_export(user, outfile, passphrase): def list_users(): data_dir = user_data_dir("pantalaimon", "") store = PanStore(data_dir) - accounts = store.get_users() + accounts = store.load_all_users() click.echo(f"Pantalaimon users:") for user, device in accounts: @@ -617,7 +621,7 @@ def start( loop = asyncio.get_event_loop() proxy, app = loop.run_until_complete(init( - homeserver.geturl(), + homeserver, proxy.geturl() if proxy else None, ssl )) diff --git a/pantalaimon/store.py b/pantalaimon/store.py index 79b9e38..6d67ce7 100644 --- a/pantalaimon/store.py +++ b/pantalaimon/store.py @@ -17,12 +17,38 @@ class AccessTokens(Model): ) +class Servers(Model): + hostname = TextField() + + class Meta: + constraints = [SQL("UNIQUE(hostname)")] + + +class ServerUsers(Model): + user_id = TextField() + server = ForeignKeyField( + model=Servers, + column_name="server_id", + backref="users", + on_delete="CASCADE" + ) + + class Meta: + constraints = [SQL("UNIQUE(user_id,server_id)")] + + class Clients(Model): user_id = TextField() token = TextField() + server = ForeignKeyField( + model=Servers, + column_name="server_id", + backref="clients", + on_delete="CASCADE" + ) class Meta: - constraints = [SQL("UNIQUE(user_id,token)")] + constraints = [SQL("UNIQUE(user_id,token,server_id)")] @attr.s @@ -37,7 +63,7 @@ class PanStore: database_name = attr.ib(type=str, default="pan.db") database = attr.ib(type=SqliteDatabase, init=False) database_path = attr.ib(type=str, init=False) - models = [Accounts, AccessTokens, Clients] + models = [Accounts, AccessTokens, Clients, Servers, ServerUsers] def __attrs_post_init__(self): self.database_path = os.path.join( @@ -71,8 +97,17 @@ class PanStore: return None @use_database - def get_users(self): - # type: () -> List[Tuple[str, str]] + def save_server_user(self, homeserver, user_id): + # type: (ClientInfo) -> None + server, _ = Servers.get_or_create(hostname=homeserver) + + ServerUsers.replace( + user_id=user_id, + server=server + ).execute() + + @use_database + def load_all_users(self): users = [] query = Accounts.select( @@ -85,6 +120,31 @@ class PanStore: return users + @use_database + def load_users(self, homeserver): + # type: () -> List[Tuple[str, str]] + users = [] + + server = Servers.get_or_none(Servers.hostname == homeserver) + + if not server: + return [] + + server_users = [] + + for u in server.users: + server_users.append(u.user_id) + + query = Accounts.select( + Accounts.user_id, + Accounts.device_id, + ).where(Accounts.user_id.in_(server_users)) + + for account in query: + users.append((account.user_id, account.device_id)) + + return users + @use_database def save_access_token(self, user_id, device_id, access_token): account = self._get_account(user_id, device_id) @@ -109,21 +169,24 @@ class PanStore: return None @use_database - def save_client(self, client): + def save_client(self, homeserver, client): # type: (ClientInfo) -> None + server, _ = Servers.get_or_create(hostname=homeserver) + Clients.replace( user_id=client.user_id, - token=client.access_token + token=client.access_token, + server=server.id ).execute() @use_database - def load_clients(self): + def load_clients(self, homeserver): # type: () -> Dict[str, ClientInfo] clients = dict() - query = Clients.select() + server, _ = Servers.get_or_create(hostname=homeserver) - for c in query: + for c in server.clients: client = ClientInfo(c.user_id, c.token) clients[c.token] = client diff --git a/tests/store_test.py b/tests/store_test.py index 291ef13..08f4057 100644 --- a/tests/store_test.py +++ b/tests/store_test.py @@ -68,12 +68,12 @@ def panstore(tempdir): class TestClass(object): def test_account_loading(self, panstore): - accounts = panstore.get_users() + accounts = panstore.load_all_users() # pdb.set_trace() assert len(accounts) == 10 def test_token_saving(self, panstore, access_token): - accounts = panstore.get_users() + accounts = panstore.load_all_users() user_id = accounts[0][0] device_id = accounts[0][1] @@ -83,15 +83,32 @@ class TestClass(object): access_token == token def test_child_clinets_storing(self, panstore, client): - clients = panstore.load_clients() + server = faker.hostname() + clients = panstore.load_clients(server) assert not clients - panstore.save_client(client) + panstore.save_client(server, client) - clients = panstore.load_clients() + clients = panstore.load_clients(server) assert clients client2 = faker.client() - panstore.save_client(client2) - clients = panstore.load_clients() + panstore.save_client(server, client2) + clients = panstore.load_clients(server) assert len(clients) == 2 + + def test_server_account_storing(self, panstore): + accounts = panstore.load_all_users() + + user_id, device_id = accounts[0] + server = faker.hostname() + + panstore.save_server_user(server, user_id) + + server2 = faker.hostname() + user_id2, device_id2 = accounts[1] + panstore.save_server_user(server2, user_id2) + + server_users = panstore.load_users(server) + assert (user_id, device_id) in server_users + assert (user_id2, device_id2) not in server_users