mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-22 21:21:03 -05:00
Implement device key caching over federation
This commit is contained in:
parent
51e9fe36e4
commit
c974116f19
@ -126,6 +126,16 @@ class FederationClient(FederationBase):
|
|||||||
destination, content, timeout
|
destination, content, timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def query_user_devices(self, destination, user_id, timeout=30000):
|
||||||
|
"""Query the device keys for a list of user ids hosted on a remote
|
||||||
|
server.
|
||||||
|
"""
|
||||||
|
sent_queries_counter.inc("user_devices")
|
||||||
|
return self.transport_layer.query_user_devices(
|
||||||
|
destination, user_id, timeout
|
||||||
|
)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def claim_client_keys(self, destination, content, timeout):
|
def claim_client_keys(self, destination, content, timeout):
|
||||||
"""Claims one-time keys for a device hosted on a remote server.
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
@ -416,6 +416,9 @@ class FederationServer(FederationBase):
|
|||||||
def on_query_client_keys(self, origin, content):
|
def on_query_client_keys(self, origin, content):
|
||||||
return self.on_query_request("client_keys", content)
|
return self.on_query_request("client_keys", content)
|
||||||
|
|
||||||
|
def on_query_user_devices(self, origin, user_id):
|
||||||
|
return self.on_query_request("user_devices", user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_claim_client_keys(self, origin, content):
|
def on_claim_client_keys(self, origin, content):
|
||||||
|
@ -346,6 +346,32 @@ class TransportLayerClient(object):
|
|||||||
)
|
)
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def query_user_devices(self, destination, user_id, timeout):
|
||||||
|
"""Query the devices for a user id hosted on a remote server.
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"stream_id": "...",
|
||||||
|
"devices": [ { ... } ]
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination(str): The server to query.
|
||||||
|
query_content(dict): The user ids to query.
|
||||||
|
Returns:
|
||||||
|
A dict containg the device keys.
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/user/devices/" + user_id
|
||||||
|
|
||||||
|
content = yield self.client.get_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
defer.returnValue(content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def claim_client_keys(self, destination, query_content, timeout):
|
def claim_client_keys(self, destination, query_content, timeout):
|
||||||
|
@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
|
|||||||
return self.handler.on_query_client_keys(origin, content)
|
return self.handler.on_query_client_keys(origin, content)
|
||||||
|
|
||||||
|
|
||||||
|
class FederationUserDevicesQueryServlet(BaseFederationServlet):
|
||||||
|
PATH = "/user/devices/(?P<user_id>[^/]*)"
|
||||||
|
|
||||||
|
def on_GET(self, origin, content, query, user_id):
|
||||||
|
return self.handler.on_query_user_devices(origin, user_id)
|
||||||
|
|
||||||
|
|
||||||
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||||
PATH = "/user/keys/claim"
|
PATH = "/user/keys/claim"
|
||||||
|
|
||||||
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
|
|||||||
FederationGetMissingEventsServlet,
|
FederationGetMissingEventsServlet,
|
||||||
FederationEventAuthServlet,
|
FederationEventAuthServlet,
|
||||||
FederationClientKeysQueryServlet,
|
FederationClientKeysQueryServlet,
|
||||||
|
FederationUserDevicesQueryServlet,
|
||||||
FederationClientKeysClaimServlet,
|
FederationClientKeysClaimServlet,
|
||||||
FederationThirdPartyInviteExchangeServlet,
|
FederationThirdPartyInviteExchangeServlet,
|
||||||
On3pidBindServlet,
|
On3pidBindServlet,
|
||||||
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from synapse.api import errors
|
from synapse.api import errors
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
|
from synapse.util.async import Linearizer
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
@ -28,8 +29,18 @@ class DeviceHandler(BaseHandler):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(DeviceHandler, self).__init__(hs)
|
super(DeviceHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.hs = hs
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
self.federation = hs.get_federation_sender()
|
self.federation_sender = hs.get_federation_sender()
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
|
||||||
|
|
||||||
|
self.federation.register_edu_handler(
|
||||||
|
"m.device_list_update", self._incoming_device_list_update,
|
||||||
|
)
|
||||||
|
self.federation.register_query_handler(
|
||||||
|
"user_devices", self.on_federation_query_user_devices,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_device_registered(self, user_id, device_id,
|
def check_device_registered(self, user_id, device_id,
|
||||||
@ -55,7 +66,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
initial_device_display_name=initial_device_display_name,
|
initial_device_display_name=initial_device_display_name,
|
||||||
)
|
)
|
||||||
if new_device:
|
if new_device:
|
||||||
yield self.notify_device_update(user_id, device_id)
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
defer.returnValue(device_id)
|
defer.returnValue(device_id)
|
||||||
|
|
||||||
# if the device id is not specified, we'll autogen one, but loop a few
|
# if the device id is not specified, we'll autogen one, but loop a few
|
||||||
@ -69,7 +80,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
initial_device_display_name=initial_device_display_name,
|
initial_device_display_name=initial_device_display_name,
|
||||||
)
|
)
|
||||||
if new_device:
|
if new_device:
|
||||||
yield self.notify_device_update(user_id, device_id)
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
defer.returnValue(device_id)
|
defer.returnValue(device_id)
|
||||||
attempts += 1
|
attempts += 1
|
||||||
|
|
||||||
@ -151,7 +162,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
user_id=user_id, device_id=device_id
|
user_id=user_id, device_id=device_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.notify_device_update(user_id, device_id)
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def update_device(self, user_id, device_id, content):
|
def update_device(self, user_id, device_id, content):
|
||||||
@ -172,7 +183,7 @@ class DeviceHandler(BaseHandler):
|
|||||||
device_id,
|
device_id,
|
||||||
new_display_name=content.get("display_name")
|
new_display_name=content.get("display_name")
|
||||||
)
|
)
|
||||||
yield self.notify_device_update(user_id, device_id)
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
except errors.StoreError, e:
|
except errors.StoreError, e:
|
||||||
if e.code == 404:
|
if e.code == 404:
|
||||||
raise errors.NotFoundError()
|
raise errors.NotFoundError()
|
||||||
@ -180,26 +191,28 @@ class DeviceHandler(BaseHandler):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def notify_device_update(self, user_id, device_id):
|
def notify_device_update(self, user_id, device_ids):
|
||||||
rooms = yield self.store.get_rooms_for_user(user_id)
|
rooms = yield self.store.get_rooms_for_user(user_id)
|
||||||
room_ids = [r.room_id for r in rooms]
|
room_ids = [r.room_id for r in rooms]
|
||||||
|
|
||||||
hosts = set()
|
hosts = set()
|
||||||
for room_id in room_ids:
|
if self.hs.is_mine_id(user_id):
|
||||||
users = yield self.state.get_current_user_in_room(room_id)
|
for room_id in room_ids:
|
||||||
hosts.update(get_domain_from_id(u) for u in users)
|
users = yield self.state.get_current_user_in_room(room_id)
|
||||||
hosts.discard(self.server_name)
|
hosts.update(get_domain_from_id(u) for u in users)
|
||||||
|
hosts.discard(self.server_name)
|
||||||
|
|
||||||
position = yield self.store.add_device_change_to_streams(
|
position = yield self.store.add_device_change_to_streams(
|
||||||
user_id, device_id, list(hosts)
|
user_id, device_ids, list(hosts)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.notifier.on_new_event(
|
yield self.notifier.on_new_event(
|
||||||
"device_list_key", position, rooms=room_ids,
|
"device_list_key", position, rooms=room_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Sending device list update notif to: %r", hosts)
|
||||||
for host in hosts:
|
for host in hosts:
|
||||||
self.federation.send_device_messages(host)
|
self.federation_sender.send_device_messages(host)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_device_list_changes(self, user_id, room_ids, from_key):
|
def get_device_list_changes(self, user_id, room_ids, from_key):
|
||||||
@ -214,6 +227,54 @@ class DeviceHandler(BaseHandler):
|
|||||||
|
|
||||||
defer.returnValue(user_ids_changed)
|
defer.returnValue(user_ids_changed)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _incoming_device_list_update(self, origin, edu_content):
|
||||||
|
user_id = edu_content["user_id"]
|
||||||
|
device_id = edu_content["device_id"]
|
||||||
|
stream_id = edu_content["stream_id"]
|
||||||
|
prev_ids = edu_content.get("prev_id", [])
|
||||||
|
|
||||||
|
if get_domain_from_id(user_id) != origin:
|
||||||
|
# TODO: Raise?
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("Got edu: %r", edu_content)
|
||||||
|
|
||||||
|
with (yield self._remote_edue_linearizer.queue(user_id)):
|
||||||
|
resync = True
|
||||||
|
if len(prev_ids) == 1:
|
||||||
|
extremity = yield self.store.get_device_list_remote_extremity(user_id)
|
||||||
|
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
|
||||||
|
if str(extremity) == str(prev_ids[0]):
|
||||||
|
resync = False
|
||||||
|
|
||||||
|
if resync:
|
||||||
|
result = yield self.federation.query_user_devices(origin, user_id)
|
||||||
|
stream_id = result["stream_id"]
|
||||||
|
devices = result["devices"]
|
||||||
|
yield self.store.update_remote_device_list_cache(
|
||||||
|
user_id, devices, stream_id,
|
||||||
|
)
|
||||||
|
device_ids = [device["device_id"] for device in devices]
|
||||||
|
yield self.notify_device_update(user_id, device_ids)
|
||||||
|
else:
|
||||||
|
content = dict(edu_content)
|
||||||
|
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
|
||||||
|
content.pop(key, None)
|
||||||
|
yield self.store.update_remote_device_list_cache_entry(
|
||||||
|
user_id, device_id, content, stream_id,
|
||||||
|
)
|
||||||
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_federation_query_user_devices(self, user_id):
|
||||||
|
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||||
|
defer.returnValue({
|
||||||
|
"user_id": user_id,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"devices": devices,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def _update_device_from_client_ips(device, client_ips):
|
def _update_device_from_client_ips(device, client_ips):
|
||||||
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
ip = client_ips.get((device["user_id"], device["device_id"]), {})
|
||||||
|
@ -73,8 +73,7 @@ class E2eKeysHandler(object):
|
|||||||
if self.is_mine_id(user_id):
|
if self.is_mine_id(user_id):
|
||||||
local_query[user_id] = device_ids
|
local_query[user_id] = device_ids
|
||||||
else:
|
else:
|
||||||
domain = get_domain_from_id(user_id)
|
remote_queries[user_id] = device_ids
|
||||||
remote_queries.setdefault(domain, {})[user_id] = device_ids
|
|
||||||
|
|
||||||
# do the queries
|
# do the queries
|
||||||
failures = {}
|
failures = {}
|
||||||
@ -85,9 +84,40 @@ class E2eKeysHandler(object):
|
|||||||
if user_id in local_query:
|
if user_id in local_query:
|
||||||
results[user_id] = keys
|
results[user_id] = keys
|
||||||
|
|
||||||
|
remote_queries_not_in_cache = {}
|
||||||
|
if remote_queries:
|
||||||
|
query_list = []
|
||||||
|
for user_id, device_ids in remote_queries.iteritems():
|
||||||
|
if device_ids:
|
||||||
|
query_list.extend((user_id, device_id) for device_id in device_ids)
|
||||||
|
else:
|
||||||
|
query_list.append((user_id, None))
|
||||||
|
|
||||||
|
user_ids_not_in_cache, remote_results = (
|
||||||
|
yield self.store.get_user_devices_from_cache(
|
||||||
|
query_list
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for user_id, devices in remote_results.iteritems():
|
||||||
|
user_devices = results.setdefault(user_id, {})
|
||||||
|
for device_id, device in devices.iteritems():
|
||||||
|
keys = device.get("keys", None)
|
||||||
|
device_display_name = device.get("device_display_name", None)
|
||||||
|
if keys:
|
||||||
|
result = dict(keys)
|
||||||
|
unsigned = result.setdefault("unsigned", {})
|
||||||
|
if device_display_name:
|
||||||
|
unsigned["device_display_name"] = device_display_name
|
||||||
|
user_devices[device_id] = result
|
||||||
|
|
||||||
|
for user_id in user_ids_not_in_cache:
|
||||||
|
domain = get_domain_from_id(user_id)
|
||||||
|
r = remote_queries_not_in_cache.setdefault(domain, {})
|
||||||
|
r[user_id] = remote_queries[user_id]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_remote_query(destination):
|
def do_remote_query(destination):
|
||||||
destination_query = remote_queries[destination]
|
destination_query = remote_queries_not_in_cache[destination]
|
||||||
try:
|
try:
|
||||||
limiter = yield get_retry_limiter(
|
limiter = yield get_retry_limiter(
|
||||||
destination, self.clock, self.store
|
destination, self.clock, self.store
|
||||||
@ -119,7 +149,7 @@ class E2eKeysHandler(object):
|
|||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults([
|
yield preserve_context_over_deferred(defer.gatherResults([
|
||||||
preserve_fn(do_remote_query)(destination)
|
preserve_fn(do_remote_query)(destination)
|
||||||
for destination in remote_queries
|
for destination in remote_queries_not_in_cache
|
||||||
]))
|
]))
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
@ -259,7 +289,7 @@ class E2eKeysHandler(object):
|
|||||||
user_id, device_id, time_now,
|
user_id, device_id, time_now,
|
||||||
encode_canonical_json(device_keys)
|
encode_canonical_json(device_keys)
|
||||||
)
|
)
|
||||||
yield self.device_handler.notify_device_update(user_id, device_id)
|
yield self.device_handler.notify_device_update(user_id, [device_id])
|
||||||
|
|
||||||
one_time_keys = keys.get("one_time_keys", None)
|
one_time_keys = keys.get("one_time_keys", None)
|
||||||
if one_time_keys:
|
if one_time_keys:
|
||||||
|
@ -138,6 +138,89 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue({d["device_id"]: d for d in devices})
|
defer.returnValue({d["device_id"]: d for d in devices})
|
||||||
|
|
||||||
|
def get_device_list_remote_extremity(self, user_id):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="device_lists_remote_extremeties",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
retcol="stream_id",
|
||||||
|
desc="get_device_list_remote_extremity",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
|
||||||
|
stream_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"update_remote_device_list_cache_entry",
|
||||||
|
self._update_remote_device_list_cache_entry_txn,
|
||||||
|
user_id, device_id, content, stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
|
||||||
|
content, stream_id):
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_cache",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"content": json.dumps(content),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_extremeties",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"stream_id": stream_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_remote_device_list_cache(self, user_id, devices, stream_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"update_remote_device_list_cache",
|
||||||
|
self._update_remote_device_list_cache_txn,
|
||||||
|
user_id, devices, stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
|
||||||
|
stream_id):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_cache",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_cache",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": content["device_id"],
|
||||||
|
"content": json.dumps(content),
|
||||||
|
}
|
||||||
|
for content in devices
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_extremeties",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"stream_id": stream_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def get_devices_by_remote(self, destination, from_stream_id):
|
def get_devices_by_remote(self, destination, from_stream_id):
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
@ -184,7 +267,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
txn.execute(prev_sent_id_sql, (destination, user_id, True))
|
txn.execute(prev_sent_id_sql, (destination, user_id, True))
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
prev_id = rows[0][0]
|
prev_id = rows[0][0]
|
||||||
for device_id, result in user_devices.iteritems():
|
for device_id, device in user_devices.iteritems():
|
||||||
stream_id = query_map[(user_id, device_id)]
|
stream_id = query_map[(user_id, device_id)]
|
||||||
result = {
|
result = {
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
@ -195,10 +278,10 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
prev_id = stream_id
|
prev_id = stream_id
|
||||||
|
|
||||||
key_json = result.get("key_json", None)
|
key_json = device.get("key_json", None)
|
||||||
if key_json:
|
if key_json:
|
||||||
result["keys"] = json.loads(key_json)
|
result["keys"] = json.loads(key_json)
|
||||||
device_display_name = result.get("device_display_name", None)
|
device_display_name = device.get("device_display_name", None)
|
||||||
if device_display_name:
|
if device_display_name:
|
||||||
result["device_display_name"] = device_display_name
|
result["device_display_name"] = device_display_name
|
||||||
|
|
||||||
@ -206,6 +289,96 @@ class DeviceStore(SQLBaseStore):
|
|||||||
|
|
||||||
return (now_stream_id, results)
|
return (now_stream_id, results)
|
||||||
|
|
||||||
|
def get_user_devices_from_cache(self, query_list):
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
|
||||||
|
query_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_user_devices_from_cache_txn(self, txn, query_list):
|
||||||
|
user_ids = {user_id for user_id, _ in query_list}
|
||||||
|
|
||||||
|
user_ids_in_cache = set()
|
||||||
|
for user_id in user_ids:
|
||||||
|
stream_ids = self._simple_select_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_extremeties",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
retcol="stream_id",
|
||||||
|
)
|
||||||
|
if stream_ids:
|
||||||
|
user_ids_in_cache.add(user_id)
|
||||||
|
|
||||||
|
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for user_id, device_id in query_list:
|
||||||
|
if user_id not in user_ids_in_cache:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if device_id:
|
||||||
|
content = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_cache",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
retcol="content",
|
||||||
|
)
|
||||||
|
results.setdefault(user_id, {})[device_id] = json.loads(content)
|
||||||
|
else:
|
||||||
|
devices = self._simple_select_list_txn(
|
||||||
|
txn,
|
||||||
|
table="device_lists_remote_cache",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
retcols=("device_id", "content"),
|
||||||
|
)
|
||||||
|
results[user_id] = {
|
||||||
|
device["device_id"]: json.loads(device["content"])
|
||||||
|
for device in devices
|
||||||
|
}
|
||||||
|
user_ids_in_cache.discard(user_id)
|
||||||
|
|
||||||
|
return user_ids_not_in_cache, results
|
||||||
|
|
||||||
|
def get_devices_with_keys_by_user(self, user_id):
|
||||||
|
return self.runInteraction(
|
||||||
|
"get_devices_with_keys_by_user",
|
||||||
|
self._get_devices_with_keys_by_user_txn, user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
||||||
|
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
|
devices = self._get_e2e_device_keys_txn(
|
||||||
|
txn, [(user_id, None)], include_all_devices=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for user_id, user_devices in devices.iteritems():
|
||||||
|
results = []
|
||||||
|
for device_id, device in user_devices.iteritems():
|
||||||
|
result = {
|
||||||
|
"device_id": device_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
key_json = device.get("key_json", None)
|
||||||
|
if key_json:
|
||||||
|
result["keys"] = json.loads(key_json)
|
||||||
|
device_display_name = device.get("device_display_name", None)
|
||||||
|
if device_display_name:
|
||||||
|
result["device_display_name"] = device_display_name
|
||||||
|
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return now_stream_id, results
|
||||||
|
|
||||||
|
return now_stream_id, []
|
||||||
|
|
||||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
|
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
|
||||||
@ -242,17 +415,17 @@ class DeviceStore(SQLBaseStore):
|
|||||||
defer.returnValue(set(row["user_id"] for row in rows))
|
defer.returnValue(set(row["user_id"] for row in rows))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_device_change_to_streams(self, user_id, device_id, hosts):
|
def add_device_change_to_streams(self, user_id, device_ids, hosts):
|
||||||
# device_lists_stream
|
# device_lists_stream
|
||||||
# device_lists_outbound_pokes
|
# device_lists_outbound_pokes
|
||||||
with self._device_list_id_gen.get_next() as stream_id:
|
with self._device_list_id_gen.get_next() as stream_id:
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"add_device_change_to_streams", self._add_device_change_txn,
|
"add_device_change_to_streams", self._add_device_change_txn,
|
||||||
user_id, device_id, hosts, stream_id,
|
user_id, device_ids, hosts, stream_id,
|
||||||
)
|
)
|
||||||
defer.returnValue(stream_id)
|
defer.returnValue(stream_id)
|
||||||
|
|
||||||
def _add_device_change_txn(self, txn, user_id, device_id, hosts, stream_id):
|
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self._device_list_stream_cache.entity_has_changed,
|
self._device_list_stream_cache.entity_has_changed,
|
||||||
user_id, stream_id,
|
user_id, stream_id,
|
||||||
@ -263,14 +436,17 @@ class DeviceStore(SQLBaseStore):
|
|||||||
host, stream_id,
|
host, stream_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="device_lists_stream",
|
table="device_lists_stream",
|
||||||
values={
|
values=[
|
||||||
"stream_id": stream_id,
|
{
|
||||||
"user_id": user_id,
|
"stream_id": stream_id,
|
||||||
"device_id": device_id,
|
"user_id": user_id,
|
||||||
}
|
"device_id": device_id,
|
||||||
|
}
|
||||||
|
for device_id in device_ids
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
@ -285,6 +461,7 @@ class DeviceStore(SQLBaseStore):
|
|||||||
"sent": False,
|
"sent": False,
|
||||||
}
|
}
|
||||||
for destination in hosts
|
for destination in hosts
|
||||||
|
for device_id in device_ids
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -52,11 +52,11 @@ class EndToEndKeyStore(SQLBaseStore):
|
|||||||
query_params = []
|
query_params = []
|
||||||
|
|
||||||
for (user_id, device_id) in query_list:
|
for (user_id, device_id) in query_list:
|
||||||
query_clause = "k.user_id = ?"
|
query_clause = "user_id = ?"
|
||||||
query_params.append(user_id)
|
query_params.append(user_id)
|
||||||
|
|
||||||
if device_id:
|
if device_id:
|
||||||
query_clause += " AND k.device_id = ?"
|
query_clause += " AND device_id = ?"
|
||||||
query_params.append(device_id)
|
query_params.append(device_id)
|
||||||
|
|
||||||
query_clauses.append(query_clause)
|
query_clauses.append(query_clause)
|
||||||
|
@ -13,18 +13,6 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
CREATE TABLE device_list_streams_remote (
|
|
||||||
list_id TEXT NOT NULL,
|
|
||||||
origin TEXT NOT NULL,
|
|
||||||
user_id TEXT NOT NULL,
|
|
||||||
is_full BOOLEAN NOT NULL,
|
|
||||||
ts BIGINT NOT NULL
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX device_list_streams_remote_id_origin ON device_list_streams_remote(
|
|
||||||
origin, list_id, user_id
|
|
||||||
);
|
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE device_lists_remote_cache (
|
CREATE TABLE device_lists_remote_cache (
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
@ -35,6 +23,14 @@ CREATE TABLE device_lists_remote_cache (
|
|||||||
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
|
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE device_lists_remote_extremeties (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
stream_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
|
||||||
|
|
||||||
|
|
||||||
CREATE TABLE device_lists_stream (
|
CREATE TABLE device_lists_stream (
|
||||||
stream_id BIGINT NOT NULL,
|
stream_id BIGINT NOT NULL,
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
|
@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
hs = yield utils.setup_test_homeserver(handlers=None)
|
hs = yield utils.setup_test_homeserver()
|
||||||
self.handler = synapse.handlers.device.DeviceHandler(hs)
|
self.handler = hs.get_device_handler()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_device_is_created_if_doesnt_exist(self):
|
def test_device_is_created_if_doesnt_exist(self):
|
||||||
res = yield self.handler.check_device_registered(
|
res = yield self.handler.check_device_registered(
|
||||||
user_id="boris",
|
user_id="@boris:foo",
|
||||||
device_id="fco",
|
device_id="fco",
|
||||||
initial_device_display_name="display name"
|
initial_device_display_name="display name"
|
||||||
)
|
)
|
||||||
self.assertEqual(res, "fco")
|
self.assertEqual(res, "fco")
|
||||||
|
|
||||||
dev = yield self.handler.store.get_device("boris", "fco")
|
dev = yield self.handler.store.get_device("@boris:foo", "fco")
|
||||||
self.assertEqual(dev["display_name"], "display name")
|
self.assertEqual(dev["display_name"], "display name")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_device_is_preserved_if_exists(self):
|
def test_device_is_preserved_if_exists(self):
|
||||||
res1 = yield self.handler.check_device_registered(
|
res1 = yield self.handler.check_device_registered(
|
||||||
user_id="boris",
|
user_id="@boris:foo",
|
||||||
device_id="fco",
|
device_id="fco",
|
||||||
initial_device_display_name="display name"
|
initial_device_display_name="display name"
|
||||||
)
|
)
|
||||||
self.assertEqual(res1, "fco")
|
self.assertEqual(res1, "fco")
|
||||||
|
|
||||||
res2 = yield self.handler.check_device_registered(
|
res2 = yield self.handler.check_device_registered(
|
||||||
user_id="boris",
|
user_id="@boris:foo",
|
||||||
device_id="fco",
|
device_id="fco",
|
||||||
initial_device_display_name="new display name"
|
initial_device_display_name="new display name"
|
||||||
)
|
)
|
||||||
self.assertEqual(res2, "fco")
|
self.assertEqual(res2, "fco")
|
||||||
|
|
||||||
dev = yield self.handler.store.get_device("boris", "fco")
|
dev = yield self.handler.store.get_device("@boris:foo", "fco")
|
||||||
self.assertEqual(dev["display_name"], "display name")
|
self.assertEqual(dev["display_name"], "display name")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_device_id_is_made_up_if_unspecified(self):
|
def test_device_id_is_made_up_if_unspecified(self):
|
||||||
device_id = yield self.handler.check_device_registered(
|
device_id = yield self.handler.check_device_registered(
|
||||||
user_id="theresa",
|
user_id="@theresa:foo",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
initial_device_display_name="display"
|
initial_device_display_name="display"
|
||||||
)
|
)
|
||||||
|
|
||||||
dev = yield self.handler.store.get_device("theresa", device_id)
|
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
|
||||||
self.assertEqual(dev["display_name"], "display")
|
self.assertEqual(dev["display_name"], "display")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.mock_federation = Mock(spec=[
|
self.mock_federation = Mock(spec=[
|
||||||
"make_query",
|
"make_query",
|
||||||
|
"register_edu_handler",
|
||||||
])
|
])
|
||||||
|
|
||||||
self.query_handlers = {}
|
self.query_handlers = {}
|
||||||
|
@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.mock_federation = Mock(spec=[
|
self.mock_federation = Mock(spec=[
|
||||||
"make_query",
|
"make_query",
|
||||||
|
"register_edu_handler",
|
||||||
])
|
])
|
||||||
|
|
||||||
self.query_handlers = {}
|
self.query_handlers = {}
|
||||||
|
@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
|||||||
event_cache_size=1,
|
event_cache_size=1,
|
||||||
password_providers=[],
|
password_providers=[],
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
hs = yield setup_test_homeserver(
|
||||||
|
config=config,
|
||||||
|
federation_sender=Mock(),
|
||||||
|
replication_layer=Mock(),
|
||||||
|
)
|
||||||
|
|
||||||
self.as_token = "token1"
|
self.as_token = "token1"
|
||||||
self.as_url = "some_url"
|
self.as_url = "some_url"
|
||||||
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
event_cache_size=1,
|
event_cache_size=1,
|
||||||
password_providers=[],
|
password_providers=[],
|
||||||
)
|
)
|
||||||
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
|
hs = yield setup_test_homeserver(
|
||||||
|
config=config,
|
||||||
|
federation_sender=Mock(),
|
||||||
|
replication_layer=Mock(),
|
||||||
|
)
|
||||||
self.db_pool = hs.get_db_pool()
|
self.db_pool = hs.get_db_pool()
|
||||||
|
|
||||||
self.as_list = [
|
self.as_list = [
|
||||||
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
config=config,
|
config=config,
|
||||||
datastore=Mock(),
|
datastore=Mock(),
|
||||||
federation_sender=Mock()
|
federation_sender=Mock(),
|
||||||
|
replication_layer=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
ApplicationServiceStore(hs)
|
ApplicationServiceStore(hs)
|
||||||
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
config=config,
|
config=config,
|
||||||
datastore=Mock(),
|
datastore=Mock(),
|
||||||
federation_sender=Mock()
|
federation_sender=Mock(),
|
||||||
|
replication_layer=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
|||||||
hs = yield setup_test_homeserver(
|
hs = yield setup_test_homeserver(
|
||||||
config=config,
|
config=config,
|
||||||
datastore=Mock(),
|
datastore=Mock(),
|
||||||
federation_sender=Mock()
|
federation_sender=Mock(),
|
||||||
|
replication_layer=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
with self.assertRaises(ConfigError) as cm:
|
with self.assertRaises(ConfigError) as cm:
|
||||||
|
Loading…
Reference in New Issue
Block a user