ui: Don't load the devices from the store.

The device list in the UI thread is replicated so we can show UI clients
the list without the need for a lock.

The previous implementation relied on loading and reloading of the device
list from the store every time an event changed either the devices or their
trust state.

This leads to a couple of ineficiencies leading to timeouts while
waiting on the database lock if a user has a large number of devices.

The new implementation never loads devices in the UI thread from the
database, they get passed through the thread queue by the main thread
which already holds them in memory.
This commit is contained in:
Damir Jelić 2019-07-01 16:44:39 +02:00
parent 9308dfec3f
commit f2415738f3
7 changed files with 127 additions and 24 deletions

View File

@ -209,10 +209,30 @@ class PanClient(AsyncClient):
"""Send a thread message to the UI thread.""" """Send a thread message to the UI thread."""
await self.queue.put(message) await self.queue.put(message)
async def send_update_devcies(self): async def send_update_devices(self, devices):
message = UpdateDevicesMessage() """Send a dictionary of devices to the UI thread."""
dict_devices = defaultdict(dict)
for user_devices in devices.values():
for device in user_devices.values():
# Turn the OlmDevice type into a dictionary, flatten the
# keys dict and remove the deleted key/value.
# Since all the keys and values are strings this also
# copies them making it thread safe.
device_dict = device.as_dict()
device_dict = {**device_dict, **device_dict["keys"]}
device_dict.pop("keys")
display_name = device_dict.pop("display_name")
device_dict["device_display_name"] = display_name
dict_devices[device.user_id][device.id] = device_dict
message = UpdateDevicesMessage(self.user_id, dict_devices)
await self.queue.put(message) await self.queue.put(message)
async def send_update_device(self, device):
"""Send a single device to the UI thread to be updated."""
await self.send_update_devices({device.user_id: {device.id: device}})
def delete_fetcher_task(self, task): def delete_fetcher_task(self, task):
self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task) self.pan_store.delete_fetcher_task(self.server_name, self.user_id, task)
@ -314,7 +334,7 @@ class PanClient(AsyncClient):
await self.history_fetch_queue.put(task) await self.history_fetch_queue.put(task)
async def keys_query_cb(self, response): async def keys_query_cb(self, response):
await self.send_update_devcies() await self.send_update_devices(response.changed)
def undecrypted_event_cb(self, room, event): def undecrypted_event_cb(self, room, event):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -377,7 +397,7 @@ class PanClient(AsyncClient):
) )
) )
self.key_verificatins_tasks.append(task) self.key_verificatins_tasks.append(task)
task = loop.create_task(self.send_update_devcies()) task = loop.create_task(self.send_update_device(device))
self.key_verificatins_tasks.append(task) self.key_verificatins_tasks.append(task)
def start_loop(self): def start_loop(self):
@ -532,7 +552,7 @@ class PanClient(AsyncClient):
device = sas.other_olm_device device = sas.other_olm_device
if sas.verified: if sas.verified:
await self.send_update_devcies() await self.send_update_device(device)
await self.send_message( await self.send_message(
SasDoneSignal( SasDoneSignal(
self.user_id, device.user_id, device.id, sas.transaction_id self.user_id, device.user_id, device.id, sas.transaction_id

View File

@ -55,7 +55,6 @@ from pantalaimon.thread_messages import (
StartSasMessage, StartSasMessage,
UnverifiedDevicesSignal, UnverifiedDevicesSignal,
UnverifiedResponse, UnverifiedResponse,
UpdateDevicesMessage,
UpdateUsersMessage, UpdateUsersMessage,
) )
@ -138,6 +137,8 @@ class ProxyDaemon:
) )
) )
loop.create_task(pan_client.send_update_devices(pan_client.device_store))
pan_client.start_loop() pan_client.start_loop()
async def _find_client(self, access_token): async def _find_client(self, access_token):
@ -194,7 +195,7 @@ class ProxyDaemon:
msg = ( msg = (
f"Device {device.id} of user " f"{device.user_id} succesfully verified." f"Device {device.id} of user " f"{device.user_id} succesfully verified."
) )
await self.send_update_devcies() await client.send_update_device(device)
else: else:
msg = f"Device {device.id} of user " f"{device.user_id} already verified." msg = f"Device {device.id} of user " f"{device.user_id} already verified."
@ -209,7 +210,7 @@ class ProxyDaemon:
f"Device {device.id} of user " f"Device {device.id} of user "
f"{device.user_id} succesfully unverified." f"{device.user_id} succesfully unverified."
) )
await self.send_update_devcies() await client.send_update_device(device)
else: else:
msg = f"Device {device.id} of user " f"{device.user_id} already unverified." msg = f"Device {device.id} of user " f"{device.user_id} already unverified."
@ -224,7 +225,7 @@ class ProxyDaemon:
f"Device {device.id} of user " f"Device {device.id} of user "
f"{device.user_id} succesfully blacklisted." f"{device.user_id} succesfully blacklisted."
) )
await self.send_update_devcies() await client.send_update_device(device)
else: else:
msg = ( msg = (
f"Device {device.id} of user " f"{device.user_id} already blacklisted." f"Device {device.id} of user " f"{device.user_id} already blacklisted."
@ -241,7 +242,7 @@ class ProxyDaemon:
f"Device {device.id} of user " f"Device {device.id} of user "
f"{device.user_id} succesfully unblacklisted." f"{device.user_id} succesfully unblacklisted."
) )
await self.send_update_devcies() await client.send_update_device(device)
else: else:
msg = ( msg = (
f"Device {device.id} of user " f"Device {device.id} of user "
@ -256,10 +257,6 @@ class ProxyDaemon:
message = DaemonResponse(message_id, pan_user, code, message) message = DaemonResponse(message_id, pan_user, code, message)
await self.send_queue.put(message) await self.send_queue.put(message)
async def send_update_devcies(self):
message = UpdateDevicesMessage()
await self.send_queue.put(message)
async def receive_message(self, message): async def receive_message(self, message):
client = self.pan_clients.get(message.pan_user) client = self.pan_clients.get(message.pan_user)
@ -864,7 +861,8 @@ class ProxyDaemon:
) )
ret = await _send(True) ret = await _send(True)
await self.send_update_devcies() # TODO send all the devices of a room to be updated
# await client.send_update_devices()
return ret return ret
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@ -18,7 +18,8 @@ from collections import defaultdict
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import attr import attr
from nio.store import Accounts, DeviceKeys, DeviceTrustState, TrustState, use_database from nio.crypto import TrustState
from nio.store import Accounts, DeviceKeys, DeviceTrustState, use_database
from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField from peewee import SQL, DoesNotExist, ForeignKeyField, Model, SqliteDatabase, TextField

View File

@ -61,7 +61,8 @@ class UpdateUsersMessage(Message):
@attr.s @attr.s
class UpdateDevicesMessage(Message): class UpdateDevicesMessage(Message):
pass pan_user = attr.ib(type=str)
devices = attr.ib(type=dict)
@attr.s @attr.s

View File

@ -14,6 +14,7 @@
from collections import defaultdict from collections import defaultdict
from queue import Empty from queue import Empty
from collections import defaultdict
import attr import attr
import dbus import dbus
@ -264,10 +265,9 @@ class Devices:
def __init__(self, queue, store, id_counter): def __init__(self, queue, store, id_counter):
self.store = store self.store = store
self.device_list = None self.device_list = dict()
self.queue = queue self.queue = queue
self.id_counter = id_counter self.id_counter = id_counter
self.update_devices()
@property @property
def message_id(self): def message_id(self):
@ -342,8 +342,22 @@ class Devices:
self.queue.put(message) self.queue.put(message)
return message.message_id return message.message_id
def update_devices(self): def update_devices(self, message):
self.device_list = self.store.load_all_devices() if message.pan_user not in self.device_list:
self.device_list[message.pan_user] = defaultdict(dict)
device_list = self.device_list.get(message.pan_user)
for user_devices in message.devices.values():
for device in user_devices.values():
if device["deleted"]:
try:
device_list[device["user_id"]].pop(device["device_id"])
except KeyError:
pass
device.pop("deleted")
device_list[device["user_id"]][device["device_id"]] = device
@attr.s @attr.s
@ -483,7 +497,7 @@ class GlibT:
logger.debug(f"UI loop received message {message}") logger.debug(f"UI loop received message {message}")
if isinstance(message, UpdateDevicesMessage): if isinstance(message, UpdateDevicesMessage):
self.device_if.update_devices() self.device_if.update_devices(message)
elif isinstance(message, UpdateUsersMessage): elif isinstance(message, UpdateUsersMessage):
self.control_if.update_users(message) self.control_if.update_users(message)

View File

@ -11,7 +11,7 @@ from aiohttp import web
from aioresponses import aioresponses from aioresponses import aioresponses
from faker import Faker from faker import Faker
from faker.providers import BaseProvider from faker.providers import BaseProvider
from nio.crypto import OlmAccount from nio.crypto import OlmAccount, OlmDevice
from nio.store import SqliteStore from nio.store import SqliteStore
from pantalaimon.config import ServerConfig from pantalaimon.config import ServerConfig
@ -35,6 +35,28 @@ class Provider(BaseProvider):
return ClientInfo(faker.mx_id(), faker.access_token()) return ClientInfo(faker.mx_id(), faker.access_token())
def avatar_url(self):
return "mxc://{}/{}#auto".format(
faker.hostname(),
"".join(choice(ascii_letters) for i in range(24))
)
def olm_key_pair(self):
return OlmAccount().identity_keys
def olm_device(self):
user_id = faker.mx_id()
device_id = faker.device_id()
key_pair = faker.olm_key_pair()
return OlmDevice(
user_id,
device_id,
key_pair,
)
faker.add_provider(Provider) faker.add_provider(Provider)

View File

@ -1,11 +1,18 @@
import asyncio import asyncio
import json import json
import re import re
from collections import defaultdict
from aiohttp import web from aiohttp import web
from nio.crypto import OlmDevice
from conftest import faker from conftest import faker
from pantalaimon.thread_messages import UpdateUsersMessage from pantalaimon.thread_messages import UpdateDevicesMessage, UpdateUsersMessage
BOB_ID = "@bob:example.org"
BOB_DEVICE = "AGMTSWVYML"
BOB_CURVE = "T9tOKF+TShsn6mk1zisW2IBsBbTtzDNvw99RBFMJOgI"
BOB_ONETIME = "6QlQw3mGUveS735k/JDaviuoaih5eEi6S1J65iHjfgU"
class TestClass(object): class TestClass(object):
@ -36,6 +43,25 @@ class TestClass(object):
} }
} }
@property
def example_devices(self):
devices = defaultdict(dict)
for _ in range(10):
device = faker.olm_device()
devices[device.user_id][device.id] = device
bob_device = OlmDevice(
BOB_ID,
BOB_DEVICE,
{"ed25519": BOB_ONETIME,
"curve25519": BOB_CURVE}
)
devices[BOB_ID][BOB_DEVICE] = bob_device
return devices
async def test_daemon_start(self, pan_proxy_server, aiohttp_client, aioresponse): async def test_daemon_start(self, pan_proxy_server, aiohttp_client, aioresponse):
server, daemon, _ = pan_proxy_server server, daemon, _ = pan_proxy_server
@ -161,3 +187,24 @@ class TestClass(object):
assert message.user_id == "@example:example.org" assert message.user_id == "@example:example.org"
assert message.device_id == "GHTYAJCE" assert message.device_id == "GHTYAJCE"
async def tests_server_devices_update(self, running_proxy):
_, _, proxy, queues = running_proxy
queue, _ = queues
queue = queue.sync_q
devices = self.example_devices
bob_device = devices[BOB_ID][BOB_DEVICE]
message = queue.get_nowait()
assert isinstance(message, UpdateUsersMessage)
client = list(proxy.pan_clients.values())[0]
client.store.save_device_keys(devices)
await client.send_update_device(bob_device)
message = queue.get_nowait()
assert isinstance(message, UpdateDevicesMessage)
assert BOB_DEVICE in message.devices[BOB_ID]