Add type hints to synapse.events.*. (#11066)

Except `synapse/events/__init__.py`, which will be done in a follow-up.
This commit is contained in:
Patrick Cloke 2021-10-13 07:24:07 -04:00 committed by GitHub
parent cdd308845b
commit 1f9d0b8a7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 208 additions and 145 deletions

1
changelog.d/11066.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.events`.

View File

@ -22,8 +22,11 @@ files =
synapse/crypto,
synapse/event_auth.py,
synapse/events/builder.py,
synapse/events/presence_router.py,
synapse/events/snapshot.py,
synapse/events/spamcheck.py,
synapse/events/third_party_rules.py,
synapse/events/utils.py,
synapse/events/validator.py,
synapse/federation,
synapse/groups,
@ -96,6 +99,9 @@ files =
tests/util/test_itertools.py,
tests/util/test_stream_change_cache.py
[mypy-synapse.events.*]
disallow_untyped_defs = True
[mypy-synapse.handlers.*]
disallow_untyped_defs = True

View File

@ -90,13 +90,13 @@ class EventBuilder:
)
@property
def state_key(self):
def state_key(self) -> str:
if self._state_key is not None:
return self._state_key
raise AttributeError("state_key")
def is_state(self):
def is_state(self) -> bool:
return self._state_key is not None
async def build(

View File

@ -14,6 +14,7 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
@ -33,14 +34,13 @@ if TYPE_CHECKING:
GET_USERS_FOR_STATES_CALLBACK = Callable[
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
]
GET_INTERESTED_USERS_CALLBACK = Callable[
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
]
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
logger = logging.getLogger(__name__)
def load_legacy_presence_router(hs: "HomeServer"):
def load_legacy_presence_router(hs: "HomeServer") -> None:
"""Wrapper that loads a presence router module configured using the old
configuration, and registers the hooks they implement.
"""
@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
if f is None:
return None
def run(*args, **kwargs):
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
@ -104,7 +105,7 @@ class PresenceRouter:
self,
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
):
) -> None:
# PresenceRouter modules are required to implement both of these methods
# or neither of them as they are assumed to act in a complementary manner
paired_methods = [get_users_for_states, get_interested_users]
@ -142,7 +143,7 @@ class PresenceRouter:
# Don't include any extra destinations for presence updates
return {}
users_for_states = {}
users_for_states: Dict[str, Set[UserPresenceState]] = {}
# run all the callbacks for get_users_for_states and combine the results
for callback in self._get_users_for_states_callbacks:
try:
@ -171,7 +172,7 @@ class PresenceRouter:
return users_for_states
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
"""
Retrieve a list of users that `user_id` is interested in receiving the
presence of. This will be in addition to those they share a room with.

View File

@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import attr
from frozendict import frozendict
from twisted.internet.defer import Deferred
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
@ -112,13 +115,13 @@ class EventContext:
@staticmethod
def with_state(
state_group,
state_group_before_event,
current_state_ids,
prev_state_ids,
prev_group=None,
delta_ids=None,
):
state_group: Optional[int],
state_group_before_event: Optional[int],
current_state_ids: Optional[StateMap[str]],
prev_state_ids: Optional[StateMap[str]],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
) -> "EventContext":
return EventContext(
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
@ -129,22 +132,22 @@ class EventContext:
)
@staticmethod
def for_outlier():
def for_outlier() -> "EventContext":
"""Return an EventContext instance suitable for persisting an outlier event"""
return EventContext(
current_state_ids={},
prev_state_ids={},
)
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
Args:
event (FrozenEvent): The event that this context relates to
event: The event that this context relates to
Returns:
dict
The serialized event.
"""
# We don't serialize the full state dicts, instead they get pulled out
@ -170,17 +173,16 @@ class EventContext:
}
@staticmethod
def deserialize(storage, input):
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
"""Converts a dict that was produced by `serialize` back into a
EventContext.
Args:
storage (Storage): Used to convert AS ID to AS object and fetch
state.
input (dict): A dict produced by `serialize`
storage: Used to convert AS ID to AS object and fetch state.
input: A dict produced by `serialize`
Returns:
EventContext
The event context.
"""
context = _AsyncEventContextImpl(
# We use the state_group and prev_state_id stuff to pull the
@ -241,22 +243,25 @@ class EventContext:
await self._ensure_fetched()
return self._current_state_ids
async def get_prev_state_ids(self):
async def get_prev_state_ids(self) -> StateMap[str]:
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Returns:
dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Returns {} if state_group is None, which happens when the associated
event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
await self._ensure_fetched()
# There *should* be previous state IDs now.
assert self._prev_state_ids is not None
return self._prev_state_ids
def get_cached_current_state_ids(self):
def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
"""Gets the current state IDs if we have them already cached.
It is an error to access this for a rejected event, since rejected state should
@ -264,16 +269,17 @@ class EventContext:
``rejected`` is set.
Returns:
dict[(str, str), str]|None: Returns None if we haven't cached the
state or if state_group is None, which happens when the associated
event is an outlier.
Returns None if we haven't cached the state or if state_group is None
(which happens when the associated event is an outlier).
Otherwise, returns the the current state IDs.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
return self._current_state_ids
async def _ensure_fetched(self):
async def _ensure_fetched(self) -> None:
return None
@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
Attributes:
_storage (Storage)
_storage
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_fetching_state_deferred: Resolves when *_state_ids have been calculated.
None if we haven't started calculating yet
_event_type (str): The type of the event the context is associated with.
_event_type: The type of the event the context is associated with.
_event_state_key (str): The state_key of the event the context is
associated with.
_event_state_key: The state_key of the event the context is associated with.
_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
_prev_state_id: If the event associated with the context is a state event,
then `_prev_state_id` is the event_id of the state that was replaced.
"""
# This needs to have a default as we're inheriting
_storage = attr.ib(default=None)
_prev_state_id = attr.ib(default=None)
_event_type = attr.ib(default=None)
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
_storage: "Storage" = attr.ib(default=None)
_prev_state_id: Optional[str] = attr.ib(default=None)
_event_type: str = attr.ib(default=None)
_event_state_key: Optional[str] = attr.ib(default=None)
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
async def _ensure_fetched(self):
async def _ensure_fetched(self) -> None:
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)
return await make_deferred_yieldable(self._fetching_state_deferred)
await make_deferred_yieldable(self._fetching_state_deferred)
async def _fill_out_state(self):
async def _fill_out_state(self) -> None:
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
# Set this separately so mypy knows current_state_ids is not None.
self._current_state_ids = current_state_ids
if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
self._prev_state_ids = dict(current_state_ids)
key = (self._event_type, self._event_state_key)
if self._prev_state_id:
@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
else:
self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = self._current_state_ids
self._prev_state_ids = current_state_ids
def _encode_state_dict(state_dict):
def _encode_state_dict(
state_dict: Optional[StateMap[str]],
) -> Optional[List[Tuple[str, str, str]]]:
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
JSON we need to convert them to a form that can.
"""
@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
def _decode_state_dict(input):
def _decode_state_dict(
input: Optional[List[Tuple[str, str, str]]]
) -> Optional[StateMap[str]]:
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:
return None

View File

@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
]
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
"""Wrapper that loads spam checkers configured using the old configuration, and
registers the spam checker hooks they implement.
"""
@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
# We've already made sure f is not None above, but mypy doesn't
# do well across function boundaries so we need to tell it f is
# definitely not None.
# Assertion required because mypy can't prove we won't
# change `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return f(
@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
"Bad signature for callback check_registration_for_spam",
)
def run(*args, **kwargs):
# mypy doesn't do well across function boundaries so we need to tell it
# wrapped_func is definitely not None.
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert wrapped_func is not None
return maybe_awaitable(wrapped_func(*args, **kwargs))
@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
class SpamChecker:
def __init__(self):
def __init__(self) -> None:
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@ -209,7 +210,7 @@ class SpamChecker:
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
] = None,
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
):
) -> None:
"""Register callbacks from module for each hook."""
if check_event_for_spam is not None:
self._check_event_for_spam_callbacks.append(check_event_for_spam)
@ -275,7 +276,9 @@ class SpamChecker:
return False
async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
async def user_may_join_room(
self, user_id: str, room_id: str, is_invited: bool
) -> bool:
"""Checks if a given users is allowed to join a room.
Not called when a user creates a room.
@ -285,7 +288,7 @@ class SpamChecker:
is_invited: Whether the user is invited into the room
Returns:
bool: Whether the user may join the room
Whether the user may join the room
"""
for callback in self._user_may_join_room_callbacks:
if await callback(user_id, room_id, is_invited) is False:

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.events import EventBase
@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
]
def load_legacy_third_party_event_rules(hs: "HomeServer"):
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
"""Wrapper that loads a third party event rules module configured using the old
configuration, and registers the hooks they implement.
"""
@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
event: EventBase,
state_events: StateMap[EventBase],
) -> Tuple[bool, Optional[dict]]:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
# Assertion required because mypy can't prove we won't change
# `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
res = await f(event, state_events)
@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
async def wrap_on_create_room(
requester: Requester, config: dict, is_requester_admin: bool
) -> None:
# We've already made sure f is not None above, but mypy doesn't do well
# across function boundaries so we need to tell it f is definitely not
# None.
# Assertion required because mypy can't prove we won't change
# `f` back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
res = await f(requester, config, is_requester_admin)
@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
return wrap_on_create_room
def run(*args, **kwargs):
# mypy doesn't do well across function boundaries so we need to tell it
# f is definitely not None.
def run(*args: Any, **kwargs: Any) -> Awaitable:
# Assertion required because mypy can't prove we won't change `f`
# back to `None`. See
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
assert f is not None
return maybe_awaitable(f(*args, **kwargs))
@ -162,7 +163,7 @@ class ThirdPartyEventRules:
check_visibility_can_be_modified: Optional[
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
] = None,
):
) -> None:
"""Register callbacks from modules for each hook."""
if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed)

View File

@ -13,18 +13,32 @@
# limitations under the License.
import collections.abc
import re
from typing import Any, Mapping, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)
from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze
from . import EventBase
if TYPE_CHECKING:
from synapse.server import HomeServer
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
return pruned_event
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
new_content = {}
def add_fields(*fields):
def add_fields(*fields: str) -> None:
for field in fields:
if field in event_dict["content"]:
new_content[field] = event_dict["content"][field]
@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
allowed_fields["content"] = new_content
unsigned = {}
unsigned: JsonDict = {}
allowed_fields["unsigned"] = unsigned
event_unsigned = event_dict.get("unsigned", {})
@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
return allowed_fields
def _copy_field(src, dst, field):
def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
"""Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
then dst={"foo":{"bar":5}}.
Args:
src(dict): The dict to read from.
dst(dict): The dict to modify.
field(list<str>): List of keys to drill down to in 'src'.
src: The dict to read from.
dst: The dict to modify.
field: List of keys to drill down to in 'src'.
"""
if len(field) == 0: # this should be impossible
return
@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
sub_out_dict[key_to_move] = sub_dict[key_to_move]
def only_fields(dictionary, fields):
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
A literal '.' character in a field name may be escaped using a '\'.
Args:
dictionary(dict): The dictionary to read from.
fields(list<str>): A list of fields to copy over. Only shallow refs are
dictionary: The dictionary to read from.
fields: A list of fields to copy over. Only shallow refs are
taken.
Returns:
dict: A new dictionary with only the given fields. If fields was empty,
A new dictionary with only the given fields. If fields was empty,
the same dictionary is returned.
"""
if len(fields) == 0:
@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
output = {}
output: JsonDict = {}
for field_array in split_fields:
_copy_field(dictionary, output, field_array)
return output
def format_event_raw(d):
def format_event_raw(d: JsonDict) -> JsonDict:
return d
def format_event_for_client_v1(d):
def format_event_for_client_v1(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d)
sender = d.get("sender")
@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
return d
def format_event_for_client_v2(d):
def format_event_for_client_v2(d: JsonDict) -> JsonDict:
drop_keys = (
"auth_events",
"prev_events",
@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
return d
def format_event_for_client_v2_without_room_id(d):
def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d)
d.pop("room_id", None)
return d
def serialize_event(
e,
time_now_ms,
as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
include_stripped_room_state=False,
):
e: Union[JsonDict, EventBase],
time_now_ms: int,
as_client_event: bool = True,
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
token_id: Optional[str] = None,
only_event_fields: Optional[List[str]] = None,
include_stripped_room_state: bool = False,
) -> JsonDict:
"""Serialize event for clients
Args:
e (EventBase)
time_now_ms (int)
as_client_event (bool)
e
time_now_ms
as_client_event
event_format
token_id
only_event_fields
include_stripped_room_state (bool): Some events can have stripped room state
include_stripped_room_state: Some events can have stripped room state
stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept.
Returns:
dict
The serialized event dictionary.
"""
# FIXME(erikj): To handle the case of presence events and the like
@ -369,25 +383,29 @@ class EventClientSerializer:
clients.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = (
hs.config.server.experimental_msc1849_support_enabled
)
async def serialize_event(
self, event, time_now, bundle_aggregations=True, **kwargs
):
self,
event: Union[JsonDict, EventBase],
time_now: int,
bundle_aggregations: bool = True,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Args:
event (EventBase)
time_now (int): The current time in milliseconds
bundle_aggregations (bool): Whether to bundle in related events
event
time_now: The current time in milliseconds
bundle_aggregations: Whether to bundle in related events
**kwargs: Arguments to pass to `serialize_event`
Returns:
dict: The serialized event
The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
@ -448,25 +466,27 @@ class EventClientSerializer:
return serialized_event
def serialize_events(self, events, time_now, **kwargs):
async def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
"""Serializes multiple events.
Args:
event (iter[EventBase])
time_now (int): The current time in milliseconds
event
time_now: The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event`
Returns:
Deferred[list[dict]]: The list of serialized events
The list of serialized events
"""
return yieldable_gather_results(
return await yieldable_gather_results(
self.serialize_event, events, time_now=time_now, **kwargs
)
def copy_power_levels_contents(
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
):
) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way
Raises:
@ -475,7 +495,7 @@ def copy_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels = {}
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
for k, v in old_power_levels.items():
if isinstance(v, int):
@ -483,7 +503,8 @@ def copy_power_levels_contents(
continue
if isinstance(v, collections.abc.Mapping):
power_levels[k] = h = {}
h: Dict[str, int] = {}
power_levels[k] = h
for k1, v1 in v.items():
# we should only have one level of nesting
if not isinstance(v1, int):
@ -498,7 +519,7 @@ def copy_power_levels_contents(
return power_levels
def validate_canonicaljson(value: Any):
def validate_canonicaljson(value: Any) -> None:
"""
Ensure that the JSON object is valid according to the rules of canonical JSON.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
from typing import Union
from typing import Iterable, Union
import jsonschema
@ -28,11 +28,11 @@ from synapse.events.utils import (
validate_canonicaljson,
)
from synapse.federation.federation_server import server_matches_acl_event
from synapse.types import EventID, RoomID, UserID
from synapse.types import EventID, JsonDict, RoomID, UserID
class EventValidator:
def validate_new(self, event: EventBase, config: HomeServerConfig):
def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
"""Validates the event has roughly the right format
Args:
@ -116,7 +116,7 @@ class EventValidator:
errcode=Codes.BAD_JSON,
)
def _validate_retention(self, event: EventBase):
def _validate_retention(self, event: EventBase) -> None:
"""Checks that an event that defines the retention policy for a room respects the
format enforced by the spec.
@ -156,7 +156,7 @@ class EventValidator:
errcode=Codes.BAD_JSON,
)
def validate_builder(self, event: Union[EventBase, EventBuilder]):
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
fields an event would have
@ -204,14 +204,14 @@ class EventValidator:
self._ensure_state_event(event)
def _ensure_strings(self, d, keys):
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], str):
raise SynapseError(400, "'%s' not a string type" % (s,))
def _ensure_state_event(self, event):
def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
if not event.is_state():
raise SynapseError(400, "'%s' must be state events" % (event.type,))
@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
}
def _create_power_level_validator():
# This could return something newer than Draft 7, but that's the current "latest"
# validator.
def _create_power_level_validator() -> jsonschema.Draft7Validator:
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
# by default jsonschema does not consider a frozendict to be an object so

View File

@ -465,17 +465,35 @@ class RoomCreationHandler:
# the room has been created
# Calculate the minimum power level needed to clone the room
event_power_levels = power_levels.get("events", {})
if not isinstance(event_power_levels, dict):
event_power_levels = {}
state_default = power_levels.get("state_default", 50)
try:
state_default_int = int(state_default) # type: ignore[arg-type]
except (TypeError, ValueError):
state_default_int = 50
ban = power_levels.get("ban", 50)
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
try:
ban = int(ban) # type: ignore[arg-type]
except (TypeError, ValueError):
ban = 50
needed_power_level = max(
state_default_int, ban, max(event_power_levels.values())
)
# Get the user's current power level, this matches the logic in get_user_power_level,
# but without the entire state map.
user_power_levels = power_levels.setdefault("users", {})
if not isinstance(user_power_levels, dict):
user_power_levels = {}
users_default = power_levels.get("users_default", 0)
current_power_level = user_power_levels.get(user_id, users_default)
try:
current_power_level_int = int(current_power_level) # type: ignore[arg-type]
except (TypeError, ValueError):
current_power_level_int = 0
# Raise the requester's power level in the new room if necessary
if current_power_level < needed_power_level:
if current_power_level_int < needed_power_level:
user_power_levels[user_id] = needed_power_level
await self._send_events_for_new_room(

View File

@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
# Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them
# here.
events = await self._event_serializer.serialize_events(
serialized_events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False
)
return_value = pagination_chunk.to_dict()
return_value["chunk"] = events
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
return 200, return_value
@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
)
now = self.clock.time_msec()
events = await self._event_serializer.serialize_events(events, now)
serialized_events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = events
return_value["chunk"] = serialized_events
return 200, return_value