Do not assume calls to runInteraction return Deferreds. (#8133)

This commit is contained in:
Patrick Cloke 2020-08-20 06:39:55 -04:00 committed by GitHub
parent 12aebdfa5a
commit 76c43f086a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 41 additions and 31 deletions

1
changelog.d/8133.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception: except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name) logger.exception("Error getting keys %s from %s", key_ids, server_name)
return await yieldable_gather_results( await yieldable_gather_results(get_key, keys_to_fetch.items())
get_key, keys_to_fetch.items() return results
).addCallback(lambda _: results)
async def get_server_verify_key_v2_direct(self, server_name, key_ids): async def get_server_verify_key_v2_direct(self, server_name, key_ids):
""" """
@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
key_ids (iterable[str]): key_ids (iterable[str]):
Returns: Returns:
Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result dict[str, FetchKeyResult]: map from key ID to lookup result
Raises: Raises:
KeyLookupError if there was a problem making the lookup KeyLookupError if there was a problem making the lookup

View File

@ -167,8 +167,10 @@ class ModuleApi(object):
external_id: id on that system external_id: id on that system
user_id: complete mxid that it is mapped to user_id: complete mxid that it is mapped to
""" """
return self._store.record_user_external_id( return defer.ensureDeferred(
auth_provider_id, remote_user_id, registered_user_id self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
) )
def generate_short_term_login_token( def generate_short_term_login_token(
@ -223,7 +225,9 @@ class ModuleApi(object):
Returns: Returns:
Deferred[object]: result of func Deferred[object]: result of func
""" """
return self._store.db_pool.runInteraction(desc, func, *args, **kwargs) return defer.ensureDeferred(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
)
def complete_sso_login( def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str

View File

@ -48,8 +48,10 @@ class SpamCheckerApi(object):
twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]: twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
The filtered state events in the room. The filtered state events in the room.
""" """
state_ids = yield self._store.get_filtered_current_state_ids( state_ids = yield defer.ensureDeferred(
room_id=room_id, state_filter=StateFilter.from_types(types) self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
) )
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values())) state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values() return state.values()

View File

@ -341,14 +341,15 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_users_for_summary_by_role", _get_users_for_summary_txn "get_users_for_summary_by_role", _get_users_for_summary_txn
) )
def is_user_in_group(self, user_id, group_id): async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
return self.db_pool.simple_select_one_onecol( result = await self.db_pool.simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id", retcol="user_id",
allow_none=True, allow_none=True,
desc="is_user_in_group", desc="is_user_in_group",
).addCallback(lambda r: bool(r)) )
return bool(result)
def is_user_admin_in_group(self, group_id, user_id): def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol( return self.db_pool.simple_select_one_onecol(

View File

@ -16,6 +16,7 @@
import itertools import itertools
import logging import logging
from typing import Iterable, Tuple
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -88,12 +89,17 @@ class KeyStore(SQLBaseStore):
return self.db_pool.runInteraction("get_server_verify_keys", _txn) return self.db_pool.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys): async def store_server_verify_keys(
self,
from_server: str,
ts_added_ms: int,
verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
) -> None:
"""Stores NACL verification keys for remote servers. """Stores NACL verification keys for remote servers.
Args: Args:
from_server (str): Where the verification keys were looked up from_server: Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added ts_added_ms: The time to record that the key was added
verify_keys (iterable[tuple[str, str, FetchKeyResult]]): verify_keys:
keys to be stored. Each entry is a triplet of keys to be stored. Each entry is a triplet of
(server_name, key_id, key). (server_name, key_id, key).
""" """
@ -115,13 +121,7 @@ class KeyStore(SQLBaseStore):
# param, which is itself the 2-tuple (server_name, key_id). # param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id)) invalidations.append((server_name, key_id))
def _invalidate(res): await self.db_pool.runInteraction(
f = self._get_server_verify_key.invalidate
for i in invalidations:
f((i,))
return res
return self.db_pool.runInteraction(
"store_server_verify_keys", "store_server_verify_keys",
self.db_pool.simple_upsert_many_txn, self.db_pool.simple_upsert_many_txn,
table="server_signature_keys", table="server_signature_keys",
@ -134,7 +134,11 @@ class KeyStore(SQLBaseStore):
"verify_key", "verify_key",
), ),
value_values=value_values, value_values=value_values,
).addCallback(_invalidate) )
invalidate = self._get_server_verify_key.invalidate
for i in invalidations:
invalidate((i,))
def store_server_keys_json( def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes

View File

@ -13,30 +13,29 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import operator
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
class UserErasureWorkerStore(SQLBaseStore): class UserErasureWorkerStore(SQLBaseStore):
@cached() @cached()
def is_user_erased(self, user_id): async def is_user_erased(self, user_id: str) -> bool:
""" """
Check if the given user id has requested erasure Check if the given user id has requested erasure
Args: Args:
user_id (str): full user id to check user_id: full user id to check
Returns: Returns:
Deferred[bool]: True if the user has requested erasure True if the user has requested erasure
""" """
return self.db_pool.simple_select_onecol( result = await self.db_pool.simple_select_onecol(
table="erased_users", table="erased_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="1", retcol="1",
desc="is_user_erased", desc="is_user_erased",
).addCallback(operator.truth) )
return bool(result)
@cachedList(cached_method_name="is_user_erased", list_name="user_ids") @cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids): async def are_users_erased(self, user_ids):