mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2025-01-03 16:40:47 -05:00
Add type hints to synapse.events.*. (#11066)
Except `synapse/events/__init__.py`, which will be done in a follow-up.
This commit is contained in:
parent
cdd308845b
commit
1f9d0b8a7a
1
changelog.d/11066.misc
Normal file
1
changelog.d/11066.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to `synapse.events`.
|
6
mypy.ini
6
mypy.ini
@ -22,8 +22,11 @@ files =
|
|||||||
synapse/crypto,
|
synapse/crypto,
|
||||||
synapse/event_auth.py,
|
synapse/event_auth.py,
|
||||||
synapse/events/builder.py,
|
synapse/events/builder.py,
|
||||||
|
synapse/events/presence_router.py,
|
||||||
|
synapse/events/snapshot.py,
|
||||||
synapse/events/spamcheck.py,
|
synapse/events/spamcheck.py,
|
||||||
synapse/events/third_party_rules.py,
|
synapse/events/third_party_rules.py,
|
||||||
|
synapse/events/utils.py,
|
||||||
synapse/events/validator.py,
|
synapse/events/validator.py,
|
||||||
synapse/federation,
|
synapse/federation,
|
||||||
synapse/groups,
|
synapse/groups,
|
||||||
@ -96,6 +99,9 @@ files =
|
|||||||
tests/util/test_itertools.py,
|
tests/util/test_itertools.py,
|
||||||
tests/util/test_stream_change_cache.py
|
tests/util/test_stream_change_cache.py
|
||||||
|
|
||||||
|
[mypy-synapse.events.*]
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
[mypy-synapse.handlers.*]
|
[mypy-synapse.handlers.*]
|
||||||
disallow_untyped_defs = True
|
disallow_untyped_defs = True
|
||||||
|
|
||||||
|
@ -90,13 +90,13 @@ class EventBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state_key(self):
|
def state_key(self) -> str:
|
||||||
if self._state_key is not None:
|
if self._state_key is not None:
|
||||||
return self._state_key
|
return self._state_key
|
||||||
|
|
||||||
raise AttributeError("state_key")
|
raise AttributeError("state_key")
|
||||||
|
|
||||||
def is_state(self):
|
def is_state(self) -> bool:
|
||||||
return self._state_key is not None
|
return self._state_key is not None
|
||||||
|
|
||||||
async def build(
|
async def build(
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
@ -33,14 +34,13 @@ if TYPE_CHECKING:
|
|||||||
GET_USERS_FOR_STATES_CALLBACK = Callable[
|
GET_USERS_FOR_STATES_CALLBACK = Callable[
|
||||||
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
|
[Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
|
||||||
]
|
]
|
||||||
GET_INTERESTED_USERS_CALLBACK = Callable[
|
# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
|
||||||
[str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
|
GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
|
||||||
]
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_legacy_presence_router(hs: "HomeServer"):
|
def load_legacy_presence_router(hs: "HomeServer") -> None:
|
||||||
"""Wrapper that loads a presence router module configured using the old
|
"""Wrapper that loads a presence router module configured using the old
|
||||||
configuration, and registers the hooks they implement.
|
configuration, and registers the hooks they implement.
|
||||||
"""
|
"""
|
||||||
@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
|
|||||||
if f is None:
|
if f is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def run(*args, **kwargs):
|
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
||||||
# mypy doesn't do well across function boundaries so we need to tell it
|
# Assertion required because mypy can't prove we won't change `f`
|
||||||
# f is definitely not None.
|
# back to `None`. See
|
||||||
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
assert f is not None
|
assert f is not None
|
||||||
|
|
||||||
return maybe_awaitable(f(*args, **kwargs))
|
return maybe_awaitable(f(*args, **kwargs))
|
||||||
@ -104,7 +105,7 @@ class PresenceRouter:
|
|||||||
self,
|
self,
|
||||||
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
|
get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
|
||||||
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
|
get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
|
||||||
):
|
) -> None:
|
||||||
# PresenceRouter modules are required to implement both of these methods
|
# PresenceRouter modules are required to implement both of these methods
|
||||||
# or neither of them as they are assumed to act in a complementary manner
|
# or neither of them as they are assumed to act in a complementary manner
|
||||||
paired_methods = [get_users_for_states, get_interested_users]
|
paired_methods = [get_users_for_states, get_interested_users]
|
||||||
@ -142,7 +143,7 @@ class PresenceRouter:
|
|||||||
# Don't include any extra destinations for presence updates
|
# Don't include any extra destinations for presence updates
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
users_for_states = {}
|
users_for_states: Dict[str, Set[UserPresenceState]] = {}
|
||||||
# run all the callbacks for get_users_for_states and combine the results
|
# run all the callbacks for get_users_for_states and combine the results
|
||||||
for callback in self._get_users_for_states_callbacks:
|
for callback in self._get_users_for_states_callbacks:
|
||||||
try:
|
try:
|
||||||
@ -171,7 +172,7 @@ class PresenceRouter:
|
|||||||
|
|
||||||
return users_for_states
|
return users_for_states
|
||||||
|
|
||||||
async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
|
async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
|
||||||
"""
|
"""
|
||||||
Retrieve a list of users that `user_id` is interested in receiving the
|
Retrieve a list of users that `user_id` is interested in receiving the
|
||||||
presence of. This will be in addition to those they share a room with.
|
presence of. This will be in addition to those they share a room with.
|
||||||
|
@ -11,17 +11,20 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
|
||||||
from synapse.appservice import ApplicationService
|
from synapse.appservice import ApplicationService
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||||
from synapse.types import StateMap
|
from synapse.types import JsonDict, StateMap
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from synapse.storage import Storage
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
|
|
||||||
@ -112,13 +115,13 @@ class EventContext:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def with_state(
|
def with_state(
|
||||||
state_group,
|
state_group: Optional[int],
|
||||||
state_group_before_event,
|
state_group_before_event: Optional[int],
|
||||||
current_state_ids,
|
current_state_ids: Optional[StateMap[str]],
|
||||||
prev_state_ids,
|
prev_state_ids: Optional[StateMap[str]],
|
||||||
prev_group=None,
|
prev_group: Optional[int] = None,
|
||||||
delta_ids=None,
|
delta_ids: Optional[StateMap[str]] = None,
|
||||||
):
|
) -> "EventContext":
|
||||||
return EventContext(
|
return EventContext(
|
||||||
current_state_ids=current_state_ids,
|
current_state_ids=current_state_ids,
|
||||||
prev_state_ids=prev_state_ids,
|
prev_state_ids=prev_state_ids,
|
||||||
@ -129,22 +132,22 @@ class EventContext:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def for_outlier():
|
def for_outlier() -> "EventContext":
|
||||||
"""Return an EventContext instance suitable for persisting an outlier event"""
|
"""Return an EventContext instance suitable for persisting an outlier event"""
|
||||||
return EventContext(
|
return EventContext(
|
||||||
current_state_ids={},
|
current_state_ids={},
|
||||||
prev_state_ids={},
|
prev_state_ids={},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
|
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
|
||||||
"""Converts self to a type that can be serialized as JSON, and then
|
"""Converts self to a type that can be serialized as JSON, and then
|
||||||
deserialized by `deserialize`
|
deserialized by `deserialize`
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (FrozenEvent): The event that this context relates to
|
event: The event that this context relates to
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict
|
The serialized event.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We don't serialize the full state dicts, instead they get pulled out
|
# We don't serialize the full state dicts, instead they get pulled out
|
||||||
@ -170,17 +173,16 @@ class EventContext:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def deserialize(storage, input):
|
def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
|
||||||
"""Converts a dict that was produced by `serialize` back into a
|
"""Converts a dict that was produced by `serialize` back into a
|
||||||
EventContext.
|
EventContext.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage (Storage): Used to convert AS ID to AS object and fetch
|
storage: Used to convert AS ID to AS object and fetch state.
|
||||||
state.
|
input: A dict produced by `serialize`
|
||||||
input (dict): A dict produced by `serialize`
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EventContext
|
The event context.
|
||||||
"""
|
"""
|
||||||
context = _AsyncEventContextImpl(
|
context = _AsyncEventContextImpl(
|
||||||
# We use the state_group and prev_state_id stuff to pull the
|
# We use the state_group and prev_state_id stuff to pull the
|
||||||
@ -241,22 +243,25 @@ class EventContext:
|
|||||||
await self._ensure_fetched()
|
await self._ensure_fetched()
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
async def get_prev_state_ids(self):
|
async def get_prev_state_ids(self) -> StateMap[str]:
|
||||||
"""
|
"""
|
||||||
Gets the room state map, excluding this event.
|
Gets the room state map, excluding this event.
|
||||||
|
|
||||||
For a non-state event, this will be the same as get_current_state_ids().
|
For a non-state event, this will be the same as get_current_state_ids().
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[(str, str), str]|None: Returns None if state_group
|
Returns {} if state_group is None, which happens when the associated
|
||||||
is None, which happens when the associated event is an outlier.
|
event is an outlier.
|
||||||
|
|
||||||
Maps a (type, state_key) to the event ID of the state event matching
|
Maps a (type, state_key) to the event ID of the state event matching
|
||||||
this tuple.
|
this tuple.
|
||||||
"""
|
"""
|
||||||
await self._ensure_fetched()
|
await self._ensure_fetched()
|
||||||
|
# There *should* be previous state IDs now.
|
||||||
|
assert self._prev_state_ids is not None
|
||||||
return self._prev_state_ids
|
return self._prev_state_ids
|
||||||
|
|
||||||
def get_cached_current_state_ids(self):
|
def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||||
"""Gets the current state IDs if we have them already cached.
|
"""Gets the current state IDs if we have them already cached.
|
||||||
|
|
||||||
It is an error to access this for a rejected event, since rejected state should
|
It is an error to access this for a rejected event, since rejected state should
|
||||||
@ -264,16 +269,17 @@ class EventContext:
|
|||||||
``rejected`` is set.
|
``rejected`` is set.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict[(str, str), str]|None: Returns None if we haven't cached the
|
Returns None if we haven't cached the state or if state_group is None
|
||||||
state or if state_group is None, which happens when the associated
|
(which happens when the associated event is an outlier).
|
||||||
event is an outlier.
|
|
||||||
|
Otherwise, returns the the current state IDs.
|
||||||
"""
|
"""
|
||||||
if self.rejected:
|
if self.rejected:
|
||||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||||
|
|
||||||
return self._current_state_ids
|
return self._current_state_ids
|
||||||
|
|
||||||
async def _ensure_fetched(self):
|
async def _ensure_fetched(self) -> None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
|
|||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
|
|
||||||
_storage (Storage)
|
_storage
|
||||||
|
|
||||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
_fetching_state_deferred: Resolves when *_state_ids have been calculated.
|
||||||
been calculated. None if we haven't started calculating yet
|
None if we haven't started calculating yet
|
||||||
|
|
||||||
_event_type (str): The type of the event the context is associated with.
|
_event_type: The type of the event the context is associated with.
|
||||||
|
|
||||||
_event_state_key (str): The state_key of the event the context is
|
_event_state_key: The state_key of the event the context is associated with.
|
||||||
associated with.
|
|
||||||
|
|
||||||
_prev_state_id (str|None): If the event associated with the context is
|
_prev_state_id: If the event associated with the context is a state event,
|
||||||
a state event, then `_prev_state_id` is the event_id of the state
|
then `_prev_state_id` is the event_id of the state that was replaced.
|
||||||
that was replaced.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This needs to have a default as we're inheriting
|
# This needs to have a default as we're inheriting
|
||||||
_storage = attr.ib(default=None)
|
_storage: "Storage" = attr.ib(default=None)
|
||||||
_prev_state_id = attr.ib(default=None)
|
_prev_state_id: Optional[str] = attr.ib(default=None)
|
||||||
_event_type = attr.ib(default=None)
|
_event_type: str = attr.ib(default=None)
|
||||||
_event_state_key = attr.ib(default=None)
|
_event_state_key: Optional[str] = attr.ib(default=None)
|
||||||
_fetching_state_deferred = attr.ib(default=None)
|
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
|
||||||
|
|
||||||
async def _ensure_fetched(self):
|
async def _ensure_fetched(self) -> None:
|
||||||
if not self._fetching_state_deferred:
|
if not self._fetching_state_deferred:
|
||||||
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||||
|
|
||||||
return await make_deferred_yieldable(self._fetching_state_deferred)
|
await make_deferred_yieldable(self._fetching_state_deferred)
|
||||||
|
|
||||||
async def _fill_out_state(self):
|
async def _fill_out_state(self) -> None:
|
||||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||||
attributes by loading from the database.
|
attributes by loading from the database.
|
||||||
"""
|
"""
|
||||||
if self.state_group is None:
|
if self.state_group is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
|
current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||||
self.state_group
|
self.state_group
|
||||||
)
|
)
|
||||||
|
# Set this separately so mypy knows current_state_ids is not None.
|
||||||
|
self._current_state_ids = current_state_ids
|
||||||
if self._event_state_key is not None:
|
if self._event_state_key is not None:
|
||||||
self._prev_state_ids = dict(self._current_state_ids)
|
self._prev_state_ids = dict(current_state_ids)
|
||||||
|
|
||||||
key = (self._event_type, self._event_state_key)
|
key = (self._event_type, self._event_state_key)
|
||||||
if self._prev_state_id:
|
if self._prev_state_id:
|
||||||
@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
|
|||||||
else:
|
else:
|
||||||
self._prev_state_ids.pop(key, None)
|
self._prev_state_ids.pop(key, None)
|
||||||
else:
|
else:
|
||||||
self._prev_state_ids = self._current_state_ids
|
self._prev_state_ids = current_state_ids
|
||||||
|
|
||||||
|
|
||||||
def _encode_state_dict(state_dict):
|
def _encode_state_dict(
|
||||||
|
state_dict: Optional[StateMap[str]],
|
||||||
|
) -> Optional[List[Tuple[str, str, str]]]:
|
||||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||||
JSON we need to convert them to a form that can.
|
JSON we need to convert them to a form that can.
|
||||||
"""
|
"""
|
||||||
@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
|
|||||||
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
|
return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
|
||||||
|
|
||||||
|
|
||||||
def _decode_state_dict(input):
|
def _decode_state_dict(
|
||||||
|
input: Optional[List[Tuple[str, str, str]]]
|
||||||
|
) -> Optional[StateMap[str]]:
|
||||||
"""Decodes a state dict encoded using `_encode_state_dict` above"""
|
"""Decodes a state dict encoded using `_encode_state_dict` above"""
|
||||||
if input is None:
|
if input is None:
|
||||||
return None
|
return None
|
||||||
|
@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
|
||||||
"""Wrapper that loads spam checkers configured using the old configuration, and
|
"""Wrapper that loads spam checkers configured using the old configuration, and
|
||||||
registers the spam checker hooks they implement.
|
registers the spam checker hooks they implement.
|
||||||
"""
|
"""
|
||||||
@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
|||||||
request_info: Collection[Tuple[str, str]],
|
request_info: Collection[Tuple[str, str]],
|
||||||
auth_provider_id: Optional[str],
|
auth_provider_id: Optional[str],
|
||||||
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
|
) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
|
||||||
# We've already made sure f is not None above, but mypy doesn't
|
# Assertion required because mypy can't prove we won't
|
||||||
# do well across function boundaries so we need to tell it f is
|
# change `f` back to `None`. See
|
||||||
# definitely not None.
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
assert f is not None
|
assert f is not None
|
||||||
|
|
||||||
return f(
|
return f(
|
||||||
@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
|||||||
"Bad signature for callback check_registration_for_spam",
|
"Bad signature for callback check_registration_for_spam",
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(*args, **kwargs):
|
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
||||||
# mypy doesn't do well across function boundaries so we need to tell it
|
# Assertion required because mypy can't prove we won't change `f`
|
||||||
# wrapped_func is definitely not None.
|
# back to `None`. See
|
||||||
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
assert wrapped_func is not None
|
assert wrapped_func is not None
|
||||||
|
|
||||||
return maybe_awaitable(wrapped_func(*args, **kwargs))
|
return maybe_awaitable(wrapped_func(*args, **kwargs))
|
||||||
@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
|
|||||||
|
|
||||||
|
|
||||||
class SpamChecker:
|
class SpamChecker:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
|
self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
|
||||||
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
|
self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
|
||||||
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
|
self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
|
||||||
@ -209,7 +210,7 @@ class SpamChecker:
|
|||||||
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
CHECK_REGISTRATION_FOR_SPAM_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Register callbacks from module for each hook."""
|
"""Register callbacks from module for each hook."""
|
||||||
if check_event_for_spam is not None:
|
if check_event_for_spam is not None:
|
||||||
self._check_event_for_spam_callbacks.append(check_event_for_spam)
|
self._check_event_for_spam_callbacks.append(check_event_for_spam)
|
||||||
@ -275,7 +276,9 @@ class SpamChecker:
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
|
async def user_may_join_room(
|
||||||
|
self, user_id: str, room_id: str, is_invited: bool
|
||||||
|
) -> bool:
|
||||||
"""Checks if a given users is allowed to join a room.
|
"""Checks if a given users is allowed to join a room.
|
||||||
Not called when a user creates a room.
|
Not called when a user creates a room.
|
||||||
|
|
||||||
@ -285,7 +288,7 @@ class SpamChecker:
|
|||||||
is_invited: Whether the user is invited into the room
|
is_invited: Whether the user is invited into the room
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether the user may join the room
|
Whether the user may join the room
|
||||||
"""
|
"""
|
||||||
for callback in self._user_may_join_room_callbacks:
|
for callback in self._user_may_join_room_callbacks:
|
||||||
if await callback(user_id, room_id, is_invited) is False:
|
if await callback(user_id, room_id, is_invited) is False:
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
|
||||||
"""Wrapper that loads a third party event rules module configured using the old
|
"""Wrapper that loads a third party event rules module configured using the old
|
||||||
configuration, and registers the hooks they implement.
|
configuration, and registers the hooks they implement.
|
||||||
"""
|
"""
|
||||||
@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
|||||||
event: EventBase,
|
event: EventBase,
|
||||||
state_events: StateMap[EventBase],
|
state_events: StateMap[EventBase],
|
||||||
) -> Tuple[bool, Optional[dict]]:
|
) -> Tuple[bool, Optional[dict]]:
|
||||||
# We've already made sure f is not None above, but mypy doesn't do well
|
# Assertion required because mypy can't prove we won't change
|
||||||
# across function boundaries so we need to tell it f is definitely not
|
# `f` back to `None`. See
|
||||||
# None.
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
assert f is not None
|
assert f is not None
|
||||||
|
|
||||||
res = await f(event, state_events)
|
res = await f(event, state_events)
|
||||||
@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
|||||||
async def wrap_on_create_room(
|
async def wrap_on_create_room(
|
||||||
requester: Requester, config: dict, is_requester_admin: bool
|
requester: Requester, config: dict, is_requester_admin: bool
|
||||||
) -> None:
|
) -> None:
|
||||||
# We've already made sure f is not None above, but mypy doesn't do well
|
# Assertion required because mypy can't prove we won't change
|
||||||
# across function boundaries so we need to tell it f is definitely not
|
# `f` back to `None`. See
|
||||||
# None.
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
assert f is not None
|
assert f is not None
|
||||||
|
|
||||||
res = await f(requester, config, is_requester_admin)
|
res = await f(requester, config, is_requester_admin)
|
||||||
@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
|
|||||||
|
|
||||||
return wrap_on_create_room
|
return wrap_on_create_room
|
||||||
|
|
||||||
def run(*args, **kwargs):
|
def run(*args: Any, **kwargs: Any) -> Awaitable:
|
||||||
# mypy doesn't do well across function boundaries so we need to tell it
|
# Assertion required because mypy can't prove we won't change `f`
|
||||||
# f is definitely not None.
|
# back to `None`. See
|
||||||
|
# https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
|
||||||
assert f is not None
|
assert f is not None
|
||||||
|
|
||||||
return maybe_awaitable(f(*args, **kwargs))
|
return maybe_awaitable(f(*args, **kwargs))
|
||||||
@ -162,7 +163,7 @@ class ThirdPartyEventRules:
|
|||||||
check_visibility_can_be_modified: Optional[
|
check_visibility_can_be_modified: Optional[
|
||||||
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
|
||||||
] = None,
|
] = None,
|
||||||
):
|
) -> None:
|
||||||
"""Register callbacks from modules for each hook."""
|
"""Register callbacks from modules for each hook."""
|
||||||
if check_event_allowed is not None:
|
if check_event_allowed is not None:
|
||||||
self._check_event_allowed_callbacks.append(check_event_allowed)
|
self._check_event_allowed_callbacks.append(check_event_allowed)
|
||||||
|
@ -13,18 +13,32 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections.abc
|
import collections.abc
|
||||||
import re
|
import re
|
||||||
from typing import Any, Mapping, Union
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from frozendict import frozendict
|
from frozendict import frozendict
|
||||||
|
|
||||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
|
from synapse.types import JsonDict
|
||||||
from synapse.util.async_helpers import yieldable_gather_results
|
from synapse.util.async_helpers import yieldable_gather_results
|
||||||
from synapse.util.frozenutils import unfreeze
|
from synapse.util.frozenutils import unfreeze
|
||||||
|
|
||||||
from . import EventBase
|
from . import EventBase
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
|
||||||
# (?<!stuff) matches if the current position in the string is not preceded
|
# (?<!stuff) matches if the current position in the string is not preceded
|
||||||
# by a match for 'stuff'.
|
# by a match for 'stuff'.
|
||||||
@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
|
|||||||
return pruned_event
|
return pruned_event
|
||||||
|
|
||||||
|
|
||||||
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
|
||||||
"""Redacts the event_dict in the same way as `prune_event`, except it
|
"""Redacts the event_dict in the same way as `prune_event`, except it
|
||||||
operates on dicts rather than event objects
|
operates on dicts rather than event objects
|
||||||
|
|
||||||
@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
|||||||
|
|
||||||
new_content = {}
|
new_content = {}
|
||||||
|
|
||||||
def add_fields(*fields):
|
def add_fields(*fields: str) -> None:
|
||||||
for field in fields:
|
for field in fields:
|
||||||
if field in event_dict["content"]:
|
if field in event_dict["content"]:
|
||||||
new_content[field] = event_dict["content"][field]
|
new_content[field] = event_dict["content"][field]
|
||||||
@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
|||||||
|
|
||||||
allowed_fields["content"] = new_content
|
allowed_fields["content"] = new_content
|
||||||
|
|
||||||
unsigned = {}
|
unsigned: JsonDict = {}
|
||||||
allowed_fields["unsigned"] = unsigned
|
allowed_fields["unsigned"] = unsigned
|
||||||
|
|
||||||
event_unsigned = event_dict.get("unsigned", {})
|
event_unsigned = event_dict.get("unsigned", {})
|
||||||
@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
|
|||||||
return allowed_fields
|
return allowed_fields
|
||||||
|
|
||||||
|
|
||||||
def _copy_field(src, dst, field):
|
def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
|
||||||
"""Copy the field in 'src' to 'dst'.
|
"""Copy the field in 'src' to 'dst'.
|
||||||
|
|
||||||
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
|
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
|
||||||
then dst={"foo":{"bar":5}}.
|
then dst={"foo":{"bar":5}}.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
src(dict): The dict to read from.
|
src: The dict to read from.
|
||||||
dst(dict): The dict to modify.
|
dst: The dict to modify.
|
||||||
field(list<str>): List of keys to drill down to in 'src'.
|
field: List of keys to drill down to in 'src'.
|
||||||
"""
|
"""
|
||||||
if len(field) == 0: # this should be impossible
|
if len(field) == 0: # this should be impossible
|
||||||
return
|
return
|
||||||
@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
|
|||||||
sub_out_dict[key_to_move] = sub_dict[key_to_move]
|
sub_out_dict[key_to_move] = sub_dict[key_to_move]
|
||||||
|
|
||||||
|
|
||||||
def only_fields(dictionary, fields):
|
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
|
||||||
"""Return a new dict with only the fields in 'dictionary' which are present
|
"""Return a new dict with only the fields in 'dictionary' which are present
|
||||||
in 'fields'.
|
in 'fields'.
|
||||||
|
|
||||||
@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
|
|||||||
A literal '.' character in a field name may be escaped using a '\'.
|
A literal '.' character in a field name may be escaped using a '\'.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dictionary(dict): The dictionary to read from.
|
dictionary: The dictionary to read from.
|
||||||
fields(list<str>): A list of fields to copy over. Only shallow refs are
|
fields: A list of fields to copy over. Only shallow refs are
|
||||||
taken.
|
taken.
|
||||||
Returns:
|
Returns:
|
||||||
dict: A new dictionary with only the given fields. If fields was empty,
|
A new dictionary with only the given fields. If fields was empty,
|
||||||
the same dictionary is returned.
|
the same dictionary is returned.
|
||||||
"""
|
"""
|
||||||
if len(fields) == 0:
|
if len(fields) == 0:
|
||||||
@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
|
|||||||
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
|
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
|
||||||
]
|
]
|
||||||
|
|
||||||
output = {}
|
output: JsonDict = {}
|
||||||
for field_array in split_fields:
|
for field_array in split_fields:
|
||||||
_copy_field(dictionary, output, field_array)
|
_copy_field(dictionary, output, field_array)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def format_event_raw(d):
|
def format_event_raw(d: JsonDict) -> JsonDict:
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def format_event_for_client_v1(d):
|
def format_event_for_client_v1(d: JsonDict) -> JsonDict:
|
||||||
d = format_event_for_client_v2(d)
|
d = format_event_for_client_v2(d)
|
||||||
|
|
||||||
sender = d.get("sender")
|
sender = d.get("sender")
|
||||||
@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def format_event_for_client_v2(d):
|
def format_event_for_client_v2(d: JsonDict) -> JsonDict:
|
||||||
drop_keys = (
|
drop_keys = (
|
||||||
"auth_events",
|
"auth_events",
|
||||||
"prev_events",
|
"prev_events",
|
||||||
@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
|
|||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def format_event_for_client_v2_without_room_id(d):
|
def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
|
||||||
d = format_event_for_client_v2(d)
|
d = format_event_for_client_v2(d)
|
||||||
d.pop("room_id", None)
|
d.pop("room_id", None)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
def serialize_event(
|
def serialize_event(
|
||||||
e,
|
e: Union[JsonDict, EventBase],
|
||||||
time_now_ms,
|
time_now_ms: int,
|
||||||
as_client_event=True,
|
as_client_event: bool = True,
|
||||||
event_format=format_event_for_client_v1,
|
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
|
||||||
token_id=None,
|
token_id: Optional[str] = None,
|
||||||
only_event_fields=None,
|
only_event_fields: Optional[List[str]] = None,
|
||||||
include_stripped_room_state=False,
|
include_stripped_room_state: bool = False,
|
||||||
):
|
) -> JsonDict:
|
||||||
"""Serialize event for clients
|
"""Serialize event for clients
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
e (EventBase)
|
e
|
||||||
time_now_ms (int)
|
time_now_ms
|
||||||
as_client_event (bool)
|
as_client_event
|
||||||
event_format
|
event_format
|
||||||
token_id
|
token_id
|
||||||
only_event_fields
|
only_event_fields
|
||||||
include_stripped_room_state (bool): Some events can have stripped room state
|
include_stripped_room_state: Some events can have stripped room state
|
||||||
stored in the `unsigned` field. This is required for invite and knock
|
stored in the `unsigned` field. This is required for invite and knock
|
||||||
functionality. If this option is False, that state will be removed from the
|
functionality. If this option is False, that state will be removed from the
|
||||||
event before it is returned. Otherwise, it will be kept.
|
event before it is returned. Otherwise, it will be kept.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict
|
The serialized event dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# FIXME(erikj): To handle the case of presence events and the like
|
# FIXME(erikj): To handle the case of presence events and the like
|
||||||
@ -369,25 +383,29 @@ class EventClientSerializer:
|
|||||||
clients.
|
clients.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.experimental_msc1849_support_enabled = (
|
self.experimental_msc1849_support_enabled = (
|
||||||
hs.config.server.experimental_msc1849_support_enabled
|
hs.config.server.experimental_msc1849_support_enabled
|
||||||
)
|
)
|
||||||
|
|
||||||
async def serialize_event(
|
async def serialize_event(
|
||||||
self, event, time_now, bundle_aggregations=True, **kwargs
|
self,
|
||||||
):
|
event: Union[JsonDict, EventBase],
|
||||||
|
time_now: int,
|
||||||
|
bundle_aggregations: bool = True,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> JsonDict:
|
||||||
"""Serializes a single event.
|
"""Serializes a single event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (EventBase)
|
event
|
||||||
time_now (int): The current time in milliseconds
|
time_now: The current time in milliseconds
|
||||||
bundle_aggregations (bool): Whether to bundle in related events
|
bundle_aggregations: Whether to bundle in related events
|
||||||
**kwargs: Arguments to pass to `serialize_event`
|
**kwargs: Arguments to pass to `serialize_event`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: The serialized event
|
The serialized event
|
||||||
"""
|
"""
|
||||||
# To handle the case of presence events and the like
|
# To handle the case of presence events and the like
|
||||||
if not isinstance(event, EventBase):
|
if not isinstance(event, EventBase):
|
||||||
@ -448,25 +466,27 @@ class EventClientSerializer:
|
|||||||
|
|
||||||
return serialized_event
|
return serialized_event
|
||||||
|
|
||||||
def serialize_events(self, events, time_now, **kwargs):
|
async def serialize_events(
|
||||||
|
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
|
||||||
|
) -> List[JsonDict]:
|
||||||
"""Serializes multiple events.
|
"""Serializes multiple events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event (iter[EventBase])
|
event
|
||||||
time_now (int): The current time in milliseconds
|
time_now: The current time in milliseconds
|
||||||
**kwargs: Arguments to pass to `serialize_event`
|
**kwargs: Arguments to pass to `serialize_event`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[list[dict]]: The list of serialized events
|
The list of serialized events
|
||||||
"""
|
"""
|
||||||
return yieldable_gather_results(
|
return await yieldable_gather_results(
|
||||||
self.serialize_event, events, time_now=time_now, **kwargs
|
self.serialize_event, events, time_now=time_now, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def copy_power_levels_contents(
|
def copy_power_levels_contents(
|
||||||
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
|
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
|
||||||
):
|
) -> Dict[str, Union[int, Dict[str, int]]]:
|
||||||
"""Copy the content of a power_levels event, unfreezing frozendicts along the way
|
"""Copy the content of a power_levels event, unfreezing frozendicts along the way
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -475,7 +495,7 @@ def copy_power_levels_contents(
|
|||||||
if not isinstance(old_power_levels, collections.abc.Mapping):
|
if not isinstance(old_power_levels, collections.abc.Mapping):
|
||||||
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
|
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
|
||||||
|
|
||||||
power_levels = {}
|
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
|
||||||
for k, v in old_power_levels.items():
|
for k, v in old_power_levels.items():
|
||||||
|
|
||||||
if isinstance(v, int):
|
if isinstance(v, int):
|
||||||
@ -483,7 +503,8 @@ def copy_power_levels_contents(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(v, collections.abc.Mapping):
|
if isinstance(v, collections.abc.Mapping):
|
||||||
power_levels[k] = h = {}
|
h: Dict[str, int] = {}
|
||||||
|
power_levels[k] = h
|
||||||
for k1, v1 in v.items():
|
for k1, v1 in v.items():
|
||||||
# we should only have one level of nesting
|
# we should only have one level of nesting
|
||||||
if not isinstance(v1, int):
|
if not isinstance(v1, int):
|
||||||
@ -498,7 +519,7 @@ def copy_power_levels_contents(
|
|||||||
return power_levels
|
return power_levels
|
||||||
|
|
||||||
|
|
||||||
def validate_canonicaljson(value: Any):
|
def validate_canonicaljson(value: Any) -> None:
|
||||||
"""
|
"""
|
||||||
Ensure that the JSON object is valid according to the rules of canonical JSON.
|
Ensure that the JSON object is valid according to the rules of canonical JSON.
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import collections.abc
|
import collections.abc
|
||||||
from typing import Union
|
from typing import Iterable, Union
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
|
|
||||||
@ -28,11 +28,11 @@ from synapse.events.utils import (
|
|||||||
validate_canonicaljson,
|
validate_canonicaljson,
|
||||||
)
|
)
|
||||||
from synapse.federation.federation_server import server_matches_acl_event
|
from synapse.federation.federation_server import server_matches_acl_event
|
||||||
from synapse.types import EventID, RoomID, UserID
|
from synapse.types import EventID, JsonDict, RoomID, UserID
|
||||||
|
|
||||||
|
|
||||||
class EventValidator:
|
class EventValidator:
|
||||||
def validate_new(self, event: EventBase, config: HomeServerConfig):
|
def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
|
||||||
"""Validates the event has roughly the right format
|
"""Validates the event has roughly the right format
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -116,7 +116,7 @@ class EventValidator:
|
|||||||
errcode=Codes.BAD_JSON,
|
errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_retention(self, event: EventBase):
|
def _validate_retention(self, event: EventBase) -> None:
|
||||||
"""Checks that an event that defines the retention policy for a room respects the
|
"""Checks that an event that defines the retention policy for a room respects the
|
||||||
format enforced by the spec.
|
format enforced by the spec.
|
||||||
|
|
||||||
@ -156,7 +156,7 @@ class EventValidator:
|
|||||||
errcode=Codes.BAD_JSON,
|
errcode=Codes.BAD_JSON,
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_builder(self, event: Union[EventBase, EventBuilder]):
|
def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
|
||||||
"""Validates that the builder/event has roughly the right format. Only
|
"""Validates that the builder/event has roughly the right format. Only
|
||||||
checks values that we expect a proto event to have, rather than all the
|
checks values that we expect a proto event to have, rather than all the
|
||||||
fields an event would have
|
fields an event would have
|
||||||
@ -204,14 +204,14 @@ class EventValidator:
|
|||||||
|
|
||||||
self._ensure_state_event(event)
|
self._ensure_state_event(event)
|
||||||
|
|
||||||
def _ensure_strings(self, d, keys):
|
def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
|
||||||
for s in keys:
|
for s in keys:
|
||||||
if s not in d:
|
if s not in d:
|
||||||
raise SynapseError(400, "'%s' not in content" % (s,))
|
raise SynapseError(400, "'%s' not in content" % (s,))
|
||||||
if not isinstance(d[s], str):
|
if not isinstance(d[s], str):
|
||||||
raise SynapseError(400, "'%s' not a string type" % (s,))
|
raise SynapseError(400, "'%s' not a string type" % (s,))
|
||||||
|
|
||||||
def _ensure_state_event(self, event):
|
def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
|
||||||
if not event.is_state():
|
if not event.is_state():
|
||||||
raise SynapseError(400, "'%s' must be state events" % (event.type,))
|
raise SynapseError(400, "'%s' must be state events" % (event.type,))
|
||||||
|
|
||||||
@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _create_power_level_validator():
|
# This could return something newer than Draft 7, but that's the current "latest"
|
||||||
|
# validator.
|
||||||
|
def _create_power_level_validator() -> jsonschema.Draft7Validator:
|
||||||
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
|
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
|
||||||
|
|
||||||
# by default jsonschema does not consider a frozendict to be an object so
|
# by default jsonschema does not consider a frozendict to be an object so
|
||||||
|
@ -465,17 +465,35 @@ class RoomCreationHandler:
|
|||||||
# the room has been created
|
# the room has been created
|
||||||
# Calculate the minimum power level needed to clone the room
|
# Calculate the minimum power level needed to clone the room
|
||||||
event_power_levels = power_levels.get("events", {})
|
event_power_levels = power_levels.get("events", {})
|
||||||
|
if not isinstance(event_power_levels, dict):
|
||||||
|
event_power_levels = {}
|
||||||
state_default = power_levels.get("state_default", 50)
|
state_default = power_levels.get("state_default", 50)
|
||||||
|
try:
|
||||||
|
state_default_int = int(state_default) # type: ignore[arg-type]
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
state_default_int = 50
|
||||||
ban = power_levels.get("ban", 50)
|
ban = power_levels.get("ban", 50)
|
||||||
needed_power_level = max(state_default, ban, max(event_power_levels.values()))
|
try:
|
||||||
|
ban = int(ban) # type: ignore[arg-type]
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
ban = 50
|
||||||
|
needed_power_level = max(
|
||||||
|
state_default_int, ban, max(event_power_levels.values())
|
||||||
|
)
|
||||||
|
|
||||||
# Get the user's current power level, this matches the logic in get_user_power_level,
|
# Get the user's current power level, this matches the logic in get_user_power_level,
|
||||||
# but without the entire state map.
|
# but without the entire state map.
|
||||||
user_power_levels = power_levels.setdefault("users", {})
|
user_power_levels = power_levels.setdefault("users", {})
|
||||||
|
if not isinstance(user_power_levels, dict):
|
||||||
|
user_power_levels = {}
|
||||||
users_default = power_levels.get("users_default", 0)
|
users_default = power_levels.get("users_default", 0)
|
||||||
current_power_level = user_power_levels.get(user_id, users_default)
|
current_power_level = user_power_levels.get(user_id, users_default)
|
||||||
|
try:
|
||||||
|
current_power_level_int = int(current_power_level) # type: ignore[arg-type]
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
current_power_level_int = 0
|
||||||
# Raise the requester's power level in the new room if necessary
|
# Raise the requester's power level in the new room if necessary
|
||||||
if current_power_level < needed_power_level:
|
if current_power_level_int < needed_power_level:
|
||||||
user_power_levels[user_id] = needed_power_level
|
user_power_levels[user_id] = needed_power_level
|
||||||
|
|
||||||
await self._send_events_for_new_room(
|
await self._send_events_for_new_room(
|
||||||
|
@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
|
|||||||
# Similarly, we don't allow relations to be applied to relations, so we
|
# Similarly, we don't allow relations to be applied to relations, so we
|
||||||
# return the original relations without any aggregations on top of them
|
# return the original relations without any aggregations on top of them
|
||||||
# here.
|
# here.
|
||||||
events = await self._event_serializer.serialize_events(
|
serialized_events = await self._event_serializer.serialize_events(
|
||||||
events, now, bundle_aggregations=False
|
events, now, bundle_aggregations=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return_value = pagination_chunk.to_dict()
|
return_value = pagination_chunk.to_dict()
|
||||||
return_value["chunk"] = events
|
return_value["chunk"] = serialized_events
|
||||||
return_value["original_event"] = original_event
|
return_value["original_event"] = original_event
|
||||||
|
|
||||||
return 200, return_value
|
return 200, return_value
|
||||||
@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
events = await self._event_serializer.serialize_events(events, now)
|
serialized_events = await self._event_serializer.serialize_events(events, now)
|
||||||
|
|
||||||
return_value = result.to_dict()
|
return_value = result.to_dict()
|
||||||
return_value["chunk"] = events
|
return_value["chunk"] = serialized_events
|
||||||
|
|
||||||
return 200, return_value
|
return 200, return_value
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user