make user signatures a separate stream

This commit is contained in:
Hubert Chathi 2019-10-30 17:22:52 -04:00
parent 670972c0e1
commit 998f7fe7d4
5 changed files with 50 additions and 14 deletions

View File

@ -42,7 +42,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def stream_positions(self): def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions() result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token() result["user_signature"] = result[
"device_lists"
] = self._device_list_id_gen.get_current_token()
return result return result
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
@ -50,13 +52,15 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._device_list_id_gen.advance(token) self._device_list_id_gen.advance(token)
for row in rows: for row in rows:
self._invalidate_caches_for_devices(token, row.user_id, row.destination) self._invalidate_caches_for_devices(token, row.user_id, row.destination)
elif stream_name == "user_signature":
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows( return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows stream_name, token, rows
) )
def _invalidate_caches_for_devices(self, token, user_id, destination): def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(user_id, token) self._device_list_stream_cache.entity_has_changed(user_id, token)
self._user_signature_stream_cache.entity_has_changed(user_id, token)
if destination: if destination:
self._device_list_federation_stream_cache.entity_has_changed( self._device_list_federation_stream_cache.entity_has_changed(

View File

@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream, _base.TagAccountDataStream,
_base.AccountDataStream, _base.AccountDataStream,
_base.GroupServerStream, _base.GroupServerStream,
_base.UserSignatureStream,
) )
} }

View File

@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
"GroupsStreamRow", "GroupsStreamRow",
("group_id", "user_id", "type", "content"), # str # str # str # dict ("group_id", "user_id", "type", "content"), # str # str # str # dict
) )
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object): class Stream(object):
@ -438,3 +439,20 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs) super(GroupServerStream, self).__init__(hs)
class UserSignatureStream(Stream):
"""A user has signed their own device with their user-signing key
"""
NAME = "user_signature"
_LIMITED = False
ROW_TYPE = UserSignatureStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_device_stream_token
self.update_function = store.get_all_user_signature_changes_for_remotes
super(UserSignatureStream, self).__init__(hs)

View File

@ -543,20 +543,9 @@ class DeviceWorkerStore(SQLBaseStore):
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination GROUP BY user_id, destination
UNION
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id, NULL AS destination
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
""" """
return self._execute( return self._execute(
"get_all_device_list_changes_for_remotes", "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
None,
sql,
from_key,
to_key,
from_key,
to_key,
) )
@cached(max_entries=10000) @cached(max_entries=10000)

View File

@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id, from_user_id,
) )
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
from_key (int): the stream ID to start at (exclusive)
to_key (int): the stream ID to end at (inclusive)
Returns:
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
"""
return self._execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):