Add type hints for state. (#8140)

This commit is contained in:
Patrick Cloke 2020-08-24 14:25:27 -04:00 committed by GitHub
parent cbd8d83da7
commit 5758dcf30c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 420 additions and 203 deletions

1
changelog.d/8140.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to `synapse.state`.

47
stubs/frozendict.pyi Normal file
View File

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Stub for frozendict.
from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.
class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...

View File

@ -329,10 +329,10 @@ class FederationSender(object):
room_id = receipt.room_id room_id = receipt.room_id
# Work out which remote servers should be poked and poke them. # Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id) domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [ domains = [
d d
for d in domains for d in domains_set
if d != self.server_name if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d) and self._federation_shard_config.should_handle(self._instance_name, d)
] ]

View File

@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler):
) )
state_sets = list(state_sets.values()) state_sets = list(state_sets.values())
state_sets.append(state) state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events( current_states = await self.state_handler.resolve_events(
room_version, state_sets, event room_version, state_sets, event
) )
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()} current_state_ids = {k: e.event_id for k, e in current_states.items()}
else: else:
current_state_ids = await self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids event.room_id, latest_event_ids=extrem_ids
@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state # Now check if event pass auth against said current state
auth_types = auth_types_for_event(event) auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types] current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
auth_events_map = await self.store.get_events(current_state_ids) auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = { current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values() (e.type, e.state_key): e for e in auth_events_map.values()
} }

View File

@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -1318,7 +1318,7 @@ async def get_interested_parties(
async def get_interested_remotes( async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]: ) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers """Given a list of presence states figure out which remote servers
should be sent which. should be sent which.
@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each each tuple the list of UserPresenceState should be sent to each
destination destination
""" """
hosts_and_states = [] hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit # First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote # subscriptions), then for each distinct room we look up the remote

View File

@ -17,7 +17,7 @@ import abc
import logging import logging
import random import random
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room from synapse.util.distributor import user_joined_room, user_left_room
@ -738,9 +746,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN: if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id) await self._user_left_room(target_user, room_id)
async def _can_guest_join( async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
""" """
Returns whether a guest can join a room based on its current state. Returns whether a guest can join a room based on its current state.
""" """
@ -969,9 +975,7 @@ class RoomMemberHandler(object):
) )
return stream_id return stream_id
async def _is_host_in_room( async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
# Have we just created the room, and is this about to be the very # Have we just created the room, and is this about to be the very
# first member event? # first member event?
create_event_id = current_state_ids.get(("m.room.create", "")) create_event_id = current_state_ids.get(("m.room.create", ""))

View File

@ -16,11 +16,22 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Awaitable, Dict, Iterable, List, Optional, Set from typing import (
Awaitable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Union,
overload,
)
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from prometheus_client import Histogram from prometheus_client import Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap from synapse.types import Collection, StateMap
from synapse.util import Clock from synapse.util import Clock
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -68,8 +79,14 @@ def _gen_state_id():
class _StateCacheEntry(object): class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None): def __init__(
# dict[(str, str), str] map from (type, state_key) to event_id self,
state: StateMap[str],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
# A map from (type, state_key) to event_id.
self.state = frozendict(state) self.state = frozendict(state)
# the ID of a state group if one and only one is involved. # the ID of a state group if one and only one is involved.
@ -107,24 +124,49 @@ class StateHandler(object):
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@overload
async def get_current_state( async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None self,
): room_id: str,
""" Retrieves the current state for the room. This is done by event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts. event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send This is equivalent to getting the state of an event that were to send
next before receiving any new events. next before receiving any new events.
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Returns: Returns:
map from (type, state_key) to event If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
""" """
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@ -140,34 +182,30 @@ class StateHandler(object):
state_map = await self.store.get_events( state_map = await self.store.get_events(
list(state.values()), get_prev_content=False list(state.values()), get_prev_content=False
) )
state = { return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
} }
return state async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
async def get_current_state_ids(self, room_id, latest_event_ids=None): ) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room """Get the current state, or the state at a set of events, for a room
Args: Args:
room_id (str): room_id:
latest_event_ids: if given, the forward extremities to resolve. If
latest_event_ids (iterable[str]|None): if given, the forward None, we look them up from the database (via a cache).
extremities to resolve. If None, we look them up from the
database (via a cache)
Returns: Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from the state dict, mapping from (event_type, state_key) -> event_id
(event_type, state_key) -> event_id
""" """
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state return dict(ret.state)
return state
async def get_current_users_in_room( async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None self, room_id: str, latest_event_ids: Optional[List[str]] = None
@ -183,32 +221,34 @@ class StateHandler(object):
""" """
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room") logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry) return await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
async def get_current_hosts_in_room(self, room_id): async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id) event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids) return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(self, room_id, event_ids): async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: List[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids """Get the hosts that were in a room at the given event ids
Args: Args:
room_id (str): room_id:
event_ids (list[str]): event_ids:
Returns: Returns:
Deferred[list[str]]: the hosts in the room at the given events The hosts in the room at the given events
""" """
entry = await self.resolve_state_groups_for_events(room_id, event_ids) entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry) return await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
async def compute_event_context( async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
): ) -> EventContext:
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.
This works out what the current state should be for the event, and This works out what the current state should be for the event, and
@ -221,7 +261,7 @@ class StateHandler(object):
when receiving an event from federation where we don't have the when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling. prev events for, e.g. when backfilling.
Returns: Returns:
synapse.events.snapshot.EventContext: The event context.
""" """
if event.internal_metadata.is_outlier(): if event.internal_metadata.is_outlier():
@ -275,7 +315,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids() event.room_id, event.prev_event_ids()
) )
state_ids_before_event = entry.state state_ids_before_event = dict(entry.state)
state_group_before_event = entry.state_group state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids deltas_to_state_group_before_event = entry.delta_ids
@ -346,19 +386,18 @@ class StateHandler(object):
) )
@measure_func() @measure_func()
async def resolve_state_groups_for_events(self, room_id, event_ids): async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str]
) -> _StateCacheEntry:
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
Args: Args:
room_id (str) room_id
event_ids (list[str]) event_ids
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
Returns: Returns:
Deferred[_StateCacheEntry]: resolved state The resolved state
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
@ -394,7 +433,12 @@ class StateHandler(object):
) )
return result return result
async def resolve_events(self, room_version, state_sets, event): async def resolve_events(
self,
room_version: str,
state_sets: Collection[Iterable[EventBase]],
event: EventBase,
) -> StateMap[EventBase]:
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
@ -414,9 +458,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store), state_res_store=StateResolutionStore(self.store),
) )
new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()} return {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state
class StateResolutionHandler(object): class StateResolutionHandler(object):
@ -444,7 +486,12 @@ class StateResolutionHandler(object):
@log_function @log_function
async def resolve_state_groups( async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store self,
room_id: str,
room_version: str,
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
): ):
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
@ -452,13 +499,13 @@ class StateResolutionHandler(object):
not be called for a single state group not be called for a single state group
Args: Args:
room_id (str): room we are resolving for (used for logging and sanity checks) room_id: room we are resolving for (used for logging and sanity checks)
room_version (str): version of the room room_version: version of the room
state_groups_ids (dict[int, dict[(str, str), str]]): state_groups_ids:
map from state group id to the state in that state group A map from state group id to the state in that state group
(where 'state' is a map from state key to event id) (where 'state' is a map from state key to event id)
event_map(dict[str,FrozenEvent]|None): event_map:
a dict from event_id to event, for any events that we happen to a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing used as a starting point fof finding the state we need; any missing
@ -466,10 +513,10 @@ class StateResolutionHandler(object):
If None, all events will be fetched via state_res_store. If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore) state_res_store
Returns: Returns:
_StateCacheEntry: resolved state The resolved state
""" """
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
@ -530,21 +577,22 @@ class StateResolutionHandler(object):
return cache return cache
def _make_state_cache_entry(new_state, state_groups_ids): def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base """Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry. _StateCacheEntry.
Args: Args:
new_state (dict[(str, str), str]): resolved state map (mapping from new_state: resolved state map (mapping from (type, state_key) to event_id)
(type, state_key) to event_id)
state_groups_ids (dict[int, dict[(str, str), str]]): state_groups_ids:
map from state group id to the state in that state group map from state group id to the state in that state group (where
(where 'state' is a map from state key to event id) 'state' is a map from state key to event id)
Returns: Returns:
_StateCacheEntry The cache entry.
""" """
# if the new state matches any of the input state groups, we can # if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id # use that state group again. Otherwise we will generate a state_id
@ -585,7 +633,7 @@ def resolve_events_with_store(
clock: Clock, clock: Clock,
room_id: str, room_id: str,
room_version: str, room_version: str,
state_sets: List[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]: ) -> Awaitable[StateMap[str]]:
@ -633,15 +681,17 @@ class StateResolutionStore(object):
store = attr.ib() store = attr.ib()
def get_events(self, event_ids, allow_rejected=False): def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database """Get events from the database
Args: Args:
event_ids (list): The event_ids of the events to fetch event_ids: The event_ids of the events to fetch
allow_rejected (bool): If True return rejected events. allow_rejected: If True return rejected events.
Returns: Returns:
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event. An awaitable which resolves to a dict from event_id to event.
""" """
return self.store.get_events( return self.store.get_events(
@ -651,7 +701,9 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected, allow_rejected=allow_rejected,
) )
def get_auth_chain_difference(self, state_sets: List[Set[str]]): def get_auth_chain_difference(
self, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as """Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm). per state res v2 algorithm).
@ -660,7 +712,7 @@ class StateResolutionStore(object):
chain. chain.
Returns: Returns:
Deferred[Set[str]]: Set of event IDs. An awaitable that resolves to a set of event IDs.
""" """
return self.store.get_auth_chain_difference(state_sets) return self.store.get_auth_chain_difference(state_sets)

View File

@ -15,7 +15,17 @@
import hashlib import hashlib
import logging import logging
from typing import Awaitable, Callable, Dict, List, Optional from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store( async def resolve_events_with_store(
room_id: str, room_id: str,
state_sets: List[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[List[str]], Awaitable], state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
): ) -> StateMap[str]:
""" """
Args: Args:
room_id: the room we are working in room_id: the room we are working in
@ -56,8 +66,7 @@ async def resolve_events_with_store(
an Awaitable that resolves to a dict of event_id to event. an Awaitable that resolves to a dict of event_id to event.
Returns: Returns:
Deferred[dict[(str, str), str]]: A map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
""" """
if len(state_sets) == 1: if len(state_sets) == 1:
return state_sets[0] return state_sets[0]
@ -75,8 +84,8 @@ async def resolve_events_with_store(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
) )
# dict[str, FrozenEvent]: a map from state event id to event. Only includes # A map from state event id to event. Only includes the state events which
# the state events which are in conflict (and those in event_map) # are in conflict (and those in event_map).
state_map = await state_map_factory(needed_events) state_map = await state_map_factory(needed_events)
if event_map is not None: if event_map is not None:
state_map.update(event_map) state_map.update(event_map)
@ -91,8 +100,6 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the # get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state. # conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps( auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )
@ -122,29 +129,30 @@ async def resolve_events_with_store(
) )
def _seperate(state_sets): def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and """Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated which aren't. i.e., which have multiple different event_ids associated
with them in different state sets. with them in different state sets.
Args: Args:
state_sets(iterable[dict[(str, str), str]]): state_sets:
List of dicts of (type, state_key) -> event_id, which are the List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve. different state groups to resolve.
Returns: Returns:
(dict[(str, str), str], dict[(str, str), set[str]]): A tuple of (unconflicted_state, conflicted_state), where:
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys. for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys. event ids for conflicted state keys.
""" """
state_set_iterator = iter(state_sets) state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator)) unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {} conflicted_state = {} # type: StateMap[Set[str]]
for state_set in state_set_iterator: for state_set in state_set_iterator:
for key, value in state_set.items(): for key, value in state_set.items():
@ -171,7 +179,21 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state return unconflicted_state, conflicted_state
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): def _create_auth_events_from_maps(
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
) -> StateMap[str]:
"""
Args:
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
Returns:
A map from state key to event id.
"""
auth_events = {} auth_events = {}
for event_ids in conflicted_state.values(): for event_ids in conflicted_state.values():
for event_id in event_ids: for event_id in event_ids:
@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
keys = event_auth.auth_types_for_event(state_map[event_id]) keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys: for key in keys:
if key not in auth_events: if key not in auth_events:
event_id = unconflicted_state.get(key, None) auth_event_id = unconflicted_state.get(key, None)
if event_id: if auth_event_id:
auth_events[key] = event_id auth_events[key] = auth_event_id
return auth_events return auth_events
def _resolve_with_state( def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map unconflicted_state_ids: StateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
): ):
conflicted_state = {} conflicted_state = {}
for key, event_ids in conflicted_state_ids.items(): for key, event_ids in conflicted_state_ids.items():
@ -215,7 +240,9 @@ def _resolve_with_state(
return new_state return new_state
def _resolve_state_events(conflicted_state, auth_events): def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to """ This is where we actually decide which of the conflicted state to
use. use.
@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state return resolved_state
def _resolve_auth_events(events, auth_events): def _resolve_auth_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events))) reverse = list(reversed(_ordered_events(events)))
auth_keys = { auth_keys = {
@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event return event
def _resolve_normal_events(events, auth_events): def _resolve_normal_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
for event in _ordered_events(events): for event in _ordered_events(events):
try: try:
# The signatures have already been checked at this point # The signatures have already been checked at this point
@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event return event
def _ordered_events(events): def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e): def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow # we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/ # people to send us events with non-ascii event IDs :/

View File

@ -16,7 +16,21 @@
import heapq import heapq
import itertools import itertools
import logging import logging
from typing import Dict, List, Optional from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
overload,
)
from typing_extensions import Literal
import synapse.state import synapse.state
from synapse import event_auth from synapse import event_auth
@ -40,10 +54,10 @@ async def resolve_events_with_store(
clock: Clock, clock: Clock,
room_id: str, room_id: str,
room_version: str, room_version: str,
state_sets: List[StateMap[str]], state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: "synapse.state.StateResolutionStore",
): ) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm """Resolves the state using the v2 state resolution algorithm
Args: Args:
@ -63,8 +77,7 @@ async def resolve_events_with_store(
state_res_store: state_res_store:
Returns: Returns:
Deferred[dict[(str, str), str]]: A map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
""" """
logger.debug("Computing conflicted state") logger.debug("Computing conflicted state")
@ -171,18 +184,23 @@ async def resolve_events_with_store(
return resolved_state return resolved_state
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): async def _get_power_level_for_sender(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Return the power level of the sender of the given event according to """Return the power level of the sender of the given event according to
their auth events. their auth events.
Args: Args:
room_id (str) room_id
event_id (str) event_id
event_map (dict[str,FrozenEvent]) event_map
state_res_store (StateResolutionStore) state_res_store
Returns: Returns:
Deferred[int] The power level.
""" """
event = await _get_event(room_id, event_id, event_map, state_res_store) event = await _get_event(room_id, event_id, event_map, state_res_store)
@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
return int(level) return int(level)
async def _get_auth_chain_difference(state_sets, event_map, state_res_store): async def _get_auth_chain_difference(
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events """Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains. that only appear in some but not all of the auth chains.
Args: Args:
state_sets (list) state_sets
event_map (dict[str,FrozenEvent]) event_map
state_res_store (StateResolutionStore) state_res_store
Returns: Returns:
Deferred[set[str]]: Set of event IDs Set of event IDs
""" """
difference = await state_res_store.get_auth_chain_difference( difference = await state_res_store.get_auth_chain_difference(
@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
return difference return difference
def _seperate(state_sets): def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Return the unconflicted and conflicted state. This is different than in """Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key. the state sets doesn't have that key.
Args: Args:
state_sets (list) state_sets
Returns: Returns:
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The A tuple of unconflicted and conflicted state. The conflicted state dict
conflicted state dict is a map from type/state_key to set of event IDs is a map from type/state_key to set of event IDs
""" """
unconflicted_state = {} unconflicted_state = {}
conflicted_state = {} conflicted_state = {}
@ -260,18 +284,20 @@ def _seperate(state_sets):
event_ids.discard(None) event_ids.discard(None)
conflicted_state[key] = event_ids conflicted_state[key] = event_ids
return unconflicted_state, conflicted_state # mypy doesn't understand that discarding None above means that conflicted
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
return unconflicted_state, conflicted_state # type: ignore
def _is_power_event(event): def _is_power_event(event: EventBase) -> bool:
"""Return whether or not the event is a "power event", as defined by the """Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm v2 state resolution algorithm
Args: Args:
event (FrozenEvent) event
Returns: Returns:
boolean True if the event is a power event.
""" """
if (event.type, event.state_key) in ( if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""), (EventTypes.PowerLevels, ""),
@ -288,19 +314,23 @@ def _is_power_event(event):
async def _add_event_and_auth_chain_to_graph( async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff graph: Dict[str, Set[str]],
): room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event """Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph and its auth chain (that is in the auth diff) to the graph
Args: Args:
graph (dict[str, set[str]]): A map from event ID to the events auth graph: A map from event ID to the events auth event IDs
event IDs room_id: the room we are working in
room_id (str): the room we are working in event_id: Event to add to the graph
event_id (str): Event to add to the graph event_map
event_map (dict[str,FrozenEvent]) state_res_store
state_res_store (StateResolutionStore) auth_diff: Set of event IDs that are in the auth difference.
auth_diff (set[str]): Set of event IDs that are in the auth difference.
""" """
state = [event_id] state = [event_id]
@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
async def _reverse_topological_power_sort( async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff clock: Clock,
): room_id: str,
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering, """Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts and then by power level and origin_server_ts
Args: Args:
clock (Clock) clock
room_id (str): the room we are working in room_id: the room we are working in
event_ids (list[str]): The events to sort event_ids: The events to sort
event_map (dict[str,FrozenEvent]) event_map
state_res_store (StateResolutionStore) state_res_store
auth_diff (set[str]): Set of event IDs that are in the auth difference. auth_diff: Set of event IDs that are in the auth difference.
Returns: Returns:
Deferred[list[str]]: The sorted list The sorted list
""" """
graph = {} graph = {} # type: Dict[str, Set[str]]
for idx, event_id in enumerate(event_ids, start=1): for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph( await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, auth_diff
@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks( async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store clock: Clock,
): room_id: str,
room_version: str,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> StateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
state as it goes along. state as it goes along.
Args: Args:
clock (Clock) clock
room_id (str) room_id
room_version (str) room_version
event_ids (list[str]): Ordered list of events to apply auth checks to event_ids: Ordered list of events to apply auth checks to
base_state (StateMap[str]): The set of state to start with base_state: The set of state to start with
event_map (dict[str,FrozenEvent]) event_map
state_res_store (StateResolutionStore) state_res_store
Returns: Returns:
Deferred[StateMap[str]]: Returns the final updated state Returns the final updated state
""" """
resolved_state = base_state.copy() resolved_state = base_state.copy()
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@ -439,21 +480,26 @@ async def _iterative_auth_checks(
async def _mainline_sort( async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store clock: Clock,
): room_id: str,
event_ids: List[str],
resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on """Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id the given event resolved_power_event_id
Args: Args:
clock (Clock) clock
room_id (str): room we're working in room_id: room we're working in
event_ids (list[str]): Events to sort event_ids: Events to sort
resolved_power_event_id (str): The final resolved power level event ID resolved_power_event_id: The final resolved power level event ID
event_map (dict[str,FrozenEvent]) event_map
state_res_store (StateResolutionStore) state_res_store
Returns: Returns:
Deferred[list[str]]: The sorted list The sorted list
""" """
if not event_ids: if not event_ids:
# It's possible for there to be no event IDs here to sort, so we can # It's possible for there to be no event IDs here to sort, so we can
@ -505,59 +551,90 @@ async def _mainline_sort(
async def _get_mainline_depth_for_event( async def _get_mainline_depth_for_event(
event, mainline_map, event_map, state_res_store event: EventBase,
): mainline_map: Dict[str, int],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Get the mainline depths for the given event based on the mainline map """Get the mainline depths for the given event based on the mainline map
Args: Args:
event (FrozenEvent) event
mainline_map (dict[str, int]): Map from event_id to mainline depth for mainline_map: Map from event_id to mainline depth for events in the mainline.
events in the mainline. event_map
event_map (dict[str,FrozenEvent]) state_res_store
state_res_store (StateResolutionStore)
Returns: Returns:
Deferred[int] The mainline depth
""" """
room_id = event.room_id room_id = event.room_id
tmp_event = event # type: Optional[EventBase]
# We do an iterative search, replacing `event with the power level in its # We do an iterative search, replacing `event with the power level in its
# auth events (if any) # auth events (if any)
while event: while tmp_event:
depth = mainline_map.get(event.event_id) depth = mainline_map.get(event.event_id)
if depth is not None: if depth is not None:
return depth return depth
auth_events = event.auth_event_ids() auth_events = tmp_event.auth_event_ids()
event = None tmp_event = None
for aid in auth_events: for aid in auth_events:
aev = await _get_event( aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True room_id, aid, event_map, state_res_store, allow_none=True
) )
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev tmp_event = aev
break break
# Didn't find a power level auth event, so we just return 0 # Didn't find a power level auth event, so we just return 0
return 0 return 0
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): @overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[False] = False,
) -> EventBase:
...
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[True],
) -> Optional[EventBase]:
...
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: bool = False,
) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking """Helper function to look up event in event_map, falling back to looking
it up in the store it up in the store
Args: Args:
room_id (str) room_id
event_id (str) event_id
event_map (dict[str,FrozenEvent]) event_map
state_res_store (StateResolutionStore) state_res_store
allow_none (bool): if the event is not found, return None rather than raising allow_none: if the event is not found, return None rather than raising
an exception an exception
Returns: Returns:
Deferred[Optional[FrozenEvent]] The event, or none if the event does not exist (and allow_none is True).
""" """
if event_id not in event_map: if event_id not in event_map:
events = await state_res_store.get_events([event_id], allow_rejected=True) events = await state_res_store.get_events([event_id], allow_rejected=True)
@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
return event return event
def lexicographical_topological_sort(graph, key): def lexicographical_topological_sort(
graph: Dict[str, Set[str]], key: Callable[[str], Any]
) -> Generator[str, None, None]:
"""Performs a lexicographic reverse topological sort on the graph. """Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B This returns a reverse topological sort (i.e. if node A references B then B
@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
NOTE: `graph` is modified during the sort. NOTE: `graph` is modified during the sort.
Args: Args:
graph (dict[str, set[str]]): A representation of the graph where each graph: A representation of the graph where each node is a key in the
node is a key in the dict and its value are the nodes edges. dict and its value are the nodes edges.
key (func): A function that takes a node and returns a value that is key: A function that takes a node and returns a value that is comparable
comparable and used to order nodes and used to order nodes
Yields: Yields:
str: The next node in the topological sort The next node in the topological sort
""" """
# Note, this is basically Kahn's algorithm except we look at nodes with no # Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f. # outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph outdegree_map = graph
reverse_graph = {} reverse_graph = {} # type: Dict[str, Set[str]]
# Lists of nodes with zero out degree. Is actually a tuple of # Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing # `(key(node), node)` so that sorting does the right thing

View File

@ -209,6 +209,7 @@ commands = mypy \
synapse/server.py \ synapse/server.py \
synapse/server_notices \ synapse/server_notices \
synapse/spam_checker_api \ synapse/spam_checker_api \
synapse/state \
synapse/storage/databases/main/ui_auth.py \ synapse/storage/databases/main/ui_auth.py \
synapse/storage/database.py \ synapse/storage/database.py \
synapse/storage/engines \ synapse/storage/engines \