Convert events worker database to async/await. (#8071)

This commit is contained in:
Patrick Cloke 2020-08-18 16:20:49 -04:00 committed by GitHub
parent acfb7c3b5d
commit f40645e60b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 106 additions and 97 deletions

View file

@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
from typing import List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
from typing_extensions import Literal
from twisted.internet import defer
@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
from synapse.events import make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
@defer.inlineCallbacks
def get_event(
# Inform mypy that if allow_none is False (the default) then get_event
# always returns an EventBase.
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
) -> EventBase:
...
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
) -> Optional[EventBase]:
...
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
):
) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
Deferred[EventBase|None]
The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
@defer.inlineCallbacks
def get_events(
async def get_events(
self,
event_ids: List[str],
event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
Deferred : Dict from event_id to event.
A mapping from event_id to event.
"""
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
@defer.inlineCallbacks
def get_events_as_list(
async def get_events_as_list(
self,
event_ids: List[str],
event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
events are in the same order as `event_ids` arg.
List of events fetched from the database. The events are in the same
order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
event_entry_map = yield self._get_events_from_cache_or_db(
event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
prev = yield self.get_event(
prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
@defer.inlineCallbacks
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
missing_events = yield self._get_events_from_db(
missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _get_events_from_db(self, event_ids, allow_rejected=False):
async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
@defer.inlineCallbacks
def _enqueue_events(self, events):
async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
row_map = yield events_d
row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
return {r["event_id"] for r in rows}
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
set[str]: The events we have already seen.
"""
results = set()
@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
@defer.inlineCallbacks
def get_room_complexity(self, room_id):
async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
Deferred[dict[str:int]] of complexity version to complexity.
dict[str:int] of complexity version to complexity.
"""
state_events = yield self.get_current_state_event_counts(room_id)
state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id):
res = yield self.db_pool.simple_select_one(
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},