store: The daemon clients need to be per homeserver.

This commit is contained in:
Damir Jelić 2019-04-12 17:59:30 +02:00
parent d1090a714a
commit ea33359daa
3 changed files with 109 additions and 25 deletions

View File

@ -32,6 +32,7 @@ class ProxyDaemon:
ssl = attr.ib(default=None) ssl = attr.ib(default=None)
store = attr.ib(type=PanStore, init=False) store = attr.ib(type=PanStore, init=False)
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
pan_clients = attr.ib(init=False, default=attr.Factory(dict)) pan_clients = attr.ib(init=False, default=attr.Factory(dict))
client_info = attr.ib( client_info = attr.ib(
init=False, init=False,
@ -42,10 +43,12 @@ class ProxyDaemon:
database_name = "pan.db" database_name = "pan.db"
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.homeserver_url = self.homeserver.geturl()
self.hostname = self.homeserver.hostname
self.store = PanStore(self.data_dir) self.store = PanStore(self.data_dir)
accounts = self.store.get_users() accounts = self.store.load_users(self.hostname)
self.client_info = self.store.load_clients() self.client_info = self.store.load_clients(self.hostname)
for user_id, device_id in accounts: for user_id, device_id in accounts:
token = self.store.load_access_token(user_id, device_id) token = self.store.load_access_token(user_id, device_id)
@ -58,7 +61,7 @@ class ProxyDaemon:
logger.info(f"Restoring client for {user_id} {device_id}") logger.info(f"Restoring client for {user_id} {device_id}")
pan_client = PanClient( pan_client = PanClient(
self.homeserver, self.homeserver_url,
user_id, user_id,
device_id, device_id,
store_path=self.data_dir, store_path=self.data_dir,
@ -126,7 +129,7 @@ class ProxyDaemon:
return await session.request( return await session.request(
method, method,
self.homeserver + path, self.homeserver_url + path,
data=data, data=data,
params=params, params=params,
headers=headers, headers=headers,
@ -158,7 +161,8 @@ class ProxyDaemon:
async def start_pan_client(self, access_token, user, user_id, password): async def start_pan_client(self, access_token, user, user_id, password):
client = ClientInfo(user_id, access_token) client = ClientInfo(user_id, access_token)
self.client_info[access_token] = client self.client_info[access_token] = client
self.store.save_client(client) self.store.save_client(self.hostname, client)
self.store.save_server_user(self.hostname, user_id)
if user_id in self.pan_clients: if user_id in self.pan_clients:
logger.info(f"Background sync client already exists for {user_id}," logger.info(f"Background sync client already exists for {user_id},"
@ -166,7 +170,7 @@ class ProxyDaemon:
return return
pan_client = PanClient( pan_client = PanClient(
self.homeserver, self.homeserver_url,
user, user,
store_path=self.data_dir, store_path=self.data_dir,
ssl=self.ssl, ssl=self.ssl,
@ -479,7 +483,7 @@ def cli():
def _find_device(user): def _find_device(user):
data_dir = user_data_dir("pantalaimon", "") data_dir = user_data_dir("pantalaimon", "")
store = PanStore(data_dir) store = PanStore(data_dir)
accounts = store.get_users() accounts = store.load_all_users()
for user_id, device in accounts: for user_id, device in accounts:
if user == user_id: if user == user_id:
@ -549,7 +553,7 @@ def keys_export(user, outfile, passphrase):
def list_users(): def list_users():
data_dir = user_data_dir("pantalaimon", "") data_dir = user_data_dir("pantalaimon", "")
store = PanStore(data_dir) store = PanStore(data_dir)
accounts = store.get_users() accounts = store.load_all_users()
click.echo(f"Pantalaimon users:") click.echo(f"Pantalaimon users:")
for user, device in accounts: for user, device in accounts:
@ -617,7 +621,7 @@ def start(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
proxy, app = loop.run_until_complete(init( proxy, app = loop.run_until_complete(init(
homeserver.geturl(), homeserver,
proxy.geturl() if proxy else None, proxy.geturl() if proxy else None,
ssl ssl
)) ))

View File

@ -17,12 +17,38 @@ class AccessTokens(Model):
) )
class Servers(Model):
hostname = TextField()
class Meta:
constraints = [SQL("UNIQUE(hostname)")]
class ServerUsers(Model):
user_id = TextField()
server = ForeignKeyField(
model=Servers,
column_name="server_id",
backref="users",
on_delete="CASCADE"
)
class Meta:
constraints = [SQL("UNIQUE(user_id,server_id)")]
class Clients(Model): class Clients(Model):
user_id = TextField() user_id = TextField()
token = TextField() token = TextField()
server = ForeignKeyField(
model=Servers,
column_name="server_id",
backref="clients",
on_delete="CASCADE"
)
class Meta: class Meta:
constraints = [SQL("UNIQUE(user_id,token)")] constraints = [SQL("UNIQUE(user_id,token,server_id)")]
@attr.s @attr.s
@ -37,7 +63,7 @@ class PanStore:
database_name = attr.ib(type=str, default="pan.db") database_name = attr.ib(type=str, default="pan.db")
database = attr.ib(type=SqliteDatabase, init=False) database = attr.ib(type=SqliteDatabase, init=False)
database_path = attr.ib(type=str, init=False) database_path = attr.ib(type=str, init=False)
models = [Accounts, AccessTokens, Clients] models = [Accounts, AccessTokens, Clients, Servers, ServerUsers]
def __attrs_post_init__(self): def __attrs_post_init__(self):
self.database_path = os.path.join( self.database_path = os.path.join(
@ -71,8 +97,17 @@ class PanStore:
return None return None
@use_database @use_database
def get_users(self): def save_server_user(self, homeserver, user_id):
# type: () -> List[Tuple[str, str]] # type: (ClientInfo) -> None
server, _ = Servers.get_or_create(hostname=homeserver)
ServerUsers.replace(
user_id=user_id,
server=server
).execute()
@use_database
def load_all_users(self):
users = [] users = []
query = Accounts.select( query = Accounts.select(
@ -85,6 +120,31 @@ class PanStore:
return users return users
@use_database
def load_users(self, homeserver):
# type: () -> List[Tuple[str, str]]
users = []
server = Servers.get_or_none(Servers.hostname == homeserver)
if not server:
return []
server_users = []
for u in server.users:
server_users.append(u.user_id)
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
).where(Accounts.user_id.in_(server_users))
for account in query:
users.append((account.user_id, account.device_id))
return users
@use_database @use_database
def save_access_token(self, user_id, device_id, access_token): def save_access_token(self, user_id, device_id, access_token):
account = self._get_account(user_id, device_id) account = self._get_account(user_id, device_id)
@ -109,21 +169,24 @@ class PanStore:
return None return None
@use_database @use_database
def save_client(self, client): def save_client(self, homeserver, client):
# type: (ClientInfo) -> None # type: (ClientInfo) -> None
server, _ = Servers.get_or_create(hostname=homeserver)
Clients.replace( Clients.replace(
user_id=client.user_id, user_id=client.user_id,
token=client.access_token token=client.access_token,
server=server.id
).execute() ).execute()
@use_database @use_database
def load_clients(self): def load_clients(self, homeserver):
# type: () -> Dict[str, ClientInfo] # type: () -> Dict[str, ClientInfo]
clients = dict() clients = dict()
query = Clients.select() server, _ = Servers.get_or_create(hostname=homeserver)
for c in query: for c in server.clients:
client = ClientInfo(c.user_id, c.token) client = ClientInfo(c.user_id, c.token)
clients[c.token] = client clients[c.token] = client

View File

@ -68,12 +68,12 @@ def panstore(tempdir):
class TestClass(object): class TestClass(object):
def test_account_loading(self, panstore): def test_account_loading(self, panstore):
accounts = panstore.get_users() accounts = panstore.load_all_users()
# pdb.set_trace() # pdb.set_trace()
assert len(accounts) == 10 assert len(accounts) == 10
def test_token_saving(self, panstore, access_token): def test_token_saving(self, panstore, access_token):
accounts = panstore.get_users() accounts = panstore.load_all_users()
user_id = accounts[0][0] user_id = accounts[0][0]
device_id = accounts[0][1] device_id = accounts[0][1]
@ -83,15 +83,32 @@ class TestClass(object):
access_token == token access_token == token
def test_child_clinets_storing(self, panstore, client): def test_child_clinets_storing(self, panstore, client):
clients = panstore.load_clients() server = faker.hostname()
clients = panstore.load_clients(server)
assert not clients assert not clients
panstore.save_client(client) panstore.save_client(server, client)
clients = panstore.load_clients() clients = panstore.load_clients(server)
assert clients assert clients
client2 = faker.client() client2 = faker.client()
panstore.save_client(client2) panstore.save_client(server, client2)
clients = panstore.load_clients() clients = panstore.load_clients(server)
assert len(clients) == 2 assert len(clients) == 2
def test_server_account_storing(self, panstore):
accounts = panstore.load_all_users()
user_id, device_id = accounts[0]
server = faker.hostname()
panstore.save_server_user(server, user_id)
server2 = faker.hostname()
user_id2, device_id2 = accounts[1]
panstore.save_server_user(server2, user_id2)
server_users = panstore.load_users(server)
assert (user_id, device_id) in server_users
assert (user_id2, device_id2) not in server_users