mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-03 18:10:51 -05:00
Convert events worker database to async/await. (#8071)
This commit is contained in:
parent
acfb7c3b5d
commit
f40645e60b
1
changelog.d/8071.misc
Normal file
1
changelog.d/8071.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Convert various parts of the codebase to async/await.
|
@ -47,7 +47,7 @@ def check(
|
|||||||
Args:
|
Args:
|
||||||
room_version_obj: the version of the room
|
room_version_obj: the version of the room
|
||||||
event: the event being checked.
|
event: the event being checked.
|
||||||
auth_events (dict: event-key -> event): the existing room state.
|
auth_events: the existing room state.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the checks fail
|
AuthError if the checks fail
|
||||||
|
@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler):
|
|||||||
"""Returns the state at the event. i.e. not including said event.
|
"""Returns the state at the event. i.e. not including said event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
event = await self.store.get_event(
|
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||||
event_id, allow_none=False, check_room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
||||||
|
|
||||||
@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler):
|
|||||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||||
"""Returns the state at the event. i.e. not including said event.
|
"""Returns the state at the event. i.e. not including said event.
|
||||||
"""
|
"""
|
||||||
event = await self.store.get_event(
|
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||||
event_id, allow_none=False, check_room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||||
|
|
||||||
@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler):
|
|||||||
auth_types = auth_types_for_event(event)
|
auth_types = auth_types_for_event(event)
|
||||||
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
||||||
|
|
||||||
current_auth_events = await self.store.get_events(current_state_ids)
|
auth_events_map = await self.store.get_events(current_state_ids)
|
||||||
current_auth_events = {
|
current_auth_events = {
|
||||||
(e.type, e.state_key): e for e in current_auth_events.values()
|
(e.type, e.state_key): e for e in auth_events_map.values()
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler):
|
|||||||
if not in_room:
|
if not in_room:
|
||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
event = await self.store.get_event(
|
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||||
event_id, allow_none=False, check_room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Just go through and process each event in `remote_auth_chain`. We
|
# Just go through and process each event in `remote_auth_chain`. We
|
||||||
# don't want to fall into the trap of `missing` being wrong.
|
# don't want to fall into the trap of `missing` being wrong.
|
||||||
|
@ -960,7 +960,7 @@ class EventCreationHandler(object):
|
|||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
is_admin_redaction = (
|
is_admin_redaction = bool(
|
||||||
original_event and event.sender != original_event.sender
|
original_event and event.sender != original_event.sender
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1080,8 +1080,8 @@ class EventCreationHandler(object):
|
|||||||
auth_events_ids = self.auth.compute_auth_events(
|
auth_events_ids = self.auth.compute_auth_events(
|
||||||
event, prev_state_ids, for_verification=True
|
event, prev_state_ids, for_verification=True
|
||||||
)
|
)
|
||||||
auth_events = await self.store.get_events(auth_events_ids)
|
auth_events_map = await self.store.get_events(auth_events_ids)
|
||||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
|
||||||
|
|
||||||
room_version = await self.store.get_room_version_id(event.room_id)
|
room_version = await self.store.get_room_version_id(event.room_id)
|
||||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||||
|
@ -716,7 +716,7 @@ class RoomMemberHandler(object):
|
|||||||
|
|
||||||
guest_access = await self.store.get_event(guest_access_id)
|
guest_access = await self.store.get_event(guest_access_id)
|
||||||
|
|
||||||
return (
|
return bool(
|
||||||
guest_access
|
guest_access
|
||||||
and guest_access.content
|
and guest_access.content
|
||||||
and "guest_access" in guest_access.content
|
and "guest_access" in guest_access.content
|
||||||
|
@ -51,5 +51,5 @@ class SpamCheckerApi(object):
|
|||||||
state_ids = yield self._store.get_filtered_current_state_ids(
|
state_ids = yield self._store.get_filtered_current_state_ids(
|
||||||
room_id=room_id, state_filter=StateFilter.from_types(types)
|
room_id=room_id, state_filter=StateFilter.from_types(types)
|
||||||
)
|
)
|
||||||
state = yield self._store.get_events(state_ids.values())
|
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
|
||||||
return state.values()
|
return state.values()
|
||||||
|
@ -641,7 +641,7 @@ class StateResolutionStore(object):
|
|||||||
allow_rejected (bool): If True return rejected events.
|
allow_rejected (bool): If True return rejected events.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
|
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.store.get_events(
|
return self.store.get_events(
|
||||||
|
@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||||
def get_auth_chain(self, event_ids, include_given=False):
|
async def get_auth_chain(self, event_ids, include_given=False):
|
||||||
"""Get auth events for given event_ids. The events *must* be state events.
|
"""Get auth events for given event_ids. The events *must* be state events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
Returns:
|
Returns:
|
||||||
list of events
|
list of events
|
||||||
"""
|
"""
|
||||||
return self.get_auth_chain_ids(
|
event_ids = await self.get_auth_chain_ids(
|
||||||
event_ids, include_given=include_given
|
event_ids, include_given=include_given
|
||||||
).addCallback(self.get_events_as_list)
|
)
|
||||||
|
return await self.get_events_as_list(event_ids)
|
||||||
|
|
||||||
def get_auth_chain_ids(
|
def get_auth_chain_ids(
|
||||||
self,
|
self,
|
||||||
@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_backfill_events(self, room_id, event_list, limit):
|
async def get_backfill_events(self, room_id, event_list, limit):
|
||||||
"""Get a list of Events for a given topic that occurred before (and
|
"""Get a list of Events for a given topic that occurred before (and
|
||||||
including) the events in event_list. Return a list of max size `limit`
|
including) the events in event_list. Return a list of max size `limit`
|
||||||
|
|
||||||
@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
event_list (list)
|
event_list (list)
|
||||||
limit (int)
|
limit (int)
|
||||||
"""
|
"""
|
||||||
return (
|
event_ids = await self.db_pool.runInteraction(
|
||||||
self.db_pool.runInteraction(
|
|
||||||
"get_backfill_events",
|
"get_backfill_events",
|
||||||
self._get_backfill_events,
|
self._get_backfill_events,
|
||||||
room_id,
|
room_id,
|
||||||
event_list,
|
event_list,
|
||||||
limit,
|
limit,
|
||||||
)
|
)
|
||||||
.addCallback(self.get_events_as_list)
|
events = await self.get_events_as_list(event_ids)
|
||||||
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
|
return sorted(events, key=lambda e: -e.depth)
|
||||||
)
|
|
||||||
|
|
||||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||||
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
|
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
|
||||||
@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
|||||||
latest_events,
|
latest_events,
|
||||||
limit,
|
limit,
|
||||||
)
|
)
|
||||||
events = await self.get_events_as_list(ids)
|
return await self.get_events_as_list(ids)
|
||||||
return events
|
|
||||||
|
|
||||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||||
|
|
||||||
|
@ -19,9 +19,10 @@ import itertools
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple, overload
|
||||||
|
|
||||||
from constantly import NamedConstant, Names
|
from constantly import NamedConstant, Names
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -32,7 +33,7 @@ from synapse.api.room_versions import (
|
|||||||
EventFormatVersions,
|
EventFormatVersions,
|
||||||
RoomVersions,
|
RoomVersions,
|
||||||
)
|
)
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
|
|||||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import Collection, get_domain_from_id
|
||||||
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
|
from synapse.util.caches.descriptors import Cache, cached
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
desc="get_received_ts",
|
desc="get_received_ts",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
# Inform mypy that if allow_none is False (the default) then get_event
|
||||||
def get_event(
|
# always returns an EventBase.
|
||||||
|
@overload
|
||||||
|
async def get_event(
|
||||||
|
self,
|
||||||
|
event_id: str,
|
||||||
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
|
get_prev_content: bool = False,
|
||||||
|
allow_rejected: bool = False,
|
||||||
|
allow_none: Literal[False] = False,
|
||||||
|
check_room_id: Optional[str] = None,
|
||||||
|
) -> EventBase:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def get_event(
|
||||||
|
self,
|
||||||
|
event_id: str,
|
||||||
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
|
get_prev_content: bool = False,
|
||||||
|
allow_rejected: bool = False,
|
||||||
|
allow_none: Literal[True] = False,
|
||||||
|
check_room_id: Optional[str] = None,
|
||||||
|
) -> Optional[EventBase]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_event(
|
||||||
self,
|
self,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
allow_rejected: bool = False,
|
allow_rejected: bool = False,
|
||||||
allow_none: bool = False,
|
allow_none: bool = False,
|
||||||
check_room_id: Optional[str] = None,
|
check_room_id: Optional[str] = None,
|
||||||
):
|
) -> Optional[EventBase]:
|
||||||
"""Get an event from the database by event_id.
|
"""Get an event from the database by event_id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
If there is a mismatch, behave as per allow_none.
|
If there is a mismatch, behave as per allow_none.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[EventBase|None]
|
The event, or None if the event was not found.
|
||||||
"""
|
"""
|
||||||
if not isinstance(event_id, str):
|
if not isinstance(event_id, str):
|
||||||
raise TypeError("Invalid event event_id %r" % (event_id,))
|
raise TypeError("Invalid event event_id %r" % (event_id,))
|
||||||
|
|
||||||
events = yield self.get_events_as_list(
|
events = await self.get_events_as_list(
|
||||||
[event_id],
|
[event_id],
|
||||||
redact_behaviour=redact_behaviour,
|
redact_behaviour=redact_behaviour,
|
||||||
get_prev_content=get_prev_content,
|
get_prev_content=get_prev_content,
|
||||||
@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_events(
|
||||||
def get_events(
|
|
||||||
self,
|
self,
|
||||||
event_ids: List[str],
|
event_ids: Iterable[str],
|
||||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
get_prev_content: bool = False,
|
get_prev_content: bool = False,
|
||||||
allow_rejected: bool = False,
|
allow_rejected: bool = False,
|
||||||
):
|
) -> Dict[str, EventBase]:
|
||||||
"""Get events from the database
|
"""Get events from the database
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
omits rejeted events from the response.
|
omits rejeted events from the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred : Dict from event_id to event.
|
A mapping from event_id to event.
|
||||||
"""
|
"""
|
||||||
events = yield self.get_events_as_list(
|
events = await self.get_events_as_list(
|
||||||
event_ids,
|
event_ids,
|
||||||
redact_behaviour=redact_behaviour,
|
redact_behaviour=redact_behaviour,
|
||||||
get_prev_content=get_prev_content,
|
get_prev_content=get_prev_content,
|
||||||
@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return {e.event_id: e for e in events}
|
return {e.event_id: e for e in events}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_events_as_list(
|
||||||
def get_events_as_list(
|
|
||||||
self,
|
self,
|
||||||
event_ids: List[str],
|
event_ids: Collection[str],
|
||||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||||
get_prev_content: bool = False,
|
get_prev_content: bool = False,
|
||||||
allow_rejected: bool = False,
|
allow_rejected: bool = False,
|
||||||
):
|
) -> List[EventBase]:
|
||||||
"""Get events from the database and return in a list in the same order
|
"""Get events from the database and return in a list in the same order
|
||||||
as given by `event_ids` arg.
|
as given by `event_ids` arg.
|
||||||
|
|
||||||
@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
omits rejected events from the response.
|
omits rejected events from the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[EventBase]]: List of events fetched from the database. The
|
List of events fetched from the database. The events are in the same
|
||||||
events are in the same order as `event_ids` arg.
|
order as `event_ids` arg.
|
||||||
|
|
||||||
Note that the returned list may be smaller than the list of event
|
Note that the returned list may be smaller than the list of event
|
||||||
IDs if not all events could be fetched.
|
IDs if not all events could be fetched.
|
||||||
@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# there may be duplicates so we cast the list to a set
|
# there may be duplicates so we cast the list to a set
|
||||||
event_entry_map = yield self._get_events_from_cache_or_db(
|
event_entry_map = await self._get_events_from_cache_or_db(
|
||||||
set(event_ids), allow_rejected=allow_rejected
|
set(event_ids), allow_rejected=allow_rejected
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
redacted_event_id = entry.event.redacts
|
redacted_event_id = entry.event.redacts
|
||||||
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
|
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
|
||||||
original_event_entry = event_map.get(redacted_event_id)
|
original_event_entry = event_map.get(redacted_event_id)
|
||||||
if not original_event_entry:
|
if not original_event_entry:
|
||||||
# we don't have the redacted event (or it was rejected).
|
# we don't have the redacted event (or it was rejected).
|
||||||
@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
if get_prev_content:
|
if get_prev_content:
|
||||||
if "replaces_state" in event.unsigned:
|
if "replaces_state" in event.unsigned:
|
||||||
prev = yield self.get_event(
|
prev = await self.get_event(
|
||||||
event.unsigned["replaces_state"],
|
event.unsigned["replaces_state"],
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return events
|
return events
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
|
||||||
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
|
|
||||||
"""Fetch a bunch of events from the cache or the database.
|
"""Fetch a bunch of events from the cache or the database.
|
||||||
|
|
||||||
If events are pulled from the database, they will be cached for future lookups.
|
If events are pulled from the database, they will be cached for future lookups.
|
||||||
@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
rejected events are omitted from the response.
|
rejected events are omitted from the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[str, _EventCacheEntry]]:
|
Dict[str, _EventCacheEntry]:
|
||||||
map from event id to result
|
map from event id to result
|
||||||
"""
|
"""
|
||||||
event_entry_map = self._get_events_from_cache(
|
event_entry_map = self._get_events_from_cache(
|
||||||
@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# the events have been redacted, and if so pulling the redaction event out
|
# the events have been redacted, and if so pulling the redaction event out
|
||||||
# of the database to check it.
|
# of the database to check it.
|
||||||
#
|
#
|
||||||
missing_events = yield self._get_events_from_db(
|
missing_events = await self._get_events_from_db(
|
||||||
missing_events_ids, allow_rejected=allow_rejected
|
missing_events_ids, allow_rejected=allow_rejected
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _get_events_from_db(self, event_ids, allow_rejected=False):
|
||||||
def _get_events_from_db(self, event_ids, allow_rejected=False):
|
|
||||||
"""Fetch a bunch of events from the database.
|
"""Fetch a bunch of events from the database.
|
||||||
|
|
||||||
Returned events will be added to the cache for future lookups.
|
Returned events will be added to the cache for future lookups.
|
||||||
@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
rejected events are omitted from the response.
|
rejected events are omitted from the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[str, _EventCacheEntry]]:
|
Dict[str, _EventCacheEntry]:
|
||||||
map from event id to result. May return extra events which
|
map from event id to result. May return extra events which
|
||||||
weren't asked for.
|
weren't asked for.
|
||||||
"""
|
"""
|
||||||
@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
events_to_fetch = event_ids
|
events_to_fetch = event_ids
|
||||||
|
|
||||||
while events_to_fetch:
|
while events_to_fetch:
|
||||||
row_map = yield self._enqueue_events(events_to_fetch)
|
row_map = await self._enqueue_events(events_to_fetch)
|
||||||
|
|
||||||
# we need to recursively fetch any redactions of those events
|
# we need to recursively fetch any redactions of those events
|
||||||
redaction_ids = set()
|
redaction_ids = set()
|
||||||
@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
return result_map
|
return result_map
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _enqueue_events(self, events):
|
||||||
def _enqueue_events(self, events):
|
|
||||||
"""Fetches events from the database using the _event_fetch_list. This
|
"""Fetches events from the database using the _event_fetch_list. This
|
||||||
allows batch and bulk fetching of events - it allows us to fetch events
|
allows batch and bulk fetching of events - it allows us to fetch events
|
||||||
without having to create a new transaction for each request for events.
|
without having to create a new transaction for each request for events.
|
||||||
@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
events (Iterable[str]): events to be fetched.
|
events (Iterable[str]): events to be fetched.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
|
Dict[str, Dict]: map from event id to row data from the database.
|
||||||
May contain events that weren't requested.
|
May contain events that weren't requested.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
logger.debug("Loading %d events: %s", len(events), events)
|
logger.debug("Loading %d events: %s", len(events), events)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
row_map = yield events_d
|
row_map = await events_d
|
||||||
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
|
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
|
||||||
|
|
||||||
return row_map
|
return row_map
|
||||||
@ -842,13 +863,11 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# no valid redaction found for this event
|
# no valid redaction found for this event
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def have_events_in_timeline(self, event_ids):
|
||||||
def have_events_in_timeline(self, event_ids):
|
|
||||||
"""Given a list of event ids, check if we have already processed and
|
"""Given a list of event ids, check if we have already processed and
|
||||||
stored them as non outliers.
|
stored them as non outliers.
|
||||||
"""
|
"""
|
||||||
rows = yield defer.ensureDeferred(
|
rows = await self.db_pool.simple_select_many_batch(
|
||||||
self.db_pool.simple_select_many_batch(
|
|
||||||
table="events",
|
table="events",
|
||||||
retcols=("event_id",),
|
retcols=("event_id",),
|
||||||
column="event_id",
|
column="event_id",
|
||||||
@ -856,19 +875,17 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
keyvalues={"outlier": False},
|
keyvalues={"outlier": False},
|
||||||
desc="have_events_in_timeline",
|
desc="have_events_in_timeline",
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return {r["event_id"] for r in rows}
|
return {r["event_id"] for r in rows}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def have_seen_events(self, event_ids):
|
||||||
def have_seen_events(self, event_ids):
|
|
||||||
"""Given a list of event ids, check if we have already processed them.
|
"""Given a list of event ids, check if we have already processed them.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids (iterable[str]):
|
event_ids (iterable[str]):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[str]]: The events we have already seen.
|
set[str]: The events we have already seen.
|
||||||
"""
|
"""
|
||||||
results = set()
|
results = set()
|
||||||
|
|
||||||
@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
# break the input up into chunks of 100
|
# break the input up into chunks of 100
|
||||||
input_iterator = iter(event_ids)
|
input_iterator = iter(event_ids)
|
||||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
|
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
|
||||||
yield self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"have_seen_events", have_seen_events_txn, chunk
|
"have_seen_events", have_seen_events_txn, chunk
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
room_id,
|
room_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_room_complexity(self, room_id):
|
||||||
def get_room_complexity(self, room_id):
|
|
||||||
"""
|
"""
|
||||||
Get a rough approximation of the complexity of the room. This is used by
|
Get a rough approximation of the complexity of the room. This is used by
|
||||||
remote servers to decide whether they wish to join the room or not.
|
remote servers to decide whether they wish to join the room or not.
|
||||||
@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
room_id (str)
|
room_id (str)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[str:int]] of complexity version to complexity.
|
dict[str:int] of complexity version to complexity.
|
||||||
"""
|
"""
|
||||||
state_events = yield self.get_current_state_event_counts(room_id)
|
state_events = await self.get_current_state_event_counts(room_id)
|
||||||
|
|
||||||
# Call this one "v1", so we can introduce new ones as we want to develop
|
# Call this one "v1", so we can introduce new ones as we want to develop
|
||||||
# it.
|
# it.
|
||||||
@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
|
|||||||
to_2, so_2 = await self.get_event_ordering(event_id2)
|
to_2, so_2 = await self.get_event_ordering(event_id2)
|
||||||
return (to_1, so_1) > (to_2, so_2)
|
return (to_1, so_1) > (to_2, so_2)
|
||||||
|
|
||||||
@cachedInlineCallbacks(max_entries=5000)
|
@cached(max_entries=5000)
|
||||||
def get_event_ordering(self, event_id):
|
async def get_event_ordering(self, event_id):
|
||||||
res = yield self.db_pool.simple_select_one(
|
res = await self.db_pool.simple_select_one(
|
||||||
table="events",
|
table="events",
|
||||||
retcols=["topological_ordering", "stream_ordering"],
|
retcols=["topological_ordering", "stream_ordering"],
|
||||||
keyvalues={"event_id": event_id},
|
keyvalues={"event_id": event_id},
|
||||||
|
@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
|||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
order: str = "DESC",
|
order: str = "DESC",
|
||||||
) -> Tuple[List[EventBase], str]:
|
) -> Tuple[List[EventBase], str]:
|
||||||
|
|
||||||
"""Get new room events in stream ordering since `from_key`.
|
"""Get new room events in stream ordering since `from_key`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock(
|
||||||
return_value=defer.succeed({"123": mock_event})
|
return_value=make_awaitable({"123": mock_event})
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
# Would be better to check the content, but once == remove blocking event
|
# Would be better to check the content, but once == remove blocking event
|
||||||
@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock(
|
||||||
return_value=defer.succeed({"123": mock_event})
|
return_value=make_awaitable({"123": mock_event})
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
|||||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||||
)
|
)
|
||||||
self._rlsn._store.get_events = Mock(
|
self._rlsn._store.get_events = Mock(
|
||||||
return_value=defer.succeed({"123": mock_event})
|
return_value=make_awaitable({"123": mock_event})
|
||||||
)
|
)
|
||||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
from tests.test_utils import make_awaitable
|
||||||
from tests.utils import setup_test_homeserver
|
from tests.utils import setup_test_homeserver
|
||||||
|
|
||||||
|
|
||||||
@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
|||||||
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
|
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
|
||||||
|
|
||||||
# we aren't testing store._base stuff here, so mock this out
|
# we aren't testing store._base stuff here, so mock this out
|
||||||
self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
|
self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
|
||||||
|
|
||||||
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
|
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
|
||||||
yield self._insert_txn(service.id, 10, events)
|
yield self._insert_txn(service.id, 10, events)
|
||||||
|
Loading…
Reference in New Issue
Block a user