Add remaining type hints to synapse.events. (#11098)

This commit is contained in:
Patrick Cloke 2021-11-02 09:55:52 -04:00 committed by GitHub
parent 4535532526
commit c01bc5f43d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 184 additions and 109 deletions

1
changelog.d/11098.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.events`.

View File

@ -22,13 +22,7 @@ files =
synapse/config, synapse/config,
synapse/crypto, synapse/crypto,
synapse/event_auth.py, synapse/event_auth.py,
synapse/events/builder.py, synapse/events,
synapse/events/presence_router.py,
synapse/events/snapshot.py,
synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/utils.py,
synapse/events/validator.py,
synapse/federation, synapse/federation,
synapse/groups, synapse/groups,
synapse/handlers, synapse/handlers,

View File

@ -16,8 +16,23 @@
import abc import abc
import os import os
from typing import Dict, Optional, Tuple, Type from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
overload,
)
from typing_extensions import Literal
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze from synapse.util.frozenutils import freeze
from synapse.util.stringutils import strtobool from synapse.util.stringutils import strtobool
if TYPE_CHECKING:
from synapse.events.builder import EventBuilder
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting a # bugs where we accidentally share e.g. signature dicts. However, converting a
# dict to frozen_dicts is expensive. # dict to frozen_dicts is expensive.
@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
class DictProperty: T = TypeVar("T")
# DictProperty (and DefaultDictProperty) require the classes they're used with to
# have a _dict property to pull properties from.
#
# TODO _DictPropertyInstance should not include EventBuilder but due to
# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and
# DefaultDictProperty get applied to EventBuilder when it is in a Union with
# EventBase. This is the least invasive hack to get mypy to comply.
#
# Note that DictProperty/DefaultDictProperty cannot actually be used with
# EventBuilder as it lacks a _dict property.
_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"]
class DictProperty(Generic[T]):
"""An object property which delegates to the `_dict` within its parent object.""" """An object property which delegates to the `_dict` within its parent object."""
__slots__ = ["key"] __slots__ = ["key"]
@ -45,12 +79,33 @@ class DictProperty:
def __init__(self, key: str): def __init__(self, key: str):
self.key = key self.key = key
def __get__(self, instance, owner=None): @overload
def __get__(
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DictProperty":
...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T:
...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> Union[T, "DictProperty"]:
# if the property is accessed as a class property rather than an instance # if the property is accessed as a class property rather than an instance
# property, return the property itself rather than the value # property, return the property itself rather than the value
if instance is None: if instance is None:
return self return self
try: try:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict[self.key] return instance._dict[self.key]
except KeyError as e1: except KeyError as e1:
# We want this to look like a regular attribute error (mostly so that # We want this to look like a regular attribute error (mostly so that
@ -65,10 +120,12 @@ class DictProperty:
"'%s' has no '%s' property" % (type(instance), self.key) "'%s' has no '%s' property" % (type(instance), self.key)
) from e1.__context__ ) from e1.__context__
def __set__(self, instance, v): def __set__(self, instance: _DictPropertyInstance, v: T) -> None:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
instance._dict[self.key] = v instance._dict[self.key] = v
def __delete__(self, instance): def __delete__(self, instance: _DictPropertyInstance) -> None:
assert isinstance(instance, (EventBase, _EventInternalMetadata))
try: try:
del instance._dict[self.key] del instance._dict[self.key]
except KeyError as e1: except KeyError as e1:
@ -77,7 +134,7 @@ class DictProperty:
) from e1.__context__ ) from e1.__context__
class DefaultDictProperty(DictProperty): class DefaultDictProperty(DictProperty, Generic[T]):
"""An extension of DictProperty which provides a default if the property is """An extension of DictProperty which provides a default if the property is
not present in the parent's _dict. not present in the parent's _dict.
@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty):
__slots__ = ["default"] __slots__ = ["default"]
def __init__(self, key, default): def __init__(self, key: str, default: T):
super().__init__(key) super().__init__(key)
self.default = default self.default = default
def __get__(self, instance, owner=None): @overload
def __get__(
self,
instance: Literal[None],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> "DefaultDictProperty":
...
@overload
def __get__(
self,
instance: _DictPropertyInstance,
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> T:
...
def __get__(
self,
instance: Optional[_DictPropertyInstance],
owner: Optional[Type[_DictPropertyInstance]] = None,
) -> Union[T, "DefaultDictProperty"]:
if instance is None: if instance is None:
return self return self
assert isinstance(instance, (EventBase, _EventInternalMetadata))
return instance._dict.get(self.key, self.default) return instance._dict.get(self.key, self.default)
@ -111,22 +189,22 @@ class _EventInternalMetadata:
# in the DAG) # in the DAG)
self.outlier = False self.outlier = False
out_of_band_membership: bool = DictProperty("out_of_band_membership") out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership")
send_on_behalf_of: str = DictProperty("send_on_behalf_of") send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of")
recheck_redaction: bool = DictProperty("recheck_redaction") recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction")
soft_failed: bool = DictProperty("soft_failed") soft_failed: DictProperty[bool] = DictProperty("soft_failed")
proactively_send: bool = DictProperty("proactively_send") proactively_send: DictProperty[bool] = DictProperty("proactively_send")
redacted: bool = DictProperty("redacted") redacted: DictProperty[bool] = DictProperty("redacted")
txn_id: str = DictProperty("txn_id") txn_id: DictProperty[str] = DictProperty("txn_id")
token_id: int = DictProperty("token_id") token_id: DictProperty[int] = DictProperty("token_id")
historical: bool = DictProperty("historical") historical: DictProperty[bool] = DictProperty("historical")
# XXX: These are set by StreamWorkerStore._set_before_and_after. # XXX: These are set by StreamWorkerStore._set_before_and_after.
# I'm pretty sure that these are never persisted to the database, so shouldn't # I'm pretty sure that these are never persisted to the database, so shouldn't
# be here # be here
before: RoomStreamToken = DictProperty("before") before: DictProperty[RoomStreamToken] = DictProperty("before")
after: RoomStreamToken = DictProperty("after") after: DictProperty[RoomStreamToken] = DictProperty("after")
order: Tuple[int, int] = DictProperty("order") order: DictProperty[Tuple[int, int]] = DictProperty("order")
def get_dict(self) -> JsonDict: def get_dict(self) -> JsonDict:
return dict(self._dict) return dict(self._dict)
@ -162,9 +240,6 @@ class _EventInternalMetadata:
If the sender of the redaction event is allowed to redact any event If the sender of the redaction event is allowed to redact any event
due to auth rules, then this will always return false. due to auth rules, then this will always return false.
Returns:
bool
""" """
return self._dict.get("recheck_redaction", False) return self._dict.get("recheck_redaction", False)
@ -176,32 +251,23 @@ class _EventInternalMetadata:
sent to clients. sent to clients.
2. They should not be added to the forward extremities (and 2. They should not be added to the forward extremities (and
therefore not to current state). therefore not to current state).
Returns:
bool
""" """
return self._dict.get("soft_failed", False) return self._dict.get("soft_failed", False)
def should_proactively_send(self): def should_proactively_send(self) -> bool:
"""Whether the event, if ours, should be sent to other clients and """Whether the event, if ours, should be sent to other clients and
servers. servers.
This is used for sending dummy events internally. Servers and clients This is used for sending dummy events internally. Servers and clients
can still explicitly fetch the event. can still explicitly fetch the event.
Returns:
bool
""" """
return self._dict.get("proactively_send", True) return self._dict.get("proactively_send", True)
def is_redacted(self): def is_redacted(self) -> bool:
"""Whether the event has been redacted. """Whether the event has been redacted.
This is used for efficiently checking whether an event has been This is used for efficiently checking whether an event has been
marked as redacted without needing to make another database call. marked as redacted without needing to make another database call.
Returns:
bool
""" """
return self._dict.get("redacted", False) return self._dict.get("redacted", False)
@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta):
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict) self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
auth_events = DictProperty("auth_events") depth: DictProperty[int] = DictProperty("depth")
depth = DictProperty("depth") content: DictProperty[JsonDict] = DictProperty("content")
content = DictProperty("content") hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
hashes = DictProperty("hashes") origin: DictProperty[str] = DictProperty("origin")
origin = DictProperty("origin") origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
origin_server_ts = DictProperty("origin_server_ts") redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
prev_events = DictProperty("prev_events") room_id: DictProperty[str] = DictProperty("room_id")
redacts = DefaultDictProperty("redacts", None) sender: DictProperty[str] = DictProperty("sender")
room_id = DictProperty("room_id") # TODO state_key should be Optional[str], this is generally asserted in Synapse
sender = DictProperty("sender") # by calling is_state() first (which ensures this), but it is hard (not possible?)
state_key = DictProperty("state_key") # to properly annotate that calling is_state() asserts that state_key exists
type = DictProperty("type") # and is non-None.
user_id = DictProperty("sender") state_key: DictProperty[str] = DictProperty("state_key")
type: DictProperty[str] = DictProperty("type")
user_id: DictProperty[str] = DictProperty("sender")
@property @property
def event_id(self) -> str: def event_id(self) -> str:
raise NotImplementedError() raise NotImplementedError()
@property @property
def membership(self): def membership(self) -> str:
return self.content["membership"] return self.content["membership"]
def is_state(self): def is_state(self) -> bool:
return hasattr(self, "state_key") and self.state_key is not None return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self) -> JsonDict: def get_dict(self) -> JsonDict:
@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta):
return d return d
def get(self, key, default=None): def get(self, key: str, default: Optional[Any] = None) -> Any:
return self._dict.get(key, default) return self._dict.get(key, default)
def get_internal_metadata_dict(self): def get_internal_metadata_dict(self) -> JsonDict:
return self.internal_metadata.get_dict() return self.internal_metadata.get_dict()
def get_pdu_json(self, time_now=None) -> JsonDict: def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict:
pdu_json = self.get_dict() pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]: if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta):
return template_json return template_json
def __set__(self, instance, value): def __getitem__(self, field: str) -> Optional[Any]:
raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
return self._dict[field] return self._dict[field]
def __contains__(self, field): def __contains__(self, field: str) -> bool:
return field in self._dict return field in self._dict
def items(self): def items(self) -> List[Tuple[str, Optional[Any]]]:
return list(self._dict.items()) return list(self._dict.items())
def keys(self): def keys(self) -> Iterable[str]:
return self._dict.keys() return self._dict.keys()
def prev_event_ids(self): def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.
Returns: Returns:
list[str]: The list of event IDs of this event's prev_events The list of event IDs of this event's prev_events
""" """
return [e for e, _ in self.prev_events] return [e for e, _ in self._dict["prev_events"]]
def auth_event_ids(self): def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order """Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.
Returns: Returns:
list[str]: The list of event IDs of this event's auth_events The list of event IDs of this event's auth_events
""" """
return [e for e, _ in self.auth_events] return [e for e, _ in self._dict["auth_events"]]
def freeze(self): def freeze(self) -> None:
"""'Freeze' the event dict, so it cannot be modified by accident""" """'Freeze' the event dict, so it cannot be modified by accident"""
# this will be a no-op if the event dict is already frozen. # this will be a no-op if the event dict is already frozen.
self._dict = freeze(self._dict) self._dict = freeze(self._dict)
def __str__(self): def __str__(self) -> str:
return self.__repr__() return self.__repr__()
def __repr__(self): def __repr__(self) -> str:
rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
return ( return (
@ -443,7 +508,7 @@ class FrozenEventV2(EventBase):
else: else:
frozen_dict = event_dict frozen_dict = event_dict
self._event_id = None self._event_id: Optional[str] = None
super().__init__( super().__init__(
frozen_dict, frozen_dict,
@ -455,7 +520,7 @@ class FrozenEventV2(EventBase):
) )
@property @property
def event_id(self): def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which # We have to import this here as otherwise we get an import loop which
# is hard to break. # is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
@ -465,23 +530,23 @@ class FrozenEventV2(EventBase):
self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
return self._event_id return self._event_id
def prev_event_ids(self): def prev_event_ids(self) -> Sequence[str]:
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.
Returns: Returns:
list[str]: The list of event IDs of this event's prev_events The list of event IDs of this event's prev_events
""" """
return self.prev_events return self._dict["prev_events"]
def auth_event_ids(self): def auth_event_ids(self) -> Sequence[str]:
"""Returns the list of auth event IDs. The order matches the order """Returns the list of auth event IDs. The order matches the order
specified in the event, though there is no meaning to it. specified in the event, though there is no meaning to it.
Returns: Returns:
list[str]: The list of event IDs of this event's auth_events The list of event IDs of this event's auth_events
""" """
return self.auth_events return self._dict["auth_events"]
class FrozenEventV3(FrozenEventV2): class FrozenEventV3(FrozenEventV2):
@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2):
format_version = EventFormatVersions.V3 # All events of this type are V3 format_version = EventFormatVersions.V3 # All events of this type are V3
@property @property
def event_id(self): def event_id(self) -> str:
# We have to import this here as otherwise we get an import loop which # We have to import this here as otherwise we get an import loop which
# is hard to break. # is hard to break.
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id return self._event_id
def _event_type_from_format_version(format_version: int) -> Type[EventBase]: def _event_type_from_format_version(
format_version: int,
) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
"""Returns the python type to use to construct an Event object for the """Returns the python type to use to construct an Event object for the
given event format version. given event format version.
Args: Args:
format_version (int): The event format version format_version: The event format version
Returns: Returns:
type: A type that can be initialized as per the initializer of type: A type that can be initialized as per the initializer of

View File

@ -55,7 +55,7 @@ class EventValidator:
] ]
for k in required: for k in required:
if not hasattr(event, k): if k not in event:
raise SynapseError(400, "Event does not have key %s" % (k,)) raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values # Check that the following keys have string values

View File

@ -1643,7 +1643,7 @@ class FederationEventHandler:
event: the event whose auth_events we want event: the event whose auth_events we want
Returns: Returns:
all of the events in `event.auth_events`, after deduplication all of the events listed in `event.auth_events_ids`, after deduplication
Raises: Raises:
AuthError if we were unable to fetch the auth_events for any reason. AuthError if we were unable to fetch the auth_events for any reason.

View File

@ -1318,6 +1318,8 @@ class EventCreationHandler:
# user is actually admin or not). # user is actually admin or not).
is_admin_redaction = False is_admin_redaction = False
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
assert event.redacts is not None
original_event = await self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
@ -1413,6 +1415,8 @@ class EventCreationHandler:
) )
if event.type == EventTypes.Redaction: if event.type == EventTypes.Redaction:
assert event.redacts is not None
original_event = await self.store.get_event( original_event = await self.store.get_event(
event.redacts, event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS, redact_behaviour=EventRedactBehaviour.AS_IS,
@ -1500,6 +1504,8 @@ class EventCreationHandler:
next_batch_id = event.content.get( next_batch_id = event.content.get(
EventContentFields.MSC2716_NEXT_BATCH_ID EventContentFields.MSC2716_NEXT_BATCH_ID
) )
conflicting_insertion_event_id = None
if next_batch_id:
conflicting_insertion_event_id = ( conflicting_insertion_event_id = (
await self.store.get_insertion_event_by_batch_id( await self.store.get_insertion_event_by_batch_id(
event.room_id, next_batch_id event.room_id, next_batch_id

View File

@ -525,7 +525,7 @@ class RoomCreationHandler:
): ):
await self.room_member_handler.update_membership( await self.room_member_handler.update_membership(
requester, requester,
UserID.from_string(old_event["state_key"]), UserID.from_string(old_event.state_key),
new_room_id, new_room_id,
"ban", "ban",
ratelimit=False, ratelimit=False,

View File

@ -355,7 +355,7 @@ class RoomBatchHandler:
for (event, context) in reversed(events_to_persist): for (event, context) in reversed(events_to_persist):
await self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
await self.create_requester_for_user_id_from_app_service( await self.create_requester_for_user_id_from_app_service(
event["sender"], app_service_requester.app_service event.sender, app_service_requester.app_service
), ),
event=event, event=event,
context=context, context=context,

View File

@ -1669,7 +1669,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# #
# the prev_events consist solely of the previous membership event. # the prev_events consist solely of the previous membership event.
prev_event_ids = [previous_membership_event.event_id] prev_event_ids = [previous_membership_event.event_id]
auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids auth_event_ids = (
list(previous_membership_event.auth_event_ids()) + prev_event_ids
)
event, context = await self.event_creation_handler.create_event( event, context = await self.event_creation_handler.create_event(
requester, requester,

View File

@ -232,6 +232,8 @@ class BulkPushRuleEvaluator:
# that user, as they might not be already joined. # that user, as they might not be already joined.
if event.type == EventTypes.Member and event.state_key == uid: if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None) display_name = event.content.get("displayname", None)
if not isinstance(display_name, str):
display_name = None
if count_as_unread: if count_as_unread:
# Add an element for the current user if the event needs to be marked as # Add an element for the current user if the event needs to be marked as
@ -268,7 +270,7 @@ def _condition_checker(
evaluator: PushRuleEvaluatorForEvent, evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict], conditions: List[dict],
uid: str, uid: str,
display_name: str, display_name: Optional[str],
cache: Dict[str, bool], cache: Dict[str, bool],
) -> bool: ) -> bool:
for cond in conditions: for cond in conditions:

View File

@ -18,7 +18,7 @@ import re
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import UserID from synapse.types import JsonDict, UserID
from synapse.util import glob_to_regex, re_word_boundary from synapse.util import glob_to_regex, re_word_boundary
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
@ -129,7 +129,7 @@ class PushRuleEvaluatorForEvent:
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
def matches( def matches(
self, condition: Dict[str, Any], user_id: str, display_name: str self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
) -> bool: ) -> bool:
if condition["kind"] == "event_match": if condition["kind"] == "event_match":
return self._event_match(condition, user_id) return self._event_match(condition, user_id)
@ -172,7 +172,7 @@ class PushRuleEvaluatorForEvent:
return _glob_matches(pattern, haystack) return _glob_matches(pattern, haystack)
def _contains_display_name(self, display_name: str) -> bool: def _contains_display_name(self, display_name: Optional[str]) -> bool:
if not display_name: if not display_name:
return False return False
@ -222,7 +222,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
def _flatten_dict( def _flatten_dict(
d: Union[EventBase, dict], d: Union[EventBase, JsonDict],
prefix: Optional[List[str]] = None, prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None, result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]: ) -> Dict[str, str]:
@ -233,7 +233,7 @@ def _flatten_dict(
for key, value in d.items(): for key, value in d.items():
if isinstance(value, str): if isinstance(value, str):
result[".".join(prefix + [key])] = value.lower() result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"): elif isinstance(value, dict):
_flatten_dict(value, prefix=(prefix + [key]), result=result) _flatten_dict(value, prefix=(prefix + [key]), result=result)
return result return result

View File

@ -191,7 +191,7 @@ class RoomBatchSendEventRestServlet(RestServlet):
depth=inherited_depth, depth=inherited_depth,
) )
batch_id_to_connect_to = base_insertion_event["content"][ batch_id_to_connect_to = base_insertion_event.content[
EventContentFields.MSC2716_NEXT_BATCH_ID EventContentFields.MSC2716_NEXT_BATCH_ID
] ]

View File

@ -247,7 +247,7 @@ class StateHandler:
return await self.get_hosts_in_room_at_events(room_id, event_ids) return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: List[str] self, room_id: str, event_ids: Iterable[str]
) -> Set[str]: ) -> Set[str]:
"""Get the hosts that were in a room at the given event ids """Get the hosts that were in a room at the given event ids

View File

@ -24,6 +24,7 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sequence,
Set, Set,
Tuple, Tuple,
) )
@ -494,7 +495,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator, event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, Sequence[str]],
) -> None: ) -> None:
"""Calculate the chain cover index for the given events. """Calculate the chain cover index for the given events.
@ -786,7 +787,7 @@ class PersistEventsStore:
event_chain_id_gen: SequenceGenerator, event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]], event_to_auth_chain: Dict[str, Sequence[str]],
events_to_calc_chain_id_for: Set[str], events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]], chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]: ) -> Dict[str, Tuple[int, int]]:
@ -1794,7 +1795,7 @@ class PersistEventsStore:
) )
# Insert an edge for every prev_event connection # Insert an edge for every prev_event connection
for prev_event_id in event.prev_events: for prev_event_id in event.prev_event_ids():
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
txn, txn,
table="insertion_event_edges", table="insertion_event_edges",

View File

@ -570,7 +570,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
async def get_joined_users_from_context( async def get_joined_users_from_context(
self, event: EventBase, context: EventContext self, event: EventBase, context: EventContext
): ) -> Dict[str, ProfileInfo]:
state_group = context.state_group state_group = context.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -584,7 +584,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event.room_id, state_group, current_state_ids, event=event, context=context event.room_id, state_group, current_state_ids, event=event, context=context
) )
async def get_joined_users_from_state(self, room_id, state_entry): async def get_joined_users_from_state(
self, room_id, state_entry
) -> Dict[str, ProfileInfo]:
state_group = state_entry.state_group state_group = state_entry.state_group
if not state_group: if not state_group:
# If state_group is None it means it has yet to be assigned a # If state_group is None it means it has yet to be assigned a
@ -607,7 +609,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache_context, cache_context,
event=None, event=None,
context=None, context=None,
): ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states # on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different. # with a state_group of None are likely to be different.