Convert simple_update* and simple_select* to async (#8173)

This commit is contained in:
Patrick Cloke 2020-08-27 07:08:38 -04:00 committed by GitHub
parent a466b67972
commit 4a739c73b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 164 additions and 133 deletions

1
changelog.d/8173.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -51,7 +51,7 @@ from synapse.types import (
create_requester, create_requester,
) )
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer, maybe_awaitable from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -1329,9 +1329,7 @@ class RoomShutdownHandler(object):
ratelimit=False, ratelimit=False,
) )
aliases_for_room = await maybe_awaitable( aliases_for_room = await self.store.get_aliases_for_room(room_id)
self.store.get_aliases_for_room(room_id)
)
await self.store.update_aliases_for_room( await self.store.update_aliases_for_room(
room_id, new_room_id, requester_user_id room_id, new_room_id, requester_user_id

View File

@ -1132,13 +1132,13 @@ class DatabasePool(object):
return [r[0] for r in txn] return [r[0] for r in txn]
def simple_select_onecol( async def simple_select_onecol(
self, self,
table: str, table: str,
keyvalues: Optional[Dict[str, Any]], keyvalues: Optional[Dict[str, Any]],
retcol: str, retcol: str,
desc: str = "simple_select_onecol", desc: str = "simple_select_onecol",
) -> defer.Deferred: ) -> List[Any]:
"""Executes a SELECT query on the named table, which returns a list """Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows. comprising of the values of the named column from the selected rows.
@ -1148,19 +1148,19 @@ class DatabasePool(object):
retcol: column whos value we wish to retrieve. retcol: column whos value we wish to retrieve.
Returns: Returns:
Deferred: Results in a list Results in a list
""" """
return self.runInteraction( return await self.runInteraction(
desc, self.simple_select_onecol_txn, table, keyvalues, retcol desc, self.simple_select_onecol_txn, table, keyvalues, retcol
) )
def simple_select_list( async def simple_select_list(
self, self,
table: str, table: str,
keyvalues: Optional[Dict[str, Any]], keyvalues: Optional[Dict[str, Any]],
retcols: Iterable[str], retcols: Iterable[str],
desc: str = "simple_select_list", desc: str = "simple_select_list",
) -> defer.Deferred: ) -> List[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which may return zero or """Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts. more rows, returning the result as a list of dicts.
@ -1170,10 +1170,11 @@ class DatabasePool(object):
column names and values to select the rows with, or None to not column names and values to select the rows with, or None to not
apply a WHERE clause. apply a WHERE clause.
retcols: the names of the columns to return retcols: the names of the columns to return
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] A list of dictionaries.
""" """
return self.runInteraction( return await self.runInteraction(
desc, self.simple_select_list_txn, table, keyvalues, retcols desc, self.simple_select_list_txn, table, keyvalues, retcols
) )
@ -1299,14 +1300,14 @@ class DatabasePool(object):
txn.execute(sql, values) txn.execute(sql, values)
return cls.cursor_to_dict(txn) return cls.cursor_to_dict(txn)
def simple_update( async def simple_update(
self, self,
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any], updatevalues: Dict[str, Any],
desc: str, desc: str,
) -> defer.Deferred: ) -> int:
return self.runInteraction( return await self.runInteraction(
desc, self.simple_update_txn, table, keyvalues, updatevalues desc, self.simple_update_txn, table, keyvalues, updatevalues
) )
@ -1332,13 +1333,13 @@ class DatabasePool(object):
return txn.rowcount return txn.rowcount
def simple_update_one( async def simple_update_one(
self, self,
table: str, table: str,
keyvalues: Dict[str, Any], keyvalues: Dict[str, Any],
updatevalues: Dict[str, Any], updatevalues: Dict[str, Any],
desc: str = "simple_update_one", desc: str = "simple_update_one",
) -> defer.Deferred: ) -> None:
"""Executes an UPDATE query on the named table, setting new values for """Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values. columns in a row matching the key values.
@ -1347,7 +1348,7 @@ class DatabasePool(object):
keyvalues: dict of column names and values to select the row with keyvalues: dict of column names and values to select the row with
updatevalues: dict giving column names and values to update updatevalues: dict giving column names and values to update
""" """
return self.runInteraction( await self.runInteraction(
desc, self.simple_update_one_txn, table, keyvalues, updatevalues desc, self.simple_update_one_txn, table, keyvalues, updatevalues
) )

View File

@ -18,6 +18,7 @@
import calendar import calendar
import logging import logging
import time import time
from typing import Any, Dict, List
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -476,14 +477,13 @@ class DataStore(
"generate_user_daily_visits", _generate_user_daily_visits "generate_user_daily_visits", _generate_user_daily_visits
) )
def get_users(self): async def get_users(self) -> List[Dict[str, Any]]:
"""Function to retrieve a list of users in users table. """Function to retrieve a list of users in users table.
Args:
Returns: Returns:
defer.Deferred: resolves to list[dict[str, Any]] A list of dictionaries representing users.
""" """
return self.db_pool.simple_select_list( return await self.db_pool.simple_select_list(
table="users", table="users",
keyvalues={}, keyvalues={},
retcols=[ retcols=[

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from collections import namedtuple from collections import namedtuple
from typing import Iterable, Optional from typing import Iterable, List, Optional
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -68,8 +68,8 @@ class DirectoryWorkerStore(SQLBaseStore):
) )
@cached(max_entries=5000) @cached(max_entries=5000)
def get_aliases_for_room(self, room_id): async def get_aliases_for_room(self, room_id: str) -> List[str]:
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
"room_aliases", "room_aliases",
{"room_id": room_id}, {"room_id": room_id},
"room_alias", "room_alias",

View File

@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -368,18 +370,22 @@ class EndToEndRoomKeyStore(SQLBaseStore):
) )
@trace @trace
def update_e2e_room_keys_version( async def update_e2e_room_keys_version(
self, user_id, version, info=None, version_etag=None self,
): user_id: str,
version: str,
info: Optional[dict] = None,
version_etag: Optional[int] = None,
) -> None:
"""Update a given backup version """Update a given backup version
Args: Args:
user_id(str): the user whose backup version we're updating user_id: the user whose backup version we're updating
version(str): the version ID of the backup version we're updating version: the version ID of the backup version we're updating
info (dict): the new backup version info to store. If None, then info: the new backup version info to store. If None, then the backup
the backup version info is not updated version info is not updated.
version_etag (Optional[int]): etag of the keys in the backup. If version_etag: etag of the keys in the backup. If None, then the etag
None, then the etag is not updated is not updated.
""" """
updatevalues = {} updatevalues = {}
@ -389,7 +395,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
updatevalues["etag"] = version_etag updatevalues["etag"] = version_etag
if updatevalues: if updatevalues:
return self.db_pool.simple_update( await self.db_pool.simple_update(
table="e2e_room_keys_versions", table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": version}, keyvalues={"user_id": user_id, "version": version},
updatevalues=updatevalues, updatevalues=updatevalues,

View File

@ -368,8 +368,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
) )
@cached(max_entries=5000, iterable=True) @cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id): async def get_latest_event_ids_in_room(self, room_id: str) -> List[str]:
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="event_forward_extremities", table="event_forward_extremities",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="event_id", retcol="event_id",

View File

@ -44,24 +44,26 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group", desc="get_group",
) )
def get_users_in_group(self, group_id, include_private=False): async def get_users_in_group(
self, group_id: str, include_private: bool = False
) -> List[Dict[str, Any]]:
# TODO: Pagination # TODO: Pagination
keyvalues = {"group_id": group_id} keyvalues = {"group_id": group_id}
if not include_private: if not include_private:
keyvalues["is_public"] = True keyvalues["is_public"] = True
return self.db_pool.simple_select_list( return await self.db_pool.simple_select_list(
table="group_users", table="group_users",
keyvalues=keyvalues, keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"), retcols=("user_id", "is_public", "is_admin"),
desc="get_users_in_group", desc="get_users_in_group",
) )
def get_invited_users_in_group(self, group_id): async def get_invited_users_in_group(self, group_id: str) -> List[str]:
# TODO: Pagination # TODO: Pagination
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="group_invites", table="group_invites",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcol="user_id", retcol="user_id",
@ -265,15 +267,14 @@ class GroupServerWorkerStore(SQLBaseStore):
return role return role
def get_local_groups_for_room(self, room_id): async def get_local_groups_for_room(self, room_id: str) -> List[str]:
"""Get all of the local group that contain a given room """Get all of the local group that contain a given room
Args: Args:
room_id (str): The ID of a room room_id: The ID of a room
Returns: Returns:
Deferred[list[str]]: A twisted.Deferred containing a list of group ids A list of group ids containing this room
containing this room
""" """
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="group_rooms", table="group_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
retcol="group_id", retcol="group_id",
@ -422,10 +423,10 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_users_membership_info_in_group", _get_users_membership_in_group_txn "get_users_membership_info_in_group", _get_users_membership_in_group_txn
) )
def get_publicised_groups_for_user(self, user_id): async def get_publicised_groups_for_user(self, user_id: str) -> List[str]:
"""Get all groups a user is publicising """Get all groups a user is publicising
""" """
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="local_group_membership", table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True}, keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id", retcol="group_id",
@ -466,8 +467,8 @@ class GroupServerWorkerStore(SQLBaseStore):
return None return None
def get_joined_groups(self, user_id): async def get_joined_groups(self, user_id: str) -> List[str]:
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="local_group_membership", table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"}, keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id", retcol="group_id",
@ -585,14 +586,14 @@ class GroupServerWorkerStore(SQLBaseStore):
class GroupServerStore(GroupServerWorkerStore): class GroupServerStore(GroupServerWorkerStore):
def set_group_join_policy(self, group_id, join_policy): async def set_group_join_policy(self, group_id: str, join_policy: str) -> None:
"""Set the join policy of a group. """Set the join policy of a group.
join_policy can be one of: join_policy can be one of:
* "invite" * "invite"
* "open" * "open"
""" """
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy}, updatevalues={"join_policy": join_policy},
@ -1050,8 +1051,10 @@ class GroupServerStore(GroupServerWorkerStore):
desc="add_room_to_group", desc="add_room_to_group",
) )
def update_room_in_group_visibility(self, group_id, room_id, is_public): async def update_room_in_group_visibility(
return self.db_pool.simple_update( self, group_id: str, room_id: str, is_public: bool
) -> int:
return await self.db_pool.simple_update(
table="group_rooms", table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id}, keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public}, updatevalues={"is_public": is_public},
@ -1076,10 +1079,12 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_room_from_group", _remove_room_from_group_txn "remove_room_from_group", _remove_room_from_group_txn
) )
def update_group_publicity(self, group_id, user_id, publicise): async def update_group_publicity(
self, group_id: str, user_id: str, publicise: bool
) -> None:
"""Update whether the user is publicising their membership of the group """Update whether the user is publicising their membership of the group
""" """
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="local_group_membership", table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise}, updatevalues={"is_publicised": publicise},
@ -1218,20 +1223,24 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_group_profile", desc="update_group_profile",
) )
def update_attestation_renewal(self, group_id, user_id, attestation): async def update_attestation_renewal(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that we have renewed """Update an attestation that we have renewed
""" """
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="group_attestations_renewals", table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]}, updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
desc="update_attestation_renewal", desc="update_attestation_renewal",
) )
def update_remote_attestion(self, group_id, user_id, attestation): async def update_remote_attestion(
self, group_id: str, user_id: str, attestation: dict
) -> None:
"""Update an attestation that a remote has renewed """Update an attestation that a remote has renewed
""" """
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={ updatevalues={

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -84,9 +84,9 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_media", desc="store_local_media",
) )
def mark_local_media_as_safe(self, media_id: str): async def mark_local_media_as_safe(self, media_id: str) -> None:
"""Mark a local media as safe from quarantining.""" """Mark a local media as safe from quarantining."""
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="local_media_repository", table="local_media_repository",
keyvalues={"media_id": media_id}, keyvalues={"media_id": media_id},
updatevalues={"safe_from_quarantine": True}, updatevalues={"safe_from_quarantine": True},
@ -158,8 +158,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_url_cache", desc="store_url_cache",
) )
def get_local_media_thumbnails(self, media_id): async def get_local_media_thumbnails(self, media_id: str) -> List[Dict[str, Any]]:
return self.db_pool.simple_select_list( return await self.db_pool.simple_select_list(
"local_media_repository_thumbnails", "local_media_repository_thumbnails",
{"media_id": media_id}, {"media_id": media_id},
( (
@ -271,8 +271,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"update_cached_last_access_time", update_cache_txn "update_cached_last_access_time", update_cache_txn
) )
def get_remote_media_thumbnails(self, origin, media_id): async def get_remote_media_thumbnails(
return self.db_pool.simple_select_list( self, origin: str, media_id: str
) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
"remote_media_cache_thumbnails", "remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id}, {"media_origin": origin, "media_id": media_id},
( (

View File

@ -71,16 +71,20 @@ class ProfileWorkerStore(SQLBaseStore):
table="profiles", values={"user_id": user_localpart}, desc="create_profile" table="profiles", values={"user_id": user_localpart}, desc="create_profile"
) )
def set_profile_displayname(self, user_localpart, new_displayname): async def set_profile_displayname(
return self.db_pool.simple_update_one( self, user_localpart: str, new_displayname: str
) -> None:
await self.db_pool.simple_update_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname}, updatevalues={"displayname": new_displayname},
desc="set_profile_displayname", desc="set_profile_displayname",
) )
def set_profile_avatar_url(self, user_localpart, new_avatar_url): async def set_profile_avatar_url(
return self.db_pool.simple_update_one( self, user_localpart: str, new_avatar_url: str
) -> None:
await self.db_pool.simple_update_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url}, updatevalues={"avatar_url": new_avatar_url},
@ -106,8 +110,10 @@ class ProfileStore(ProfileWorkerStore):
desc="add_remote_profile_cache", desc="add_remote_profile_cache",
) )
def update_remote_profile_cache(self, user_id, displayname, avatar_url): async def update_remote_profile_cache(
return self.db_pool.simple_update( self, user_id: str, displayname: str, avatar_url: str
) -> int:
return await self.db_pool.simple_update(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
updatevalues={ updatevalues={

View File

@ -16,7 +16,7 @@
import abc import abc
import logging import logging
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -62,8 +62,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {r["user_id"] for r in receipts} return {r["user_id"] for r in receipts}
@cached(num_args=2) @cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type): async def get_receipts_for_room(
return self.db_pool.simple_select_list( self, room_id: str, receipt_type: str
) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type}, keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"), retcols=("user_id", "event_id"),

View File

@ -578,20 +578,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid", desc="add_user_bound_threepid",
) )
def user_get_bound_threepids(self, user_id): async def user_get_bound_threepids(self, user_id: str) -> List[Dict[str, Any]]:
"""Get the threepids that a user has bound to an identity server through the homeserver """Get the threepids that a user has bound to an identity server through the homeserver
The homeserver remembers where binds to an identity server occurred. Using this The homeserver remembers where binds to an identity server occurred. Using this
method can retrieve those threepids. method can retrieve those threepids.
Args: Args:
user_id (str): The ID of the user to retrieve threepids for user_id: The ID of the user to retrieve threepids for
Returns: Returns:
Deferred[list[dict]]: List of dictionaries containing the following: List of dictionaries containing the following keys:
medium (str): The medium of the threepid (e.g "email") medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com") address (str): The address of the threepid (e.g "bob@example.com")
""" """
return self.db_pool.simple_select_list( return await self.db_pool.simple_select_list(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=["medium", "address"], retcols=["medium", "address"],
@ -623,19 +623,21 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="remove_user_bound_threepid", desc="remove_user_bound_threepid",
) )
def get_id_servers_user_bound(self, user_id, medium, address): async def get_id_servers_user_bound(
self, user_id: str, medium: str, address: str
) -> List[str]:
"""Get the list of identity servers that the server proxied bind """Get the list of identity servers that the server proxied bind
requests to for given user and threepid requests to for given user and threepid
Args: Args:
user_id (str) user_id: The user to query for identity servers.
medium (str) medium: The medium to query for identity servers.
address (str) address: The address to query for identity servers.
Returns: Returns:
Deferred[list[str]]: Resolves to a list of identity servers A list of identity servers
""" """
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="user_threepid_id_server", table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address}, keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server", retcol="id_server",

View File

@ -125,8 +125,8 @@ class RoomWorkerStore(SQLBaseStore):
"get_room_with_stats", get_room_with_stats_txn, room_id "get_room_with_stats", get_room_with_stats_txn, room_id
) )
def get_public_room_ids(self): async def get_public_room_ids(self) -> List[str]:
return self.db_pool.simple_select_onecol( return await self.db_pool.simple_select_onecol(
table="rooms", table="rooms",
keyvalues={"is_public": True}, keyvalues={"is_public": True},
retcol="room_id", retcol="room_id",

View File

@ -537,8 +537,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory", desc="get_user_in_directory",
) )
def update_user_directory_stream_pos(self, stream_id): async def update_user_directory_stream_pos(self, stream_id: str) -> None:
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="user_directory_stream_pos", table="user_directory_stream_pos",
keyvalues={}, keyvalues={},
updatevalues={"stream_id": stream_id}, updatevalues={"stream_id": stream_id},

View File

@ -81,8 +81,8 @@ class StatsRoomTests(unittest.HomeserverTestCase):
) )
) )
def get_all_room_state(self): async def get_all_room_state(self):
return self.store.db_pool.simple_select_list( return await self.store.db_pool.simple_select_list(
"room_stats_state", None, retcols=("name", "topic", "canonical_alias") "room_stats_state", None, retcols=("name", "topic", "canonical_alias")
) )

View File

@ -148,8 +148,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)])) self.mock_txn.__iter__ = Mock(return_value=iter([(1,), (2,), (3,)]))
self.mock_txn.description = (("colA", None, None, None, None, None, None),) self.mock_txn.description = (("colA", None, None, None, None, None, None),)
ret = yield self.datastore.db_pool.simple_select_list( ret = yield defer.ensureDeferred(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"] self.datastore.db_pool.simple_select_list(
table="tablename", keyvalues={"keycol": "A set"}, retcols=["colA"]
)
) )
self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret) self.assertEquals([{"colA": 1}, {"colA": 2}, {"colA": 3}], ret)
@ -161,10 +163,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_1col(self): def test_update_one_1col(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_update_one( yield defer.ensureDeferred(
table="tablename", self.datastore.db_pool.simple_update_one(
keyvalues={"keycol": "TheKey"}, table="tablename",
updatevalues={"columnname": "New Value"}, keyvalues={"keycol": "TheKey"},
updatevalues={"columnname": "New Value"},
)
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(
@ -176,10 +180,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
def test_update_one_4cols(self): def test_update_one_4cols(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1
yield self.datastore.db_pool.simple_update_one( yield defer.ensureDeferred(
table="tablename", self.datastore.db_pool.simple_update_one(
keyvalues=OrderedDict([("colA", 1), ("colB", 2)]), table="tablename",
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]), keyvalues=OrderedDict([("colA", 1), ("colB", 2)]),
updatevalues=OrderedDict([("colC", 3), ("colD", 4)]),
)
) )
self.mock_txn.execute.assert_called_with( self.mock_txn.execute.assert_called_with(

View File

@ -42,7 +42,11 @@ class DirectoryStoreTestCase(unittest.TestCase):
self.assertEquals( self.assertEquals(
["#my-room:test"], ["#my-room:test"],
(yield self.store.get_aliases_for_room(self.room.to_string())), (
yield defer.ensureDeferred(
self.store.get_aliases_for_room(self.room.to_string())
)
),
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -36,7 +36,9 @@ class DataStoreTestCase(unittest.TestCase):
def test_get_users_paginate(self): def test_get_users_paginate(self):
yield self.store.register_user(self.user.to_string(), "pass") yield self.store.register_user(self.user.to_string(), "pass")
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart)) yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
yield self.store.set_profile_displayname(self.user.localpart, self.displayname) yield defer.ensureDeferred(
self.store.set_profile_displayname(self.user.localpart, self.displayname)
)
users, total = yield self.store.get_users_paginate( users, total = yield self.store.get_users_paginate(
0, 10, name="bc", guests=False 0, 10, name="bc", guests=False

View File

@ -15,8 +15,9 @@
from mock import Mock from mock import Mock
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed from twisted.internet.defer import succeed
from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID from synapse.types import Requester, UserID
@ -44,22 +45,17 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
user_id = UserID("us", "test") user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, False, None, None) our_user = Requester(user_id, None, False, False, None, None)
room_creator = self.homeserver.get_room_creation_handler() room_creator = self.homeserver.get_room_creation_handler()
room_deferred = ensureDeferred( self.room_id = self.get_success(
room_creator.create_room( room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False our_user, room_creator._presets_dict["public_chat"], ratelimit=False
) )
) )[0]["room_id"]
self.reactor.advance(0.1)
self.room_id = self.successResultOf(room_deferred)[0]["room_id"]
self.store = self.homeserver.get_datastore() self.store = self.homeserver.get_datastore()
# Figure out what the most recent event is # Figure out what the most recent event is
most_recent = self.successResultOf( most_recent = self.get_success(
maybeDeferred( self.homeserver.get_datastore().get_latest_event_ids_in_room(self.room_id)
self.homeserver.get_datastore().get_latest_event_ids_in_room,
self.room_id,
)
)[0] )[0]
join_event = make_event_from_dict( join_event = make_event_from_dict(
@ -89,19 +85,18 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
# Send the join, it should return None (which is not an error) # Send the join, it should return None (which is not an error)
d = ensureDeferred( self.assertEqual(
self.handler.on_receive_pdu( self.get_success(
"test.serv", join_event, sent_to_us_directly=True self.handler.on_receive_pdu(
) "test.serv", join_event, sent_to_us_directly=True
)
),
None,
) )
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
# Make sure we actually joined the room # Make sure we actually joined the room
self.assertEqual( self.assertEqual(
self.successResultOf( self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id)
)[0],
"$join:test.serv", "$join:test.serv",
) )
@ -119,8 +114,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
self.http_client.post_json = post_json self.http_client.post_json = post_json
# Figure out what the most recent event is # Figure out what the most recent event is
most_recent = self.successResultOf( most_recent = self.get_success(
maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) self.store.get_latest_event_ids_in_room(self.room_id)
)[0] )[0]
# Now lie about an event # Now lie about an event
@ -140,17 +135,14 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
with LoggingContext(request="lying_event"): with LoggingContext(request="lying_event"):
d = ensureDeferred( failure = self.get_failure(
self.handler.on_receive_pdu( self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True "test.serv", lying_event, sent_to_us_directly=True
) ),
FederationError,
) )
# Step the reactor, so the database fetches come back
self.reactor.advance(1)
# on_receive_pdu should throw an error # on_receive_pdu should throw an error
failure = self.failureResultOf(d)
self.assertEqual( self.assertEqual(
failure.value.args[0], failure.value.args[0],
( (
@ -160,8 +152,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
# Make sure the invalid event isn't there # Make sure the invalid event isn't there
extrem = maybeDeferred(self.store.get_latest_event_ids_in_room, self.room_id) extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") self.assertEqual(extrem[0], "$join:test.serv")
def test_retry_device_list_resync(self): def test_retry_device_list_resync(self):
"""Tests that device lists are marked as stale if they couldn't be synced, and """Tests that device lists are marked as stale if they couldn't be synced, and