mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Move and rename get_devices_with_keys_by_user
(#8204)
* Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`.
This commit is contained in:
parent
45e8f7726f
commit
aa07c37cf0
1
changelog.d/8204.misc
Normal file
1
changelog.d/8204.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Refactor queries for device keys and cross-signatures.
|
@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
async def on_federation_query_user_devices(self, user_id):
|
async def on_federation_query_user_devices(self, user_id):
|
||||||
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
|
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
|
||||||
|
user_id
|
||||||
|
)
|
||||||
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
|
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||||
self_signing_key = await self.store.get_e2e_cross_signing_key(
|
self_signing_key = await self.store.get_e2e_cross_signing_key(
|
||||||
user_id, "self_signing"
|
user_id, "self_signing"
|
||||||
|
@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
|||||||
"DeviceListFederationStreamChangeCache", device_list_max
|
"DeviceListFederationStreamChangeCache", device_list_max
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_device_stream_token(self) -> int:
|
||||||
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||||
if stream_name == DeviceListsStream.NAME:
|
if stream_name == DeviceListsStream.NAME:
|
||||||
self._device_list_id_gen.advance(instance_name, token)
|
self._device_list_id_gen.advance(instance_name, token)
|
||||||
|
@ -264,6 +264,9 @@ class DataStore(
|
|||||||
# Used in _generate_user_daily_visits to keep track of progress
|
# Used in _generate_user_daily_visits to keep track of progress
|
||||||
self._last_user_visit_update = self._get_start_of_day()
|
self._last_user_visit_update = self._get_start_of_day()
|
||||||
|
|
||||||
|
def get_device_stream_token(self) -> int:
|
||||||
|
return self._device_list_id_gen.get_current_token()
|
||||||
|
|
||||||
def take_presence_startup_info(self):
|
def take_presence_startup_info(self):
|
||||||
active_on_startup = self._presence_on_startup
|
active_on_startup = self._presence_on_startup
|
||||||
self._presence_on_startup = None
|
self._presence_on_startup = None
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
update included in the response), and the list of updates, where
|
update included in the response), and the list of updates, where
|
||||||
each update is a pair of EDU type and EDU contents.
|
each update is a pair of EDU type and EDU contents.
|
||||||
"""
|
"""
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
now_stream_id = self.get_device_stream_token()
|
||||||
|
|
||||||
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
||||||
destination, int(from_stream_id)
|
destination, int(from_stream_id)
|
||||||
@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def get_device_stream_token(self) -> int:
|
def get_device_stream_token(self) -> int:
|
||||||
return self._device_list_id_gen.get_current_token()
|
"""Get the current stream id from the _device_list_id_gen"""
|
||||||
|
...
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_user_devices_from_cache(
|
async def get_user_devices_from_cache(
|
||||||
@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||||||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
device["device_id"]: db_to_json(device["content"]) for device in devices
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_devices_with_keys_by_user(self, user_id: str):
|
|
||||||
"""Get all devices (with any device keys) for a user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred which resolves to (stream_id, devices)
|
|
||||||
"""
|
|
||||||
return self.db_pool.runInteraction(
|
|
||||||
"get_devices_with_keys_by_user",
|
|
||||||
self._get_devices_with_keys_by_user_txn,
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_devices_with_keys_by_user_txn(
|
|
||||||
self, txn: LoggingTransaction, user_id: str
|
|
||||||
) -> Tuple[int, List[JsonDict]]:
|
|
||||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
|
||||||
|
|
||||||
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
|
|
||||||
|
|
||||||
if devices:
|
|
||||||
user_devices = devices[user_id]
|
|
||||||
results = []
|
|
||||||
for device_id, device in user_devices.items():
|
|
||||||
result = {"device_id": device_id}
|
|
||||||
|
|
||||||
key_json = device.get("key_json", None)
|
|
||||||
if key_json:
|
|
||||||
result["keys"] = db_to_json(key_json)
|
|
||||||
|
|
||||||
if "signatures" in device:
|
|
||||||
for sig_user_id, sigs in device["signatures"].items():
|
|
||||||
result["keys"].setdefault("signatures", {}).setdefault(
|
|
||||||
sig_user_id, {}
|
|
||||||
).update(sigs)
|
|
||||||
|
|
||||||
device_display_name = device.get("device_display_name", None)
|
|
||||||
if device_display_name:
|
|
||||||
result["device_display_name"] = device_display_name
|
|
||||||
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return now_stream_id, results
|
|
||||||
|
|
||||||
return now_stream_id, []
|
|
||||||
|
|
||||||
async def get_users_whose_devices_changed(
|
async def get_users_whose_devices_changed(
|
||||||
self, from_key: str, user_ids: Iterable[str]
|
self, from_key: str, user_ids: Iterable[str]
|
||||||
) -> Set[str]:
|
) -> Set[str]:
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import abc
|
||||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection
|
|||||||
|
|
||||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||||
from synapse.storage.database import make_in_list_sql_clause
|
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
@ -33,6 +34,51 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||||
|
def get_e2e_device_keys_for_federation_query(self, user_id: str):
|
||||||
|
"""Get all devices (with any device keys) for a user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred which resolves to (stream_id, devices)
|
||||||
|
"""
|
||||||
|
return self.db_pool.runInteraction(
|
||||||
|
"get_e2e_device_keys_for_federation_query",
|
||||||
|
self._get_e2e_device_keys_for_federation_query_txn,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_e2e_device_keys_for_federation_query_txn(
|
||||||
|
self, txn: LoggingTransaction, user_id: str
|
||||||
|
) -> Tuple[int, List[JsonDict]]:
|
||||||
|
now_stream_id = self.get_device_stream_token()
|
||||||
|
|
||||||
|
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
|
||||||
|
|
||||||
|
if devices:
|
||||||
|
user_devices = devices[user_id]
|
||||||
|
results = []
|
||||||
|
for device_id, device in user_devices.items():
|
||||||
|
result = {"device_id": device_id}
|
||||||
|
|
||||||
|
key_json = device.get("key_json", None)
|
||||||
|
if key_json:
|
||||||
|
result["keys"] = db_to_json(key_json)
|
||||||
|
|
||||||
|
if "signatures" in device:
|
||||||
|
for sig_user_id, sigs in device["signatures"].items():
|
||||||
|
result["keys"].setdefault("signatures", {}).setdefault(
|
||||||
|
sig_user_id, {}
|
||||||
|
).update(sigs)
|
||||||
|
|
||||||
|
device_display_name = device.get("device_display_name", None)
|
||||||
|
if device_display_name:
|
||||||
|
result["device_display_name"] = device_display_name
|
||||||
|
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return now_stream_id, results
|
||||||
|
|
||||||
|
return now_stream_id, []
|
||||||
|
|
||||||
@trace
|
@trace
|
||||||
async def get_e2e_device_keys_for_cs_api(
|
async def get_e2e_device_keys_for_cs_api(
|
||||||
self, query_list: List[Tuple[str, Optional[str]]]
|
self, query_list: List[Tuple[str, Optional[str]]]
|
||||||
@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
|||||||
_get_all_user_signature_changes_for_remotes_txn,
|
_get_all_user_signature_changes_for_remotes_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def get_device_stream_token(self) -> int:
|
||||||
|
"""Get the current stream id from the _device_list_id_gen"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||||
|
Loading…
Reference in New Issue
Block a user