mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-12 02:44:20 -05:00
Do not assume calls to runInteraction return Deferreds. (#8133)
This commit is contained in:
parent
12aebdfa5a
commit
76c43f086a
1
changelog.d/8133.misc
Normal file
1
changelog.d/8133.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
logger.exception("Error getting keys %s from %s", key_ids, server_name)
|
||||||
|
|
||||||
return await yieldable_gather_results(
|
await yieldable_gather_results(get_key, keys_to_fetch.items())
|
||||||
get_key, keys_to_fetch.items()
|
return results
|
||||||
).addCallback(lambda _: results)
|
|
||||||
|
|
||||||
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||||
"""
|
"""
|
||||||
@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
|
|||||||
key_ids (iterable[str]):
|
key_ids (iterable[str]):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
|
dict[str, FetchKeyResult]: map from key ID to lookup result
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyLookupError if there was a problem making the lookup
|
KeyLookupError if there was a problem making the lookup
|
||||||
|
@ -167,8 +167,10 @@ class ModuleApi(object):
|
|||||||
external_id: id on that system
|
external_id: id on that system
|
||||||
user_id: complete mxid that it is mapped to
|
user_id: complete mxid that it is mapped to
|
||||||
"""
|
"""
|
||||||
return self._store.record_user_external_id(
|
return defer.ensureDeferred(
|
||||||
auth_provider_id, remote_user_id, registered_user_id
|
self._store.record_user_external_id(
|
||||||
|
auth_provider_id, remote_user_id, registered_user_id
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_short_term_login_token(
|
def generate_short_term_login_token(
|
||||||
@ -223,7 +225,9 @@ class ModuleApi(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Deferred[object]: result of func
|
Deferred[object]: result of func
|
||||||
"""
|
"""
|
||||||
return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
|
return defer.ensureDeferred(
|
||||||
|
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
def complete_sso_login(
|
def complete_sso_login(
|
||||||
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
|
||||||
|
@ -48,8 +48,10 @@ class SpamCheckerApi(object):
|
|||||||
twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
|
twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
|
||||||
The filtered state events in the room.
|
The filtered state events in the room.
|
||||||
"""
|
"""
|
||||||
state_ids = yield self._store.get_filtered_current_state_ids(
|
state_ids = yield defer.ensureDeferred(
|
||||||
room_id=room_id, state_filter=StateFilter.from_types(types)
|
self._store.get_filtered_current_state_ids(
|
||||||
|
room_id=room_id, state_filter=StateFilter.from_types(types)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
|
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
|
||||||
return state.values()
|
return state.values()
|
||||||
|
@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore):
|
|||||||
"get_users_for_summary_by_role", _get_users_for_summary_txn
|
"get_users_for_summary_by_role", _get_users_for_summary_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_user_in_group(self, user_id, group_id):
|
async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
|
||||||
return self.db_pool.simple_select_one_onecol(
|
result = await self.db_pool.simple_select_one_onecol(
|
||||||
table="group_users",
|
table="group_users",
|
||||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||||
retcol="user_id",
|
retcol="user_id",
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
desc="is_user_in_group",
|
desc="is_user_in_group",
|
||||||
).addCallback(lambda r: bool(r))
|
)
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
def is_user_admin_in_group(self, group_id, user_id):
|
def is_user_admin_in_group(self, group_id, user_id):
|
||||||
return self.db_pool.simple_select_one_onecol(
|
return self.db_pool.simple_select_one_onecol(
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Iterable, Tuple
|
||||||
|
|
||||||
from signedjson.key import decode_verify_key_bytes
|
from signedjson.key import decode_verify_key_bytes
|
||||||
|
|
||||||
@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
|
|||||||
|
|
||||||
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
|
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
|
||||||
|
|
||||||
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
|
async def store_server_verify_keys(
|
||||||
|
self,
|
||||||
|
from_server: str,
|
||||||
|
ts_added_ms: int,
|
||||||
|
verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
|
||||||
|
) -> None:
|
||||||
"""Stores NACL verification keys for remote servers.
|
"""Stores NACL verification keys for remote servers.
|
||||||
Args:
|
Args:
|
||||||
from_server (str): Where the verification keys were looked up
|
from_server: Where the verification keys were looked up
|
||||||
ts_added_ms (int): The time to record that the key was added
|
ts_added_ms: The time to record that the key was added
|
||||||
verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
|
verify_keys:
|
||||||
keys to be stored. Each entry is a triplet of
|
keys to be stored. Each entry is a triplet of
|
||||||
(server_name, key_id, key).
|
(server_name, key_id, key).
|
||||||
"""
|
"""
|
||||||
@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
|
|||||||
# param, which is itself the 2-tuple (server_name, key_id).
|
# param, which is itself the 2-tuple (server_name, key_id).
|
||||||
invalidations.append((server_name, key_id))
|
invalidations.append((server_name, key_id))
|
||||||
|
|
||||||
def _invalidate(res):
|
await self.db_pool.runInteraction(
|
||||||
f = self._get_server_verify_key.invalidate
|
|
||||||
for i in invalidations:
|
|
||||||
f((i,))
|
|
||||||
return res
|
|
||||||
|
|
||||||
return self.db_pool.runInteraction(
|
|
||||||
"store_server_verify_keys",
|
"store_server_verify_keys",
|
||||||
self.db_pool.simple_upsert_many_txn,
|
self.db_pool.simple_upsert_many_txn,
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
|
|||||||
"verify_key",
|
"verify_key",
|
||||||
),
|
),
|
||||||
value_values=value_values,
|
value_values=value_values,
|
||||||
).addCallback(_invalidate)
|
)
|
||||||
|
|
||||||
|
invalidate = self._get_server_verify_key.invalidate
|
||||||
|
for i in invalidations:
|
||||||
|
invalidate((i,))
|
||||||
|
|
||||||
def store_server_keys_json(
|
def store_server_keys_json(
|
||||||
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
|
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
|
||||||
|
@ -13,30 +13,29 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import operator
|
|
||||||
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
|
|
||||||
class UserErasureWorkerStore(SQLBaseStore):
|
class UserErasureWorkerStore(SQLBaseStore):
|
||||||
@cached()
|
@cached()
|
||||||
def is_user_erased(self, user_id):
|
async def is_user_erased(self, user_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the given user id has requested erasure
|
Check if the given user id has requested erasure
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str): full user id to check
|
user_id: full user id to check
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[bool]: True if the user has requested erasure
|
True if the user has requested erasure
|
||||||
"""
|
"""
|
||||||
return self.db_pool.simple_select_onecol(
|
result = await self.db_pool.simple_select_onecol(
|
||||||
table="erased_users",
|
table="erased_users",
|
||||||
keyvalues={"user_id": user_id},
|
keyvalues={"user_id": user_id},
|
||||||
retcol="1",
|
retcol="1",
|
||||||
desc="is_user_erased",
|
desc="is_user_erased",
|
||||||
).addCallback(operator.truth)
|
)
|
||||||
|
return bool(result)
|
||||||
|
|
||||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||||
async def are_users_erased(self, user_ids):
|
async def are_users_erased(self, user_ids):
|
||||||
|
Loading…
Reference in New Issue
Block a user