Convert appservice, group server, profile and more databases to async (#8066)

This commit is contained in:
Patrick Cloke 2020-08-12 09:28:48 -04:00 committed by GitHub
parent 9d1e4942ab
commit a3a59bab7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 116 deletions

1
changelog.d/8066.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -18,8 +18,6 @@ import re
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.appservice import AppServiceTransaction from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
@ -124,17 +122,15 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
class ApplicationServiceTransactionWorkerStore( class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore ApplicationServiceWorkerStore, EventsWorkerStore
): ):
@defer.inlineCallbacks async def get_appservices_by_state(self, state):
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state. """Get a list of application services based on their state.
Args: Args:
state(ApplicationServiceState): The state to filter on. state(ApplicationServiceState): The state to filter on.
Returns: Returns:
A Deferred which resolves to a list of ApplicationServices, which A list of ApplicationServices, which may be empty.
may be empty.
""" """
results = yield self.db_pool.simple_select_list( results = await self.db_pool.simple_select_list(
"application_services_state", {"state": state}, ["as_id"] "application_services_state", {"state": state}, ["as_id"]
) )
# NB: This assumes this class is linked with ApplicationServiceStore # NB: This assumes this class is linked with ApplicationServiceStore
@ -147,16 +143,15 @@ class ApplicationServiceTransactionWorkerStore(
services.append(service) services.append(service)
return services return services
@defer.inlineCallbacks async def get_appservice_state(self, service):
def get_appservice_state(self, service):
"""Get the application service state. """Get the application service state.
Args: Args:
service(ApplicationService): The service whose state to set. service(ApplicationService): The service whose state to set.
Returns: Returns:
A Deferred which resolves to ApplicationServiceState. An ApplicationServiceState.
""" """
result = yield self.db_pool.simple_select_one( result = await self.db_pool.simple_select_one(
"application_services_state", "application_services_state",
{"as_id": service.id}, {"as_id": service.id},
["state"], ["state"],
@ -270,16 +265,14 @@ class ApplicationServiceTransactionWorkerStore(
"complete_appservice_txn", _complete_appservice_txn "complete_appservice_txn", _complete_appservice_txn
) )
@defer.inlineCallbacks async def get_oldest_unsent_txn(self, service):
def get_oldest_unsent_txn(self, service):
"""Get the oldest transaction which has not been sent for this """Get the oldest transaction which has not been sent for this
service. service.
Args: Args:
service(ApplicationService): The app service to get the oldest txn. service(ApplicationService): The app service to get the oldest txn.
Returns: Returns:
A Deferred which resolves to an AppServiceTransaction or An AppServiceTransaction or None.
None.
""" """
def _get_oldest_unsent_txn(txn): def _get_oldest_unsent_txn(txn):
@ -298,7 +291,7 @@ class ApplicationServiceTransactionWorkerStore(
return entry return entry
entry = yield self.db_pool.runInteraction( entry = await self.db_pool.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn "get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
) )
@ -307,7 +300,7 @@ class ApplicationServiceTransactionWorkerStore(
event_ids = db_to_json(entry["event_ids"]) event_ids = db_to_json(entry["event_ids"])
events = yield self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
return AppServiceTransaction(service=service, id=entry["txn_id"], events=events) return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@ -332,8 +325,7 @@ class ApplicationServiceTransactionWorkerStore(
"set_appservice_last_pos", set_appservice_last_pos_txn "set_appservice_last_pos", set_appservice_last_pos_txn
) )
@defer.inlineCallbacks async def get_new_events_for_appservice(self, current_id, limit):
def get_new_events_for_appservice(self, current_id, limit):
"""Get all new evnets""" """Get all new evnets"""
def get_new_events_for_appservice_txn(txn): def get_new_events_for_appservice_txn(txn):
@ -357,11 +349,11 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows] return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.db_pool.runInteraction( upper_bound, event_ids = await self.db_pool.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn "get_new_events_for_appservice", get_new_events_for_appservice_txn
) )
events = yield self.get_events_as_list(event_ids) events = await self.get_events_as_list(event_ids)
return upper_bound, events return upper_bound, events

View File

@ -17,12 +17,12 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore): class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2) @cached(num_args=2)
def get_user_filter(self, user_localpart, filter_id): async def get_user_filter(self, user_localpart, filter_id):
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN. # with a coherent error message rather than 500 M_UNKNOWN.
try: try:
@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError: except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM) raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
def_json = yield self.db_pool.simple_select_one_onecol( def_json = await self.db_pool.simple_select_one_onecol(
table="user_filters", table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id}, keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json", retcol="filter_json",

View File

@ -14,12 +14,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Tuple from typing import List, Optional, Tuple
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
# The category ID for the "default" category. We don't store as null in the # The category ID for the "default" category. We don't store as null in the
@ -210,9 +209,8 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_rooms_for_summary", _get_rooms_for_summary_txn "get_rooms_for_summary", _get_rooms_for_summary_txn
) )
@defer.inlineCallbacks async def get_group_categories(self, group_id):
def get_group_categories(self, group_id): rows = await self.db_pool.simple_select_list(
rows = yield self.db_pool.simple_select_list(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"), retcols=("category_id", "is_public", "profile"),
@ -227,9 +225,8 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
@defer.inlineCallbacks async def get_group_category(self, group_id, category_id):
def get_group_category(self, group_id, category_id): category = await self.db_pool.simple_select_one(
category = yield self.db_pool.simple_select_one(
table="group_room_categories", table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id}, keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"), retcols=("is_public", "profile"),
@ -240,9 +237,8 @@ class GroupServerWorkerStore(SQLBaseStore):
return category return category
@defer.inlineCallbacks async def get_group_roles(self, group_id):
def get_group_roles(self, group_id): rows = await self.db_pool.simple_select_list(
rows = yield self.db_pool.simple_select_list(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"), retcols=("role_id", "is_public", "profile"),
@ -257,9 +253,8 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
@defer.inlineCallbacks async def get_group_role(self, group_id, role_id):
def get_group_role(self, group_id, role_id): role = await self.db_pool.simple_select_one(
role = yield self.db_pool.simple_select_one(
table="group_roles", table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id}, keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"), retcols=("is_public", "profile"),
@ -448,12 +443,11 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_attestations_need_renewals", _get_attestations_need_renewals_txn "get_attestations_need_renewals", _get_attestations_need_renewals_txn
) )
@defer.inlineCallbacks async def get_remote_attestation(self, group_id, user_id):
def get_remote_attestation(self, group_id, user_id):
"""Get the attestation that proves the remote agrees that the user is """Get the attestation that proves the remote agrees that the user is
in the group. in the group.
""" """
row = yield self.db_pool.simple_select_one( row = await self.db_pool.simple_select_one(
table="group_attestations_remote", table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id}, keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"), retcols=("valid_until_ms", "attestation_json"),
@ -499,13 +493,13 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_all_groups_for_user", _get_all_groups_for_user_txn "get_all_groups_for_user", _get_all_groups_for_user_txn
) )
def get_groups_changes_for_user(self, user_id, from_token, to_token): async def get_groups_changes_for_user(self, user_id, from_token, to_token):
from_token = int(from_token) from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_entity_changed( has_changed = self._group_updates_stream_cache.has_entity_changed(
user_id, from_token user_id, from_token
) )
if not has_changed: if not has_changed:
return defer.succeed([]) return []
def _get_groups_changes_for_user_txn(txn): def _get_groups_changes_for_user_txn(txn):
sql = """ sql = """
@ -525,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for group_id, membership, gtype, content_json in txn for group_id, membership, gtype, content_json in txn
] ]
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_groups_changes_for_user", _get_groups_changes_for_user_txn "get_groups_changes_for_user", _get_groups_changes_for_user_txn
) )
@ -1087,31 +1081,31 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_group_publicity", desc="update_group_publicity",
) )
@defer.inlineCallbacks async def register_user_group_membership(
def register_user_group_membership(
self, self,
group_id, group_id: str,
user_id, user_id: str,
membership, membership: str,
is_admin=False, is_admin: bool = False,
content={}, content: JsonDict = {},
local_attestation=None, local_attestation: Optional[dict] = None,
remote_attestation=None, remote_attestation: Optional[dict] = None,
is_publicised=False, is_publicised: bool = False,
): ) -> int:
"""Registers that a local user is a member of a (local or remote) group. """Registers that a local user is a member of a (local or remote) group.
Args: Args:
group_id (str) group_id: The group the member is being added to.
user_id (str) user_id: THe user ID to add to the group.
membership (str) membership: The type of group membership.
is_admin (bool) is_admin: Whether the user should be added as a group admin.
content (dict): Content of the membership, e.g. includes the inviter content: Content of the membership, e.g. includes the inviter
if the user has been invited. if the user has been invited.
local_attestation (dict): If remote group then store the fact that we local_attestation: If remote group then store the fact that we
have given out an attestation, else None. have given out an attestation, else None.
remote_attestation (dict): If remote group then store the remote remote_attestation: If remote group then store the remote
attestation from the group, else None. attestation from the group, else None.
is_publicised: Whether this should be publicised.
""" """
def _register_user_group_membership_txn(txn, next_id): def _register_user_group_membership_txn(txn, next_id):
@ -1188,18 +1182,17 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id return next_id
with self._group_updates_id_gen.get_next() as next_id: with self._group_updates_id_gen.get_next() as next_id:
res = yield self.db_pool.runInteraction( res = await self.db_pool.runInteraction(
"register_user_group_membership", "register_user_group_membership",
_register_user_group_membership_txn, _register_user_group_membership_txn,
next_id, next_id,
) )
return res return res
@defer.inlineCallbacks async def create_group(
def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description self, group_id, user_id, name, avatar_url, short_description, long_description
): ) -> None:
yield self.db_pool.simple_insert( await self.db_pool.simple_insert(
table="groups", table="groups",
values={ values={
"group_id": group_id, "group_id": group_id,
@ -1212,9 +1205,8 @@ class GroupServerStore(GroupServerWorkerStore):
desc="create_group", desc="create_group",
) )
@defer.inlineCallbacks async def update_group_profile(self, group_id, profile):
def update_group_profile(self, group_id, profile): await self.db_pool.simple_update_one(
yield self.db_pool.simple_update_one(
table="groups", table="groups",
keyvalues={"group_id": group_id}, keyvalues={"group_id": group_id},
updatevalues=profile, updatevalues=profile,

View File

@ -15,8 +15,6 @@
from typing import List, Tuple from typing import List, Tuple
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -24,14 +22,13 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
@defer.inlineCallbacks async def update_presence(self, presence_states):
def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult( stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states) len(presence_states)
) )
with stream_ordering_manager as stream_orderings: with stream_ordering_manager as stream_orderings:
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"update_presence", "update_presence",
self._update_presence_txn, self._update_presence_txn,
stream_orderings, stream_orderings,

View File

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore): class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks async def get_profileinfo(self, user_localpart):
def get_profileinfo(self, user_localpart):
try: try:
profile = yield self.db_pool.simple_select_one( profile = await self.db_pool.simple_select_one(
table="profiles", table="profiles",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"), retcols=("displayname", "avatar_url"),
@ -118,14 +115,13 @@ class ProfileStore(ProfileWorkerStore):
desc="update_remote_profile_cache", desc="update_remote_profile_cache",
) )
@defer.inlineCallbacks async def maybe_delete_remote_profile_cache(self, user_id):
def maybe_delete_remote_profile_cache(self, user_id):
"""Check if we still care about the remote user's profile, and if we """Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache don't then remove their profile from the cache
""" """
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed: if not subscribed:
yield self.db_pool.simple_delete( await self.db_pool.simple_delete(
table="remote_profile_cache", table="remote_profile_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache", desc="delete_remote_profile_cache",
@ -151,11 +147,10 @@ class ProfileStore(ProfileWorkerStore):
_get_remote_profile_cache_entries_that_expire_txn, _get_remote_profile_cache_entries_that_expire_txn,
) )
@defer.inlineCallbacks async def is_subscribed_remote_profile_for_user(self, user_id):
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile. """Check whether we are interested in a remote user's profile.
""" """
res = yield self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
table="group_users", table="group_users",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="user_id", retcol="user_id",
@ -166,7 +161,7 @@ class ProfileStore(ProfileWorkerStore):
if res: if res:
return True return True
res = yield self.db_pool.simple_select_one_onecol( res = await self.db_pool.simple_select_one_onecol(
table="group_invites", table="group_invites",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcol="user_id", retcol="user_id",

View File

@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
import attr import attr
from synapse.api.constants import RelationTypes from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.relations import ( from synapse.storage.relations import (
@ -25,7 +27,7 @@ from synapse.storage.relations import (
PaginationChunk, PaginationChunk,
RelationPaginationToken, RelationPaginationToken,
) )
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -227,18 +229,18 @@ class RelationsWorkerStore(SQLBaseStore):
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
) )
@cachedInlineCallbacks() @cached()
def get_applicable_edit(self, event_id): async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
"""Get the most recent edit (if any) that has happened for the given """Get the most recent edit (if any) that has happened for the given
event. event.
Correctly handles checking whether edits were allowed to happen. Correctly handles checking whether edits were allowed to happen.
Args: Args:
event_id (str): The original event ID event_id: The original event ID
Returns: Returns:
Deferred[EventBase|None]: Returns the most recent edit, if any. The most recent edit, if any.
""" """
# We only allow edits for `m.room.message` events that have the same sender # We only allow edits for `m.room.message` events that have the same sender
@ -268,15 +270,14 @@ class RelationsWorkerStore(SQLBaseStore):
if row: if row:
return row[0] return row[0]
edit_id = yield self.db_pool.runInteraction( edit_id = await self.db_pool.runInteraction(
"get_applicable_edit", _get_applicable_edit_txn "get_applicable_edit", _get_applicable_edit_txn
) )
if not edit_id: if not edit_id:
return return None
edit_event = yield self.get_event(edit_id, allow_none=True) return await self.get_event(edit_id, allow_none=True)
return edit_event
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
"""Check if a user has already annotated an event with the same key """Check if a user has already annotated an event with the same key

View File

@ -18,8 +18,6 @@ from collections import namedtuple
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -126,8 +124,7 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response", desc="set_received_txn_response",
) )
@defer.inlineCallbacks async def get_destination_retry_timings(self, destination):
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination. """Gets the current retry timings (if any) for a given destination.
Args: Args:
@ -142,7 +139,7 @@ class TransactionStore(SQLBaseStore):
if result is not SENTINEL: if result is not SENTINEL:
return result return result
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
"get_destination_retry_timings", "get_destination_retry_timings",
self._get_destination_retry_timings, self._get_destination_retry_timings,
destination, destination,

View File

@ -178,14 +178,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservice_state_none(self): def test_get_appservice_state_none(self):
service = Mock(id="999") service = Mock(id="999")
state = yield self.store.get_appservice_state(service) state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(None, state) self.assertEquals(None, state)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservice_state_up(self): def test_get_appservice_state_up(self):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP) yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.UP)
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
state = yield self.store.get_appservice_state(service) state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.UP, state) self.assertEquals(ApplicationServiceState.UP, state)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -194,13 +194,13 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
service = Mock(id=self.as_list[1]["id"]) service = Mock(id=self.as_list[1]["id"])
state = yield self.store.get_appservice_state(service) state = yield defer.ensureDeferred(self.store.get_appservice_state(service))
self.assertEquals(ApplicationServiceState.DOWN, state) self.assertEquals(ApplicationServiceState.DOWN, state)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_appservices_by_state_none(self): def test_get_appservices_by_state_none(self):
services = yield self.store.get_appservices_by_state( services = yield defer.ensureDeferred(
ApplicationServiceState.DOWN self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
) )
self.assertEquals(0, len(services)) self.assertEquals(0, len(services))
@ -339,7 +339,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def test_get_oldest_unsent_txn_none(self): def test_get_oldest_unsent_txn_none(self):
service = Mock(id=self.as_list[0]["id"]) service = Mock(id=self.as_list[0]["id"])
txn = yield self.store.get_oldest_unsent_txn(service) txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(None, txn) self.assertEquals(None, txn)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -349,14 +349,14 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out # we aren't testing store._base stuff here, so mock this out
self.store.get_events_as_list = Mock(return_value=events) self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 10, events)
yield self._insert_txn(service.id, 11, other_events) yield self._insert_txn(service.id, 11, other_events)
yield self._insert_txn(service.id, 12, other_events) yield self._insert_txn(service.id, 12, other_events)
txn = yield self.store.get_oldest_unsent_txn(service) txn = yield defer.ensureDeferred(self.store.get_oldest_unsent_txn(service))
self.assertEquals(service, txn.service) self.assertEquals(service, txn.service)
self.assertEquals(10, txn.id) self.assertEquals(10, txn.id)
self.assertEquals(events, txn.events) self.assertEquals(events, txn.events)
@ -366,8 +366,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[0]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP) yield self._set_state(self.as_list[1]["id"], ApplicationServiceState.UP)
services = yield self.store.get_appservices_by_state( services = yield defer.ensureDeferred(
ApplicationServiceState.DOWN self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
) )
self.assertEquals(1, len(services)) self.assertEquals(1, len(services))
self.assertEquals(self.as_list[0]["id"], services[0].id) self.assertEquals(self.as_list[0]["id"], services[0].id)
@ -379,8 +379,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN) yield self._set_state(self.as_list[2]["id"], ApplicationServiceState.DOWN)
yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP) yield self._set_state(self.as_list[3]["id"], ApplicationServiceState.UP)
services = yield self.store.get_appservices_by_state( services = yield defer.ensureDeferred(
ApplicationServiceState.DOWN self.store.get_appservices_by_state(ApplicationServiceState.DOWN)
) )
self.assertEquals(2, len(services)) self.assertEquals(2, len(services))
self.assertEquals( self.assertEquals(