Add StateMap type alias (#6715)

This commit is contained in:
Erik Johnston 2020-01-16 13:31:22 +00:00 committed by GitHub
parent 7b14c4a018
commit d386f2f339
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 115 additions and 93 deletions

1
changelog.d/6715.misc Normal file
View File

@ -0,0 +1 @@
Add StateMap type alias to simplify types.

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Tuple
from six import itervalues from six import itervalues
@ -35,7 +34,7 @@ from synapse.api.errors import (
ResourceLimitError, ResourceLimitError,
) )
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.types import UserID from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -509,10 +508,7 @@ class Auth(object):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
def compute_auth_events( def compute_auth_events(
self, self, event, current_state_ids: StateMap[str], for_verification: bool = False,
event,
current_state_ids: Dict[Tuple[str, str], str],
for_verification: bool = False,
): ):
"""Given an event and current state return the list of event IDs used """Given an event and current state return the list of event IDs used
to auth an event. to auth an event.

View File

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Optional, Tuple, Union from typing import Optional, Union
from six import iteritems from six import iteritems
@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
@attr.s(slots=True) @attr.s(slots=True)
@ -106,13 +107,11 @@ class EventContext:
_state_group = attr.ib(default=None, type=Optional[int]) _state_group = attr.ib(default=None, type=Optional[int])
state_group_before_event = attr.ib(default=None, type=Optional[int]) state_group_before_event = attr.ib(default=None, type=Optional[int])
prev_group = attr.ib(default=None, type=Optional[int]) prev_group = attr.ib(default=None, type=Optional[int])
delta_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]]) delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
app_service = attr.ib(default=None, type=Optional[ApplicationService]) app_service = attr.ib(default=None, type=Optional[ApplicationService])
_current_state_ids = attr.ib( _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
default=None, type=Optional[Dict[Tuple[str, str], str]] _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
)
_prev_state_ids = attr.ib(default=None, type=Optional[Dict[Tuple[str, str], str]])
@staticmethod @staticmethod
def with_state( def with_state(

View File

@ -31,6 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.types import StateMap
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver. # This is defined in the Matrix spec and enforced by the receiver.
@ -77,7 +78,7 @@ class PerDestinationQueue(object):
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id) # based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu # Map of (edu_type, key) -> Edu
self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu] self._pending_edus_keyed = {} # type: StateMap[Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this # Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination # destination

View File

@ -14,9 +14,11 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import RoomStreamToken from synapse.events import FrozenEvent
from synapse.types import RoomStreamToken, StateMap
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -259,35 +261,26 @@ class ExfiltrationWriter(object):
"""Interface used to specify how to write exported data. """Interface used to specify how to write exported data.
""" """
def write_events(self, room_id, events): def write_events(self, room_id: str, events: List[FrozenEvent]):
"""Write a batch of events for a room. """Write a batch of events for a room.
Args:
room_id (str)
events (list[FrozenEvent])
""" """
pass pass
def write_state(self, room_id, event_id, state): def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
"""Write the state at the given event in the room. """Write the state at the given event in the room.
This only gets called for backward extremities rather than for each This only gets called for backward extremities rather than for each
event. event.
Args:
room_id (str)
event_id (str)
state (dict[tuple[str, str], FrozenEvent])
""" """
pass pass
def write_invite(self, room_id, event, state): def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
"""Write an invite for the room, with associated invite state. """Write an invite for the room, with associated invite state.
Args: Args:
room_id (str) room_id
event (FrozenEvent) event
state (dict[tuple[str, str], dict]): A subset of the state at the state: A subset of the state at the
invite, with a subset of the event keys (type, state_key invite, with a subset of the event keys (type, state_key
content and sender) content and sender)
""" """

View File

@ -64,7 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id from synapse.types import StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@ -89,7 +89,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase) event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None) state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None) auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
def shortstr(iterable, maxitems=5): def shortstr(iterable, maxitems=5):
@ -352,9 +352,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen) ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id # state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list( state_maps = list(ours.values()) # type: list[StateMap[str]]
ours.values()
) # type: list[dict[tuple[str, str], str]]
# we don't need this any more, let's delete it. # we don't need this any more, let's delete it.
del ours del ours
@ -1912,7 +1910,7 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
state: Optional[Iterable[EventBase]], state: Optional[Iterable[EventBase]],
auth_events: Optional[Dict[Tuple[str, str], EventBase]], auth_events: Optional[StateMap[EventBase]],
backfilled: bool, backfilled: bool,
): ):
""" """

View File

@ -32,7 +32,15 @@ from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, Syna
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID from synapse.types import (
Requester,
RoomAlias,
RoomID,
RoomStreamToken,
StateMap,
StreamToken,
UserID,
)
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -207,15 +215,19 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _update_upgraded_room_pls( def _update_upgraded_room_pls(
self, requester, old_room_id, new_room_id, old_room_state, self,
requester: Requester,
old_room_id: str,
new_room_id: str,
old_room_state: StateMap[str],
): ):
"""Send updated power levels in both rooms after an upgrade """Send updated power levels in both rooms after an upgrade
Args: Args:
requester (synapse.types.Requester): the user requesting the upgrade requester: the user requesting the upgrade
old_room_id (str): the id of the room to be replaced old_room_id: the id of the room to be replaced
new_room_id (str): the id of the replacement room new_room_id: the id of the replacement room
old_room_state (dict[tuple[str, str], str]): the state map for the old room old_room_state: the state map for the old room
Returns: Returns:
Deferred Deferred

View File

@ -16,7 +16,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple from typing import Dict, Iterable, List, Optional
from six import iteritems, itervalues from six import iteritems, itervalues
@ -33,6 +33,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -594,7 +595,7 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store( def resolve_events_with_store(
room_id: str, room_id: str,
room_version: str, room_version: str,
state_sets: List[Dict[Tuple[str, str], str]], state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
): ):

View File

@ -15,7 +15,7 @@
import hashlib import hashlib
import logging import logging
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional
from six import iteritems, iterkeys, itervalues from six import iteritems, iterkeys, itervalues
@ -26,6 +26,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +37,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks @defer.inlineCallbacks
def resolve_events_with_store( def resolve_events_with_store(
room_id: str, room_id: str,
state_sets: List[Dict[Tuple[str, str], str]], state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable, state_map_factory: Callable,
): ):

View File

@ -16,7 +16,7 @@
import heapq import heapq
import itertools import itertools
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional
from six import iteritems, itervalues from six import iteritems, itervalues
@ -27,6 +27,7 @@ from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,7 +36,7 @@ logger = logging.getLogger(__name__)
def resolve_events_with_store( def resolve_events_with_store(
room_id: str, room_id: str,
room_version: str, room_version: str,
state_sets: List[Dict[Tuple[str, str], str]], state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore", state_res_store: "synapse.state.StateResolutionStore",
): ):
@ -393,12 +394,12 @@ def _iterative_auth_checks(
room_id (str) room_id (str)
room_version (str) room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (dict[tuple[str, str], str]): The set of state to start with base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent]) event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore) state_res_store (StateResolutionStore)
Returns: Returns:
Deferred[dict[tuple[str, str], str]]: Returns the final updated state Deferred[StateMap[str]]: Returns the final updated state
""" """
resolved_state = base_state.copy() resolved_state = base_state.copy()

View File

@ -165,19 +165,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
) )
# FIXME: how should this be cached? # FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()): def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
):
"""Get the current state event of a given type for a room based on the """Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state of doing a fresh state resolution as per state_handler.get_current_state
Args: Args:
room_id (str) room_id
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[tuple[str, str], str]]: Map from type/state_key to defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
event ID.
""" """
where_clause, where_args = state_filter.make_sql_filter_clause() where_clause, where_args = state_filter.make_sql_filter_clause()

View File

@ -15,6 +15,7 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
from six import iteritems from six import iteritems
from six.moves import range from six.moves import range
@ -26,6 +27,7 @@ from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
from synapse.storage.database import Database from synapse.storage.database import Database
from synapse.storage.state import StateFilter from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.caches import get_cache_factor_for from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
@ -133,17 +135,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter): def _get_state_groups_from_groups(
"""Returns the state groups for a given set of groups, filtering on self, groups: List[int], state_filter: StateFilter
types of state events. ):
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
Args: Args:
groups(list[int]): list of state group IDs to query groups: list of state group IDs to query
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
results = {} results = {}
@ -199,18 +202,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
):
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
Args: Args:
groups (iterable[int]): list of state groups for which we want groups: list of state groups for which we want
to get the state. to get the state.
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()
@ -268,24 +272,24 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state return state
def _get_state_for_groups_using_cache(self, groups, cache, state_filter): def _get_state_for_groups_using_cache(
self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache. filtering by type/state_key, querying from a specific cache.
Args: Args:
groups (iterable[int]): list of state groups for which we want groups: list of state groups for which we want to get the state.
to get the state. cache: the cache of group ids to state dicts which
cache (DictionaryCache): the cache of group ids to state dicts which we will pass through - either the normal state cache or the
we will pass through - either the normal state cache or the specific specific members state cache.
members state cache. state_filter: The state filter used to fetch state from the
state_filter (StateFilter): The state filter used to fetch state database.
from the database.
Returns: Returns:
tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of Tuple of dict of state_group_id to state map of entries in the
dict of state_group_id -> (dict of (type, state_key) -> event id) cache, and the state group ids either missing from the cache or
of entries in the cache, and the state group ids either missing incomplete.
from the cache or incomplete.
""" """
results = {} results = {}
incomplete_groups = set() incomplete_groups = set()

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, List, TypeVar
from six import iteritems, itervalues from six import iteritems, itervalues
@ -22,9 +23,13 @@ import attr
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Used for generic functions below
T = TypeVar("T")
@attr.s(slots=True) @attr.s(slots=True)
class StateFilter(object): class StateFilter(object):
@ -233,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types()) return len(self.concrete_types())
def filter_state(self, state_dict): def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter """Returns the state filtered with by this StateFilter
Args: Args:
state (dict[tuple[str, str], Any]): The state map to filter state: The state map to filter
Returns: Returns:
dict[tuple[str, str], Any]: The filtered state map The filtered state map
""" """
if self.is_full(): if self.is_full():
return dict(state_dict) return dict(state_dict)
@ -333,12 +338,12 @@ class StateGroupStorage(object):
def __init__(self, hs, stores): def __init__(self, hs, stores):
self.stores = stores self.stores = stores
def get_state_group_delta(self, state_group): def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
Returns: Returns:
Deferred[Tuple[Optional[int], Optional[list[dict[tuple[str, str], str]]]]]): Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids) (prev_group, delta_ids)
""" """
@ -353,7 +358,7 @@ class StateGroupStorage(object):
event_ids (iterable[str]): ids of the events event_ids (iterable[str]): ids of the events
Returns: Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]: Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id) dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
if not event_ids: if not event_ids:
@ -410,17 +415,18 @@ class StateGroupStorage(object):
for group, event_id_map in iteritems(group_to_ids) for group, event_id_map in iteritems(group_to_ids)
} }
def _get_state_groups_from_groups(self, groups, state_filter): def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
):
"""Returns the state groups for a given set of groups, filtering on """Returns the state groups for a given set of groups, filtering on
types of state events. types of state events.
Args: Args:
groups(list[int]): list of state group IDs to query groups: list of state group IDs to query
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]: Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@ -519,7 +525,9 @@ class StateGroupStorage(object):
state_map = yield self.get_state_ids_for_events([event_id], state_filter) state_map = yield self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
):
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -529,8 +537,7 @@ class StateGroupStorage(object):
state_filter (StateFilter): The state filter used to fetch state state_filter (StateFilter): The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]: Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(groups, state_filter)

View File

@ -17,6 +17,7 @@ import re
import string import string
import sys import sys
from collections import namedtuple from collections import namedtuple
from typing import Dict, Tuple, TypeVar
import attr import attr
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -28,7 +29,7 @@ from synapse.api.errors import SynapseError
if sys.version_info[:3] >= (3, 6, 0): if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection from typing import Collection
else: else:
from typing import Sized, Iterable, Container, TypeVar from typing import Sized, Iterable, Container
T_co = TypeVar("T_co", covariant=True) T_co = TypeVar("T_co", covariant=True)
@ -36,6 +37,12 @@ else:
__slots__ = () __slots__ = ()
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
StateMap = Dict[Tuple[str, str], T]
class Requester( class Requester(
namedtuple( namedtuple(
"Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"] "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]