mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Make StateFilter frozen so we can hash it (#10816)
Also enables Mypy for related tests.
This commit is contained in:
parent
14b8c0476f
commit
8eb7cb2e0d
1
changelog.d/10816.misc
Normal file
1
changelog.d/10816.misc
Normal file
@ -0,0 +1 @@
|
||||
Make `StateFilter` frozen so it is hashable.
|
1
mypy.ini
1
mypy.ini
@ -86,6 +86,7 @@ files =
|
||||
tests/handlers/test_sync.py,
|
||||
tests/rest/client/test_login.py,
|
||||
tests/rest/client/test_auth.py,
|
||||
tests/storage/test_state.py,
|
||||
tests/util/test_itertools.py,
|
||||
tests/util/test_stream_change_cache.py
|
||||
|
||||
|
@ -25,12 +25,15 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases import Databases
|
||||
|
||||
@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class StateFilter:
|
||||
"""A filter used when querying for state.
|
||||
|
||||
@ -53,14 +56,19 @@ class StateFilter:
|
||||
appear in `types`.
|
||||
"""
|
||||
|
||||
types = attr.ib(type=Dict[str, Optional[Set[str]]])
|
||||
types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
|
||||
include_others = attr.ib(default=False, type=bool)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
# If `include_others` is set we canonicalise the filter by removing
|
||||
# wildcards from the types dictionary
|
||||
if self.include_others:
|
||||
self.types = {k: v for k, v in self.types.items() if v is not None}
|
||||
# this is needed to work around the fact that StateFilter is frozen
|
||||
object.__setattr__(
|
||||
self,
|
||||
"types",
|
||||
frozendict({k: v for k, v in self.types.items() if v is not None}),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def all() -> "StateFilter":
|
||||
@ -69,7 +77,7 @@ class StateFilter:
|
||||
Returns:
|
||||
The new state filter.
|
||||
"""
|
||||
return StateFilter(types={}, include_others=True)
|
||||
return StateFilter(types=frozendict(), include_others=True)
|
||||
|
||||
@staticmethod
|
||||
def none() -> "StateFilter":
|
||||
@ -78,7 +86,7 @@ class StateFilter:
|
||||
Returns:
|
||||
The new state filter.
|
||||
"""
|
||||
return StateFilter(types={}, include_others=False)
|
||||
return StateFilter(types=frozendict(), include_others=False)
|
||||
|
||||
@staticmethod
|
||||
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
|
||||
@ -103,7 +111,12 @@ class StateFilter:
|
||||
|
||||
type_dict.setdefault(typ, set()).add(s) # type: ignore
|
||||
|
||||
return StateFilter(types=type_dict)
|
||||
return StateFilter(
|
||||
types=frozendict(
|
||||
(k, frozenset(v) if v is not None else None)
|
||||
for k, v in type_dict.items()
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
||||
@ -116,7 +129,10 @@ class StateFilter:
|
||||
Returns:
|
||||
The new state filter
|
||||
"""
|
||||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
|
||||
return StateFilter(
|
||||
types=frozendict({EventTypes.Member: frozenset(members)}),
|
||||
include_others=True,
|
||||
)
|
||||
|
||||
def return_expanded(self) -> "StateFilter":
|
||||
"""Creates a new StateFilter where type wild cards have been removed
|
||||
@ -173,7 +189,7 @@ class StateFilter:
|
||||
# We want to return all non-members, but only particular
|
||||
# memberships
|
||||
return StateFilter(
|
||||
types={EventTypes.Member: self.types[EventTypes.Member]},
|
||||
types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
|
||||
include_others=True,
|
||||
)
|
||||
|
||||
@ -245,14 +261,15 @@ class StateFilter:
|
||||
|
||||
return len(self.concrete_types())
|
||||
|
||||
def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
|
||||
"""Returns the state filtered with by this StateFilter
|
||||
def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
|
||||
"""Returns the state filtered with by this StateFilter.
|
||||
|
||||
Args:
|
||||
state: The state map to filter
|
||||
|
||||
Returns:
|
||||
The filtered state map
|
||||
The filtered state map.
|
||||
This is a copy, so it's safe to mutate.
|
||||
"""
|
||||
if self.is_full():
|
||||
return dict(state_dict)
|
||||
@ -324,14 +341,16 @@ class StateFilter:
|
||||
if state_keys is None:
|
||||
member_filter = StateFilter.all()
|
||||
else:
|
||||
member_filter = StateFilter({EventTypes.Member: state_keys})
|
||||
member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
|
||||
elif self.include_others:
|
||||
member_filter = StateFilter.all()
|
||||
else:
|
||||
member_filter = StateFilter.none()
|
||||
|
||||
non_member_filter = StateFilter(
|
||||
types={k: v for k, v in self.types.items() if k != EventTypes.Member},
|
||||
types=frozendict(
|
||||
{k: v for k, v in self.types.items() if k != EventTypes.Member}
|
||||
),
|
||||
include_others=self.include_others,
|
||||
)
|
||||
|
||||
|
@ -14,6 +14,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.storage.state import StateFilter
|
||||
@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.storage.state.get_state_for_event(
|
||||
e5.event_id,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {self.u_alice.to_string()}},
|
||||
types=frozendict(
|
||||
{EventTypes.Member: frozenset({self.u_alice.to_string()})}
|
||||
),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.storage.state.get_state_for_event(
|
||||
e5.event_id,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset()}),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
types=frozendict({EventTypes.Member: None}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
types=frozendict({EventTypes.Member: None}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: set()}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset()}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
types=frozendict({EventTypes.Member: None}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: None}, include_others=True
|
||||
types=frozendict({EventTypes.Member: None}), include_others=True
|
||||
),
|
||||
)
|
||||
|
||||
@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=True
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
self.state_datastore._state_group_members_cache,
|
||||
group,
|
||||
state_filter=StateFilter(
|
||||
types={EventTypes.Member: {e5.state_key}}, include_others=False
|
||||
types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user