diff --git a/changelog.d/6715.misc b/changelog.d/6715.misc new file mode 100644 index 000000000..8876b0446 --- /dev/null +++ b/changelog.d/6715.misc @@ -0,0 +1 @@ +Add StateMap type alias to simplify types. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index abbc7079a..2cbfab256 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,7 +14,6 @@ # limitations under the License. import logging -from typing import Dict, Tuple from six import itervalues @@ -35,7 +34,7 @@ from synapse.api.errors import ( ResourceLimitError, ) from synapse.config.server import is_threepid_reserved -from synapse.types import UserID +from synapse.types import StateMap, UserID from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -509,10 +508,7 @@ class Auth(object): return self.store.is_server_admin(user) def compute_auth_events( - self, - event, - current_state_ids: Dict[Tuple[str, str], str], - for_verification: bool = False, + self, event, current_state_ids: StateMap[str], for_verification: bool = False, ): """Given an event and current state return the list of event IDs used to auth an event. diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index a44baea36..9ea85e93e 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Union from six import iteritems @@ -23,6 +23,7 @@ from twisted.internet import defer from synapse.appservice import ApplicationService from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import StateMap @attr.s(slots=True) @@ -106,13 +107,11 @@ class EventContext: _state_group = attr.ib(default=None, type=Optional[int]) state_group_before_event = attr.ib(default=None, type=Optional[int]) prev_group = attr.ib(default=None, type=Optional[int]) - delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + delta_ids = attr.ib(default=None, type=Optional[StateMap[str]]) app_service = attr.ib(default=None, type=Optional[ApplicationService]) - _current_state_ids = attr.ib( - default=None, type=Optional[Dict[Tuple[str, str], str]] - ) - _prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) + _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) + _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]]) @staticmethod def with_state( diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index a5b36b182..5012aaea3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.metrics import sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.presence import UserPresenceState +from synapse.types import StateMap from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter # This is defined in the Matrix spec and enforced by the receiver. @@ -77,7 +78,7 @@ class PerDestinationQueue(object): # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # based on their key (e.g. typing events by room_id) # Map of (edu_type, key) -> Edu - self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] + self._pending_edus_keyed = {} # type: StateMap[Edu] # Map of user_id -> UserPresenceState of pending presence to be sent to this # destination diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index a9407553b..60a7c938b 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -14,9 +14,11 @@ # limitations under the License. import logging +from typing import List from synapse.api.constants import Membership -from synapse.types import RoomStreamToken +from synapse.events import FrozenEvent +from synapse.types import RoomStreamToken, StateMap from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -259,35 +261,26 @@ class ExfiltrationWriter(object): """Interface used to specify how to write exported data. """ - def write_events(self, room_id, events): + def write_events(self, room_id: str, events: List[FrozenEvent]): """Write a batch of events for a room. - - Args: - room_id (str) - events (list[FrozenEvent]) """ pass - def write_state(self, room_id, event_id, state): + def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]): """Write the state at the given event in the room. This only gets called for backward extremities rather than for each event. - - Args: - room_id (str) - event_id (str) - state (dict[tuple[str, str], FrozenEvent]) """ pass - def write_invite(self, room_id, event, state): + def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]): """Write an invite for the room, with associated invite state. Args: - room_id (str) - event (FrozenEvent) - state (dict[tuple[str, str], dict]): A subset of the state at the + room_id + event + state: A subset of the state at the invite, with a subset of the event keys (type, state_key content and sender) """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 61b6713c8..d4f9a792f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,7 +64,7 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour -from synapse.types import UserID, get_domain_from_id +from synapse.types import StateMap, UserID, get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -89,7 +89,7 @@ class _NewEventInfo: event = attr.ib(type=EventBase) state = attr.ib(type=Optional[Sequence[EventBase]], default=None) - auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) + auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None) def shortstr(iterable, maxitems=5): @@ -352,9 +352,7 @@ class FederationHandler(BaseHandler): ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id - state_maps = list( - ours.values() - ) # type: list[dict[tuple[str, str], str]] + state_maps = list(ours.values()) # type: list[StateMap[str]] # we don't need this any more, let's delete it. del ours @@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler): origin: str, event: EventBase, state: Optional[Iterable[EventBase]], - auth_events: Optional[Dict[Tuple[str, str], EventBase]], + auth_events: Optional[StateMap[EventBase]], backfilled: bool, ): """ diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9cab2adbf..9f50196ea 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter -from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID +from synapse.types import ( + Requester, + RoomAlias, + RoomID, + RoomStreamToken, + StateMap, + StreamToken, + UserID, +) from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache @@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler): @defer.inlineCallbacks def _update_upgraded_room_pls( - self, requester, old_room_id, new_room_id, old_room_state, + self, + requester: Requester, + old_room_id: str, + new_room_id: str, + old_room_state: StateMap[str], ): """Send updated power levels in both rooms after an upgrade Args: - requester (synapse.types.Requester): the user requesting the upgrade - old_room_id (str): the id of the room to be replaced - new_room_id (str): the id of the replacement room - old_room_state (dict[tuple[str, str], str]): the state map for the old room + requester: the user requesting the upgrade + old_room_id: the id of the room to be replaced + new_room_id: the id of the replacement room + old_room_state: the state map for the old room Returns: Deferred diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 5accc071a..cacd0c0c2 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Optional from six import iteritems, itervalues @@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour +from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids): def resolve_events_with_store( room_id: str, room_version: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "StateResolutionStore", ): diff --git a/synapse/state/v1.py b/synapse/state/v1.py index b2f9865f3..d6c34ce3b 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,7 +15,7 @@ import hashlib import logging -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional from six import iteritems, iterkeys, itervalues @@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks def resolve_events_with_store( room_id: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_map_factory: Callable, ): diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 72fb8a631..6216fdd20 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,7 +16,7 @@ import heapq import itertools import logging -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional from six import iteritems, itervalues @@ -27,6 +27,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events import EventBase +from synapse.types import StateMap logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ logger = logging.getLogger(__name__) def resolve_events_with_store( room_id: str, room_version: str, - state_sets: List[Dict[Tuple[str, str], str]], + state_sets: List[StateMap[str]], event_map: Optional[Dict[str, EventBase]], state_res_store: "synapse.state.StateResolutionStore", ): @@ -393,12 +394,12 @@ def _iterative_auth_checks( room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to - base_state (dict[tuple[str, str], str]): The set of state to start with + base_state (StateMap[str]): The set of state to start with event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) Returns: - Deferred[dict[tuple[str, str], str]]: Returns the final updated state + Deferred[StateMap[str]]: Returns the final updated state """ resolved_state = base_state.copy() diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index d07440e3e..33bebd1c4 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) # FIXME: how should this be cached? - def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): + def get_filtered_current_state_ids( + self, room_id: str, state_filter: StateFilter = StateFilter.all() + ): """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state Args: - room_id (str) - state_filter (StateFilter): The state filter used to fetch state + room_id + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[tuple[str, str], str]]: Map from type/state_key to - event ID. + defer.Deferred[StateMap[str]]: Map from type/state_key to event ID. """ where_clause, where_args = state_filter.make_sql_filter_clause() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index d53695f23..c4ee9b7cc 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -15,6 +15,7 @@ import logging from collections import namedtuple +from typing import Dict, Iterable, List, Set, Tuple from six import iteritems from six.moves import range @@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database from synapse.storage.state import StateFilter +from synapse.types import StateMap from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): + """Returns the state groups for a given set of groups from the + database, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ results = {} @@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state_filter.filter_state(state_dict_ids), not missing_types @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key Args: - groups (iterable[int]): list of state groups for which we want + groups: list of state groups for which we want to get the state. - state_filter (StateFilter): The state filter used to fetch state + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ member_filter, non_member_filter = state_filter.get_member_split() @@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return state - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + def _get_state_for_groups_using_cache( + self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter + ) -> Tuple[Dict[int, StateMap[str]], Set[int]]: """Gets the state at each of a list of state groups, optionally filtering by type/state_key, querying from a specific cache. Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. + groups: list of state groups for which we want to get the state. + cache: the cache of group ids to state dicts which + we will pass through - either the normal state cache or the + specific members state cache. + state_filter: The state filter used to fetch state from the + database. Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. + Tuple of dict of state_group_id to state map of entries in the + cache, and the state group ids either missing from the cache or + incomplete. """ results = {} incomplete_groups = set() diff --git a/synapse/storage/state.py b/synapse/storage/state.py index cbeb58601..c522c8092 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Iterable, List, TypeVar from six import iteritems, itervalues @@ -22,9 +23,13 @@ import attr from twisted.internet import defer from synapse.api.constants import EventTypes +from synapse.types import StateMap logger = logging.getLogger(__name__) +# Used for generic functions below +T = TypeVar("T") + @attr.s(slots=True) class StateFilter(object): @@ -233,14 +238,14 @@ class StateFilter(object): return len(self.concrete_types()) - def filter_state(self, state_dict): + def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]: """Returns the state filtered with by this StateFilter Args: - state (dict[tuple[str, str], Any]): The state map to filter + state: The state map to filter Returns: - dict[tuple[str, str], Any]: The filtered state map + The filtered state map """ if self.is_full(): return dict(state_dict) @@ -333,12 +338,12 @@ class StateGroupStorage(object): def __init__(self, hs, stores): self.stores = stores - def get_state_group_delta(self, state_group): + def get_state_group_delta(self, state_group: int): """Given a state group try to return a previous group and a delta between the old and the new. Returns: - Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): + Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: (prev_group, delta_ids) """ @@ -353,7 +358,7 @@ class StateGroupStorage(object): event_ids (iterable[str]): ids of the events Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: + Deferred[dict[int, StateMap[str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ if not event_ids: @@ -410,17 +415,18 @@ class StateGroupStorage(object): for group, event_id_map in iteritems(group_to_ids) } - def _get_state_groups_from_groups(self, groups, state_filter): + def _get_state_groups_from_groups( + self, groups: List[int], state_filter: StateFilter + ): """Returns the state groups for a given set of groups, filtering on types of state events. Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state + groups: list of state group IDs to query + state_filter: The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_groups_from_groups(groups, state_filter) @@ -519,7 +525,9 @@ class StateGroupStorage(object): state_map = yield self.get_state_ids_for_events([event_id], state_filter) return state_map[event_id] - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + def _get_state_for_groups( + self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() + ): """Gets the state at each of a list of state groups, optionally filtering by type/state_key @@ -529,8 +537,7 @@ class StateGroupStorage(object): state_filter (StateFilter): The state filter used to fetch state from the database. Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) + Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. """ return self.stores.state._get_state_for_groups(groups, state_filter) diff --git a/synapse/types.py b/synapse/types.py index cd996c0b5..65e4d8c18 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -17,6 +17,7 @@ import re import string import sys from collections import namedtuple +from typing import Dict, Tuple, TypeVar import attr from signedjson.key import decode_verify_key_bytes @@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError if sys.version_info[:3] >= (3, 6, 0): from typing import Collection else: - from typing import Sized, Iterable, Container, TypeVar + from typing import Sized, Iterable, Container T_co = TypeVar("T_co", covariant=True) @@ -36,6 +37,12 @@ else: __slots__ = () +# Define a state map type from type/state_key to T (usually an event ID or +# event) +T = TypeVar("T") +StateMap = Dict[Tuple[str, str], T] + + class Requester( namedtuple( "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]