mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-26 10:59:22 -05:00
Merge pull request #6254 from matrix-org/uhoreg/cross_signing_fix_workers_notify
make notification of signatures work with workers
This commit is contained in:
commit
3b4216f961
1
changelog.d/6254.bugfix
Normal file
1
changelog.d/6254.bugfix
Normal file
@ -0,0 +1 @@
|
|||||||
|
Make notification of cross-signing signatures work with workers.
|
@ -15,6 +15,7 @@
|
|||||||
|
|
||||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
|
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
|
||||||
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
|
from synapse.storage.data_stores.main.devices import DeviceWorkerStore
|
||||||
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
|
from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
@ -42,14 +43,22 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
result = super(SlavedDeviceStore, self).stream_positions()
|
result = super(SlavedDeviceStore, self).stream_positions()
|
||||||
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
# The user signature stream uses the same stream ID generator as the
|
||||||
|
# device list stream, so set them both to the device list ID
|
||||||
|
# generator's current token.
|
||||||
|
current_token = self._device_list_id_gen.get_current_token()
|
||||||
|
result[DeviceListsStream.NAME] = current_token
|
||||||
|
result[UserSignatureStream.NAME] = current_token
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, token, rows):
|
def process_replication_rows(self, stream_name, token, rows):
|
||||||
if stream_name == "device_lists":
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(token)
|
self._device_list_id_gen.advance(token)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
|
self._invalidate_caches_for_devices(token, row.user_id, row.destination)
|
||||||
|
elif stream_name == UserSignatureStream.NAME:
|
||||||
|
for row in rows:
|
||||||
|
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||||
stream_name, token, rows
|
stream_name, token, rows
|
||||||
)
|
)
|
||||||
|
@ -45,5 +45,6 @@ STREAMS_MAP = {
|
|||||||
_base.TagAccountDataStream,
|
_base.TagAccountDataStream,
|
||||||
_base.AccountDataStream,
|
_base.AccountDataStream,
|
||||||
_base.GroupServerStream,
|
_base.GroupServerStream,
|
||||||
|
_base.UserSignatureStream,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -95,6 +95,7 @@ GroupsStreamRow = namedtuple(
|
|||||||
"GroupsStreamRow",
|
"GroupsStreamRow",
|
||||||
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
("group_id", "user_id", "type", "content"), # str # str # str # dict
|
||||||
)
|
)
|
||||||
|
UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
|
||||||
|
|
||||||
|
|
||||||
class Stream(object):
|
class Stream(object):
|
||||||
@ -438,3 +439,20 @@ class GroupServerStream(Stream):
|
|||||||
self.update_function = store.get_all_groups_changes
|
self.update_function = store.get_all_groups_changes
|
||||||
|
|
||||||
super(GroupServerStream, self).__init__(hs)
|
super(GroupServerStream, self).__init__(hs)
|
||||||
|
|
||||||
|
|
||||||
|
class UserSignatureStream(Stream):
|
||||||
|
"""A user has signed their own device with their user-signing key
|
||||||
|
"""
|
||||||
|
|
||||||
|
NAME = "user_signature"
|
||||||
|
_LIMITED = False
|
||||||
|
ROW_TYPE = UserSignatureStreamRow
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.current_token = store.get_device_stream_token
|
||||||
|
self.update_function = store.get_all_user_signature_changes_for_remotes
|
||||||
|
|
||||||
|
super(UserSignatureStream, self).__init__(hs)
|
||||||
|
@ -139,7 +139,10 @@ class DataStore(
|
|||||||
db_conn, "public_room_list_stream", "stream_id"
|
db_conn, "public_room_list_stream", "stream_id"
|
||||||
)
|
)
|
||||||
self._device_list_id_gen = StreamIdGenerator(
|
self._device_list_id_gen = StreamIdGenerator(
|
||||||
db_conn, "device_lists_stream", "stream_id"
|
db_conn,
|
||||||
|
"device_lists_stream",
|
||||||
|
"stream_id",
|
||||||
|
extra_tables=[("user_signature_stream", "stream_id")],
|
||||||
)
|
)
|
||||||
self._cross_signing_id_gen = StreamIdGenerator(
|
self._cross_signing_id_gen = StreamIdGenerator(
|
||||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||||
|
@ -315,6 +315,30 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
from_user_id,
|
from_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
|
||||||
|
"""Return a list of changes from the user signature stream to notify remotes.
|
||||||
|
Note that the user signature stream represents when a user signs their
|
||||||
|
device with their user-signing key, which is not published to other
|
||||||
|
users or servers, so no `destination` is needed in the returned
|
||||||
|
list. However, this is needed to poke workers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_key (int): the stream ID to start at (exclusive)
|
||||||
|
to_key (int): the stream ID to end at (inclusive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
|
||||||
|
FROM user_signature_stream
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
GROUP BY user_id
|
||||||
|
"""
|
||||||
|
return self._execute(
|
||||||
|
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||||
|
Loading…
Reference in New Issue
Block a user