Convert a synapse.events to async/await. (#7949)

This commit is contained in:
Patrick Cloke 2020-07-27 13:40:22 -04:00 committed by GitHub
parent 5f65e62681
commit 8553f46498
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 86 additions and 82 deletions

View File

@ -1 +1 @@
Convert push to async/await. Convert various parts of the codebase to async/await.

1
changelog.d/7949.misc Normal file
View File

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@ -1 +1 @@
Convert groups and visibility code to async / await. Convert various parts of the codebase to async/await.

View File

@ -82,7 +82,7 @@ class Auth(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True): def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events( auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )

View File

@ -17,8 +17,6 @@ from typing import Optional
import attr import attr
from nacl.signing import SigningKey from nacl.signing import SigningKey
from twisted.internet import defer
from synapse.api.constants import MAX_DEPTH from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import ( from synapse.api.room_versions import (
@ -95,31 +93,30 @@ class EventBuilder(object):
def is_state(self): def is_state(self):
return self._state_key is not None return self._state_key is not None
@defer.inlineCallbacks async def build(self, prev_event_ids):
def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event
Args: Args:
prev_event_ids (list[str]): The event IDs to use as the prev events prev_event_ids (list[str]): The event IDs to use as the prev events
Returns: Returns:
Deferred[FrozenEvent] FrozenEvent
""" """
state_ids = yield defer.ensureDeferred( state_ids = await self._state.get_current_state_ids(
self._state.get_current_state_ids(self.room_id, prev_event_ids) self.room_id, prev_event_ids
) )
auth_ids = yield self._auth.compute_auth_events(self, state_ids) auth_ids = await self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1: if format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids) auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids) prev_events = await self._store.add_event_hashes(prev_event_ids)
else: else:
auth_events = auth_ids auth_events = auth_ids
prev_events = prev_event_ids prev_events = prev_event_ids
old_depth = yield self._store.get_max_depth_of(prev_event_ids) old_depth = await self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1 depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not # we cap depth of generated events, to ensure that they are not

View File

@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Union from typing import TYPE_CHECKING, Optional, Union
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore
@attr.s(slots=True) @attr.s(slots=True)
class EventContext: class EventContext:
@ -129,8 +131,7 @@ class EventContext:
delta_ids=delta_ids, delta_ids=delta_ids,
) )
@defer.inlineCallbacks async def serialize(self, event: EventBase, store: "DataStore") -> dict:
def serialize(self, event, store):
"""Converts self to a type that can be serialized as JSON, and then """Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize` deserialized by `deserialize`
@ -146,7 +147,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event # the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state. # id that we replaced in the state.
if event.is_state(): if event.is_state():
prev_state_ids = yield self.get_prev_state_ids() prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key)) prev_state_id = prev_state_ids.get((event.type, event.state_key))
else: else:
prev_state_id = None prev_state_id = None
@ -214,8 +215,7 @@ class EventContext:
return self._state_group return self._state_group
@defer.inlineCallbacks async def get_current_state_ids(self) -> Optional[StateMap[str]]:
def get_current_state_ids(self):
""" """
Gets the room state map, including this event - ie, the state in ``state_group`` Gets the room state map, including this event - ie, the state in ``state_group``
@ -224,8 +224,8 @@ class EventContext:
``rejected`` is set. ``rejected`` is set.
Returns: Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group Returns None if state_group is None, which happens when the associated
is None, which happens when the associated event is an outlier. event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching Maps a (type, state_key) to the event ID of the state event matching
this tuple. this tuple.
@ -233,23 +233,22 @@ class EventContext:
if self.rejected: if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event") raise RuntimeError("Attempt to access state_ids of rejected event")
yield self._ensure_fetched() await self._ensure_fetched()
return self._current_state_ids return self._current_state_ids
@defer.inlineCallbacks async def get_prev_state_ids(self):
def get_prev_state_ids(self):
""" """
Gets the room state map, excluding this event. Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids(). For a non-state event, this will be the same as get_current_state_ids().
Returns: Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier. is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching Maps a (type, state_key) to the event ID of the state event matching
this tuple. this tuple.
""" """
yield self._ensure_fetched() await self._ensure_fetched()
return self._prev_state_ids return self._prev_state_ids
def get_cached_current_state_ids(self): def get_cached_current_state_ids(self):
@ -269,8 +268,8 @@ class EventContext:
return self._current_state_ids return self._current_state_ids
def _ensure_fetched(self): async def _ensure_fetched(self):
return defer.succeed(None) return None
@attr.s(slots=True) @attr.s(slots=True)
@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None) _event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None)
def _ensure_fetched(self): async def _ensure_fetched(self):
if not self._fetching_state_deferred: if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state) self._fetching_state_deferred = run_in_background(self._fill_out_state)
return make_deferred_yieldable(self._fetching_state_deferred) return await make_deferred_yieldable(self._fetching_state_deferred)
@defer.inlineCallbacks async def _fill_out_state(self):
def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids """Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database. attributes by loading from the database.
""" """
if self.state_group is None: if self.state_group is None:
return return
self._current_state_ids = yield self._storage.state.get_state_ids_for_group( self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group self.state_group
) )
if self._event_state_key is not None: if self._event_state_key is not None:

View File

@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Requester
class ThirdPartyEventRules(object): class ThirdPartyEventRules(object):
@ -39,76 +41,79 @@ class ThirdPartyEventRules(object):
config=config, http_client=hs.get_simple_http_client() config=config, http_client=hs.get_simple_http_client()
) )
@defer.inlineCallbacks async def check_event_allowed(
def check_event_allowed(self, event, context): self, event: EventBase, context: EventContext
) -> bool:
"""Check if a provided event should be allowed in the given context. """Check if a provided event should be allowed in the given context.
Args: Args:
event (synapse.events.EventBase): The event to be checked. event: The event to be checked.
context (synapse.events.snapshot.EventContext): The context of the event. context: The context of the event.
Returns: Returns:
defer.Deferred[bool]: True if the event should be allowed, False if not. True if the event should be allowed, False if not.
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database. # Retrieve the state events from the database.
state_events = {} state_events = {}
for key, event_id in prev_state_ids.items(): for key, event_id in prev_state_ids.items():
state_events[key] = yield self.store.get_event(event_id, allow_none=True) state_events[key] = await self.store.get_event(event_id, allow_none=True)
ret = yield self.third_party_rules.check_event_allowed(event, state_events) ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret return ret
@defer.inlineCallbacks async def on_create_room(
def on_create_room(self, requester, config, is_requester_admin): self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool:
"""Intercept requests to create room to allow, deny or update the """Intercept requests to create room to allow, deny or update the
request config. request config.
Args: Args:
requester (Requester) requester
config (dict): The creation config from the client. config: The creation config from the client.
is_requester_admin (bool): If the requester is an admin is_requester_admin: If the requester is an admin
Returns: Returns:
defer.Deferred[bool]: Whether room creation is allowed or denied. Whether room creation is allowed or denied.
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
ret = yield self.third_party_rules.on_create_room( ret = await self.third_party_rules.on_create_room(
requester, config, is_requester_admin requester, config, is_requester_admin
) )
return ret return ret
@defer.inlineCallbacks async def check_threepid_can_be_invited(
def check_threepid_can_be_invited(self, medium, address, room_id): self, medium: str, address: str, room_id: str
) -> bool:
"""Check if a provided 3PID can be invited in the given room. """Check if a provided 3PID can be invited in the given room.
Args: Args:
medium (str): The 3PID's medium. medium: The 3PID's medium.
address (str): The 3PID's address. address: The 3PID's address.
room_id (str): The room we want to invite the threepid to. room_id: The room we want to invite the threepid to.
Returns: Returns:
defer.Deferred[bool], True if the 3PID can be invited, False if not. True if the 3PID can be invited, False if not.
""" """
if self.third_party_rules is None: if self.third_party_rules is None:
return True return True
state_ids = yield self.store.get_filtered_current_state_ids(room_id) state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values()) room_state_events = await self.store.get_events(state_ids.values())
state_events = {} state_events = {}
for key, event_id in state_ids.items(): for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id] state_events[key] = room_state_events[event_id]
ret = yield self.third_party_rules.check_threepid_can_be_invited( ret = await self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events medium, address, state_events
) )
return ret return ret

View File

@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
from frozendict import frozendict from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
@ -337,8 +335,9 @@ class EventClientSerializer(object):
hs.config.experimental_msc1849_support_enabled hs.config.experimental_msc1849_support_enabled
) )
@defer.inlineCallbacks async def serialize_event(
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs): self, event, time_now, bundle_aggregations=True, **kwargs
):
"""Serializes a single event. """Serializes a single event.
Args: Args:
@ -348,7 +347,7 @@ class EventClientSerializer(object):
**kwargs: Arguments to pass to `serialize_event` **kwargs: Arguments to pass to `serialize_event`
Returns: Returns:
Deferred[dict]: The serialized event dict: The serialized event
""" """
# To handle the case of presence events and the like # To handle the case of presence events and the like
if not isinstance(event, EventBase): if not isinstance(event, EventBase):
@ -363,8 +362,8 @@ class EventClientSerializer(object):
if not event.internal_metadata.is_redacted() and ( if not event.internal_metadata.is_redacted() and (
self.experimental_msc1849_support_enabled and bundle_aggregations self.experimental_msc1849_support_enabled and bundle_aggregations
): ):
annotations = yield self.store.get_aggregation_groups_for_event(event_id) annotations = await self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event( references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f" event_id, RelationTypes.REFERENCE, direction="f"
) )
@ -378,7 +377,7 @@ class EventClientSerializer(object):
edit = None edit = None
if event.type == EventTypes.Message: if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit(event_id) edit = await self.store.get_applicable_edit(event_id)
if edit: if edit:
# If there is an edit replace the content, preserving existing # If there is an edit replace the content, preserving existing

View File

@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler):
} }
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids) current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates) current_state_ids.update(state_updates)

View File

@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
""" """
event_payloads = [] event_payloads = []
for event, context in event_and_contexts: for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store) serialized_context = yield defer.ensureDeferred(
context.serialize(event, store)
)
event_payloads.append( event_payloads.append(
{ {

View File

@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event extra_users (list(UserID)): Any extra users to notify about event
""" """
serialized_context = yield context.serialize(event, store) serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
payload = { payload = {
"event": event.get_pdu_json(), "event": event.get_pdu_json(),

View File

@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def build(self, prev_event_ids): def build(self, prev_event_ids):
built_event = yield self._base_builder.build(prev_event_ids) built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids)
)
built_event._event_id = self._event_id built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id built_event._dict["event_id"] = self._event_id

View File

@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"] ctx_c = context_store["C"]
ctx_d = context_store["D"] ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids)) self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"] ctx_c = context_store["C"]
ctx_d = context_store["D"] ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values())) self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"] ctx_c = context_store["C"]
ctx_e = context_store["E"] ctx_e = context_store["E"]
prev_state_ids = yield ctx_e.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values())) self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event) self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group) self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"] ctx_b = context_store["B"]
ctx_d = context_store["D"] ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values())) self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event) self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state) self.state.compute_event_context(event, old_state=old_state)
) )
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state) self.state.compute_event_context(event, old_state=old_state)
) )
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids()) current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred(self.state.compute_event_context(event)) context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values())) self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))