Move and rename get_devices_with_keys_by_user (#8204)

* Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore`

this seems a better fit for it.

This commit simply moves the existing code: no other changes at all.

* Rename `get_devices_with_keys_by_user`

to better reflect what it does.

* get_device_stream_token abstract method

To avoid referencing fields which are declared in the derived classes, make
`get_device_stream_token` abstract, and define that in the classes which define
`_device_list_id_gen`.
This commit is contained in:
Richard van der Hoff 2020-09-01 12:41:21 +01:00 committed by GitHub
parent 45e8f7726f
commit aa07c37cf0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 67 additions and 49 deletions

View file

@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
@ -33,6 +34,51 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str):
"""Get all devices (with any device keys) for a user
Returns:
Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
)
def _get_e2e_device_keys_for_federation_query_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self.get_device_stream_token()
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
@trace
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
"""Get the current stream id from the _device_list_id_gen"""
...
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):