2019-04-10 06:03:17 -04:00
|
|
|
import os
|
2019-04-11 10:59:37 -04:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2019-04-10 06:03:17 -04:00
|
|
|
|
2019-04-10 06:21:14 -04:00
|
|
|
import attr
|
2019-04-10 06:03:17 -04:00
|
|
|
from nio.store import Accounts, use_database
|
2019-04-11 10:59:37 -04:00
|
|
|
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
|
2019-04-10 06:21:14 -04:00
|
|
|
TextField)
|
2019-04-10 06:03:17 -04:00
|
|
|
|
|
|
|
|
|
|
|
class AccessTokens(Model):
|
|
|
|
token = TextField()
|
|
|
|
account = ForeignKeyField(
|
|
|
|
model=Accounts,
|
|
|
|
primary_key=True,
|
|
|
|
backref="access_token",
|
|
|
|
on_delete="CASCADE"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2019-04-12 11:59:30 -04:00
|
|
|
class Servers(Model):
|
|
|
|
hostname = TextField()
|
|
|
|
|
|
|
|
class Meta:
|
|
|
|
constraints = [SQL("UNIQUE(hostname)")]
|
|
|
|
|
|
|
|
|
|
|
|
class ServerUsers(Model):
|
|
|
|
user_id = TextField()
|
|
|
|
server = ForeignKeyField(
|
|
|
|
model=Servers,
|
|
|
|
column_name="server_id",
|
|
|
|
backref="users",
|
|
|
|
on_delete="CASCADE"
|
|
|
|
)
|
|
|
|
|
|
|
|
class Meta:
|
|
|
|
constraints = [SQL("UNIQUE(user_id,server_id)")]
|
|
|
|
|
|
|
|
|
2019-04-11 10:59:37 -04:00
|
|
|
class Clients(Model):
|
|
|
|
user_id = TextField()
|
|
|
|
token = TextField()
|
2019-04-12 11:59:30 -04:00
|
|
|
server = ForeignKeyField(
|
|
|
|
model=Servers,
|
|
|
|
column_name="server_id",
|
|
|
|
backref="clients",
|
|
|
|
on_delete="CASCADE"
|
|
|
|
)
|
2019-04-11 10:59:37 -04:00
|
|
|
|
|
|
|
class Meta:
|
2019-04-12 11:59:30 -04:00
|
|
|
constraints = [SQL("UNIQUE(user_id,token,server_id)")]
|
2019-04-11 10:59:37 -04:00
|
|
|
|
|
|
|
|
|
|
|
@attr.s
|
|
|
|
class ClientInfo:
|
|
|
|
user_id = attr.ib(type=str)
|
|
|
|
access_token = attr.ib(type=str)
|
|
|
|
|
|
|
|
|
2019-04-10 06:03:17 -04:00
|
|
|
@attr.s
|
|
|
|
class PanStore:
|
|
|
|
store_path = attr.ib(type=str)
|
|
|
|
database_name = attr.ib(type=str, default="pan.db")
|
|
|
|
database = attr.ib(type=SqliteDatabase, init=False)
|
|
|
|
database_path = attr.ib(type=str, init=False)
|
2019-04-12 11:59:30 -04:00
|
|
|
models = [Accounts, AccessTokens, Clients, Servers, ServerUsers]
|
2019-04-10 06:03:17 -04:00
|
|
|
|
|
|
|
def __attrs_post_init__(self):
|
|
|
|
self.database_path = os.path.join(
|
|
|
|
os.path.abspath(self.store_path),
|
|
|
|
self.database_name
|
|
|
|
)
|
|
|
|
|
|
|
|
self.database = self._create_database()
|
|
|
|
self.database.connect()
|
|
|
|
|
|
|
|
with self.database.bind_ctx(self.models):
|
|
|
|
self.database.create_tables(self.models)
|
|
|
|
|
|
|
|
def _create_database(self):
|
|
|
|
return SqliteDatabase(
|
|
|
|
self.database_path,
|
|
|
|
pragmas={
|
|
|
|
"foreign_keys": 1,
|
|
|
|
"secure_delete": 1,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
@use_database
|
|
|
|
def _get_account(self, user_id, device_id):
|
|
|
|
try:
|
|
|
|
return Accounts.get(
|
|
|
|
Accounts.user_id == user_id,
|
|
|
|
Accounts.device_id == device_id,
|
|
|
|
)
|
|
|
|
except DoesNotExist:
|
|
|
|
return None
|
|
|
|
|
|
|
|
@use_database
|
2019-04-12 11:59:30 -04:00
|
|
|
def save_server_user(self, homeserver, user_id):
|
|
|
|
# type: (ClientInfo) -> None
|
|
|
|
server, _ = Servers.get_or_create(hostname=homeserver)
|
|
|
|
|
|
|
|
ServerUsers.replace(
|
|
|
|
user_id=user_id,
|
|
|
|
server=server
|
|
|
|
).execute()
|
|
|
|
|
|
|
|
@use_database
|
|
|
|
def load_all_users(self):
|
2019-04-10 06:03:17 -04:00
|
|
|
users = []
|
|
|
|
|
|
|
|
query = Accounts.select(
|
|
|
|
Accounts.user_id,
|
|
|
|
Accounts.device_id,
|
|
|
|
)
|
|
|
|
|
|
|
|
for account in query:
|
|
|
|
users.append((account.user_id, account.device_id))
|
|
|
|
|
|
|
|
return users
|
|
|
|
|
2019-04-12 11:59:30 -04:00
|
|
|
@use_database
|
|
|
|
def load_users(self, homeserver):
|
|
|
|
# type: () -> List[Tuple[str, str]]
|
|
|
|
users = []
|
|
|
|
|
|
|
|
server = Servers.get_or_none(Servers.hostname == homeserver)
|
|
|
|
|
|
|
|
if not server:
|
|
|
|
return []
|
|
|
|
|
|
|
|
server_users = []
|
|
|
|
|
|
|
|
for u in server.users:
|
|
|
|
server_users.append(u.user_id)
|
|
|
|
|
|
|
|
query = Accounts.select(
|
|
|
|
Accounts.user_id,
|
|
|
|
Accounts.device_id,
|
|
|
|
).where(Accounts.user_id.in_(server_users))
|
|
|
|
|
|
|
|
for account in query:
|
|
|
|
users.append((account.user_id, account.device_id))
|
|
|
|
|
|
|
|
return users
|
|
|
|
|
2019-04-10 06:03:17 -04:00
|
|
|
@use_database
|
|
|
|
def save_access_token(self, user_id, device_id, access_token):
|
|
|
|
account = self._get_account(user_id, device_id)
|
|
|
|
assert account
|
|
|
|
|
|
|
|
AccessTokens.replace(
|
|
|
|
account=account,
|
|
|
|
token=access_token
|
|
|
|
).execute()
|
|
|
|
|
|
|
|
@use_database
|
|
|
|
def load_access_token(self, user_id, device_id):
|
|
|
|
# type: (str, str) -> Optional[str]
|
|
|
|
account = self._get_account(user_id, device_id)
|
|
|
|
|
|
|
|
if not account:
|
|
|
|
return None
|
|
|
|
|
|
|
|
try:
|
|
|
|
return account.access_token[0].token
|
|
|
|
except IndexError:
|
|
|
|
return None
|
2019-04-11 10:59:37 -04:00
|
|
|
|
|
|
|
@use_database
|
2019-04-12 11:59:30 -04:00
|
|
|
def save_client(self, homeserver, client):
|
2019-04-11 10:59:37 -04:00
|
|
|
# type: (ClientInfo) -> None
|
2019-04-12 11:59:30 -04:00
|
|
|
server, _ = Servers.get_or_create(hostname=homeserver)
|
|
|
|
|
2019-04-11 10:59:37 -04:00
|
|
|
Clients.replace(
|
|
|
|
user_id=client.user_id,
|
2019-04-12 11:59:30 -04:00
|
|
|
token=client.access_token,
|
|
|
|
server=server.id
|
2019-04-11 10:59:37 -04:00
|
|
|
).execute()
|
|
|
|
|
|
|
|
@use_database
|
2019-04-12 11:59:30 -04:00
|
|
|
def load_clients(self, homeserver):
|
2019-04-11 10:59:37 -04:00
|
|
|
# type: () -> Dict[str, ClientInfo]
|
|
|
|
clients = dict()
|
|
|
|
|
2019-04-12 11:59:30 -04:00
|
|
|
server, _ = Servers.get_or_create(hostname=homeserver)
|
2019-04-11 10:59:37 -04:00
|
|
|
|
2019-04-12 11:59:30 -04:00
|
|
|
for c in server.clients:
|
2019-04-11 10:59:37 -04:00
|
|
|
client = ClientInfo(c.user_id, c.token)
|
|
|
|
clients[c.token] = client
|
|
|
|
|
|
|
|
return clients
|