pantalaimon/pantalaimon/store.py
2019-05-09 13:47:29 +02:00

245 lines
5.9 KiB
Python

import os
from collections import defaultdict
from typing import Dict, List, Optional, Tuple
import attr
from nio.store import (Accounts, DeviceKeys, DeviceTrustState, TrustState,
use_database)
from peewee import (SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase,
TextField)
class AccessTokens(Model):
token = TextField()
account = ForeignKeyField(
model=Accounts,
primary_key=True,
backref="access_token",
on_delete="CASCADE"
)
class Servers(Model):
hostname = TextField()
class Meta:
constraints = [SQL("UNIQUE(hostname)")]
class ServerUsers(Model):
user_id = TextField()
server = ForeignKeyField(
model=Servers,
column_name="server_id",
backref="users",
on_delete="CASCADE"
)
class Meta:
constraints = [SQL("UNIQUE(user_id,server_id)")]
class Clients(Model):
user_id = TextField()
token = TextField()
server = ForeignKeyField(
model=Servers,
column_name="server_id",
backref="clients",
on_delete="CASCADE"
)
class Meta:
constraints = [SQL("UNIQUE(user_id,token,server_id)")]
@attr.s
class ClientInfo:
user_id = attr.ib(type=str)
access_token = attr.ib(type=str)
@attr.s
class OlmDevice:
user_id = attr.ib()
id = attr.ib()
fp_key = attr.ib()
sender_key = attr.ib()
trust_state = attr.ib()
@attr.s
class PanStore:
store_path = attr.ib(type=str)
database_name = attr.ib(type=str, default="pan.db")
database = attr.ib(type=SqliteDatabase, init=False)
database_path = attr.ib(type=str, init=False)
models = [
Accounts,
AccessTokens,
Clients,
Servers,
ServerUsers,
DeviceKeys,
DeviceTrustState,
]
def __attrs_post_init__(self):
self.database_path = os.path.join(
os.path.abspath(self.store_path),
self.database_name
)
self.database = self._create_database()
self.database.connect()
with self.database.bind_ctx(self.models):
self.database.create_tables(self.models)
def _create_database(self):
return SqliteDatabase(
self.database_path,
pragmas={
"foreign_keys": 1,
"secure_delete": 1,
}
)
@use_database
def _get_account(self, user_id, device_id):
try:
return Accounts.get(
Accounts.user_id == user_id,
Accounts.device_id == device_id,
)
except DoesNotExist:
return None
@use_database
def save_server_user(self, homeserver, user_id):
# type: (ClientInfo) -> None
server, _ = Servers.get_or_create(hostname=homeserver)
ServerUsers.replace(
user_id=user_id,
server=server
).execute()
@use_database
def load_all_users(self):
users = []
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
)
for account in query:
users.append((account.user_id, account.device_id))
return users
@use_database
def load_users(self, homeserver):
# type: () -> List[Tuple[str, str]]
users = []
server = Servers.get_or_none(Servers.hostname == homeserver)
if not server:
return []
server_users = []
for u in server.users:
server_users.append(u.user_id)
query = Accounts.select(
Accounts.user_id,
Accounts.device_id,
).where(Accounts.user_id.in_(server_users))
for account in query:
users.append((account.user_id, account.device_id))
return users
@use_database
def save_access_token(self, user_id, device_id, access_token):
account = self._get_account(user_id, device_id)
assert account
AccessTokens.replace(
account=account,
token=access_token
).execute()
@use_database
def load_access_token(self, user_id, device_id):
# type: (str, str) -> Optional[str]
account = self._get_account(user_id, device_id)
if not account:
return None
try:
return account.access_token[0].token
except IndexError:
return None
@use_database
def save_client(self, homeserver, client):
# type: (ClientInfo) -> None
server, _ = Servers.get_or_create(hostname=homeserver)
Clients.replace(
user_id=client.user_id,
token=client.access_token,
server=server.id
).execute()
@use_database
def load_clients(self, homeserver):
# type: () -> Dict[str, ClientInfo]
clients = dict()
server, _ = Servers.get_or_create(hostname=homeserver)
for c in server.clients:
client = ClientInfo(c.user_id, c.token)
clients[c.token] = client
return clients
@use_database
def load_all_devices(self):
# type (str, str) -> Dict[str, Dict[str, DeviceStore]]
store = dict()
query = Accounts.select()
for account in query:
device_store = defaultdict(dict)
for d in account.device_keys:
if d.deleted:
continue
try:
trust_state = d.trust_state[0].state
except IndexError:
trust_state = TrustState.unset
device_store[d.user_id][d.device_id] = {
"user_id": d.user_id,
"device_id": d.device_id,
"fingerprint_key": d.fp_key,
"sender_key": d.sender_key,
"trust_state": trust_state.name
}
store[account.user_id] = device_store
return store