mirror of
https://github.com/matrix-org/pantalaimon.git
synced 2025-01-24 06:11:16 -05:00
store: The daemon clients need to be per homeserver.
This commit is contained in:
parent
d1090a714a
commit
ea33359daa
@ -32,6 +32,7 @@ class ProxyDaemon:
|
|||||||
ssl = attr.ib(default=None)
|
ssl = attr.ib(default=None)
|
||||||
|
|
||||||
store = attr.ib(type=PanStore, init=False)
|
store = attr.ib(type=PanStore, init=False)
|
||||||
|
homeserver_url = attr.ib(init=False, default=attr.Factory(dict))
|
||||||
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
|
pan_clients = attr.ib(init=False, default=attr.Factory(dict))
|
||||||
client_info = attr.ib(
|
client_info = attr.ib(
|
||||||
init=False,
|
init=False,
|
||||||
@ -42,10 +43,12 @@ class ProxyDaemon:
|
|||||||
database_name = "pan.db"
|
database_name = "pan.db"
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
|
self.homeserver_url = self.homeserver.geturl()
|
||||||
|
self.hostname = self.homeserver.hostname
|
||||||
self.store = PanStore(self.data_dir)
|
self.store = PanStore(self.data_dir)
|
||||||
accounts = self.store.get_users()
|
accounts = self.store.load_users(self.hostname)
|
||||||
|
|
||||||
self.client_info = self.store.load_clients()
|
self.client_info = self.store.load_clients(self.hostname)
|
||||||
|
|
||||||
for user_id, device_id in accounts:
|
for user_id, device_id in accounts:
|
||||||
token = self.store.load_access_token(user_id, device_id)
|
token = self.store.load_access_token(user_id, device_id)
|
||||||
@ -58,7 +61,7 @@ class ProxyDaemon:
|
|||||||
logger.info(f"Restoring client for {user_id} {device_id}")
|
logger.info(f"Restoring client for {user_id} {device_id}")
|
||||||
|
|
||||||
pan_client = PanClient(
|
pan_client = PanClient(
|
||||||
self.homeserver,
|
self.homeserver_url,
|
||||||
user_id,
|
user_id,
|
||||||
device_id,
|
device_id,
|
||||||
store_path=self.data_dir,
|
store_path=self.data_dir,
|
||||||
@ -126,7 +129,7 @@ class ProxyDaemon:
|
|||||||
|
|
||||||
return await session.request(
|
return await session.request(
|
||||||
method,
|
method,
|
||||||
self.homeserver + path,
|
self.homeserver_url + path,
|
||||||
data=data,
|
data=data,
|
||||||
params=params,
|
params=params,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
@ -158,7 +161,8 @@ class ProxyDaemon:
|
|||||||
async def start_pan_client(self, access_token, user, user_id, password):
|
async def start_pan_client(self, access_token, user, user_id, password):
|
||||||
client = ClientInfo(user_id, access_token)
|
client = ClientInfo(user_id, access_token)
|
||||||
self.client_info[access_token] = client
|
self.client_info[access_token] = client
|
||||||
self.store.save_client(client)
|
self.store.save_client(self.hostname, client)
|
||||||
|
self.store.save_server_user(self.hostname, user_id)
|
||||||
|
|
||||||
if user_id in self.pan_clients:
|
if user_id in self.pan_clients:
|
||||||
logger.info(f"Background sync client already exists for {user_id},"
|
logger.info(f"Background sync client already exists for {user_id},"
|
||||||
@ -166,7 +170,7 @@ class ProxyDaemon:
|
|||||||
return
|
return
|
||||||
|
|
||||||
pan_client = PanClient(
|
pan_client = PanClient(
|
||||||
self.homeserver,
|
self.homeserver_url,
|
||||||
user,
|
user,
|
||||||
store_path=self.data_dir,
|
store_path=self.data_dir,
|
||||||
ssl=self.ssl,
|
ssl=self.ssl,
|
||||||
@ -479,7 +483,7 @@ def cli():
|
|||||||
def _find_device(user):
|
def _find_device(user):
|
||||||
data_dir = user_data_dir("pantalaimon", "")
|
data_dir = user_data_dir("pantalaimon", "")
|
||||||
store = PanStore(data_dir)
|
store = PanStore(data_dir)
|
||||||
accounts = store.get_users()
|
accounts = store.load_all_users()
|
||||||
|
|
||||||
for user_id, device in accounts:
|
for user_id, device in accounts:
|
||||||
if user == user_id:
|
if user == user_id:
|
||||||
@ -549,7 +553,7 @@ def keys_export(user, outfile, passphrase):
|
|||||||
def list_users():
|
def list_users():
|
||||||
data_dir = user_data_dir("pantalaimon", "")
|
data_dir = user_data_dir("pantalaimon", "")
|
||||||
store = PanStore(data_dir)
|
store = PanStore(data_dir)
|
||||||
accounts = store.get_users()
|
accounts = store.load_all_users()
|
||||||
|
|
||||||
click.echo(f"Pantalaimon users:")
|
click.echo(f"Pantalaimon users:")
|
||||||
for user, device in accounts:
|
for user, device in accounts:
|
||||||
@ -617,7 +621,7 @@ def start(
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
proxy, app = loop.run_until_complete(init(
|
proxy, app = loop.run_until_complete(init(
|
||||||
homeserver.geturl(),
|
homeserver,
|
||||||
proxy.geturl() if proxy else None,
|
proxy.geturl() if proxy else None,
|
||||||
ssl
|
ssl
|
||||||
))
|
))
|
||||||
|
@ -17,12 +17,38 @@ class AccessTokens(Model):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Servers(Model):
|
||||||
|
hostname = TextField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(hostname)")]
|
||||||
|
|
||||||
|
|
||||||
|
class ServerUsers(Model):
|
||||||
|
user_id = TextField()
|
||||||
|
server = ForeignKeyField(
|
||||||
|
model=Servers,
|
||||||
|
column_name="server_id",
|
||||||
|
backref="users",
|
||||||
|
on_delete="CASCADE"
|
||||||
|
)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
constraints = [SQL("UNIQUE(user_id,server_id)")]
|
||||||
|
|
||||||
|
|
||||||
class Clients(Model):
|
class Clients(Model):
|
||||||
user_id = TextField()
|
user_id = TextField()
|
||||||
token = TextField()
|
token = TextField()
|
||||||
|
server = ForeignKeyField(
|
||||||
|
model=Servers,
|
||||||
|
column_name="server_id",
|
||||||
|
backref="clients",
|
||||||
|
on_delete="CASCADE"
|
||||||
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
constraints = [SQL("UNIQUE(user_id,token)")]
|
constraints = [SQL("UNIQUE(user_id,token,server_id)")]
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
@attr.s
|
||||||
@ -37,7 +63,7 @@ class PanStore:
|
|||||||
database_name = attr.ib(type=str, default="pan.db")
|
database_name = attr.ib(type=str, default="pan.db")
|
||||||
database = attr.ib(type=SqliteDatabase, init=False)
|
database = attr.ib(type=SqliteDatabase, init=False)
|
||||||
database_path = attr.ib(type=str, init=False)
|
database_path = attr.ib(type=str, init=False)
|
||||||
models = [Accounts, AccessTokens, Clients]
|
models = [Accounts, AccessTokens, Clients, Servers, ServerUsers]
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
self.database_path = os.path.join(
|
self.database_path = os.path.join(
|
||||||
@ -71,8 +97,17 @@ class PanStore:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def get_users(self):
|
def save_server_user(self, homeserver, user_id):
|
||||||
# type: () -> List[Tuple[str, str]]
|
# type: (ClientInfo) -> None
|
||||||
|
server, _ = Servers.get_or_create(hostname=homeserver)
|
||||||
|
|
||||||
|
ServerUsers.replace(
|
||||||
|
user_id=user_id,
|
||||||
|
server=server
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def load_all_users(self):
|
||||||
users = []
|
users = []
|
||||||
|
|
||||||
query = Accounts.select(
|
query = Accounts.select(
|
||||||
@ -85,6 +120,31 @@ class PanStore:
|
|||||||
|
|
||||||
return users
|
return users
|
||||||
|
|
||||||
|
@use_database
|
||||||
|
def load_users(self, homeserver):
|
||||||
|
# type: () -> List[Tuple[str, str]]
|
||||||
|
users = []
|
||||||
|
|
||||||
|
server = Servers.get_or_none(Servers.hostname == homeserver)
|
||||||
|
|
||||||
|
if not server:
|
||||||
|
return []
|
||||||
|
|
||||||
|
server_users = []
|
||||||
|
|
||||||
|
for u in server.users:
|
||||||
|
server_users.append(u.user_id)
|
||||||
|
|
||||||
|
query = Accounts.select(
|
||||||
|
Accounts.user_id,
|
||||||
|
Accounts.device_id,
|
||||||
|
).where(Accounts.user_id.in_(server_users))
|
||||||
|
|
||||||
|
for account in query:
|
||||||
|
users.append((account.user_id, account.device_id))
|
||||||
|
|
||||||
|
return users
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def save_access_token(self, user_id, device_id, access_token):
|
def save_access_token(self, user_id, device_id, access_token):
|
||||||
account = self._get_account(user_id, device_id)
|
account = self._get_account(user_id, device_id)
|
||||||
@ -109,21 +169,24 @@ class PanStore:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def save_client(self, client):
|
def save_client(self, homeserver, client):
|
||||||
# type: (ClientInfo) -> None
|
# type: (ClientInfo) -> None
|
||||||
|
server, _ = Servers.get_or_create(hostname=homeserver)
|
||||||
|
|
||||||
Clients.replace(
|
Clients.replace(
|
||||||
user_id=client.user_id,
|
user_id=client.user_id,
|
||||||
token=client.access_token
|
token=client.access_token,
|
||||||
|
server=server.id
|
||||||
).execute()
|
).execute()
|
||||||
|
|
||||||
@use_database
|
@use_database
|
||||||
def load_clients(self):
|
def load_clients(self, homeserver):
|
||||||
# type: () -> Dict[str, ClientInfo]
|
# type: () -> Dict[str, ClientInfo]
|
||||||
clients = dict()
|
clients = dict()
|
||||||
|
|
||||||
query = Clients.select()
|
server, _ = Servers.get_or_create(hostname=homeserver)
|
||||||
|
|
||||||
for c in query:
|
for c in server.clients:
|
||||||
client = ClientInfo(c.user_id, c.token)
|
client = ClientInfo(c.user_id, c.token)
|
||||||
clients[c.token] = client
|
clients[c.token] = client
|
||||||
|
|
||||||
|
@ -68,12 +68,12 @@ def panstore(tempdir):
|
|||||||
|
|
||||||
class TestClass(object):
|
class TestClass(object):
|
||||||
def test_account_loading(self, panstore):
|
def test_account_loading(self, panstore):
|
||||||
accounts = panstore.get_users()
|
accounts = panstore.load_all_users()
|
||||||
# pdb.set_trace()
|
# pdb.set_trace()
|
||||||
assert len(accounts) == 10
|
assert len(accounts) == 10
|
||||||
|
|
||||||
def test_token_saving(self, panstore, access_token):
|
def test_token_saving(self, panstore, access_token):
|
||||||
accounts = panstore.get_users()
|
accounts = panstore.load_all_users()
|
||||||
user_id = accounts[0][0]
|
user_id = accounts[0][0]
|
||||||
device_id = accounts[0][1]
|
device_id = accounts[0][1]
|
||||||
|
|
||||||
@ -83,15 +83,32 @@ class TestClass(object):
|
|||||||
access_token == token
|
access_token == token
|
||||||
|
|
||||||
def test_child_clinets_storing(self, panstore, client):
|
def test_child_clinets_storing(self, panstore, client):
|
||||||
clients = panstore.load_clients()
|
server = faker.hostname()
|
||||||
|
clients = panstore.load_clients(server)
|
||||||
assert not clients
|
assert not clients
|
||||||
|
|
||||||
panstore.save_client(client)
|
panstore.save_client(server, client)
|
||||||
|
|
||||||
clients = panstore.load_clients()
|
clients = panstore.load_clients(server)
|
||||||
assert clients
|
assert clients
|
||||||
|
|
||||||
client2 = faker.client()
|
client2 = faker.client()
|
||||||
panstore.save_client(client2)
|
panstore.save_client(server, client2)
|
||||||
clients = panstore.load_clients()
|
clients = panstore.load_clients(server)
|
||||||
assert len(clients) == 2
|
assert len(clients) == 2
|
||||||
|
|
||||||
|
def test_server_account_storing(self, panstore):
|
||||||
|
accounts = panstore.load_all_users()
|
||||||
|
|
||||||
|
user_id, device_id = accounts[0]
|
||||||
|
server = faker.hostname()
|
||||||
|
|
||||||
|
panstore.save_server_user(server, user_id)
|
||||||
|
|
||||||
|
server2 = faker.hostname()
|
||||||
|
user_id2, device_id2 = accounts[1]
|
||||||
|
panstore.save_server_user(server2, user_id2)
|
||||||
|
|
||||||
|
server_users = panstore.load_users(server)
|
||||||
|
assert (user_id, device_id) in server_users
|
||||||
|
assert (user_id2, device_id2) not in server_users
|
||||||
|
Loading…
Reference in New Issue
Block a user