mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Convert events worker database to async/await. (#8071)
This commit is contained in:
parent
acfb7c3b5d
commit
f40645e60b
1
changelog.d/8071.misc
Normal file
1
changelog.d/8071.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
@ -47,7 +47,7 @@ def check(
|
||||
Args:
|
||||
room_version_obj: the version of the room
|
||||
event: the event being checked.
|
||||
auth_events (dict: event-key -> event): the existing room state.
|
||||
auth_events: the existing room state.
|
||||
|
||||
Raises:
|
||||
AuthError if the checks fail
|
||||
|
@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler):
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
||||
|
||||
@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler):
|
||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
|
||||
@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler):
|
||||
auth_types = auth_types_for_event(event)
|
||||
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
||||
|
||||
current_auth_events = await self.store.get_events(current_state_ids)
|
||||
auth_events_map = await self.store.get_events(current_state_ids)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in current_auth_events.values()
|
||||
(e.type, e.state_key): e for e in auth_events_map.values()
|
||||
}
|
||||
|
||||
try:
|
||||
@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler):
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
# Just go through and process each event in `remote_auth_chain`. We
|
||||
# don't want to fall into the trap of `missing` being wrong.
|
||||
|
@ -960,7 +960,7 @@ class EventCreationHandler(object):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
is_admin_redaction = (
|
||||
is_admin_redaction = bool(
|
||||
original_event and event.sender != original_event.sender
|
||||
)
|
||||
|
||||
@ -1080,8 +1080,8 @@ class EventCreationHandler(object):
|
||||
auth_events_ids = self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_map = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
|
||||
|
||||
room_version = await self.store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
@ -716,7 +716,7 @@ class RoomMemberHandler(object):
|
||||
|
||||
guest_access = await self.store.get_event(guest_access_id)
|
||||
|
||||
return (
|
||||
return bool(
|
||||
guest_access
|
||||
and guest_access.content
|
||||
and "guest_access" in guest_access.content
|
||||
|
@ -51,5 +51,5 @@ class SpamCheckerApi(object):
|
||||
state_ids = yield self._store.get_filtered_current_state_ids(
|
||||
room_id=room_id, state_filter=StateFilter.from_types(types)
|
||||
)
|
||||
state = yield self._store.get_events(state_ids.values())
|
||||
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
|
||||
return state.values()
|
||||
|
@ -641,7 +641,7 @@ class StateResolutionStore(object):
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||
"""
|
||||
|
||||
return self.store.get_events(
|
||||
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
def get_auth_chain(self, event_ids, include_given=False):
|
||||
async def get_auth_chain(self, event_ids, include_given=False):
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
Returns:
|
||||
list of events
|
||||
"""
|
||||
return self.get_auth_chain_ids(
|
||||
event_ids = await self.get_auth_chain_ids(
|
||||
event_ids, include_given=include_given
|
||||
).addCallback(self.get_events_as_list)
|
||||
)
|
||||
return await self.get_events_as_list(event_ids)
|
||||
|
||||
def get_auth_chain_ids(
|
||||
self,
|
||||
@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
||||
)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
async def get_backfill_events(self, room_id, event_list, limit):
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
including) the events in event_list. Return a list of max size `limit`
|
||||
|
||||
@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
event_list (list)
|
||||
limit (int)
|
||||
"""
|
||||
return (
|
||||
self.db_pool.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events,
|
||||
room_id,
|
||||
event_list,
|
||||
limit,
|
||||
)
|
||||
.addCallback(self.get_events_as_list)
|
||||
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
|
||||
event_ids = await self.db_pool.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events,
|
||||
room_id,
|
||||
event_list,
|
||||
limit,
|
||||
)
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
return sorted(events, key=lambda e: -e.depth)
|
||||
|
||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
|
||||
@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
latest_events,
|
||||
limit,
|
||||
)
|
||||
events = await self.get_events_as_list(ids)
|
||||
return events
|
||||
return await self.get_events_as_list(ids)
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||
|
||||
|
@ -19,9 +19,10 @@ import itertools
|
||||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, overload
|
||||
|
||||
from constantly import NamedConstant, Names
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@ -32,7 +33,7 @@ from synapse.api.room_versions import (
|
||||
EventFormatVersions,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
desc="get_received_ts",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(
|
||||
# Inform mypy that if allow_none is False (the default) then get_event
|
||||
# always returns an EventBase.
|
||||
@overload
|
||||
async def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: Literal[False] = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: Literal[True] = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
allow_rejected: bool = False,
|
||||
allow_none: bool = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
):
|
||||
) -> Optional[EventBase]:
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
If there is a mismatch, behave as per allow_none.
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase|None]
|
||||
The event, or None if the event was not found.
|
||||
"""
|
||||
if not isinstance(event_id, str):
|
||||
raise TypeError("Invalid event event_id %r" % (event_id,))
|
||||
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
[event_id],
|
||||
redact_behaviour=redact_behaviour,
|
||||
get_prev_content=get_prev_content,
|
||||
@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return event
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(
|
||||
async def get_events(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
event_ids: Iterable[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
):
|
||||
) -> Dict[str, EventBase]:
|
||||
"""Get events from the database
|
||||
|
||||
Args:
|
||||
@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
omits rejeted events from the response.
|
||||
|
||||
Returns:
|
||||
Deferred : Dict from event_id to event.
|
||||
A mapping from event_id to event.
|
||||
"""
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
event_ids,
|
||||
redact_behaviour=redact_behaviour,
|
||||
get_prev_content=get_prev_content,
|
||||
@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return {e.event_id: e for e in events}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_as_list(
|
||||
async def get_events_as_list(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
event_ids: Collection[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
):
|
||||
) -> List[EventBase]:
|
||||
"""Get events from the database and return in a list in the same order
|
||||
as given by `event_ids` arg.
|
||||
|
||||
@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
omits rejected events from the response.
|
||||
|
||||
Returns:
|
||||
Deferred[list[EventBase]]: List of events fetched from the database. The
|
||||
events are in the same order as `event_ids` arg.
|
||||
List of events fetched from the database. The events are in the same
|
||||
order as `event_ids` arg.
|
||||
|
||||
Note that the returned list may be smaller than the list of event
|
||||
IDs if not all events could be fetched.
|
||||
@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
return []
|
||||
|
||||
# there may be duplicates so we cast the list to a set
|
||||
event_entry_map = yield self._get_events_from_cache_or_db(
|
||||
event_entry_map = await self._get_events_from_cache_or_db(
|
||||
set(event_ids), allow_rejected=allow_rejected
|
||||
)
|
||||
|
||||
@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
continue
|
||||
|
||||
redacted_event_id = entry.event.redacts
|
||||
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
|
||||
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
|
||||
original_event_entry = event_map.get(redacted_event_id)
|
||||
if not original_event_entry:
|
||||
# we don't have the redacted event (or it was rejected).
|
||||
@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
if get_prev_content:
|
||||
if "replaces_state" in event.unsigned:
|
||||
prev = yield self.get_event(
|
||||
prev = await self.get_event(
|
||||
event.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
allow_none=True,
|
||||
@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
|
||||
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
|
||||
"""Fetch a bunch of events from the cache or the database.
|
||||
|
||||
If events are pulled from the database, they will be cached for future lookups.
|
||||
@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
rejected events are omitted from the response.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, _EventCacheEntry]]:
|
||||
Dict[str, _EventCacheEntry]:
|
||||
map from event id to result
|
||||
"""
|
||||
event_entry_map = self._get_events_from_cache(
|
||||
@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# the events have been redacted, and if so pulling the redaction event out
|
||||
# of the database to check it.
|
||||
#
|
||||
missing_events = yield self._get_events_from_db(
|
||||
missing_events = await self._get_events_from_db(
|
||||
missing_events_ids, allow_rejected=allow_rejected
|
||||
)
|
||||
|
||||
@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_db(self, event_ids, allow_rejected=False):
|
||||
async def _get_events_from_db(self, event_ids, allow_rejected=False):
|
||||
"""Fetch a bunch of events from the database.
|
||||
|
||||
Returned events will be added to the cache for future lookups.
|
||||
@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
rejected events are omitted from the response.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, _EventCacheEntry]]:
|
||||
Dict[str, _EventCacheEntry]:
|
||||
map from event id to result. May return extra events which
|
||||
weren't asked for.
|
||||
"""
|
||||
@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
events_to_fetch = event_ids
|
||||
|
||||
while events_to_fetch:
|
||||
row_map = yield self._enqueue_events(events_to_fetch)
|
||||
row_map = await self._enqueue_events(events_to_fetch)
|
||||
|
||||
# we need to recursively fetch any redactions of those events
|
||||
redaction_ids = set()
|
||||
@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return result_map
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events):
|
||||
async def _enqueue_events(self, events):
|
||||
"""Fetches events from the database using the _event_fetch_list. This
|
||||
allows batch and bulk fetching of events - it allows us to fetch events
|
||||
without having to create a new transaction for each request for events.
|
||||
@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
events (Iterable[str]): events to be fetched.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
|
||||
Dict[str, Dict]: map from event id to row data from the database.
|
||||
May contain events that weren't requested.
|
||||
"""
|
||||
|
||||
@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
logger.debug("Loading %d events: %s", len(events), events)
|
||||
with PreserveLoggingContext():
|
||||
row_map = yield events_d
|
||||
row_map = await events_d
|
||||
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
|
||||
|
||||
return row_map
|
||||
@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# no valid redaction found for this event
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_events_in_timeline(self, event_ids):
|
||||
async def have_events_in_timeline(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = yield defer.ensureDeferred(
|
||||
self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
|
||||
return {r["event_id"] for r in rows}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_seen_events(self, event_ids):
|
||||
async def have_seen_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Args:
|
||||
event_ids (iterable[str]):
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: The events we have already seen.
|
||||
set[str]: The events we have already seen.
|
||||
"""
|
||||
results = set()
|
||||
|
||||
@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# break the input up into chunks of 100
|
||||
input_iterator = iter(event_ids)
|
||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"have_seen_events", have_seen_events_txn, chunk
|
||||
)
|
||||
return results
|
||||
@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
room_id,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_complexity(self, room_id):
|
||||
async def get_room_complexity(self, room_id):
|
||||
"""
|
||||
Get a rough approximation of the complexity of the room. This is used by
|
||||
remote servers to decide whether they wish to join the room or not.
|
||||
@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str:int]] of complexity version to complexity.
|
||||
dict[str:int] of complexity version to complexity.
|
||||
"""
|
||||
state_events = yield self.get_current_state_event_counts(room_id)
|
||||
state_events = await self.get_current_state_event_counts(room_id)
|
||||
|
||||
# Call this one "v1", so we can introduce new ones as we want to develop
|
||||
# it.
|
||||
@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
to_2, so_2 = await self.get_event_ordering(event_id2)
|
||||
return (to_1, so_1) > (to_2, so_2)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=5000)
|
||||
def get_event_ordering(self, event_id):
|
||||
res = yield self.db_pool.simple_select_one(
|
||||
@cached(max_entries=5000)
|
||||
async def get_event_ordering(self, event_id):
|
||||
res = await self.db_pool.simple_select_one(
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
keyvalues={"event_id": event_id},
|
||||
|
@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
limit: int = 0,
|
||||
order: str = "DESC",
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
|
||||
"""Get new room events in stream ordering since `from_key`.
|
||||
|
||||
Args:
|
||||
|
@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
return_value=make_awaitable({"123": mock_event})
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
# Would be better to check the content, but once == remove blocking event
|
||||
@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
return_value=make_awaitable({"123": mock_event})
|
||||
)
|
||||
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
return_value=make_awaitable({"123": mock_event})
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
|
@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
|
||||
)
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
|
||||
|
||||
# we aren't testing store._base stuff here, so mock this out
|
||||
self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
|
||||
self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
|
||||
|
||||
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
|
||||
yield self._insert_txn(service.id, 10, events)
|
||||
|
Loading…
Reference in New Issue
Block a user