Add type hints to synapse.events.*. (#11066)

Except `synapse/events/__init__.py`, which will be done in a follow-up.
This commit is contained in:
Patrick Cloke 2021-10-13 07:24:07 -04:00 committed by GitHub
parent cdd308845b
commit 1f9d0b8a7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 208 additions and 145 deletions

View file

@ -13,18 +13,32 @@
# limitations under the License.
import collections.abc
import re
from typing import Any, Mapping, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)
from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze
from . import EventBase
if TYPE_CHECKING:
from synapse.server import HomeServer
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
return pruned_event
def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
new_content = {}
def add_fields(*fields):
def add_fields(*fields: str) -> None:
for field in fields:
if field in event_dict["content"]:
new_content[field] = event_dict["content"][field]
@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
allowed_fields["content"] = new_content
unsigned = {}
unsigned: JsonDict = {}
allowed_fields["unsigned"] = unsigned
event_unsigned = event_dict.get("unsigned", {})
@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
return allowed_fields
def _copy_field(src, dst, field):
def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
"""Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
then dst={"foo":{"bar":5}}.
Args:
src(dict): The dict to read from.
dst(dict): The dict to modify.
field(list<str>): List of keys to drill down to in 'src'.
src: The dict to read from.
dst: The dict to modify.
field: List of keys to drill down to in 'src'.
"""
if len(field) == 0: # this should be impossible
return
@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
sub_out_dict[key_to_move] = sub_dict[key_to_move]
def only_fields(dictionary, fields):
def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
A literal '.' character in a field name may be escaped using a '\'.
Args:
dictionary(dict): The dictionary to read from.
fields(list<str>): A list of fields to copy over. Only shallow refs are
dictionary: The dictionary to read from.
fields: A list of fields to copy over. Only shallow refs are
taken.
Returns:
dict: A new dictionary with only the given fields. If fields was empty,
A new dictionary with only the given fields. If fields was empty,
the same dictionary is returned.
"""
if len(fields) == 0:
@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
output = {}
output: JsonDict = {}
for field_array in split_fields:
_copy_field(dictionary, output, field_array)
return output
def format_event_raw(d):
def format_event_raw(d: JsonDict) -> JsonDict:
return d
def format_event_for_client_v1(d):
def format_event_for_client_v1(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d)
sender = d.get("sender")
@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
return d
def format_event_for_client_v2(d):
def format_event_for_client_v2(d: JsonDict) -> JsonDict:
drop_keys = (
"auth_events",
"prev_events",
@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
return d
def format_event_for_client_v2_without_room_id(d):
def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d)
d.pop("room_id", None)
return d
def serialize_event(
e,
time_now_ms,
as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None,
only_event_fields=None,
include_stripped_room_state=False,
):
e: Union[JsonDict, EventBase],
time_now_ms: int,
as_client_event: bool = True,
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
token_id: Optional[str] = None,
only_event_fields: Optional[List[str]] = None,
include_stripped_room_state: bool = False,
) -> JsonDict:
"""Serialize event for clients
Args:
e (EventBase)
time_now_ms (int)
as_client_event (bool)
e
time_now_ms
as_client_event
event_format
token_id
only_event_fields
include_stripped_room_state (bool): Some events can have stripped room state
include_stripped_room_state: Some events can have stripped room state
stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept.
Returns:
dict
The serialized event dictionary.
"""
# FIXME(erikj): To handle the case of presence events and the like
@ -369,25 +383,29 @@ class EventClientSerializer:
clients.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = (
hs.config.server.experimental_msc1849_support_enabled
)
async def serialize_event(
self, event, time_now, bundle_aggregations=True, **kwargs
):
self,
event: Union[JsonDict, EventBase],
time_now: int,
bundle_aggregations: bool = True,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Args:
event (EventBase)
time_now (int): The current time in milliseconds
bundle_aggregations (bool): Whether to bundle in related events
event
time_now: The current time in milliseconds
bundle_aggregations: Whether to bundle in related events
**kwargs: Arguments to pass to `serialize_event`
Returns:
dict: The serialized event
The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
@ -448,25 +466,27 @@ class EventClientSerializer:
return serialized_event
def serialize_events(self, events, time_now, **kwargs):
async def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
"""Serializes multiple events.
Args:
event (iter[EventBase])
time_now (int): The current time in milliseconds
event
time_now: The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event`
Returns:
Deferred[list[dict]]: The list of serialized events
The list of serialized events
"""
return yieldable_gather_results(
return await yieldable_gather_results(
self.serialize_event, events, time_now=time_now, **kwargs
)
def copy_power_levels_contents(
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
):
) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way
Raises:
@ -475,7 +495,7 @@ def copy_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels = {}
power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
for k, v in old_power_levels.items():
if isinstance(v, int):
@ -483,7 +503,8 @@ def copy_power_levels_contents(
continue
if isinstance(v, collections.abc.Mapping):
power_levels[k] = h = {}
h: Dict[str, int] = {}
power_levels[k] = h
for k1, v1 in v.items():
# we should only have one level of nesting
if not isinstance(v1, int):
@ -498,7 +519,7 @@ def copy_power_levels_contents(
return power_levels
def validate_canonicaljson(value: Any):
def validate_canonicaljson(value: Any) -> None:
"""
Ensure that the JSON object is valid according to the rules of canonical JSON.