Convert storage layer to async/await. (#7963)

This commit is contained in:
Patrick Cloke 2020-07-28 16:09:53 -04:00 committed by GitHub
parent e866e3b896
commit 3345c166a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 210 additions and 185 deletions

1
changelog.d/7963.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import FrozenEvent
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def persist_events(
async def persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
) -> int:
"""
Write events to the database
Args:
@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc.
Returns:
Deferred[int]: the stream ordering of the latest persisted event
the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
max_persisted_id = yield self.main_store.get_current_events_token()
return self.main_store.get_current_events_token()
return max_persisted_id
@defer.inlineCallbacks
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[int, int]:
"""
Returns:
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
and the stream ordering of the latest persisted event
The stream ordering of `event`, and the stream ordering of the
latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id)
yield make_deferred_yieldable(deferred)
await make_deferred_yieldable(deferred)
max_persisted_id = yield self.main_store.get_current_events_token()
max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities(
self,
room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]],
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events(
self,
room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]],
events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined(
self,
room_id: str,
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
current_state: Optional[StateMap[str]],
potentially_left_users: Set[str],

View File

@ -15,8 +15,7 @@
import itertools
import logging
from twisted.internet import defer
from typing import Set
logger = logging.getLogger(__name__)
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores):
self.stores = stores
@defer.inlineCallbacks
def purge_room(self, room_id: str):
async def purge_room(self, room_id: str):
"""Deletes all record of a room
"""
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@defer.inlineCallbacks
def purge_history(self, room_id, token, delete_local_events):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Deletes room history before a certain point
Args:
room_id (str):
room_id: The room ID
token (str): A topological token to delete events before
token: A topological token to delete events before
delete_local_events (bool):
delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
"""
state_groups = yield self.stores.main.purge_history(
state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] finding state groups that can be deleted")
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
sg_to_delete = await self._find_unreferenced_groups(state_groups)
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks
def _find_unreferenced_groups(self, state_groups):
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
"""Used when purging history to figure out which state groups can be
deleted.
Args:
state_groups (set[int]): Set of state groups referenced by events
state_groups: Set of state groups referenced by events
that are going to be deleted.
Returns:
Deferred[set[int]] The set of state groups that can be deleted.
The set of state groups that can be deleted.
"""
# Graph of state group -> previous group
graph = {}
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search
referenced = yield self.stores.main.get_referenced_state_groups(
referenced = await self.stores.main.get_referenced_state_groups(
current_search
)
referenced_groups |= referenced
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
edges = yield self.stores.state.get_previous_state_groups(current_search)
edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen

View File

@ -14,13 +14,12 @@
# limitations under the License.
import logging
from typing import Iterable, List, TypeVar
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state.
Attributes:
types (dict[str, set[str]|None]): Map from type to set of state keys (or
None). This specifies which state_keys for the given type to fetch
from the DB. If None then all events with that type are fetched. If
the set is empty then no events with that type are fetched.
include_others (bool): Whether to fetch events with types that do not
types: Map from type to set of state keys (or None). This specifies
which state_keys for the given type to fetch from the DB. If None
then all events with that type are fetched. If the set is empty
then no events with that type are fetched.
include_others: Whether to fetch events with types that do not
appear in `types`.
"""
types = attr.ib()
include_others = attr.ib(default=False)
types = attr.ib(type=Dict[str, Optional[Set[str]]])
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
def all():
def all() -> "StateFilter":
"""Creates a filter that fetches everything.
Returns:
StateFilter
The new state filter.
"""
return StateFilter(types={}, include_others=True)
@staticmethod
def none():
def none() -> "StateFilter":
"""Creates a filter that fetches nothing.
Returns:
StateFilter
The new state filter.
"""
return StateFilter(types={}, include_others=False)
@staticmethod
def from_types(types):
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types
Args:
types (Iterable[tuple[str, str|None]]): A list of type and state
keys to fetch. A state_key of None fetches everything for
that type
types: A list of type and state keys to fetch. A state_key of None
fetches everything for that type
Returns:
StateFilter
The new state filter.
"""
type_dict = {}
type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None
continue
type_dict.setdefault(typ, set()).add(s)
type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict)
@staticmethod
def from_lazy_load_member_list(members):
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
events for the given users
Args:
members (iterable[str]): Set of user IDs
members: Set of user IDs
Returns:
StateFilter
The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass
@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events
Returns:
StateFilter
The new state filter.
"""
if self.is_full():
@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True,
)
def make_sql_filter_clause(self):
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
For example:
@ -179,13 +177,12 @@ class StateFilter(object):
Returns:
tuple[str, list]: The SQL string (may be empty) and arguments. An
empty SQL string is returned when the filter matches everything
(i.e. is "full").
The SQL string (may be empty) and arguments. An empty SQL string is
returned when the filter matches everything (i.e. is "full").
"""
where_clause = ""
where_args = []
where_args = [] # type: List[str]
if self.is_full():
return where_clause, where_args
@ -221,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args
def max_entries_returned(self):
def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if
known, otherwise returns None.
@ -260,33 +257,33 @@ class StateFilter(object):
return filtered_state
def is_full(self):
def is_full(self) -> bool:
"""Whether this filter fetches everything or not
Returns:
bool
True if the filter fetches everything.
"""
return self.include_others and not self.types
def has_wildcards(self):
def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch
specific state.
Returns:
bool
True if the filter includes wildcards.
"""
return self.include_others or any(
state_keys is None for state_keys in self.types.values()
)
def concrete_types(self):
def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty).
Returns:
list[tuple[str,str]]
A list of type/state_keys tuples.
"""
return [
(t, s)
@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys
]
def get_member_split(self):
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
against non member state.
@ -307,7 +304,7 @@ class StateFilter(object):
state caches).
Returns:
tuple[StateFilter, StateFilter]: The member and non member filters
The member and non member filters
"""
if EventTypes.Member in self.types:
@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
@ -347,55 +347,59 @@ class StateGroupStorage(object):
return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group (int)
state_group: A state group for which we want to get the state IDs.
Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
@ -463,24 +470,24 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
A dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
@ -491,36 +498,36 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(
@ -530,9 +537,8 @@ class StateGroupStorage(object):
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
@ -540,18 +546,23 @@ class StateGroupStorage(object):
return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
):
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
event = self.successResultOf(event)
# Purge everything before this topological token
purge = storage.purge_events.purge_history(self.room_id, event, True)
purge = defer.ensureDeferred(
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
)
# Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])

View File

@ -97,9 +97,11 @@ class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
yield self.storage.persistence.persist_event(
yield defer.ensureDeferred(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@defer.inlineCallbacks
def STALE_test_room_name(self):

View File

@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = yield self.storage.state.get_state_groups(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
state = yield self.storage.state.get_state_for_event(e5.event_id)
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(e5.event_id)
)
self.assertIsNotNone(e4)
@ -164,23 +168,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.storage.state.get_state_for_event(
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
)
self.assertStateMapEqual(
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
@ -188,13 +198,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
state = yield self.storage.state.get_state_for_event(
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
)
)
self.assertStateMapEqual(
{
@ -206,12 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
state = yield self.storage.state.get_state_for_event(
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
)
)
self.assertStateMapEqual(
{(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
group_ids = yield self.storage.state.get_state_groups_ids(
room_id, [e5.event_id]
group_ids = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]

View File

@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@defer.inlineCallbacks
def test_filtering(self):
@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@defer.inlineCallbacks
@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@defer.inlineCallbacks
@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@defer.inlineCallbacks

View File

@ -638,14 +638,8 @@ class DeferredMockCallable(object):
)
@defer.inlineCallbacks
def create_room(hs, room_id, creator_id):
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
Args:
hs
room_id (str)
creator_id (str)
"""
persistence_store = hs.get_storage().persistence
@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
yield store.store_room(
await store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=False,
@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id):
},
)
event, context = yield defer.ensureDeferred(
event_creation_handler.create_new_client_event(builder)
)
event, context = await event_creation_handler.create_new_client_event(builder)
yield persistence_store.persist_event(event, context)
await persistence_store.persist_event(event, context)

View File

@ -206,6 +206,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
synapse/storage/state.py \
synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \