Add type hints to synapse.events.*. (#11066)

Except `synapse/events/__init__.py`, which will be done in a follow-up.
This commit is contained in:
Patrick Cloke 2021-10-13 07:24:07 -04:00 committed by GitHub
parent cdd308845b
commit 1f9d0b8a7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 208 additions and 145 deletions

1
changelog.d/11066.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.events`.

View File

@ -22,8 +22,11 @@ files =
synapse/crypto, synapse/crypto,
synapse/event_auth.py, synapse/event_auth.py,
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/presence_router.py,
synapse/events/snapshot.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/events/third_party_rules.py, synapse/events/third_party_rules.py,
synapse/events/utils.py,
synapse/events/validator.py, synapse/events/validator.py,
synapse/federation, synapse/federation,
synapse/groups, synapse/groups,
@ -96,6 +99,9 @@ files =
tests/util/test_itertools.py, tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py tests/util/test_stream_change_cache.py
[mypy-synapse.events.*]
disallow_untyped_defs = True
[mypy-synapse.handlers.*] [mypy-synapse.handlers.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@ -90,13 +90,13 @@ class EventBuilder:
) )
@property @property
def state_key(self): def state_key(self) -> str:
if self._state_key is not None: if self._state_key is not None:
return self._state_key return self._state_key
raise AttributeError("state_key") raise AttributeError("state_key")
def is_state(self): def is_state(self) -> bool:
return self._state_key is not None return self._state_key is not None
async def build( async def build(

View File

@ -14,6 +14,7 @@
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any,
Awaitable, Awaitable,
Callable, Callable,
Dict, Dict,
@ -33,14 +34,13 @@ if TYPE_CHECKING:
GET_USERS_FOR_STATES_CALLBACK = Callable[ GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]] [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
] ]
GET_INTERESTED_USERS_CALLBACK = Callable[ # This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]] GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_legacy_presence_router(hs: "HomeServer"): def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old """Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement. configuration, and registers the hooks they implement.
""" """
@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
if f is None: if f is None:
return None return None
def run(*args, **kwargs): def run(*args: Any, **kwargs: Any) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it # Assertion required because mypy can't prove we won't change `f`
# f is definitely not None. # back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None assert f is not None
return maybe_awaitable(f(*args, **kwargs)) return maybe_awaitable(f(*args, **kwargs))
@ -104,7 +105,7 @@ class PresenceRouter:
self, self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None, get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None, get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
): ) -> None:
# PresenceRouter modules are required to implement both of these methods # PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner # or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users] paired_methods = [get_users_for_states, get_interested_users]
@ -142,7 +143,7 @@ class PresenceRouter:
# Don't include any extra destinations for presence updates # Don't include any extra destinations for presence updates
return {} return {}
users_for_states = {} users_for_states: Dict[str, Set[UserPresenceState]] = {}
# run all the callbacks for get_users_for_states and combine the results # run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks: for callback in self._get_users_for_states_callbacks:
try: try:
@ -171,7 +172,7 @@ class PresenceRouter:
return users_for_states return users_for_states
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]: async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
""" """
Retrieve a list of users that `user_id` is interested in receiving the Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with. presence of. This will be in addition to those they share a room with.

View File

@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from twisted.internet.defer import Deferred
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap from synapse.types import JsonDict, StateMap
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
@ -112,13 +115,13 @@ class EventContext:
@staticmethod @staticmethod
def with_state( def with_state(
state_group, state_group: Optional[int],
state_group_before_event, state_group_before_event: Optional[int],
current_state_ids, current_state_ids: Optional[StateMap[str]],
prev_state_ids, prev_state_ids: Optional[StateMap[str]],
prev_group=None, prev_group: Optional[int] = None,
delta_ids=None, delta_ids: Optional[StateMap[str]] = None,
): ) -> "EventContext":
return EventContext( return EventContext(
current_state_ids=current_state_ids, current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids, prev_state_ids=prev_state_ids,
@ -129,22 +132,22 @@ class EventContext:
) )
@staticmethod @staticmethod
def for_outlier(): def for_outlier() -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event""" """Return an EventContext instance suitable for persisting an outlier event"""
return EventContext( return EventContext(
current_state_ids={}, current_state_ids={},
prev_state_ids={}, prev_state_ids={},
) )
async def serialize(self, event: EventBase, store: "DataStore") -> dict: async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then """Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize` deserialized by `deserialize`
Args: Args:
event (FrozenEvent): The event that this context relates to event: The event that this context relates to
Returns: Returns:
dict The serialized event.
""" """
# We don't serialize the full state dicts, instead they get pulled out # We don't serialize the full state dicts, instead they get pulled out
@ -170,17 +173,16 @@ class EventContext:
} }
@staticmethod @staticmethod
def deserialize(storage, input): def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a """Converts a dict that was produced by `serialize` back into a
EventContext. EventContext.
Args: Args:
storage (Storage): Used to convert AS ID to AS object and fetch storage: Used to convert AS ID to AS object and fetch state.
state. input: A dict produced by `serialize`
input (dict): A dict produced by `serialize`
Returns: Returns:
EventContext The event context.
""" """
context = _AsyncEventContextImpl( context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the # We use the state_group and prev_state_id stuff to pull the
@ -241,22 +243,25 @@ class EventContext:
await self._ensure_fetched() await self._ensure_fetched()
return self._current_state_ids return self._current_state_ids
async def get_prev_state_ids(self): async def get_prev_state_ids(self) -> StateMap[str]:
""" """
Gets the room state map, excluding this event. Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids(). For a non-state event, this will be the same as get_current_state_ids().
Returns: Returns:
dict[(str, str), str]|None: Returns None if state_group Returns {} if state_group is None, which happens when the associated
is None, which happens when the associated event is an outlier. event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple. Maps a (type, state_key) to the event ID of the state event matching
this tuple.
""" """
await self._ensure_fetched() await self._ensure_fetched()
# There *should* be previous state IDs now.
assert self._prev_state_ids is not None
return self._prev_state_ids return self._prev_state_ids
def get_cached_current_state_ids(self): def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
"""Gets the current state IDs if we have them already cached. """Gets the current state IDs if we have them already cached.
It is an error to access this for a rejected event, since rejected state should It is an error to access this for a rejected event, since rejected state should
@ -264,16 +269,17 @@ class EventContext:
``rejected`` is set. ``rejected`` is set.
Returns: Returns:
dict[(str, str), str]|None: Returns None if we haven't cached the Returns None if we haven't cached the state or if state_group is None
state or if state_group is None, which happens when the associated (which happens when the associated event is an outlier).
event is an outlier.
Otherwise, returns the the current state IDs.
""" """
if self.rejected: if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event") raise RuntimeError("Attempt to access state_ids of rejected event")
return self._current_state_ids return self._current_state_ids
async def _ensure_fetched(self): async def _ensure_fetched(self) -> None:
return None return None
@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
Attributes: Attributes:
_storage (Storage) _storage
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have _fetching_state_deferred: Resolves when *_state_ids have been calculated.
been calculated. None if we haven't started calculating yet None if we haven't started calculating yet
_event_type (str): The type of the event the context is associated with. _event_type: The type of the event the context is associated with.
_event_state_key (str): The state_key of the event the context is _event_state_key: The state_key of the event the context is associated with.
associated with.
_prev_state_id (str|None): If the event associated with the context is _prev_state_id: If the event associated with the context is a state event,
a state event, then `_prev_state_id` is the event_id of the state then `_prev_state_id` is the event_id of the state that was replaced.
that was replaced.
""" """
# This needs to have a default as we're inheriting # This needs to have a default as we're inheriting
_storage = attr.ib(default=None) _storage: "Storage" = attr.ib(default=None)
_prev_state_id = attr.ib(default=None) _prev_state_id: Optional[str] = attr.ib(default=None)
_event_type = attr.ib(default=None) _event_type: str = attr.ib(default=None)
_event_state_key = attr.ib(default=None) _event_state_key: Optional[str] = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None) _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
async def _ensure_fetched(self): async def _ensure_fetched(self) -> None:
if not self._fetching_state_deferred: if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state) self._fetching_state_deferred = run_in_background(self._fill_out_state)
return await make_deferred_yieldable(self._fetching_state_deferred) await make_deferred_yieldable(self._fetching_state_deferred)
async def _fill_out_state(self): async def _fill_out_state(self) -> None:
"""Called to populate the _current_state_ids and _prev_state_ids """Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database. attributes by loading from the database.
""" """
if self.state_group is None: if self.state_group is None:
return return
self._current_state_ids = await self._storage.state.get_state_ids_for_group( current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group self.state_group
) )
# Set this separately so mypy knows current_state_ids is not None.
self._current_state_ids = current_state_ids
if self._event_state_key is not None: if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids) self._prev_state_ids = dict(current_state_ids)
key = (self._event_type, self._event_state_key) key = (self._event_type, self._event_state_key)
if self._prev_state_id: if self._prev_state_id:
@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
else: else:
self._prev_state_ids.pop(key, None) self._prev_state_ids.pop(key, None)
else: else:
self._prev_state_ids = self._current_state_ids self._prev_state_ids = current_state_ids
def _encode_state_dict(state_dict): def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
"""Since dicts of (type, state_key) -> event_id cannot be serialized in """Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can. JSON we need to convert them to a form that can.
""" """
@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()] return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
def _decode_state_dict(input): def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]]
) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above""" """Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None: if input is None:
return None return None

View File

@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
] ]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
"""Wrapper that loads spam checkers configured using the old configuration, and """Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement. registers the spam checker hooks they implement.
""" """
@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
request_info: Collection[Tuple[str, str]], request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str], auth_provider_id: Optional[str],
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]: ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
# We've already made sure f is not None above, but mypy doesn't # Assertion required because mypy can't prove we won't
# do well across function boundaries so we need to tell it f is # change `f` back to `None`. See
# definitely not None. # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None assert f is not None
return f( return f(
@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
"Bad signature for callback check_registration_for_spam", "Bad signature for callback check_registration_for_spam",
) )
def run(*args, **kwargs): def run(*args: Any, **kwargs: Any) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it # Assertion required because mypy can't prove we won't change `f`
# wrapped_func is definitely not None. # back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert wrapped_func is not None assert wrapped_func is not None
return maybe_awaitable(wrapped_func(*args, **kwargs)) return maybe_awaitable(wrapped_func(*args, **kwargs))
@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
class SpamChecker: class SpamChecker:
def __init__(self): def __init__(self) -> None:
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = [] self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@ -209,7 +210,7 @@ class SpamChecker:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None, ] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None, check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
): ) -> None:
"""Register callbacks from module for each hook.""" """Register callbacks from module for each hook."""
if check_event_for_spam is not None: if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam) self._check_event_for_spam_callbacks.append(check_event_for_spam)
@ -275,7 +276,9 @@ class SpamChecker:
return False return False
async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool): async def user_may_join_room(
self, user_id: str, room_id: str, is_invited: bool
) -> bool:
"""Checks if a given users is allowed to join a room. """Checks if a given users is allowed to join a room.
Not called when a user creates a room. Not called when a user creates a room.
@ -285,7 +288,7 @@ class SpamChecker:
is_invited: Whether the user is invited into the room is_invited: Whether the user is invited into the room
Returns: Returns:
bool: Whether the user may join the room Whether the user may join the room
""" """
for callback in self._user_may_join_room_callbacks: for callback in self._user_may_join_room_callbacks:
if await callback(user_id, room_id, is_invited) is False: if await callback(user_id, room_id, is_invited) is False:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events import EventBase from synapse.events import EventBase
@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
] ]
def load_legacy_third_party_event_rules(hs: "HomeServer"): def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
"""Wrapper that loads a third party event rules module configured using the old """Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement. configuration, and registers the hooks they implement.
""" """
@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
event: EventBase, event: EventBase,
state_events: StateMap[EventBase], state_events: StateMap[EventBase],
) -> Tuple[bool, Optional[dict]]: ) -> Tuple[bool, Optional[dict]]:
# We've already made sure f is not None above, but mypy doesn't do well # Assertion required because mypy can't prove we won't change
# across function boundaries so we need to tell it f is definitely not # `f` back to `None`. See
# None. # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None assert f is not None
res = await f(event, state_events) res = await f(event, state_events)
@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
async def wrap_on_create_room( async def wrap_on_create_room(
requester: Requester, config: dict, is_requester_admin: bool requester: Requester, config: dict, is_requester_admin: bool
) -> None: ) -> None:
# We've already made sure f is not None above, but mypy doesn't do well # Assertion required because mypy can't prove we won't change
# across function boundaries so we need to tell it f is definitely not # `f` back to `None`. See
# None. # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None assert f is not None
res = await f(requester, config, is_requester_admin) res = await f(requester, config, is_requester_admin)
@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
return wrap_on_create_room return wrap_on_create_room
def run(*args, **kwargs): def run(*args: Any, **kwargs: Any) -> Awaitable:
# mypy doesn't do well across function boundaries so we need to tell it # Assertion required because mypy can't prove we won't change `f`
# f is definitely not None. # back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None assert f is not None
return maybe_awaitable(f(*args, **kwargs)) return maybe_awaitable(f(*args, **kwargs))
@ -162,7 +163,7 @@ class ThirdPartyEventRules:
check_visibility_can_be_modified: Optional[ check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None, ] = None,
): ) -> None:
"""Register callbacks from modules for each hook.""" """Register callbacks from modules for each hook."""
if check_event_allowed is not None: if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed) self._check_event_allowed_callbacks.append(check_event_allowed)

View File

@ -13,18 +13,32 @@
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
import re import re
from typing import Any, Mapping, Union from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)
from frozendict import frozendict from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
from synapse.util.async_helpers import yieldable_gather_results from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze from synapse.util.frozenutils import unfreeze
from . import EventBase from . import EventBase
if TYPE_CHECKING:
from synapse.server import HomeServer
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded # (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'. # by a match for 'stuff'.
@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
return pruned_event return pruned_event
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict: def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
"""Redacts the event_dict in the same way as `prune_event`, except it """Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects operates on dicts rather than event objects
@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
new_content = {} new_content = {}
def add_fields(*fields): def add_fields(*fields: str) -> None:
for field in fields: for field in fields:
if field in event_dict["content"]: if field in event_dict["content"]:
new_content[field] = event_dict["content"][field] new_content[field] = event_dict["content"][field]
@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
allowed_fields["content"] = new_content allowed_fields["content"] = new_content
unsigned = {} unsigned: JsonDict = {}
allowed_fields["unsigned"] = unsigned allowed_fields["unsigned"] = unsigned
event_unsigned = event_dict.get("unsigned", {}) event_unsigned = event_dict.get("unsigned", {})
@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
return allowed_fields return allowed_fields
def _copy_field(src, dst, field): def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
"""Copy the field in 'src' to 'dst'. """Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"] For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
then dst={"foo":{"bar":5}}. then dst={"foo":{"bar":5}}.
Args: Args:
src(dict): The dict to read from. src: The dict to read from.
dst(dict): The dict to modify. dst: The dict to modify.
field(list<str>): List of keys to drill down to in 'src'. field: List of keys to drill down to in 'src'.
""" """
if len(field) == 0: # this should be impossible if len(field) == 0: # this should be impossible
return return
@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
sub_out_dict[key_to_move] = sub_dict[key_to_move] sub_out_dict[key_to_move] = sub_dict[key_to_move]
def only_fields(dictionary, fields): def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
"""Return a new dict with only the fields in 'dictionary' which are present """Return a new dict with only the fields in 'dictionary' which are present
in 'fields'. in 'fields'.
@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
A literal '.' character in a field name may be escaped using a '\'. A literal '.' character in a field name may be escaped using a '\'.
Args: Args:
dictionary(dict): The dictionary to read from. dictionary: The dictionary to read from.
fields(list<str>): A list of fields to copy over. Only shallow refs are fields: A list of fields to copy over. Only shallow refs are
taken. taken.
Returns: Returns:
dict: A new dictionary with only the given fields. If fields was empty, A new dictionary with only the given fields. If fields was empty,
the same dictionary is returned. the same dictionary is returned.
""" """
if len(fields) == 0: if len(fields) == 0:
@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
] ]
output = {} output: JsonDict = {}
for field_array in split_fields: for field_array in split_fields:
_copy_field(dictionary, output, field_array) _copy_field(dictionary, output, field_array)
return output return output
def format_event_raw(d): def format_event_raw(d: JsonDict) -> JsonDict:
return d return d
def format_event_for_client_v1(d): def format_event_for_client_v1(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d) d = format_event_for_client_v2(d)
sender = d.get("sender") sender = d.get("sender")
@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
return d return d
def format_event_for_client_v2(d): def format_event_for_client_v2(d: JsonDict) -> JsonDict:
drop_keys = ( drop_keys = (
"auth_events", "auth_events",
"prev_events", "prev_events",
@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
return d return d
def format_event_for_client_v2_without_room_id(d): def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d) d = format_event_for_client_v2(d)
d.pop("room_id", None) d.pop("room_id", None)
return d return d
def serialize_event( def serialize_event(
e, e: Union[JsonDict, EventBase],
time_now_ms, time_now_ms: int,
as_client_event=True, as_client_event: bool = True,
event_format=format_event_for_client_v1, event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
token_id=None, token_id: Optional[str] = None,
only_event_fields=None, only_event_fields: Optional[List[str]] = None,
include_stripped_room_state=False, include_stripped_room_state: bool = False,
): ) -> JsonDict:
"""Serialize event for clients """Serialize event for clients
Args: Args:
e (EventBase) e
time_now_ms (int) time_now_ms
as_client_event (bool) as_client_event
event_format event_format
token_id token_id
only_event_fields only_event_fields
include_stripped_room_state (bool): Some events can have stripped room state include_stripped_room_state: Some events can have stripped room state
stored in the `unsigned` field. This is required for invite and knock stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept. event before it is returned. Otherwise, it will be kept.
Returns: Returns:
dict The serialized event dictionary.
""" """
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
@ -369,25 +383,29 @@ class EventClientSerializer:
clients. clients.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = ( self.experimental_msc1849_support_enabled = (
hs.config.server.experimental_msc1849_support_enabled hs.config.server.experimental_msc1849_support_enabled
) )
async def serialize_event( async def serialize_event(
self, event, time_now, bundle_aggregations=True, **kwargs self,
): event: Union[JsonDict, EventBase],
time_now: int,
bundle_aggregations: bool = True,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event. """Serializes a single event.
Args: Args:
event (EventBase) event
time_now (int): The current time in milliseconds time_now: The current time in milliseconds
bundle_aggregations (bool): Whether to bundle in related events bundle_aggregations: Whether to bundle in related events
**kwargs: Arguments to pass to `serialize_event` **kwargs: Arguments to pass to `serialize_event`
Returns: Returns:
dict: The serialized event The serialized event
""" """
# To handle the case of presence events and the like # To handle the case of presence events and the like
if not isinstance(event, EventBase): if not isinstance(event, EventBase):
@ -448,25 +466,27 @@ class EventClientSerializer:
return serialized_event return serialized_event
def serialize_events(self, events, time_now, **kwargs): async def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
"""Serializes multiple events. """Serializes multiple events.
Args: Args:
event (iter[EventBase]) event
time_now (int): The current time in milliseconds time_now: The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event` **kwargs: Arguments to pass to `serialize_event`
Returns: Returns:
Deferred[list[dict]]: The list of serialized events The list of serialized events
""" """
return yieldable_gather_results( return await yieldable_gather_results(
self.serialize_event, events, time_now=time_now, **kwargs self.serialize_event, events, time_now=time_now, **kwargs
) )
def copy_power_levels_contents( def copy_power_levels_contents(
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]] old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
): ) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way """Copy the content of a power_levels event, unfreezing frozendicts along the way
Raises: Raises:
@ -475,7 +495,7 @@ def copy_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping): if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,)) raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels = {} power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
for k, v in old_power_levels.items(): for k, v in old_power_levels.items():
if isinstance(v, int): if isinstance(v, int):
@ -483,7 +503,8 @@ def copy_power_levels_contents(
continue continue
if isinstance(v, collections.abc.Mapping): if isinstance(v, collections.abc.Mapping):
power_levels[k] = h = {} h: Dict[str, int] = {}
power_levels[k] = h
for k1, v1 in v.items(): for k1, v1 in v.items():
# we should only have one level of nesting # we should only have one level of nesting
if not isinstance(v1, int): if not isinstance(v1, int):
@ -498,7 +519,7 @@ def copy_power_levels_contents(
return power_levels return power_levels
def validate_canonicaljson(value: Any): def validate_canonicaljson(value: Any) -> None:
""" """
Ensure that the JSON object is valid according to the rules of canonical JSON. Ensure that the JSON object is valid according to the rules of canonical JSON.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections.abc import collections.abc
from typing import Union from typing import Iterable, Union
import jsonschema import jsonschema
@ -28,11 +28,11 @@ from synapse.events.utils import (
validate_canonicaljson, validate_canonicaljson,
) )
from synapse.federation.federation_server import server_matches_acl_event from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID from synapse.types import EventID, JsonDict, RoomID, UserID
class EventValidator: class EventValidator:
def validate_new(self, event: EventBase, config: HomeServerConfig): def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
"""Validates the event has roughly the right format """Validates the event has roughly the right format
Args: Args:
@ -116,7 +116,7 @@ class EventValidator:
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
def _validate_retention(self, event: EventBase): def _validate_retention(self, event: EventBase) -> None:
"""Checks that an event that defines the retention policy for a room respects the """Checks that an event that defines the retention policy for a room respects the
format enforced by the spec. format enforced by the spec.
@ -156,7 +156,7 @@ class EventValidator:
errcode=Codes.BAD_JSON, errcode=Codes.BAD_JSON,
) )
def validate_builder(self, event: Union[EventBase, EventBuilder]): def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
"""Validates that the builder/event has roughly the right format. Only """Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the checks values that we expect a proto event to have, rather than all the
fields an event would have fields an event would have
@ -204,14 +204,14 @@ class EventValidator:
self._ensure_state_event(event) self._ensure_state_event(event)
def _ensure_strings(self, d, keys): def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
for s in keys: for s in keys:
if s not in d: if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,)) raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], str): if not isinstance(d[s], str):
raise SynapseError(400, "'%s' not a string type" % (s,)) raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event): def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
if not event.is_state(): if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,)) raise SynapseError(400, "'%s' must be state events" % (event.type,))
@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
} }
def _create_power_level_validator(): # This could return something newer than Draft 7, but that's the current "latest"
# validator.
def _create_power_level_validator() -> jsonschema.Draft7Validator:
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA) validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
# by default jsonschema does not consider a frozendict to be an object so # by default jsonschema does not consider a frozendict to be an object so

View File

@ -465,17 +465,35 @@ class RoomCreationHandler:
# the room has been created # the room has been created
# Calculate the minimum power level needed to clone the room # Calculate the minimum power level needed to clone the room
event_power_levels = power_levels.get("events", {}) event_power_levels = power_levels.get("events", {})
if not isinstance(event_power_levels, dict):
event_power_levels = {}
state_default = power_levels.get("state_default", 50) state_default = power_levels.get("state_default", 50)
try:
state_default_int = int(state_default) # type: ignore[arg-type]
except (TypeError, ValueError):
state_default_int = 50
ban = power_levels.get("ban", 50) ban = power_levels.get("ban", 50)
needed_power_level = max(state_default, ban, max(event_power_levels.values())) try:
ban = int(ban) # type: ignore[arg-type]
except (TypeError, ValueError):
ban = 50
needed_power_level = max(
state_default_int, ban, max(event_power_levels.values())
)
# Get the user's current power level, this matches the logic in get_user_power_level, # Get the user's current power level, this matches the logic in get_user_power_level,
# but without the entire state map. # but without the entire state map.
user_power_levels = power_levels.setdefault("users", {}) user_power_levels = power_levels.setdefault("users", {})
if not isinstance(user_power_levels, dict):
user_power_levels = {}
users_default = power_levels.get("users_default", 0) users_default = power_levels.get("users_default", 0)
current_power_level = user_power_levels.get(user_id, users_default) current_power_level = user_power_levels.get(user_id, users_default)
try:
current_power_level_int = int(current_power_level) # type: ignore[arg-type]
except (TypeError, ValueError):
current_power_level_int = 0
# Raise the requester's power level in the new room if necessary # Raise the requester's power level in the new room if necessary
if current_power_level < needed_power_level: if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level user_power_levels[user_id] = needed_power_level
await self._send_events_for_new_room( await self._send_events_for_new_room(

View File

@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
# Similarly, we don't allow relations to be applied to relations, so we # Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them # return the original relations without any aggregations on top of them
# here. # here.
events = await self._event_serializer.serialize_events( serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False events, now, bundle_aggregations=False
) )
return_value = pagination_chunk.to_dict() return_value = pagination_chunk.to_dict()
return_value["chunk"] = events return_value["chunk"] = serialized_events
return_value["original_event"] = original_event return_value["original_event"] = original_event
return 200, return_value return 200, return_value
@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
) )
now = self.clock.time_msec() now = self.clock.time_msec()
events = await self._event_serializer.serialize_events(events, now) serialized_events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict() return_value = result.to_dict()
return_value["chunk"] = events return_value["chunk"] = serialized_events
return 200, return_value return 200, return_value