mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2024-12-25 14:39:22 -05:00
Add type hints to synapse.events.*. (#11066)
Except `synapse/events/__init__.py`, which will be done in a follow-up.
This commit is contained in:
parent
cdd308845b
commit
1f9d0b8a7a
1
changelog.d/11066.misc
Normal file
1
changelog.d/11066.misc
Normal file
@ -0,0 +1 @@
|
||||
Add type hints to `synapse.events`.
|
6
mypy.ini
6
mypy.ini
@ -22,8 +22,11 @@ files =
|
||||
synapse/crypto,
|
||||
synapse/event_auth.py,
|
||||
synapse/events/builder.py,
|
||||
synapse/events/presence_router.py,
|
||||
synapse/events/snapshot.py,
|
||||
synapse/events/spamcheck.py,
|
||||
synapse/events/third_party_rules.py,
|
||||
synapse/events/utils.py,
|
||||
synapse/events/validator.py,
|
||||
synapse/federation,
|
||||
synapse/groups,
|
||||
@ -96,6 +99,9 @@ files =
|
||||
tests/util/test_itertools.py,
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
[mypy-synapse.events.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-synapse.handlers.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
@ -90,13 +90,13 @@ class EventBuilder:
|
||||
)
|
||||
|
||||
@property
|
||||
def state_key(self):
|
||||
def state_key(self) -> str:
|
||||
if self._state_key is not None:
|
||||
return self._state_key
|
||||
|
||||
raise AttributeError("state_key")
|
||||
|
||||
def is_state(self):
|
||||
def is_state(self) -> bool:
|
||||
return self._state_key is not None
|
||||
|
||||
async def build(
|
||||
|
@ -14,6 +14,7 @@
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
@ -33,14 +34,13 @@ if TYPE_CHECKING:
|
||||
GET_USERS_FOR_STATES_CALLBACK = Callable[
|
||||
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
|
||||
]
|
||||
GET_INTERESTED_USERS_CALLBACK = Callable[
|
||||
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
|
||||
]
|
||||
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
|
||||
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_legacy_presence_router(hs: "HomeServer"):
|
||||
def load_legacy_presence_router(hs: "HomeServer") -> None:
|
||||
"""Wrapper that loads a presence router module configured using the old
|
||||
configuration, and registers the hooks they implement.
|
||||
"""
|
||||
@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
|
||||
if f is None:
|
||||
return None
|
||||
|
||||
def run(*args, **kwargs):
|
||||
# mypy doesn't do well across function boundaries so we need to tell it
|
||||
# f is definitely not None.
|
||||
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
||||
# Assertion required because mypy can't prove we won't change `f`
|
||||
# back to `None`. See
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||
assert f is not None
|
||||
|
||||
return maybe_awaitable(f(*args, **kwargs))
|
||||
@ -104,7 +105,7 @@ class PresenceRouter:
|
||||
self,
|
||||
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
|
||||
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
|
||||
):
|
||||
) -> None:
|
||||
# PresenceRouter modules are required to implement both of these methods
|
||||
# or neither of them as they are assumed to act in a complementary manner
|
||||
paired_methods = [get_users_for_states, get_interested_users]
|
||||
@ -142,7 +143,7 @@ class PresenceRouter:
|
||||
# Don't include any extra destinations for presence updates
|
||||
return {}
|
||||
|
||||
users_for_states = {}
|
||||
users_for_states: Dict[str, Set[UserPresenceState]] = {}
|
||||
# run all the callbacks for get_users_for_states and combine the results
|
||||
for callback in self._get_users_for_states_callbacks:
|
||||
try:
|
||||
@ -171,7 +172,7 @@ class PresenceRouter:
|
||||
|
||||
return users_for_states
|
||||
|
||||
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
|
||||
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
|
||||
"""
|
||||
Retrieve a list of users that `user_id` is interested in receiving the
|
||||
presence of. This will be in addition to those they share a room with.
|
||||
|
@ -11,17 +11,20 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import JsonDict, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage import Storage
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
|
||||
@ -112,13 +115,13 @@ class EventContext:
|
||||
|
||||
@staticmethod
|
||||
def with_state(
|
||||
state_group,
|
||||
state_group_before_event,
|
||||
current_state_ids,
|
||||
prev_state_ids,
|
||||
prev_group=None,
|
||||
delta_ids=None,
|
||||
):
|
||||
state_group: Optional[int],
|
||||
state_group_before_event: Optional[int],
|
||||
current_state_ids: Optional[StateMap[str]],
|
||||
prev_state_ids: Optional[StateMap[str]],
|
||||
prev_group: Optional[int] = None,
|
||||
delta_ids: Optional[StateMap[str]] = None,
|
||||
) -> "EventContext":
|
||||
return EventContext(
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
@ -129,22 +132,22 @@ class EventContext:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def for_outlier():
|
||||
def for_outlier() -> "EventContext":
|
||||
"""Return an EventContext instance suitable for persisting an outlier event"""
|
||||
return EventContext(
|
||||
current_state_ids={},
|
||||
prev_state_ids={},
|
||||
)
|
||||
|
||||
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
|
||||
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
|
||||
"""Converts self to a type that can be serialized as JSON, and then
|
||||
deserialized by `deserialize`
|
||||
|
||||
Args:
|
||||
event (FrozenEvent): The event that this context relates to
|
||||
event: The event that this context relates to
|
||||
|
||||
Returns:
|
||||
dict
|
||||
The serialized event.
|
||||
"""
|
||||
|
||||
# We don't serialize the full state dicts, instead they get pulled out
|
||||
@ -170,17 +173,16 @@ class EventContext:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def deserialize(storage, input):
|
||||
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
|
||||
"""Converts a dict that was produced by `serialize` back into a
|
||||
EventContext.
|
||||
|
||||
Args:
|
||||
storage (Storage): Used to convert AS ID to AS object and fetch
|
||||
state.
|
||||
input (dict): A dict produced by `serialize`
|
||||
storage: Used to convert AS ID to AS object and fetch state.
|
||||
input: A dict produced by `serialize`
|
||||
|
||||
Returns:
|
||||
EventContext
|
||||
The event context.
|
||||
"""
|
||||
context = _AsyncEventContextImpl(
|
||||
# We use the state_group and prev_state_id stuff to pull the
|
||||
@ -241,22 +243,25 @@ class EventContext:
|
||||
await self._ensure_fetched()
|
||||
return self._current_state_ids
|
||||
|
||||
async def get_prev_state_ids(self):
|
||||
async def get_prev_state_ids(self) -> StateMap[str]:
|
||||
"""
|
||||
Gets the room state map, excluding this event.
|
||||
|
||||
For a non-state event, this will be the same as get_current_state_ids().
|
||||
|
||||
Returns:
|
||||
dict[(str, str), str]|None: Returns None if state_group
|
||||
is None, which happens when the associated event is an outlier.
|
||||
Returns {} if state_group is None, which happens when the associated
|
||||
event is an outlier.
|
||||
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
await self._ensure_fetched()
|
||||
# There *should* be previous state IDs now.
|
||||
assert self._prev_state_ids is not None
|
||||
return self._prev_state_ids
|
||||
|
||||
def get_cached_current_state_ids(self):
|
||||
def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||
"""Gets the current state IDs if we have them already cached.
|
||||
|
||||
It is an error to access this for a rejected event, since rejected state should
|
||||
@ -264,16 +269,17 @@ class EventContext:
|
||||
``rejected`` is set.
|
||||
|
||||
Returns:
|
||||
dict[(str, str), str]|None: Returns None if we haven't cached the
|
||||
state or if state_group is None, which happens when the associated
|
||||
event is an outlier.
|
||||
Returns None if we haven't cached the state or if state_group is None
|
||||
(which happens when the associated event is an outlier).
|
||||
|
||||
Otherwise, returns the the current state IDs.
|
||||
"""
|
||||
if self.rejected:
|
||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||
|
||||
return self._current_state_ids
|
||||
|
||||
async def _ensure_fetched(self):
|
||||
async def _ensure_fetched(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
|
||||
|
||||
Attributes:
|
||||
|
||||
_storage (Storage)
|
||||
_storage
|
||||
|
||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||
been calculated. None if we haven't started calculating yet
|
||||
_fetching_state_deferred: Resolves when *_state_ids have been calculated.
|
||||
None if we haven't started calculating yet
|
||||
|
||||
_event_type (str): The type of the event the context is associated with.
|
||||
_event_type: The type of the event the context is associated with.
|
||||
|
||||
_event_state_key (str): The state_key of the event the context is
|
||||
associated with.
|
||||
_event_state_key: The state_key of the event the context is associated with.
|
||||
|
||||
_prev_state_id (str|None): If the event associated with the context is
|
||||
a state event, then `_prev_state_id` is the event_id of the state
|
||||
that was replaced.
|
||||
_prev_state_id: If the event associated with the context is a state event,
|
||||
then `_prev_state_id` is the event_id of the state that was replaced.
|
||||
"""
|
||||
|
||||
# This needs to have a default as we're inheriting
|
||||
_storage = attr.ib(default=None)
|
||||
_prev_state_id = attr.ib(default=None)
|
||||
_event_type = attr.ib(default=None)
|
||||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
_storage: "Storage" = attr.ib(default=None)
|
||||
_prev_state_id: Optional[str] = attr.ib(default=None)
|
||||
_event_type: str = attr.ib(default=None)
|
||||
_event_state_key: Optional[str] = attr.ib(default=None)
|
||||
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
|
||||
|
||||
async def _ensure_fetched(self):
|
||||
async def _ensure_fetched(self) -> None:
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||
|
||||
return await make_deferred_yieldable(self._fetching_state_deferred)
|
||||
await make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
async def _fill_out_state(self):
|
||||
async def _fill_out_state(self) -> None:
|
||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||
attributes by loading from the database.
|
||||
"""
|
||||
if self.state_group is None:
|
||||
return
|
||||
|
||||
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||
current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||
self.state_group
|
||||
)
|
||||
# Set this separately so mypy knows current_state_ids is not None.
|
||||
self._current_state_ids = current_state_ids
|
||||
if self._event_state_key is not None:
|
||||
self._prev_state_ids = dict(self._current_state_ids)
|
||||
self._prev_state_ids = dict(current_state_ids)
|
||||
|
||||
key = (self._event_type, self._event_state_key)
|
||||
if self._prev_state_id:
|
||||
@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
|
||||
else:
|
||||
self._prev_state_ids.pop(key, None)
|
||||
else:
|
||||
self._prev_state_ids = self._current_state_ids
|
||||
self._prev_state_ids = current_state_ids
|
||||
|
||||
|
||||
def _encode_state_dict(state_dict):
|
||||
def _encode_state_dict(
|
||||
state_dict: Optional[StateMap[str]],
|
||||
) -> Optional[List[Tuple[str, str, str]]]:
|
||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||
JSON we need to convert them to a form that can.
|
||||
"""
|
||||
@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
|
||||
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
|
||||
|
||||
|
||||
def _decode_state_dict(input):
|
||||
def _decode_state_dict(
|
||||
input: Optional[List[Tuple[str, str, str]]]
|
||||
) -> Optional[StateMap[str]]:
|
||||
"""Decodes a state dict encoded using `_encode_state_dict` above"""
|
||||
if input is None:
|
||||
return None
|
||||
|
@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
|
||||
]
|
||||
|
||||
|
||||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
||||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
|
||||
"""Wrapper that loads spam checkers configured using the old configuration, and
|
||||
registers the spam checker hooks they implement.
|
||||
"""
|
||||
@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
auth_provider_id: Optional[str],
|
||||
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
|
||||
# We've already made sure f is not None above, but mypy doesn't
|
||||
# do well across function boundaries so we need to tell it f is
|
||||
# definitely not None.
|
||||
# Assertion required because mypy can't prove we won't
|
||||
# change `f` back to `None`. See
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||
assert f is not None
|
||||
|
||||
return f(
|
||||
@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
||||
"Bad signature for callback check_registration_for_spam",
|
||||
)
|
||||
|
||||
def run(*args, **kwargs):
|
||||
# mypy doesn't do well across function boundaries so we need to tell it
|
||||
# wrapped_func is definitely not None.
|
||||
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
||||
# Assertion required because mypy can't prove we won't change `f`
|
||||
# back to `None`. See
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||
assert wrapped_func is not None
|
||||
|
||||
return maybe_awaitable(wrapped_func(*args, **kwargs))
|
||||
@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
||||
|
||||
|
||||
class SpamChecker:
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
|
||||
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
|
||||
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
|
||||
@ -209,7 +210,7 @@ class SpamChecker:
|
||||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
||||
] = None,
|
||||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Register callbacks from module for each hook."""
|
||||
if check_event_for_spam is not None:
|
||||
self._check_event_for_spam_callbacks.append(check_event_for_spam)
|
||||
@ -275,7 +276,9 @@ class SpamChecker:
|
||||
|
||||
return False
|
||||
|
||||
async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
|
||||
async def user_may_join_room(
|
||||
self, user_id: str, room_id: str, is_invited: bool
|
||||
) -> bool:
|
||||
"""Checks if a given users is allowed to join a room.
|
||||
Not called when a user creates a room.
|
||||
|
||||
@ -285,7 +288,7 @@ class SpamChecker:
|
||||
is_invited: Whether the user is invited into the room
|
||||
|
||||
Returns:
|
||||
bool: Whether the user may join the room
|
||||
Whether the user may join the room
|
||||
"""
|
||||
for callback in self._user_may_join_room_callbacks:
|
||||
if await callback(user_id, room_id, is_invited) is False:
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
|
||||
]
|
||||
|
||||
|
||||
def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
||||
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
||||
"""Wrapper that loads a third party event rules module configured using the old
|
||||
configuration, and registers the hooks they implement.
|
||||
"""
|
||||
@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
||||
event: EventBase,
|
||||
state_events: StateMap[EventBase],
|
||||
) -> Tuple[bool, Optional[dict]]:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
# Assertion required because mypy can't prove we won't change
|
||||
# `f` back to `None`. See
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||
assert f is not None
|
||||
|
||||
res = await f(event, state_events)
|
||||
@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
||||
async def wrap_on_create_room(
|
||||
requester: Requester, config: dict, is_requester_admin: bool
|
||||
) -> None:
|
||||
# We've already made sure f is not None above, but mypy doesn't do well
|
||||
# across function boundaries so we need to tell it f is definitely not
|
||||
# None.
|
||||
# Assertion required because mypy can't prove we won't change
|
||||
# `f` back to `None`. See
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||
assert f is not None
|
||||
|
||||
res = await f(requester, config, is_requester_admin)
|
||||
@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
||||
|
||||
return wrap_on_create_room
|
||||
|
||||
def run(*args, **kwargs):
|
||||
# mypy doesn't do well across function boundaries so we need to tell it
|
||||
# f is definitely not None.
|
||||
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
||||
# Assertion required because mypy can't prove we won't change `f`
|
||||
# back to `None`. See
|
||||
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||
assert f is not None
|
||||
|
||||
return maybe_awaitable(f(*args, **kwargs))
|
||||
@ -162,7 +163,7 @@ class ThirdPartyEventRules:
|
||||
check_visibility_can_be_modified: Optional[
|
||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
||||
] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Register callbacks from modules for each hook."""
|
||||
if check_event_allowed is not None:
|
||||
self._check_event_allowed_callbacks.append(check_event_allowed)
|
||||
|
@ -13,18 +13,32 @@
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
import re
|
||||
from typing import Any, Mapping, Union
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.async_helpers import yieldable_gather_results
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
|
||||
from . import EventBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||
# (?<!stuff) matches if the current position in the string is not preceded
|
||||
# by a match for 'stuff'.
|
||||
@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
|
||||
return pruned_event
|
||||
|
||||
|
||||
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
|
||||
"""Redacts the event_dict in the same way as `prune_event`, except it
|
||||
operates on dicts rather than event objects
|
||||
|
||||
@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||
|
||||
new_content = {}
|
||||
|
||||
def add_fields(*fields):
|
||||
def add_fields(*fields: str) -> None:
|
||||
for field in fields:
|
||||
if field in event_dict["content"]:
|
||||
new_content[field] = event_dict["content"][field]
|
||||
@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||
|
||||
allowed_fields["content"] = new_content
|
||||
|
||||
unsigned = {}
|
||||
unsigned: JsonDict = {}
|
||||
allowed_fields["unsigned"] = unsigned
|
||||
|
||||
event_unsigned = event_dict.get("unsigned", {})
|
||||
@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
||||
return allowed_fields
|
||||
|
||||
|
||||
def _copy_field(src, dst, field):
|
||||
def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
|
||||
"""Copy the field in 'src' to 'dst'.
|
||||
|
||||
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
|
||||
then dst={"foo":{"bar":5}}.
|
||||
|
||||
Args:
|
||||
src(dict): The dict to read from.
|
||||
dst(dict): The dict to modify.
|
||||
field(list<str>): List of keys to drill down to in 'src'.
|
||||
src: The dict to read from.
|
||||
dst: The dict to modify.
|
||||
field: List of keys to drill down to in 'src'.
|
||||
"""
|
||||
if len(field) == 0: # this should be impossible
|
||||
return
|
||||
@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
|
||||
sub_out_dict[key_to_move] = sub_dict[key_to_move]
|
||||
|
||||
|
||||
def only_fields(dictionary, fields):
|
||||
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
|
||||
"""Return a new dict with only the fields in 'dictionary' which are present
|
||||
in 'fields'.
|
||||
|
||||
@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
|
||||
A literal '.' character in a field name may be escaped using a '\'.
|
||||
|
||||
Args:
|
||||
dictionary(dict): The dictionary to read from.
|
||||
fields(list<str>): A list of fields to copy over. Only shallow refs are
|
||||
dictionary: The dictionary to read from.
|
||||
fields: A list of fields to copy over. Only shallow refs are
|
||||
taken.
|
||||
Returns:
|
||||
dict: A new dictionary with only the given fields. If fields was empty,
|
||||
A new dictionary with only the given fields. If fields was empty,
|
||||
the same dictionary is returned.
|
||||
"""
|
||||
if len(fields) == 0:
|
||||
@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
|
||||
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
|
||||
]
|
||||
|
||||
output = {}
|
||||
output: JsonDict = {}
|
||||
for field_array in split_fields:
|
||||
_copy_field(dictionary, output, field_array)
|
||||
return output
|
||||
|
||||
|
||||
def format_event_raw(d):
|
||||
def format_event_raw(d: JsonDict) -> JsonDict:
|
||||
return d
|
||||
|
||||
|
||||
def format_event_for_client_v1(d):
|
||||
def format_event_for_client_v1(d: JsonDict) -> JsonDict:
|
||||
d = format_event_for_client_v2(d)
|
||||
|
||||
sender = d.get("sender")
|
||||
@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
|
||||
return d
|
||||
|
||||
|
||||
def format_event_for_client_v2(d):
|
||||
def format_event_for_client_v2(d: JsonDict) -> JsonDict:
|
||||
drop_keys = (
|
||||
"auth_events",
|
||||
"prev_events",
|
||||
@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
|
||||
return d
|
||||
|
||||
|
||||
def format_event_for_client_v2_without_room_id(d):
|
||||
def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
|
||||
d = format_event_for_client_v2(d)
|
||||
d.pop("room_id", None)
|
||||
return d
|
||||
|
||||
|
||||
def serialize_event(
|
||||
e,
|
||||
time_now_ms,
|
||||
as_client_event=True,
|
||||
event_format=format_event_for_client_v1,
|
||||
token_id=None,
|
||||
only_event_fields=None,
|
||||
include_stripped_room_state=False,
|
||||
):
|
||||
e: Union[JsonDict, EventBase],
|
||||
time_now_ms: int,
|
||||
as_client_event: bool = True,
|
||||
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
|
||||
token_id: Optional[str] = None,
|
||||
only_event_fields: Optional[List[str]] = None,
|
||||
include_stripped_room_state: bool = False,
|
||||
) -> JsonDict:
|
||||
"""Serialize event for clients
|
||||
|
||||
Args:
|
||||
e (EventBase)
|
||||
time_now_ms (int)
|
||||
as_client_event (bool)
|
||||
e
|
||||
time_now_ms
|
||||
as_client_event
|
||||
event_format
|
||||
token_id
|
||||
only_event_fields
|
||||
include_stripped_room_state (bool): Some events can have stripped room state
|
||||
include_stripped_room_state: Some events can have stripped room state
|
||||
stored in the `unsigned` field. This is required for invite and knock
|
||||
functionality. If this option is False, that state will be removed from the
|
||||
event before it is returned. Otherwise, it will be kept.
|
||||
|
||||
Returns:
|
||||
dict
|
||||
The serialized event dictionary.
|
||||
"""
|
||||
|
||||
# FIXME(erikj): To handle the case of presence events and the like
|
||||
@ -369,25 +383,29 @@ class EventClientSerializer:
|
||||
clients.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.experimental_msc1849_support_enabled = (
|
||||
hs.config.server.experimental_msc1849_support_enabled
|
||||
)
|
||||
|
||||
async def serialize_event(
|
||||
self, event, time_now, bundle_aggregations=True, **kwargs
|
||||
):
|
||||
self,
|
||||
event: Union[JsonDict, EventBase],
|
||||
time_now: int,
|
||||
bundle_aggregations: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> JsonDict:
|
||||
"""Serializes a single event.
|
||||
|
||||
Args:
|
||||
event (EventBase)
|
||||
time_now (int): The current time in milliseconds
|
||||
bundle_aggregations (bool): Whether to bundle in related events
|
||||
event
|
||||
time_now: The current time in milliseconds
|
||||
bundle_aggregations: Whether to bundle in related events
|
||||
**kwargs: Arguments to pass to `serialize_event`
|
||||
|
||||
Returns:
|
||||
dict: The serialized event
|
||||
The serialized event
|
||||
"""
|
||||
# To handle the case of presence events and the like
|
||||
if not isinstance(event, EventBase):
|
||||
@ -448,25 +466,27 @@ class EventClientSerializer:
|
||||
|
||||
return serialized_event
|
||||
|
||||
def serialize_events(self, events, time_now, **kwargs):
|
||||
async def serialize_events(
|
||||
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
|
||||
) -> List[JsonDict]:
|
||||
"""Serializes multiple events.
|
||||
|
||||
Args:
|
||||
event (iter[EventBase])
|
||||
time_now (int): The current time in milliseconds
|
||||
event
|
||||
time_now: The current time in milliseconds
|
||||
**kwargs: Arguments to pass to `serialize_event`
|
||||
|
||||
Returns:
|
||||
Deferred[list[dict]]: The list of serialized events
|
||||
The list of serialized events
|
||||
"""
|
||||
return yieldable_gather_results(
|
||||
return await yieldable_gather_results(
|
||||
self.serialize_event, events, time_now=time_now, **kwargs
|
||||
)
|
||||
|
||||
|
||||
def copy_power_levels_contents(
|
||||
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
|
||||
):
|
||||
) -> Dict[str, Union[int, Dict[str, int]]]:
|
||||
"""Copy the content of a power_levels event, unfreezing frozendicts along the way
|
||||
|
||||
Raises:
|
||||
@ -475,7 +495,7 @@ def copy_power_levels_contents(
|
||||
if not isinstance(old_power_levels, collections.abc.Mapping):
|
||||
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
|
||||
|
||||
power_levels = {}
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
|
||||
for k, v in old_power_levels.items():
|
||||
|
||||
if isinstance(v, int):
|
||||
@ -483,7 +503,8 @@ def copy_power_levels_contents(
|
||||
continue
|
||||
|
||||
if isinstance(v, collections.abc.Mapping):
|
||||
power_levels[k] = h = {}
|
||||
h: Dict[str, int] = {}
|
||||
power_levels[k] = h
|
||||
for k1, v1 in v.items():
|
||||
# we should only have one level of nesting
|
||||
if not isinstance(v1, int):
|
||||
@ -498,7 +519,7 @@ def copy_power_levels_contents(
|
||||
return power_levels
|
||||
|
||||
|
||||
def validate_canonicaljson(value: Any):
|
||||
def validate_canonicaljson(value: Any) -> None:
|
||||
"""
|
||||
Ensure that the JSON object is valid according to the rules of canonical JSON.
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import collections.abc
|
||||
from typing import Union
|
||||
from typing import Iterable, Union
|
||||
|
||||
import jsonschema
|
||||
|
||||
@ -28,11 +28,11 @@ from synapse.events.utils import (
|
||||
validate_canonicaljson,
|
||||
)
|
||||
from synapse.federation.federation_server import server_matches_acl_event
|
||||
from synapse.types import EventID, RoomID, UserID
|
||||
from synapse.types import EventID, JsonDict, RoomID, UserID
|
||||
|
||||
|
||||
class EventValidator:
|
||||
def validate_new(self, event: EventBase, config: HomeServerConfig):
|
||||
def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
|
||||
"""Validates the event has roughly the right format
|
||||
|
||||
Args:
|
||||
@ -116,7 +116,7 @@ class EventValidator:
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
def _validate_retention(self, event: EventBase):
|
||||
def _validate_retention(self, event: EventBase) -> None:
|
||||
"""Checks that an event that defines the retention policy for a room respects the
|
||||
format enforced by the spec.
|
||||
|
||||
@ -156,7 +156,7 @@ class EventValidator:
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
def validate_builder(self, event: Union[EventBase, EventBuilder]):
|
||||
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
|
||||
"""Validates that the builder/event has roughly the right format. Only
|
||||
checks values that we expect a proto event to have, rather than all the
|
||||
fields an event would have
|
||||
@ -204,14 +204,14 @@ class EventValidator:
|
||||
|
||||
self._ensure_state_event(event)
|
||||
|
||||
def _ensure_strings(self, d, keys):
|
||||
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
|
||||
for s in keys:
|
||||
if s not in d:
|
||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||
if not isinstance(d[s], str):
|
||||
raise SynapseError(400, "'%s' not a string type" % (s,))
|
||||
|
||||
def _ensure_state_event(self, event):
|
||||
def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
|
||||
if not event.is_state():
|
||||
raise SynapseError(400, "'%s' must be state events" % (event.type,))
|
||||
|
||||
@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
def _create_power_level_validator():
|
||||
# This could return something newer than Draft 7, but that's the current "latest"
|
||||
# validator.
|
||||
def _create_power_level_validator() -> jsonschema.Draft7Validator:
|
||||
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
|
||||
|
||||
# by default jsonschema does not consider a frozendict to be an object so
|
||||
|
@ -465,17 +465,35 @@ class RoomCreationHandler:
|
||||
# the room has been created
|
||||
# Calculate the minimum power level needed to clone the room
|
||||
event_power_levels = power_levels.get("events", {})
|
||||
if not isinstance(event_power_levels, dict):
|
||||
event_power_levels = {}
|
||||
state_default = power_levels.get("state_default", 50)
|
||||
try:
|
||||
state_default_int = int(state_default) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
state_default_int = 50
|
||||
ban = power_levels.get("ban", 50)
|
||||
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
|
||||
try:
|
||||
ban = int(ban) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
ban = 50
|
||||
needed_power_level = max(
|
||||
state_default_int, ban, max(event_power_levels.values())
|
||||
)
|
||||
|
||||
# Get the user's current power level, this matches the logic in get_user_power_level,
|
||||
# but without the entire state map.
|
||||
user_power_levels = power_levels.setdefault("users", {})
|
||||
if not isinstance(user_power_levels, dict):
|
||||
user_power_levels = {}
|
||||
users_default = power_levels.get("users_default", 0)
|
||||
current_power_level = user_power_levels.get(user_id, users_default)
|
||||
try:
|
||||
current_power_level_int = int(current_power_level) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
current_power_level_int = 0
|
||||
# Raise the requester's power level in the new room if necessary
|
||||
if current_power_level < needed_power_level:
|
||||
if current_power_level_int < needed_power_level:
|
||||
user_power_levels[user_id] = needed_power_level
|
||||
|
||||
await self._send_events_for_new_room(
|
||||
|
@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
|
||||
# Similarly, we don't allow relations to be applied to relations, so we
|
||||
# return the original relations without any aggregations on top of them
|
||||
# here.
|
||||
events = await self._event_serializer.serialize_events(
|
||||
serialized_events = await self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=False
|
||||
)
|
||||
|
||||
return_value = pagination_chunk.to_dict()
|
||||
return_value["chunk"] = events
|
||||
return_value["chunk"] = serialized_events
|
||||
return_value["original_event"] = original_event
|
||||
|
||||
return 200, return_value
|
||||
@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
events = await self._event_serializer.serialize_events(events, now)
|
||||
serialized_events = await self._event_serializer.serialize_events(events, now)
|
||||
|
||||
return_value = result.to_dict()
|
||||
return_value["chunk"] = events
|
||||
return_value["chunk"] = serialized_events
|
||||
|
||||
return 200, return_value
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user