Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2017-01-30 14:36:46 +00:00
commit 717e4448c4
37 changed files with 1341 additions and 432 deletions

View file

@ -6,11 +6,9 @@ media.
The API is:: The API is::
POST /_matrix/client/r0/admin/purge_media_cache POST /_matrix/client/r0/admin/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token>
{ {}
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``. ``<unix_timestamp_in_ms>``.

View file

@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep from synapse.util.async import sleep
@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedRegistrationStore, SlavedDeviceStore,
): ):
pass pass

View file

@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
@ -380,6 +382,27 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms stream_key, position, users=users, rooms=rooms
) )
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
position = row[position_index]
user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result): def notify(result):
stream = result.get("events") stream = result.get("events")
if stream: if stream:
@ -417,6 +440,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream( notify_from_stream(
result, "to_device", "to_device_key", user="user_id" result, "to_device", "to_device_key", user="user_id"
) )
yield notify_device_list_update(result)
while True: while True:
try: try:
@ -427,7 +451,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
yield presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) yield notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
yield sleep(5) yield sleep(5)

View file

@ -126,6 +126,16 @@ class FederationClient(FederationBase):
destination, content, timeout destination, content, timeout
) )
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.inc("user_devices")
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.

View file

@ -416,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) return self.on_query_request("client_keys", content)
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_claim_client_keys(self, origin, content): def on_claim_client_keys(self, origin, content):

View file

@ -100,6 +100,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@ -320,7 +321,7 @@ class TransactionQueue(object):
self.store, self.store,
) )
device_message_edus, device_stream_id = ( device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
) )
@ -355,11 +356,23 @@ class TransactionQueue(object):
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter, limiter=limiter,
) )
if not success: if success:
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
logger.info("Marking as sent %r %r", destination, dev_list_id)
yield self.store.mark_as_sent_devices_by_remote(
destination, dev_list_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break break
except NotRetryingDestination: except NotRetryingDestination:
logger.debug( logger.debug(
@ -387,13 +400,26 @@ class TransactionQueue(object):
) )
for content in contents for content in contents
] ]
defer.returnValue((edus, stream_id))
last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
now_stream_id, results = yield self.store.get_devices_by_remote(
destination, last_device_list
)
edus.extend(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.device_list_update",
content=content,
)
for content in results
)
defer.returnValue((edus, stream_id, now_stream_id))
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id, pending_failures, limiter):
should_delete_from_device_stream, limiter):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -504,13 +530,6 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
success = False success = False
else:
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.

View file

@ -346,6 +346,32 @@ class TransportLayerClient(object):
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
{
"stream_id": "...",
"devices": [ { ... } ]
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/devices/" + user_id
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content, timeout): def claim_client_keys(self, destination, query_content, timeout):

View file

@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
def on_GET(self, origin, content, query, user_id):
return self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim" PATH = "/user/keys/claim"
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,

View file

@ -88,9 +88,13 @@ class BaseHandler(object):
current_state = yield self.store.get_events( current_state = yield self.store.get_events(
context.current_state_ids.values() context.current_state_ids.values()
) )
current_state = current_state.values()
else: else:
current_state = yield self.store.get_current_state(event.room_id) current_state = yield self.state_handler.get_current_state(
event.room_id
)
current_state = current_state.values()
logger.info("maybe_kick_guest_users %r", current_state) logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state) yield self.kick_guest_users(current_state)

View file

@ -15,6 +15,8 @@
from synapse.api import errors from synapse.api import errors
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.types import get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
@ -27,6 +29,21 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(DeviceHandler, self).__init__(hs) super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_device_registered(self, user_id, device_id, def check_device_registered(self, user_id, device_id,
initial_device_display_name=None): initial_device_display_name=None):
@ -45,28 +62,28 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied) str: device id (generated if none was supplied)
""" """
if device_id is not None: if device_id is not None:
yield self.store.store_device( new_device = yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
) )
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id) defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few # if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash. # times in case of a clash.
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
try:
device_id = stringutils.random_string(10).upper() device_id = stringutils.random_string(10).upper()
yield self.store.store_device( new_device = yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
) )
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id) defer.returnValue(device_id)
except errors.StoreError:
attempts += 1 attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.") raise errors.StoreError(500, "Couldn't generate a device ID.")
@ -147,6 +164,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id user_id=user_id, device_id=device_id
) )
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device
@ -166,12 +185,110 @@ class DeviceHandler(BaseHandler):
device_id, device_id,
new_display_name=content.get("display_name") new_display_name=content.get("display_name")
) )
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e: except errors.StoreError, e:
if e.code == 404: if e.code == 404:
raise errors.NotFoundError() raise errors.NotFoundError()
else: else:
raise raise
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
hosts = set()
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
if hosts:
logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
with (yield self._remote_edue_linearizer.queue(user_id)):
# If the prev id matches whats in our cache table, then we don't need
# to resync the users device list, otherwise we do.
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
# Fetch all devices for the user.
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})

View file

@ -73,10 +73,9 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
local_query[user_id] = device_ids local_query[user_id] = device_ids
else: else:
domain = get_domain_from_id(user_id) remote_queries[user_id] = device_ids
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries # Firt get local devices.
failures = {} failures = {}
results = {} results = {}
if local_query: if local_query:
@ -85,9 +84,42 @@ class E2eKeysHandler(object):
if user_id in local_query: if user_id in local_query:
results[user_id] = keys results[user_id] = keys
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
)
for user_id, devices in remote_results.iteritems():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries[destination] destination_query = remote_queries_not_in_cache[destination]
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, self.clock, self.store destination, self.clock, self.store
@ -119,7 +151,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([ yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
for destination in remote_queries for destination in remote_queries_not_in_cache
])) ]))
defer.returnValue({ defer.returnValue({
@ -259,6 +291,7 @@ class E2eKeysHandler(object):
user_id, device_id, time_now, user_id, device_id, time_now,
encode_canonical_json(device_keys) encode_canonical_json(device_keys)
) )
yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
if one_time_keys: if one_time_keys:

View file

@ -1319,7 +1319,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event( event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context, event, new_event_context,
current_state=state,
) )
defer.returnValue((event_stream_id, max_stream_id)) defer.returnValue((event_stream_id, max_stream_id))

View file

@ -208,7 +208,9 @@ class MessageHandler(BaseHandler):
content = builder.content content = builder.content
try: try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target) content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target) content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e: except Exception as e:
logger.info( logger.info(

View file

@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device. "to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
])): ])):
__slots__ = [] __slots__ = []
@ -544,6 +545,10 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder) yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
@ -551,9 +556,32 @@ class SyncHandler(object):
invited=sync_result_builder.invited, invited=sync_result_builder.invited,
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder): def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates """Generates the portion of the sync response. Populates

View file

@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",), ("to_device",),
("public_rooms",), ("public_rooms",),
("federation",), ("federation",),
("device_lists",),
) )
@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id() public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token() federation_token = self.federation_sender.get_current_token()
device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token), int(public_rooms_token),
int(federation_token), int(federation_token),
int(device_list_token),
)) ))
@request_handler() @request_handler()
@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams)
yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack) self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
@ -495,6 +499,20 @@ class ReplicationResource(Resource):
"position", "type", "content", "position", "type", "content",
), position=upto_token) ), position=upto_token)
@defer.inlineCallbacks
def device_lists(self, writer, current_token, limit, request_streams):
current_position = current_token.device_lists
device_lists = request_streams.get("device_lists")
if device_lists is not None and device_lists != current_position:
changes = yield self.store.get_all_device_list_changes_for_remotes(
device_lists,
)
writer.write_header_and_rows("device_lists", changes, (
"position", "user_id", "destination",
), position=current_position)
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -527,7 +545,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation", "federation", "device_lists",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id",
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = DataStore.get_device_stream_token.__func__
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
get_devices_by_remote = DataStore.get_devices_by_remote.__func__
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
)
return super(SlavedDeviceStore, self).process_replication(result)

View file

@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore):
get_latest_event_ids_in_room = EventFederationStore.__dict__[ get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room" "get_latest_event_ids_in_room"
] ]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[ get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user" "get_invited_rooms_for_user"
] ]
@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore):
) )
get_event = DataStore.get_event.__func__ get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__ get_events = DataStore.get_events.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = ( get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__ DataStore.get_rooms_for_user_where_membership_is.__func__
) )
@ -248,7 +243,6 @@ class SlavedEventStore(BaseSlavedStore):
def invalidate_caches_for_event(self, event, backfilled, reset_state): def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state: if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all() self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
@ -289,7 +283,3 @@ class SlavedEventStore(BaseSlavedStore):
if (not event.internal_metadata.is_invite_from_remote() if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()): and event.internal_metadata.is_outlier()):
return return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))

View file

@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id, filter.event_fields sync_result.archived, time_now, requester.access_token_id,
filter.event_fields,
) )
response_content = { response_content = {
"account_data": {"events": sync_result.account_data}, "account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device}, "to_device": {"events": sync_result.to_device},
"device_lists": {
"changed": list(sync_result.device_lists),
},
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, time_now sync_result.presence, time_now
), ),

View file

@ -429,6 +429,9 @@ def resolve_events(state_sets, state_map_factory):
dict[(str, str), synapse.events.FrozenEvent] is a map from dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event. (type, state_key) to event.
""" """
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate( unconflicted_state, conflicted_state = _seperate(
state_sets, state_sets,
) )

View file

@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator( self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id",
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View file

@ -387,6 +387,10 @@ class SQLBaseStore(object):
Args: Args:
table : string giving the table name table : string giving the table name
values : dict of new column names and values for them values : dict of new column names and values for them
Returns:
bool: Whether the row was inserted or not. Only useful when
`or_ignore` is True
""" """
try: try:
yield self.runInteraction( yield self.runInteraction(
@ -398,6 +402,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db. # a cursor after we receive an error from the db.
if not or_ignore: if not or_ignore:
raise raise
defer.returnValue(False)
defer.returnValue(True)
@staticmethod @staticmethod
def _simple_insert_txn(txn, table, values): def _simple_insert_txn(txn, table, values):

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import ujson as json
from twisted.internet import defer from twisted.internet import defer
@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore): class DeviceStore(SQLBaseStore):
def __init__(self, hs):
super(DeviceStore, self).__init__(hs)
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, def store_device(self, user_id, device_id,
initial_device_display_name, initial_device_display_name):
ignore_if_known=True):
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
Args: Args:
user_id (str): id of user associated with the device user_id (str): id of user associated with the device
device_id (str): id of device device_id (str): id of device
initial_device_display_name (str): initial displayname of the initial_device_display_name (str): initial displayname of the
device device. Ignored if device exists.
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
Returns: Returns:
defer.Deferred defer.Deferred: boolean whether the device was inserted or an
Raises: existing device existed with that ID.
StoreError: if ignore_if_known is False and the device was already
known
""" """
try: try:
yield self._simple_insert( inserted = yield self._simple_insert(
"devices", "devices",
values={ values={
"user_id": user_id, "user_id": user_id,
@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
"display_name": initial_device_display_name "display_name": initial_device_display_name
}, },
desc="store_device", desc="store_device",
or_ignore=ignore_if_known, or_ignore=True,
) )
defer.returnValue(inserted)
except Exception as e: except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s", " display_name=%s(%r) failed: %s",
@ -139,3 +143,432 @@ class DeviceStore(SQLBaseStore):
) )
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
return self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
)
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
"""Updates a single user's device in the cache.
"""
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def update_remote_device_list_cache(self, user_id, devices, stream_id):
"""Replace the cache of the remote user's devices.
"""
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
)
self._simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
"content": json.dumps(content),
}
for content in devices
]
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(now_stream_id, [ { updates }, .. ])
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows}
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in devices.iteritems():
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return (now_stream_id, results)
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
return self.runInteraction(
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
query_list,
)
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
content = self._simple_select_one_onecol_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
)
results.setdefault(user_id, {})[device_id] = json.loads(content)
else:
devices = self._simple_select_list_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
)
results[user_id] = {
device["device_id"]: json.loads(device["content"])
for device in devices
}
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.iteritems():
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id < (
SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
)
"""
txn.execute(sql, (destination, destination, stream_id,))
sql = """
UPDATE device_lists_outbound_pokes SET sent = ?
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (True, destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row["user_id"] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE stream_id > ?
"""
return self._execute(
"get_users_and_hosts_device_list", None,
sql, from_key,
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id,
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids
]
)
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
"sent": False,
"ts": now,
}
for destination in hosts
for device_id in device_ids
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
(destination, user_id) tuple to ensure that the prev_ids remain correct
if the server does come back.
"""
now = self._clock.time_msec()
def _prune_txn(txn):
select_sql = """
SELECT destination, user_id, max(stream_id) as stream_id
FROM device_lists_outbound_pokes
GROUP BY destination, user_id
"""
txn.execute(select_sql)
rows = txn.fetchall()
delete_sql = """
DELETE FROM device_lists_outbound_pokes
WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
delete_sql,
(
(now, row["destination"], row["user_id"], row["stream_id"])
for row in rows
)
)
return self.runInteraction(
"_prune_old_outbound_device_pokes", _prune_txn
)

View file

@ -12,9 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections from twisted.internet import defer
import twisted.internet.defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -33,10 +31,12 @@ class EndToEndKeyStore(SQLBaseStore):
} }
) )
def get_e2e_device_keys(self, query_list): def get_e2e_device_keys(self, query_list, include_all_devices=False):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name". dict containing "key_json", "device_display_name".
@ -45,41 +45,42 @@ class EndToEndKeyStore(SQLBaseStore):
return {} return {}
return self.runInteraction( return self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list "get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices,
) )
def _get_e2e_device_keys_txn(self, txn, query_list): def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
query_clauses = [] query_clauses = []
query_params = [] query_params = []
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
if device_id: if device_id:
query_clause += " AND k.device_id = ?" query_clause += " AND device_id = ?"
query_params.append(device_id) query_params.append(device_id)
query_clauses.append(query_clause) query_clauses.append(query_clause)
sql = ( sql = (
"SELECT k.user_id, k.device_id, " "SELECT user_id, device_id, "
" d.display_name AS device_display_name, " " d.display_name AS device_display_name, "
" k.key_json" " k.key_json"
" FROM e2e_device_keys_json k" " FROM devices d"
" LEFT JOIN devices d ON d.user_id = k.user_id" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" AND d.device_id = k.device_id"
" WHERE %s" " WHERE %s"
) % ( ) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses) " OR ".join("(" + q + ")" for q in query_clauses)
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict) result = {}
for row in rows: for row in rows:
result[row["user_id"]][row["device_id"]] = row result.setdefault(row["user_id"], {})[row["device_id"]] = row
return result return result
@ -152,7 +153,7 @@ class EndToEndKeyStore(SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
@twisted.internet.defer.inlineCallbacks @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete( yield self._simple_delete(
table="e2e_device_keys_json", table="e2e_device_keys_json",

View file

@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore):
], ],
) )
self._update_extremeties(txn, events) self._update_backward_extremeties(txn, events)
def _update_extremeties(self, txn, events): def _update_backward_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated """Updates the event_backward_extremities tables based on the new/updated
events being persisted. events being persisted.
This is called for new events *and* for events that were outliers, but This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers. are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
""" """
events_by_room = {} events_by_room = {}
for ev in events: for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev) events_by_room.setdefault(ev.room_id, []).append(ev)
for room_id, room_events in events_by_room.items():
prevs = [
e_id for ev in room_events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
]
if prevs:
txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
)
query = (
"INSERT INTO event_forward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
" )"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id, ev.event_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
max_stream_ord = max(
ev.internal_metadata.stream_ordering for ev in events
)
new_extrem = {}
for room_id in events_by_room:
event_ids = self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
new_extrem[room_id] = event_ids
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_ord,
}
for room_id, extrem_evs in new_extrem.items()
for event_id in extrem_evs
]
)
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS (" " SELECT ?, ? WHERE NOT EXISTS ("
@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore):
] ]
) )
for room_id in events_by_room:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_forward_extremeties_for_room(self, room_id, stream_ordering): def get_forward_extremeties_for_room(self, room_id, stream_ordering):
# We want to make the cache more effective, so we clamp to the last # We want to make the cache more effective, so we clamp to the last
# change before the given ordering. # change before the given ordering.

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, _RollbackButIsFineException from ._base import SQLBaseStore
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
@ -27,6 +27,7 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict from collections import deque, namedtuple, OrderedDict
@ -71,22 +72,19 @@ class _EventPeristenceQueue(object):
""" """
_EventPersistQueueItem = namedtuple("_EventPersistQueueItem", ( _EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
"events_and_contexts", "current_state", "backfilled", "deferred", "events_and_contexts", "backfilled", "deferred",
)) ))
def __init__(self): def __init__(self):
self._event_persist_queues = {} self._event_persist_queues = {}
self._currently_persisting_rooms = set() self._currently_persisting_rooms = set()
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state): def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options. """Add events to the queue, with the given persist_event options.
""" """
queue = self._event_persist_queues.setdefault(room_id, deque()) queue = self._event_persist_queues.setdefault(room_id, deque())
if queue: if queue:
end_item = queue[-1] end_item = queue[-1]
if end_item.current_state or current_state:
# We perist events with current_state set to True one at a time
pass
if end_item.backfilled == backfilled: if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts) end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe() return end_item.deferred.observe()
@ -96,7 +94,6 @@ class _EventPeristenceQueue(object):
queue.append(self._EventPersistQueueItem( queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
backfilled=backfilled, backfilled=backfilled,
current_state=current_state,
deferred=deferred, deferred=deferred,
)) ))
@ -216,7 +213,6 @@ class EventsStore(SQLBaseStore):
d = preserve_fn(self._event_persist_queue.add_to_queue)( d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs, room_id, evs_ctxs,
backfilled=backfilled, backfilled=backfilled,
current_state=None,
) )
deferreds.append(d) deferreds.append(d)
@ -229,11 +225,10 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, current_state=None, backfilled=False): def persist_event(self, event, context, backfilled=False):
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], event.room_id, [(event, context)],
backfilled=backfilled, backfilled=backfilled,
current_state=current_state,
) )
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
@ -246,17 +241,6 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id): def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def persisting_queue(item): def persisting_queue(item):
if item.current_state:
for event, context in item.events_and_contexts:
# There should only ever be one item in
# events_and_contexts when current_state is
# not None
yield self._persist_event(
event, context,
current_state=item.current_state,
backfilled=item.backfilled,
)
else:
yield self._persist_events( yield self._persist_events(
item.events_and_contexts, item.events_and_contexts,
backfilled=item.backfilled, backfilled=item.backfilled,
@ -294,35 +278,183 @@ class EventsStore(SQLBaseStore):
for chunk in chunks: for chunk in chunks:
# We can't easily parallelize these since different chunks # We can't easily parallelize these since different chunks
# might contain the same event. :( # might contain the same event. :(
# NB: Assumes that we are only persisting events for one room
# at a time.
new_forward_extremeties = {}
current_state_for_room = {}
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
)
for room_id, ev_ctx_rm in events_by_room.items():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = yield self._calculate_new_extremeties(
room_id, [ev for ev, _ in ev_ctx_rm]
)
if new_latest_event_ids == set(latest_event_ids):
# No change in extremities, so no change in state
continue
new_forward_extremeties[room_id] = new_latest_event_ids
state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids
)
if state:
current_state_for_room[room_id] = state
yield self.runInteraction( yield self.runInteraction(
"persist_events", "persist_events",
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=chunk, events_and_contexts=chunk,
backfilled=backfilled, backfilled=backfilled,
delete_existing=delete_existing, delete_existing=delete_existing,
current_state_for_room=current_state_for_room,
new_forward_extremeties=new_forward_extremeties,
) )
persist_event_counter.inc_by(len(chunk)) persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function def _calculate_new_extremeties(self, room_id, events):
def _persist_event(self, event, context, current_state=None, backfilled=False, """Calculates the new forward extremeties for a room given events to
delete_existing=False): persist.
try:
with self._stream_id_gen.get_next() as stream_ordering: Assumes that we are only persisting events for one room at a time.
event.internal_metadata.stream_ordering = stream_ordering """
yield self.runInteraction( latest_event_ids = yield self.get_latest_event_ids_in_room(
"persist_event", room_id
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
) )
persist_event_counter.inc() new_latest_event_ids = set(latest_event_ids)
except _RollbackButIsFineException: # First, add all the new events to the list
pass new_latest_event_ids.update(
event.event_id for event in events
if not event.internal_metadata.is_outlier()
)
# Now remove all events that are referenced by the to-be-added events
new_latest_event_ids.difference_update(
e_id
for event in events
for e_id, _ in event.prev_events
if not event.internal_metadata.is_outlier()
)
# And finally remove any events that are referenced by previously added
# events.
rows = yield self._simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=list(new_latest_event_ids),
retcols=["prev_event_id"],
keyvalues={
"room_id": room_id,
"is_state": False,
},
desc="_calculate_new_extremeties",
)
new_latest_event_ids.difference_update(
row["prev_event_id"] for row in rows
)
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
2-tuple (to_delete, to_insert) where both are state dicts, i.e.
(type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert.
May return None if there are no changes to be applied.
"""
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
# First search in the list of new events we're adding,
# and then use the current state from that
for ev, ctx in events_context:
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
was_updated = True
missing_event_ids.append(event_id)
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
current_state = yield resolve_events(
state_sets,
state_map_factory=lambda ev_ids: self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
return
existing_state_rows = yield self._simple_select_list(
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
existing_events = set(row["event_id"] for row in existing_state_rows)
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
if not changed_events:
return
to_delete = {
(row["type"], row["state_key"]): row["event_id"]
for row in existing_state_rows
if row["event_id"] in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
key: ev_id for key, ev_id in current_state.iteritems()
if ev_id in events_to_insert
}
defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
@ -380,53 +512,10 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events}) defer.returnValue({e.event_id: e for e in events})
@log_function
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
table="current_state_events",
keyvalues={"room_id": event.room_id},
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
}
)
return self._persist_events_txn(
txn,
[(event, context)],
backfilled=backfilled,
delete_existing=delete_existing,
)
@log_function @log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False): delete_existing=False, current_state_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table, Rejected events are only inserted into the events table, the events_json table,
@ -436,6 +525,97 @@ class EventsStore(SQLBaseStore):
If delete_existing is True then existing events will be purged from the If delete_existing is True then existing events will be purged from the
database before insertion. This is useful when retrying due to IntegrityError. database before insertion. This is useful when retrying due to IntegrityError.
""" """
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state_tuple in current_state_for_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in to_insert.iteritems()
],
)
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
state_key for ev_type, state_key in to_delete.iterkeys()
if ev_type == EventTypes.Member
)
members_changed.update(
state_key for ev_type, state_key in to_insert.iterkeys()
if ev_type == EventTypes.Member
)
for member in members_changed:
txn.call_after(self.get_rooms_for_user.invalidate, (member,))
txn.call_after(self.get_users_in_room.invalidate, (room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": max_stream_order}
)
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
)
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self._simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
{
"event_id": ev_id,
"room_id": room_id,
}
for room_id, new_extrem in new_forward_extremeties.items()
for ev_id in new_extrem
],
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_order,
}
for room_id, new_extrem in new_forward_extremeties.items()
for event_id in new_extrem
]
)
# Ensure that we don't have the same event twice. # Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one. # Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict() new_events_and_contexts = OrderedDict()
@ -550,7 +730,7 @@ class EventsStore(SQLBaseStore):
# Update the event_backward_extremities table now that this # Update the event_backward_extremities table now that this
# event isn't an outlier any more. # event isn't an outlier any more.
self._update_extremeties(txn, [event]) self._update_backward_extremeties(txn, [event])
events_and_contexts = [ events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove ec for ec in events_and_contexts if ec[0] not in to_remove
@ -804,29 +984,6 @@ class EventsStore(SQLBaseStore):
# to update the current state table # to update the current state table
return return
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return
def _add_to_cache(self, txn, events_and_contexts): def _add_to_cache(self, txn, events_and_contexts):

View file

@ -0,0 +1,59 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Cache of remote devices.
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
-- The last update we got for a user. Empty if we're not receiving updates for
-- that user.
CREATE TABLE device_lists_remote_extremeties (
user_id TEXT NOT NULL,
stream_id TEXT NOT NULL
);
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
-- Stream of device lists updates. Includes both local and remotes
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL
);
CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
-- The stream of updates to send to other servers. We keep at least one row
-- per user that was sent so that the prev_id for any new updates can be
-- calculated
CREATE TABLE device_lists_outbound_pokes (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
sent BOOLEAN NOT NULL,
ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers
);
CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);

View file

@ -232,58 +232,6 @@ class StateStore(SQLBaseStore):
return count return count
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? "
)
if event_type and state_key is not None:
sql += " AND type = ? AND state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
def _get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=100000, iterable=True) @cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types): def _get_state_group_from_group(self, group, types):
raise NotImplementedError() raise NotImplementedError()

View file

@ -44,6 +44,7 @@ class EventSources(object):
def get_current_token(self): def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -63,6 +64,7 @@ class EventSources(object):
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key,
) )
defer.returnValue(token) defer.returnValue(token)
@ -70,6 +72,7 @@ class EventSources(object):
def get_current_token_for_room(self, room_id): def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -89,5 +92,6 @@ class EventSources(object):
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key,
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -158,6 +158,7 @@ class StreamToken(
"account_data_key", "account_data_key",
"push_rules_key", "push_rules_key",
"to_device_key", "to_device_key",
"device_list_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -195,6 +196,7 @@ class StreamToken(
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View file

@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None) hs = yield utils.setup_test_homeserver()
self.handler = synapse.handlers.device.DeviceHandler(hs) self.handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self): def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered( res = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name"
) )
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco") dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_is_preserved_if_exists(self): def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered( res1 = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name"
) )
self.assertEqual(res1, "fco") self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered( res2 = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="new display name" initial_device_display_name="new display name"
) )
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco") dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self): def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id="theresa", user_id="@theresa:foo",
device_id=None, device_id=None,
initial_device_display_name="display" initial_device_display_name="display"
) )
dev = yield self.handler.store.get_device("theresa", device_id) dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display") self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
"register_edu_handler",
]) ])
self.query_handlers = {} self.query_handlers = {}

View file

@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
"register_edu_handler",
]) ])
self.query_handlers = {} self.query_handlers = {}

View file

@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_devices_by_remote",
]), ]),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=None, handlers=None,
@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
defer.succeed(retry_timings_res) defer.succeed(retry_timings_res)
) )
self.datastore.get_devices_by_remote.return_value = (0, [])
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response self.datastore.get_received_txn_response = get_received_txn_response

View file

@ -60,7 +60,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_room_members(self): def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate() yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), []) yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), []) yield self.check("get_users_in_room", (ROOM_ID,), [])
@ -95,15 +95,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)]) )])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2]) yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist( yield self.persist(
type="m.room.member", key=USER_ID, membership="join", type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
) )
yield self.replicate() yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID]) yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2, USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self): def test_get_latest_event_ids_in_room(self):
@ -122,51 +118,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id] "get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
) )
@defer.inlineCallbacks
def test_get_current_state(self):
# Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
)
# Join the room.
join1 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join1]
)
# Add some other user to the room.
join2 = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[join2]
)
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_redactions(self): def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID) yield self.persist(type="m.room.create", key="", creator=USER_ID)
@ -283,6 +234,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if depth is None: if depth is None:
depth = self.event_id depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
event_dict = { event_dict = {
"sender": sender, "sender": sender,
"type": type, "type": type,
@ -309,12 +266,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = { state_ids = {
key: e.event_id for key, e in state.items() key: e.event_id for key, e in state.items()
} }
else:
state_ids = None
context = EventContext() context = EventContext()
context.current_state_ids = state_ids context.current_state_ids = state_ids
context.prev_state_ids = state_ids context.prev_state_ids = state_ids
elif not backfill:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
else:
context = EventContext()
context.push_actions = push_actions context.push_actions = push_actions
ordering = None ordering = None
@ -324,7 +284,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
) )
else: else:
ordering, _ = yield self.master_store.persist_event( ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state event, context,
) )
if ordering: if ordering:

View file

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))

View file

@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.as_list = [ self.as_list = [
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:

View file

@ -35,6 +35,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070 now = 1470174257070
json = '{ "key": "value" }' json = '{ "key": "value" }'
yield self.store.store_device(
"user", "device", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user", "device", now, json) "user", "device", now, json)
@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield self.store.store_device(
"user1", "device1", None
)
yield self.store.store_device(
"user1", "device2", None
)
yield self.store.store_device(
"user2", "device1", None
)
yield self.store.store_device(
"user2", "device2", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11') "user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(