store: Use the server name to store info.

This way multiple proxy servers that point to the same homeserver are
allowed.
This commit is contained in:
Damir Jelić 2019-05-17 16:49:02 +02:00
parent e0b5d3e2b6
commit 4f90e578ef
3 changed files with 15 additions and 17 deletions

View File

@ -56,9 +56,9 @@ class ProxyDaemon:
self.homeserver_url = self.homeserver.geturl() self.homeserver_url = self.homeserver.geturl()
self.hostname = self.homeserver.hostname self.hostname = self.homeserver.hostname
self.store = PanStore(self.data_dir) self.store = PanStore(self.data_dir)
accounts = self.store.load_users(self.hostname) accounts = self.store.load_users(self.name)
self.client_info = self.store.load_clients(self.hostname) self.client_info = self.store.load_clients(self.name)
for user_id, device_id in accounts: for user_id, device_id in accounts:
token = keyring.get_password( token = keyring.get_password(
@ -421,8 +421,8 @@ class ProxyDaemon:
async def start_pan_client(self, access_token, user, user_id, password): async def start_pan_client(self, access_token, user, user_id, password):
client = ClientInfo(user_id, access_token) client = ClientInfo(user_id, access_token)
self.client_info[access_token] = client self.client_info[access_token] = client
self.store.save_client(self.hostname, client) self.store.save_client(self.name, client)
self.store.save_server_user(self.hostname, user_id) self.store.save_server_user(self.name, user_id)
if user_id in self.pan_clients: if user_id in self.pan_clients:
logger.info(f"Background sync client already exists for {user_id}," logger.info(f"Background sync client already exists for {user_id},"

View File

@ -20,10 +20,10 @@ class AccessTokens(Model):
class Servers(Model): class Servers(Model):
hostname = TextField() name = TextField()
class Meta: class Meta:
constraints = [SQL("UNIQUE(hostname)")] constraints = [SQL("UNIQUE(name)")]
class ServerUsers(Model): class ServerUsers(Model):
@ -116,9 +116,9 @@ class PanStore:
return None return None
@use_database @use_database
def save_server_user(self, homeserver, user_id): def save_server_user(self, server_name, user_id):
# type: (ClientInfo) -> None # type: (ClientInfo) -> None
server, _ = Servers.get_or_create(hostname=homeserver) server, _ = Servers.get_or_create(name=server_name)
ServerUsers.replace( ServerUsers.replace(
user_id=user_id, user_id=user_id,
@ -140,11 +140,11 @@ class PanStore:
return users return users
@use_database @use_database
def load_users(self, homeserver): def load_users(self, server_name):
# type: () -> List[Tuple[str, str]] # type: () -> List[Tuple[str, str]]
users = [] users = []
server = Servers.get_or_none(Servers.hostname == homeserver) server = Servers.get_or_none(Servers.name == server_name)
if not server: if not server:
return [] return []
@ -188,9 +188,9 @@ class PanStore:
return None return None
@use_database @use_database
def save_client(self, homeserver, client): def save_client(self, server_name, client):
# type: (ClientInfo) -> None # type: (ClientInfo) -> None
server, _ = Servers.get_or_create(hostname=homeserver) server, _ = Servers.get_or_create(name=server_name)
Clients.replace( Clients.replace(
user_id=client.user_id, user_id=client.user_id,
@ -199,11 +199,11 @@ class PanStore:
).execute() ).execute()
@use_database @use_database
def load_clients(self, homeserver): def load_clients(self, server_name):
# type: () -> Dict[str, ClientInfo] # type: () -> Dict[str, ClientInfo]
clients = dict() clients = dict()
server, _ = Servers.get_or_create(hostname=homeserver) server, _ = Servers.get_or_create(name=server_name)
for c in server.clients: for c in server.clients:
client = ClientInfo(c.user_id, c.token) client = ClientInfo(c.user_id, c.token)

View File

@ -76,9 +76,7 @@ class Control:
def update_users(self): def update_users(self):
for server in self.server_list: for server in self.server_list:
self.users[server.name] = self.store.load_users( self.users[server.name] = self.store.load_users(server.name)
server.homeserver.hostname
)
@property @property
def message_id(self): def message_id(self):