mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-16 11:27:10 -05:00
Convert a synapse.events to async/await. (#7949)
This commit is contained in:
parent
5f65e62681
commit
8553f46498
@ -1 +1 @@
|
||||
Convert push to async/await.
|
||||
Convert various parts of the codebase to async/await.
|
||||
|
1
changelog.d/7949.misc
Normal file
1
changelog.d/7949.misc
Normal file
@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
@ -1 +1 @@
|
||||
Convert groups and visibility code to async / await.
|
||||
Convert various parts of the codebase to async/await.
|
||||
|
@ -82,7 +82,7 @@ class Auth(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
auth_events_ids = yield self.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
|
@ -17,8 +17,6 @@ from typing import Optional
|
||||
import attr
|
||||
from nacl.signing import SigningKey
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.errors import UnsupportedRoomVersionError
|
||||
from synapse.api.room_versions import (
|
||||
@ -95,31 +93,30 @@ class EventBuilder(object):
|
||||
def is_state(self):
|
||||
return self._state_key is not None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def build(self, prev_event_ids):
|
||||
async def build(self, prev_event_ids):
|
||||
"""Transform into a fully signed and hashed event
|
||||
|
||||
Args:
|
||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
||||
|
||||
Returns:
|
||||
Deferred[FrozenEvent]
|
||||
FrozenEvent
|
||||
"""
|
||||
|
||||
state_ids = yield defer.ensureDeferred(
|
||||
self._state.get_current_state_ids(self.room_id, prev_event_ids)
|
||||
state_ids = await self._state.get_current_state_ids(
|
||||
self.room_id, prev_event_ids
|
||||
)
|
||||
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
|
||||
auth_ids = await self._auth.compute_auth_events(self, state_ids)
|
||||
|
||||
format_version = self.room_version.event_format
|
||||
if format_version == EventFormatVersions.V1:
|
||||
auth_events = yield self._store.add_event_hashes(auth_ids)
|
||||
prev_events = yield self._store.add_event_hashes(prev_event_ids)
|
||||
auth_events = await self._store.add_event_hashes(auth_ids)
|
||||
prev_events = await self._store.add_event_hashes(prev_event_ids)
|
||||
else:
|
||||
auth_events = auth_ids
|
||||
prev_events = prev_event_ids
|
||||
|
||||
old_depth = yield self._store.get_max_depth_of(prev_event_ids)
|
||||
old_depth = await self._store.get_max_depth_of(prev_event_ids)
|
||||
depth = old_depth + 1
|
||||
|
||||
# we cap depth of generated events, to ensure that they are not
|
||||
|
@ -12,17 +12,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class EventContext:
|
||||
@ -129,8 +131,7 @@ class EventContext:
|
||||
delta_ids=delta_ids,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def serialize(self, event, store):
|
||||
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
|
||||
"""Converts self to a type that can be serialized as JSON, and then
|
||||
deserialized by `deserialize`
|
||||
|
||||
@ -146,7 +147,7 @@ class EventContext:
|
||||
# the prev_state_ids, so if we're a state event we include the event
|
||||
# id that we replaced in the state.
|
||||
if event.is_state():
|
||||
prev_state_ids = yield self.get_prev_state_ids()
|
||||
prev_state_ids = await self.get_prev_state_ids()
|
||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||
else:
|
||||
prev_state_id = None
|
||||
@ -214,8 +215,7 @@ class EventContext:
|
||||
|
||||
return self._state_group
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_ids(self):
|
||||
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||
"""
|
||||
Gets the room state map, including this event - ie, the state in ``state_group``
|
||||
|
||||
@ -224,32 +224,31 @@ class EventContext:
|
||||
``rejected`` is set.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
||||
is None, which happens when the associated event is an outlier.
|
||||
Returns None if state_group is None, which happens when the associated
|
||||
event is an outlier.
|
||||
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
if self.rejected:
|
||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||
|
||||
yield self._ensure_fetched()
|
||||
await self._ensure_fetched()
|
||||
return self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_prev_state_ids(self):
|
||||
async def get_prev_state_ids(self):
|
||||
"""
|
||||
Gets the room state map, excluding this event.
|
||||
|
||||
For a non-state event, this will be the same as get_current_state_ids().
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]|None]: Returns None if state_group
|
||||
dict[(str, str), str]|None: Returns None if state_group
|
||||
is None, which happens when the associated event is an outlier.
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
yield self._ensure_fetched()
|
||||
await self._ensure_fetched()
|
||||
return self._prev_state_ids
|
||||
|
||||
def get_cached_current_state_ids(self):
|
||||
@ -269,8 +268,8 @@ class EventContext:
|
||||
|
||||
return self._current_state_ids
|
||||
|
||||
def _ensure_fetched(self):
|
||||
return defer.succeed(None)
|
||||
async def _ensure_fetched(self):
|
||||
return None
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
|
||||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
|
||||
def _ensure_fetched(self):
|
||||
async def _ensure_fetched(self):
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||
|
||||
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||
return await make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fill_out_state(self):
|
||||
async def _fill_out_state(self):
|
||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||
attributes by loading from the database.
|
||||
"""
|
||||
if self.state_group is None:
|
||||
return
|
||||
|
||||
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
||||
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||
self.state_group
|
||||
)
|
||||
if self._event_state_key is not None:
|
||||
|
@ -13,7 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Requester
|
||||
|
||||
|
||||
class ThirdPartyEventRules(object):
|
||||
@ -39,76 +41,79 @@ class ThirdPartyEventRules(object):
|
||||
config=config, http_client=hs.get_simple_http_client()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_event_allowed(self, event, context):
|
||||
async def check_event_allowed(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> bool:
|
||||
"""Check if a provided event should be allowed in the given context.
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase): The event to be checked.
|
||||
context (synapse.events.snapshot.EventContext): The context of the event.
|
||||
event: The event to be checked.
|
||||
context: The context of the event.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool]: True if the event should be allowed, False if not.
|
||||
True if the event should be allowed, False if not.
|
||||
"""
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
# Retrieve the state events from the database.
|
||||
state_events = {}
|
||||
for key, event_id in prev_state_ids.items():
|
||||
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
|
||||
state_events[key] = await self.store.get_event(event_id, allow_none=True)
|
||||
|
||||
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
|
||||
ret = await self.third_party_rules.check_event_allowed(event, state_events)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_create_room(self, requester, config, is_requester_admin):
|
||||
async def on_create_room(
|
||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||
) -> bool:
|
||||
"""Intercept requests to create room to allow, deny or update the
|
||||
request config.
|
||||
|
||||
Args:
|
||||
requester (Requester)
|
||||
config (dict): The creation config from the client.
|
||||
is_requester_admin (bool): If the requester is an admin
|
||||
requester
|
||||
config: The creation config from the client.
|
||||
is_requester_admin: If the requester is an admin
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool]: Whether room creation is allowed or denied.
|
||||
Whether room creation is allowed or denied.
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
ret = yield self.third_party_rules.on_create_room(
|
||||
ret = await self.third_party_rules.on_create_room(
|
||||
requester, config, is_requester_admin
|
||||
)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_threepid_can_be_invited(self, medium, address, room_id):
|
||||
async def check_threepid_can_be_invited(
|
||||
self, medium: str, address: str, room_id: str
|
||||
) -> bool:
|
||||
"""Check if a provided 3PID can be invited in the given room.
|
||||
|
||||
Args:
|
||||
medium (str): The 3PID's medium.
|
||||
address (str): The 3PID's address.
|
||||
room_id (str): The room we want to invite the threepid to.
|
||||
medium: The 3PID's medium.
|
||||
address: The 3PID's address.
|
||||
room_id: The room we want to invite the threepid to.
|
||||
|
||||
Returns:
|
||||
defer.Deferred[bool], True if the 3PID can be invited, False if not.
|
||||
True if the 3PID can be invited, False if not.
|
||||
"""
|
||||
|
||||
if self.third_party_rules is None:
|
||||
return True
|
||||
|
||||
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
|
||||
room_state_events = yield self.store.get_events(state_ids.values())
|
||||
state_ids = await self.store.get_filtered_current_state_ids(room_id)
|
||||
room_state_events = await self.store.get_events(state_ids.values())
|
||||
|
||||
state_events = {}
|
||||
for key, event_id in state_ids.items():
|
||||
state_events[key] = room_state_events[event_id]
|
||||
|
||||
ret = yield self.third_party_rules.check_threepid_can_be_invited(
|
||||
ret = await self.third_party_rules.check_threepid_can_be_invited(
|
||||
medium, address, state_events
|
||||
)
|
||||
return ret
|
||||
|
@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
@ -337,8 +335,9 @@ class EventClientSerializer(object):
|
||||
hs.config.experimental_msc1849_support_enabled
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
|
||||
async def serialize_event(
|
||||
self, event, time_now, bundle_aggregations=True, **kwargs
|
||||
):
|
||||
"""Serializes a single event.
|
||||
|
||||
Args:
|
||||
@ -348,7 +347,7 @@ class EventClientSerializer(object):
|
||||
**kwargs: Arguments to pass to `serialize_event`
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The serialized event
|
||||
dict: The serialized event
|
||||
"""
|
||||
# To handle the case of presence events and the like
|
||||
if not isinstance(event, EventBase):
|
||||
@ -363,8 +362,8 @@ class EventClientSerializer(object):
|
||||
if not event.internal_metadata.is_redacted() and (
|
||||
self.experimental_msc1849_support_enabled and bundle_aggregations
|
||||
):
|
||||
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
|
||||
references = yield self.store.get_relations_for_event(
|
||||
annotations = await self.store.get_aggregation_groups_for_event(event_id)
|
||||
references = await self.store.get_relations_for_event(
|
||||
event_id, RelationTypes.REFERENCE, direction="f"
|
||||
)
|
||||
|
||||
@ -378,7 +377,7 @@ class EventClientSerializer(object):
|
||||
|
||||
edit = None
|
||||
if event.type == EventTypes.Message:
|
||||
edit = yield self.store.get_applicable_edit(event_id)
|
||||
edit = await self.store.get_applicable_edit(event_id)
|
||||
|
||||
if edit:
|
||||
# If there is an edit replace the content, preserving existing
|
||||
|
@ -2470,7 +2470,7 @@ class FederationHandler(BaseHandler):
|
||||
}
|
||||
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
current_state_ids = dict(current_state_ids)
|
||||
current_state_ids = dict(current_state_ids) # type: ignore
|
||||
|
||||
current_state_ids.update(state_updates)
|
||||
|
||||
|
@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
|
||||
"""
|
||||
event_payloads = []
|
||||
for event, context in event_and_contexts:
|
||||
serialized_context = yield context.serialize(event, store)
|
||||
serialized_context = yield defer.ensureDeferred(
|
||||
context.serialize(event, store)
|
||||
)
|
||||
|
||||
event_payloads.append(
|
||||
{
|
||||
|
@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
|
||||
extra_users (list(UserID)): Any extra users to notify about event
|
||||
"""
|
||||
|
||||
serialized_context = yield context.serialize(event, store)
|
||||
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
|
||||
|
||||
payload = {
|
||||
"event": event.get_pdu_json(),
|
||||
|
@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def build(self, prev_event_ids):
|
||||
built_event = yield self._base_builder.build(prev_event_ids)
|
||||
built_event = yield defer.ensureDeferred(
|
||||
self._base_builder.build(prev_event_ids)
|
||||
)
|
||||
|
||||
built_event._event_id = self._event_id
|
||||
built_event._dict["event_id"] = self._event_id
|
||||
|
@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
|
||||
ctx_c = context_store["C"]
|
||||
ctx_d = context_store["D"]
|
||||
|
||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
|
||||
self.assertEqual(2, len(prev_state_ids))
|
||||
|
||||
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
||||
@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
|
||||
ctx_c = context_store["C"]
|
||||
ctx_d = context_store["D"]
|
||||
|
||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
|
||||
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
|
||||
|
||||
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
|
||||
@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
|
||||
ctx_c = context_store["C"]
|
||||
ctx_e = context_store["E"]
|
||||
|
||||
prev_state_ids = yield ctx_e.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
|
||||
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
|
||||
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
|
||||
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
|
||||
@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
|
||||
ctx_b = context_store["B"]
|
||||
ctx_d = context_store["D"]
|
||||
|
||||
prev_state_ids = yield ctx_d.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
|
||||
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
|
||||
|
||||
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
|
||||
@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
||||
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
|
||||
self.state.compute_event_context(event, old_state=old_state)
|
||||
)
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
|
||||
|
||||
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
|
||||
@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
|
||||
|
||||
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
|
||||
|
||||
prev_state_ids = yield context.get_prev_state_ids()
|
||||
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
|
||||
|
||||
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user