Convert state and stream stores and related code to async (#8194)

This commit is contained in:
Patrick Cloke 2020-08-28 09:37:55 -04:00 committed by GitHub
parent b055dc9322
commit aec7085179
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 51 additions and 45 deletions

1
changelog.d/8194.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events( old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values() old_room_member_state_ids.values()
) )
for k, old_event in old_room_member_state_events.items(): for old_event in old_room_member_state_events.values():
# Only transfer ban events # Only transfer ban events
if ( if (
"membership" in old_event.content "membership" in old_event.content

View File

@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.caches import intern_string from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event return create_event
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id): async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the """Get the current state event ids for a room based on the
current_state_events table. current_state_events table.
Args: Args:
room_id (str) room_id: The room to get the state IDs of.
Returns: Returns:
deferred: dict of (type, state_key) -> event_id The current state of the room.
""" """
def _get_current_state_ids_txn(txn): def _get_current_state_ids_txn(txn):
@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn "get_current_state_ids", _get_current_state_ids_txn
) )
# FIXME: how should this be cached? # FIXME: how should this be cached?
def get_filtered_current_state_ids( async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all() self, room_id: str, state_filter: StateFilter = StateFilter.all()
): ) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the """Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state of doing a fresh state resolution as per state_handler.get_current_state
@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
from the database. from the database.
Returns: Returns:
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. Map from type/state_key to event ID.
""" """
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = state_filter.make_sql_filter_clause()
if not where_clause: if not where_clause:
# We delegate to the cached version # We delegate to the cached version
return self.get_current_state_ids(room_id) return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn): def _get_filtered_current_state_ids_txn(txn):
results = {} results = {}
@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results return results
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
) )

View File

@ -14,8 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Any, Dict, List, Tuple
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore): class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int): async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id """Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields: Each entry in the result contains the following fields:
@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
if it's new state. if it's new state.
Args: Args:
prev_stream_id (int): point to get changes since (exclusive) prev_stream_id: point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from. - ie, an upper limit to return changes from.
Returns: Returns:
Deferred[tuple[int, list[dict]]: A tuple consisting of: A tuple consisting of:
- the stream id which these results go up to - the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are - list of current_state_delta_stream rows. If it is empty, we are
up to date. up to date.
@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we # if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and # know for certain that they haven't changed between prev_stream_id and
# max_stream_id. # max_stream_id.
return defer.succeed((max_stream_id, [])) return (max_stream_id, [])
def get_current_state_deltas_txn(txn): def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than # First we calculate the max stream id that will give us less than
@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
txn.execute(sql, (prev_stream_id, clipped_stream_id)) txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn) return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn "get_current_state_deltas", get_current_state_deltas_txn
) )
@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
retcol="COALESCE(MAX(stream_id), -1)", retcol="COALESCE(MAX(stream_id), -1)",
) )
def get_max_stream_id_in_current_state_deltas(self): async def get_max_stream_id_in_current_state_deltas(self):
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas", "get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn, self._get_max_stream_id_in_current_state_deltas_txn,
) )

View File

@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token return rows, token
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int): async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering """Gets details of the first event in a room at or before a stream ordering
Args: Args:
@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering: stream_ordering:
Returns: Returns:
Deferred[(int, int, str)]: A tuple of (stream ordering, topological ordering, event_id)
(stream ordering, topological ordering, event_id)
""" """
def _f(txn): def _f(txn):
@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering)) txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone() return txn.fetchone()
return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f) return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str: async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream. """Returns the current token for rooms stream.

View File

@ -17,8 +17,6 @@ import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple from typing import Dict, Iterable, List, Set, Tuple
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
@cached(max_entries=10000, iterable=True) @cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group): async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
) )
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn "get_state_group_delta", _get_state_group_delta_txn
) )
@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types, fetched_keys=non_member_types,
) )
def store_state_group( async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids self, event_id, room_id, prev_group, delta_ids, current_state_ids
): ) -> int:
"""Store a new set of state, returning a newly assigned state group. """Store a new set of state, returning a newly assigned state group.
Args: Args:
@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to event_id. to event_id.
Returns: Returns:
Deferred[int]: The state group ID The state group ID
""" """
def _store_state_group_txn(txn): def _store_state_group_txn(txn):
@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group return state_group
return self.db_pool.runInteraction("store_state_group", _store_state_group_txn) return await self.db_pool.runInteraction(
"store_state_group", _store_state_group_txn
)
def purge_unreferenced_state_groups( async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete self, room_id: str, state_groups_to_delete
) -> defer.Deferred: ) -> None:
"""Deletes no longer referenced state groups and de-deltas any state """Deletes no longer referenced state groups and de-deltas any state
groups that reference them. groups that reference them.
@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to delete. to delete.
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"purge_unreferenced_state_groups", "purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups, self._purge_unreferenced_state_groups,
room_id, room_id,
@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows} return {row["state_group"]: row["prev_state_group"] for row in rows}
def purge_room_state(self, room_id, state_groups_to_delete): async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables """Deletes all record of a room from state tables
Args: Args:
@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete (list[int]): State groups to delete state_groups_to_delete (list[int]): State groups to delete
""" """
return self.db_pool.runInteraction( await self.db_pool.runInteraction(
"purge_room_state", "purge_room_state",
self._purge_room_state_txn, self._purge_room_state_txn,
room_id, room_id,

View File

@ -333,7 +333,7 @@ class StateGroupStorage(object):
def __init__(self, hs, stores): def __init__(self, hs, stores):
self.stores = stores self.stores = stores
def get_state_group_delta(self, state_group: int): async def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
@ -341,11 +341,11 @@ class StateGroupStorage(object):
state_group: The state group used to retrieve state deltas. state_group: The state group used to retrieve state deltas.
Returns: Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids) (prev_group, delta_ids)
""" """
return self.stores.state.get_state_group_delta(state_group) return await self.stores.state.get_state_group_delta(state_group)
async def get_state_groups_ids( async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str] self, _room_id: str, event_ids: Iterable[str]
@ -525,7 +525,7 @@ class StateGroupStorage(object):
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = await self.get_state_ids_for_events([event_id], state_filter) state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
@ -546,14 +546,14 @@ class StateGroupStorage(object):
""" """
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group( async def store_state_group(
self, self,
event_id: str, event_id: str,
room_id: str, room_id: str,
prev_group: Optional[int], prev_group: Optional[int],
delta_ids: Optional[dict], delta_ids: Optional[dict],
current_state_ids: dict, current_state_ids: dict,
): ) -> int:
"""Store a new set of state, returning a newly assigned state group. """Store a new set of state, returning a newly assigned state group.
Args: Args:
@ -567,8 +567,8 @@ class StateGroupStorage(object):
to event_id. to event_id.
Returns: Returns:
Deferred[int]: The state group ID The state group ID
""" """
return self.stores.state.store_state_group( return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids event_id, room_id, prev_group, delta_ids, current_state_ids
) )