Convert events worker database to async/await. (#8071)

This commit is contained in:
Patrick Cloke 2020-08-18 16:20:49 -04:00 committed by GitHub
parent acfb7c3b5d
commit f40645e60b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 106 additions and 97 deletions

1
changelog.d/8071.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -47,7 +47,7 @@ def check(
Args: Args:
room_version_obj: the version of the room room_version_obj: the version of the room
event: the event being checked. event: the event being checked.
auth_events (dict: event-key -> event): the existing room state. auth_events: the existing room state.
Raises: Raises:
AuthError if the checks fail AuthError if the checks fail

View File

@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = await self.store.get_event( event = await self.store.get_event(event_id, check_room_id=room_id)
event_id, allow_none=False, check_room_id=room_id
)
state_groups = await self.state_store.get_state_groups(room_id, [event_id]) state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]: async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event. """Returns the state at the event. i.e. not including said event.
""" """
event = await self.store.get_event( event = await self.store.get_event(event_id, check_room_id=room_id)
event_id, allow_none=False, check_room_id=room_id
)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id]) state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler):
auth_types = auth_types_for_event(event) auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_auth_events = await self.store.get_events(current_state_ids) auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = { current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values() (e.type, e.state_key): e for e in auth_events_map.values()
} }
try: try:
@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
event = await self.store.get_event( event = await self.store.get_event(event_id, check_room_id=room_id)
event_id, allow_none=False, check_room_id=room_id
)
# Just go through and process each event in `remote_auth_chain`. We # Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong. # don't want to fall into the trap of `missing` being wrong.

View File

@ -960,7 +960,7 @@ class EventCreationHandler(object):
allow_none=True, allow_none=True,
) )
is_admin_redaction = ( is_admin_redaction = bool(
original_event and event.sender != original_event.sender original_event and event.sender != original_event.sender
) )
@ -1080,8 +1080,8 @@ class EventCreationHandler(object):
auth_events_ids = self.auth.compute_auth_events( auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
room_version = await self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]

View File

@ -716,7 +716,7 @@ class RoomMemberHandler(object):
guest_access = await self.store.get_event(guest_access_id) guest_access = await self.store.get_event(guest_access_id)
return ( return bool(
guest_access guest_access
and guest_access.content and guest_access.content
and "guest_access" in guest_access.content and "guest_access" in guest_access.content

View File

@ -51,5 +51,5 @@ class SpamCheckerApi(object):
state_ids = yield self._store.get_filtered_current_state_ids( state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types) room_id=room_id, state_filter=StateFilter.from_types(types)
) )
state = yield self._store.get_events(state_ids.values()) state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values() return state.values()

View File

@ -641,7 +641,7 @@ class StateResolutionStore(object):
allow_rejected (bool): If True return rejected events. allow_rejected (bool): If True return rejected events.
Returns: Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
""" """
return self.store.get_events( return self.store.get_events(

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False): async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events. """Get auth events for given event_ids. The events *must* be state events.
Args: Args:
@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns: Returns:
list of events list of events
""" """
return self.get_auth_chain_ids( event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given event_ids, include_given=include_given
).addCallback(self.get_events_as_list) )
return await self.get_events_as_list(event_ids)
def get_auth_chain_ids( def get_auth_chain_ids(
self, self,
@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
) )
def get_backfill_events(self, room_id, event_list, limit): async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and """Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit` including) the events in event_list. Return a list of max size `limit`
@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list (list) event_list (list)
limit (int) limit (int)
""" """
return ( event_ids = await self.db_pool.runInteraction(
self.db_pool.runInteraction( "get_backfill_events",
"get_backfill_events", self._get_backfill_events,
self._get_backfill_events, room_id,
room_id, event_list,
event_list, limit,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
) )
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit): def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit) logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events, latest_events,
limit, limit,
) )
events = await self.get_events_as_list(ids) return await self.get_events_as_list(ids)
return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

View File

@ -19,9 +19,10 @@ import itertools
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names from constantly import NamedConstant, Names
from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions, EventFormatVersions,
RoomVersions, RoomVersions,
) )
from synapse.events import make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts", desc="get_received_ts",
) )
@defer.inlineCallbacks # Inform mypy that if allow_none is False (the default) then get_event
def get_event( # always returns an EventBase.
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
) -> EventBase:
...
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
) -> Optional[EventBase]:
...
async def get_event(
self, self,
event_id: str, event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False, allow_rejected: bool = False,
allow_none: bool = False, allow_none: bool = False,
check_room_id: Optional[str] = None, check_room_id: Optional[str] = None,
): ) -> Optional[EventBase]:
"""Get an event from the database by event_id. """Get an event from the database by event_id.
Args: Args:
@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none. If there is a mismatch, behave as per allow_none.
Returns: Returns:
Deferred[EventBase|None] The event, or None if the event was not found.
""" """
if not isinstance(event_id, str): if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,)) raise TypeError("Invalid event event_id %r" % (event_id,))
events = yield self.get_events_as_list( events = await self.get_events_as_list(
[event_id], [event_id],
redact_behaviour=redact_behaviour, redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event return event
@defer.inlineCallbacks async def get_events(
def get_events(
self, self,
event_ids: List[str], event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False, get_prev_content: bool = False,
allow_rejected: bool = False, allow_rejected: bool = False,
): ) -> Dict[str, EventBase]:
"""Get events from the database """Get events from the database
Args: Args:
@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response. omits rejeted events from the response.
Returns: Returns:
Deferred : Dict from event_id to event. A mapping from event_id to event.
""" """
events = yield self.get_events_as_list( events = await self.get_events_as_list(
event_ids, event_ids,
redact_behaviour=redact_behaviour, redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events} return {e.event_id: e for e in events}
@defer.inlineCallbacks async def get_events_as_list(
def get_events_as_list(
self, self,
event_ids: List[str], event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False, get_prev_content: bool = False,
allow_rejected: bool = False, allow_rejected: bool = False,
): ) -> List[EventBase]:
"""Get events from the database and return in a list in the same order """Get events from the database and return in a list in the same order
as given by `event_ids` arg. as given by `event_ids` arg.
@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response. omits rejected events from the response.
Returns: Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The List of events fetched from the database. The events are in the same
events are in the same order as `event_ids` arg. order as `event_ids` arg.
Note that the returned list may be smaller than the list of event Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched. IDs if not all events could be fetched.
@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return [] return []
# there may be duplicates so we cast the list to a set # there may be duplicates so we cast the list to a set
event_entry_map = yield self._get_events_from_cache_or_db( event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected set(event_ids), allow_rejected=allow_rejected
) )
@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue continue
redacted_event_id = entry.event.redacts redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id]) event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id) original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry: if not original_event_entry:
# we don't have the redacted event (or it was rejected). # we don't have the redacted event (or it was rejected).
@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content: if get_prev_content:
if "replaces_state" in event.unsigned: if "replaces_state" in event.unsigned:
prev = yield self.get_event( prev = await self.get_event(
event.unsigned["replaces_state"], event.unsigned["replaces_state"],
get_prev_content=False, get_prev_content=False,
allow_none=True, allow_none=True,
@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events return events
@defer.inlineCallbacks async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database. """Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups. If events are pulled from the database, they will be cached for future lookups.
@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response. rejected events are omitted from the response.
Returns: Returns:
Deferred[Dict[str, _EventCacheEntry]]: Dict[str, _EventCacheEntry]:
map from event id to result map from event id to result
""" """
event_entry_map = self._get_events_from_cache( event_entry_map = self._get_events_from_cache(
@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out # the events have been redacted, and if so pulling the redaction event out
# of the database to check it. # of the database to check it.
# #
missing_events = yield self._get_events_from_db( missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected missing_events_ids, allow_rejected=allow_rejected
) )
@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext(): with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e) self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks async def _get_events_from_db(self, event_ids, allow_rejected=False):
def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database. """Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups. Returned events will be added to the cache for future lookups.
@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response. rejected events are omitted from the response.
Returns: Returns:
Deferred[Dict[str, _EventCacheEntry]]: Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which map from event id to result. May return extra events which
weren't asked for. weren't asked for.
""" """
@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids events_to_fetch = event_ids
while events_to_fetch: while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch) row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events # we need to recursively fetch any redactions of those events
redaction_ids = set() redaction_ids = set()
@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map return result_map
@defer.inlineCallbacks async def _enqueue_events(self, events):
def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This """Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events. without having to create a new transaction for each request for events.
@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched. events (Iterable[str]): events to be fetched.
Returns: Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database. Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested. May contain events that weren't requested.
""" """
@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events) logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext(): with PreserveLoggingContext():
row_map = yield events_d row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map)) logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map return row_map
@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event # no valid redaction found for this event
return None return None
@defer.inlineCallbacks async def have_events_in_timeline(self, event_ids):
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and """Given a list of event ids, check if we have already processed and
stored them as non outliers. stored them as non outliers.
""" """
rows = yield defer.ensureDeferred( rows = await self.db_pool.simple_select_many_batch(
self.db_pool.simple_select_many_batch( table="events",
table="events", retcols=("event_id",),
retcols=("event_id",), column="event_id",
column="event_id", iterable=list(event_ids),
iterable=list(event_ids), keyvalues={"outlier": False},
keyvalues={"outlier": False}, desc="have_events_in_timeline",
desc="have_events_in_timeline",
)
) )
return {r["event_id"] for r in rows} return {r["event_id"] for r in rows}
@defer.inlineCallbacks async def have_seen_events(self, event_ids):
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them. """Given a list of event ids, check if we have already processed them.
Args: Args:
event_ids (iterable[str]): event_ids (iterable[str]):
Returns: Returns:
Deferred[set[str]]: The events we have already seen. set[str]: The events we have already seen.
""" """
results = set() results = set()
@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100 # break the input up into chunks of 100
input_iterator = iter(event_ids) input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk "have_seen_events", have_seen_events_txn, chunk
) )
return results return results
@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id, room_id,
) )
@defer.inlineCallbacks async def get_room_complexity(self, room_id):
def get_room_complexity(self, room_id):
""" """
Get a rough approximation of the complexity of the room. This is used by Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not. remote servers to decide whether they wish to join the room or not.
@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str) room_id (str)
Returns: Returns:
Deferred[dict[str:int]] of complexity version to complexity. dict[str:int] of complexity version to complexity.
""" """
state_events = yield self.get_current_state_event_counts(room_id) state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop # Call this one "v1", so we can introduce new ones as we want to develop
# it. # it.
@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2) to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2) return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000) @cached(max_entries=5000)
def get_event_ordering(self, event_id): async def get_event_ordering(self, event_id):
res = yield self.db_pool.simple_select_one( res = await self.db_pool.simple_select_one(
table="events", table="events",
retcols=["topological_ordering", "stream_ordering"], retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},

View File

@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
limit: int = 0, limit: int = 0,
order: str = "DESC", order: str = "DESC",
) -> Tuple[List[EventBase], str]: ) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`. """Get new room events in stream ordering since `from_key`.
Args: Args:

View File

@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=make_awaitable({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event # Would be better to check the content, but once == remove blocking event
@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=make_awaitable({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
) )
self._rlsn._store.get_events = Mock( self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event}) return_value=make_awaitable({"123": mock_event})
) )
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View File

@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
) )
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver from tests.utils import setup_test_homeserver
@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")] other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out # we aren't testing store._base stuff here, so mock this out
self.store.get_events_as_list = Mock(return_value=defer.succeed(events)) self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events) yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events) yield self._insert_txn(service.id, 10, events)