Add type hints for state. (#8140)

This commit is contained in:
Patrick Cloke 2020-08-24 14:25:27 -04:00 committed by GitHub
parent cbd8d83da7
commit 5758dcf30c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 420 additions and 203 deletions

1
changelog.d/8140.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.state`.

47
stubs/frozendict.pyi Normal file
View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Stub for frozendict.
from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.
class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...

View File

@ -329,10 +329,10 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]

View File

@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler):
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
auth_events_map = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}

View File

@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@ -1318,7 +1318,7 @@ async def get_interested_parties(
async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]:
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote

View File

@ -17,7 +17,7 @@ import abc
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@ -738,9 +746,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
async def _can_guest_join(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@ -969,9 +975,7 @@ class RoomMemberHandler(object):
)
return stream_id
async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))

View File

@ -16,11 +16,22 @@
import logging
from collections import namedtuple
from typing import Awaitable, Dict, Iterable, List, Optional, Set
from typing import (
Awaitable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.types import Collection, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@ -68,8 +79,14 @@ def _gen_state_id():
class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
# dict[(str, str), str] map from (type, state_key) to event_id
def __init__(
self,
state: StateMap[str],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
# A map from (type, state_key) to event_id.
self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
@ -107,24 +124,49 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@overload
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
Returns:
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Returns:
map from (type, state_key) to event
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@ -140,34 +182,30 @@ class StateHandler(object):
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
return state
async def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id (str):
latest_event_ids (iterable[str]|None): if given, the forward
extremities to resolve. If None, we look them up from the
database (via a cache)
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from
(event_type, state_key) -> event_id
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
return dict(ret.state)
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
@ -183,32 +221,34 @@ class StateHandler(object):
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id):
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: List[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
Args:
room_id (str):
event_ids (list[str]):
room_id:
event_ids:
Returns:
Deferred[list[str]]: the hosts in the room at the given events
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
return await self.store.get_joined_hosts(room_id, entry)
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
) -> EventContext:
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
@ -221,7 +261,7 @@ class StateHandler(object):
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns:
synapse.events.snapshot.EventContext:
The event context.
"""
if event.internal_metadata.is_outlier():
@ -275,7 +315,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids()
)
state_ids_before_event = entry.state
state_ids_before_event = dict(entry.state)
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
@ -346,19 +386,18 @@ class StateHandler(object):
)
@measure_func()
async def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str]
) -> _StateCacheEntry:
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
room_id (str)
event_ids (list[str])
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
room_id
event_ids
Returns:
Deferred[_StateCacheEntry]: resolved state
The resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -394,7 +433,12 @@ class StateHandler(object):
)
return result
async def resolve_events(self, room_version, state_sets, event):
async def resolve_events(
self,
room_version: str,
state_sets: Collection[Iterable[EventBase]],
event: EventBase,
) -> StateMap[EventBase]:
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@ -414,9 +458,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
class StateResolutionHandler(object):
@ -444,7 +486,12 @@ class StateResolutionHandler(object):
@log_function
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
self,
room_id: str,
room_version: str,
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
"""Resolves conflicts between a set of state groups
@ -452,13 +499,13 @@ class StateResolutionHandler(object):
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging and sanity checks)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
room_id: room we are resolving for (used for logging and sanity checks)
room_version: version of the room
state_groups_ids:
A map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
event_map(dict[str,FrozenEvent]|None):
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@ -466,10 +513,10 @@ class StateResolutionHandler(object):
If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore)
state_res_store
Returns:
_StateCacheEntry: resolved state
The resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
@ -530,21 +577,22 @@ class StateResolutionHandler(object):
return cache
def _make_state_cache_entry(new_state, state_groups_ids):
def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
Args:
new_state (dict[(str, str), str]): resolved state map (mapping from
(type, state_key) to event_id)
new_state: resolved state map (mapping from (type, state_key) to event_id)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
state_groups_ids:
map from state group id to the state in that state group (where
'state' is a map from state key to event id)
Returns:
_StateCacheEntry
The cache entry.
"""
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@ -585,7 +633,7 @@ def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]:
@ -633,15 +681,17 @@ class StateResolutionStore(object):
store = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
allow_rejected (bool): If True return rejected events.
event_ids: The event_ids of the events to fetch
allow_rejected: If True return rejected events.
Returns:
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
@ -651,7 +701,9 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
def get_auth_chain_difference(
self, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@ -660,7 +712,7 @@ class StateResolutionStore(object):
chain.
Returns:
Deferred[Set[str]]: Set of event IDs.
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(state_sets)

View File

@ -15,7 +15,17 @@
import hashlib
import logging
from typing import Awaitable, Callable, Dict, List, Optional
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
from synapse import event_auth
from synapse.api.constants import EventTypes
@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[List[str]], Awaitable],
):
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
room_id: the room we are working in
@ -56,8 +66,7 @@ async def resolve_events_with_store(
an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
A map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@ -75,8 +84,8 @@ async def resolve_events_with_store(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
# A map from state event id to event. Only includes the state events which
# are in conflict (and those in event_map).
state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@ -91,8 +100,6 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@ -122,18 +129,19 @@ async def resolve_events_with_store(
)
def _seperate(state_sets):
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
state_sets:
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
@ -144,7 +152,7 @@ def _seperate(state_sets):
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
conflicted_state = {} # type: StateMap[Set[str]]
for state_set in state_set_iterator:
for key, value in state_set.items():
@ -171,7 +179,21 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
def _create_auth_events_from_maps(
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
) -> StateMap[str]:
"""
Args:
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
Returns:
A map from state key to event id.
"""
auth_events = {}
for event_ids in conflicted_state.values():
for event_id in event_ids:
@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
auth_event_id = unconflicted_state.get(key, None)
if auth_event_id:
auth_events[key] = auth_event_id
return auth_events
def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
unconflicted_state_ids: StateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
):
conflicted_state = {}
for key, event_ids in conflicted_state_ids.items():
@ -215,7 +240,9 @@ def _resolve_with_state(
return new_state
def _resolve_state_events(conflicted_state, auth_events):
def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.
@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state
def _resolve_auth_events(events, auth_events):
def _resolve_auth_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event
def _resolve_normal_events(events, auth_events):
def _resolve_normal_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event
def _ordered_events(events):
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/

View File

@ -16,7 +16,21 @@
import heapq
import itertools
import logging
from typing import Dict, List, Optional
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
overload,
)
from typing_extensions import Literal
import synapse.state
from synapse import event_auth
@ -40,10 +54,10 @@ async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm
Args:
@ -63,8 +77,7 @@ async def resolve_events_with_store(
state_res_store:
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
A map from (type, state_key) to event_id.
"""
logger.debug("Computing conflicted state")
@ -171,18 +184,23 @@ async def resolve_events_with_store(
return resolved_state
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
async def _get_power_level_for_sender(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Args:
room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
room_id
event_id
event_map
state_res_store
Returns:
Deferred[int]
The power level.
"""
event = await _get_event(room_id, event_id, event_map, state_res_store)
@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
return int(level)
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
async def _get_auth_chain_difference(
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Args:
state_sets (list)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
state_sets
event_map
state_res_store
Returns:
Deferred[set[str]]: Set of event IDs
Set of event IDs
"""
difference = await state_res_store.get_auth_chain_difference(
@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
return difference
def _seperate(state_sets):
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key.
Args:
state_sets (list)
state_sets
Returns:
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
conflicted state dict is a map from type/state_key to set of event IDs
A tuple of unconflicted and conflicted state. The conflicted state dict
is a map from type/state_key to set of event IDs
"""
unconflicted_state = {}
conflicted_state = {}
@ -260,18 +284,20 @@ def _seperate(state_sets):
event_ids.discard(None)
conflicted_state[key] = event_ids
return unconflicted_state, conflicted_state
# mypy doesn't understand that discarding None above means that conflicted
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
return unconflicted_state, conflicted_state # type: ignore
def _is_power_event(event):
def _is_power_event(event: EventBase) -> bool:
"""Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm
Args:
event (FrozenEvent)
event
Returns:
boolean
True if the event is a power event.
"""
if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""),
@ -288,19 +314,23 @@ def _is_power_event(event):
async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
graph: Dict[str, Set[str]],
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs
room_id (str): the room we are working in
event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
graph: A map from event ID to the events auth event IDs
room_id: the room we are working in
event_id: Event to add to the graph
event_map
state_res_store
auth_diff: Set of event IDs that are in the auth difference.
"""
state = [event_id]
@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
clock: Clock,
room_id: str,
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
clock (Clock)
room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
clock
room_id: the room we are working in
event_ids: The events to sort
event_map
state_res_store
auth_diff: Set of event IDs that are in the auth difference.
Returns:
Deferred[list[str]]: The sorted list
The sorted list
"""
graph = {}
graph = {} # type: Dict[str, Set[str]]
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
clock: Clock,
room_id: str,
room_version: str,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> StateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
clock (Clock)
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
clock
room_id
room_version
event_ids: Ordered list of events to apply auth checks to
base_state: The set of state to start with
event_map
state_res_store
Returns:
Deferred[StateMap[str]]: Returns the final updated state
Returns the final updated state
"""
resolved_state = base_state.copy()
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@ -439,21 +480,26 @@ async def _iterative_auth_checks(
async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
clock: Clock,
room_id: str,
event_ids: List[str],
resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
clock (Clock)
room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
clock
room_id: room we're working in
event_ids: Events to sort
resolved_power_event_id: The final resolved power level event ID
event_map
state_res_store
Returns:
Deferred[list[str]]: The sorted list
The sorted list
"""
if not event_ids:
# It's possible for there to be no event IDs here to sort, so we can
@ -505,59 +551,90 @@ async def _mainline_sort(
async def _get_mainline_depth_for_event(
event, mainline_map, event_map, state_res_store
):
event: EventBase,
mainline_map: Dict[str, int],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Get the mainline depths for the given event based on the mainline map
Args:
event (FrozenEvent)
mainline_map (dict[str, int]): Map from event_id to mainline depth for
events in the mainline.
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
event
mainline_map: Map from event_id to mainline depth for events in the mainline.
event_map
state_res_store
Returns:
Deferred[int]
The mainline depth
"""
room_id = event.room_id
tmp_event = event # type: Optional[EventBase]
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while event:
while tmp_event:
depth = mainline_map.get(event.event_id)
if depth is not None:
return depth
auth_events = event.auth_event_ids()
event = None
auth_events = tmp_event.auth_event_ids()
tmp_event = None
for aid in auth_events:
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev
tmp_event = aev
break
# Didn't find a power level auth event, so we just return 0
return 0
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[False] = False,
) -> EventBase:
...
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[True],
) -> Optional[EventBase]:
...
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: bool = False,
) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
allow_none (bool): if the event is not found, return None rather than raising
room_id
event_id
event_map
state_res_store
allow_none: if the event is not found, return None rather than raising
an exception
Returns:
Deferred[Optional[FrozenEvent]]
The event, or none if the event does not exist (and allow_none is True).
"""
if event_id not in event_map:
events = await state_res_store.get_events([event_id], allow_rejected=True)
@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
return event
def lexicographical_topological_sort(graph, key):
def lexicographical_topological_sort(
graph: Dict[str, Set[str]], key: Callable[[str], Any]
) -> Generator[str, None, None]:
"""Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B
@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
NOTE: `graph` is modified during the sort.
Args:
graph (dict[str, set[str]]): A representation of the graph where each
node is a key in the dict and its value are the nodes edges.
key (func): A function that takes a node and returns a value that is
comparable and used to order nodes
graph: A representation of the graph where each node is a key in the
dict and its value are the nodes edges.
key: A function that takes a node and returns a value that is comparable
and used to order nodes
Yields:
str: The next node in the topological sort
The next node in the topological sort
"""
# Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
reverse_graph = {}
reverse_graph = {} # type: Dict[str, Set[str]]
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing

View File

@ -209,6 +209,7 @@ commands = mypy \
synapse/server.py \
synapse/server_notices \
synapse/spam_checker_api \
synapse/state \
synapse/storage/databases/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \