mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-11 04:14:19 -05:00
Add type hints to filtering classes. (#10958)
This commit is contained in:
parent
9e5a429c8b
commit
7e440520c9
1
changelog.d/10958.misc
Normal file
1
changelog.d/10958.misc
Normal file
@ -0,0 +1 @@
|
|||||||
|
Add type hints to filtering classes.
|
@ -15,7 +15,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import json
|
import json
|
||||||
from typing import List
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Awaitable,
|
||||||
|
Container,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import jsonschema
|
import jsonschema
|
||||||
from jsonschema import FormatChecker
|
from jsonschema import FormatChecker
|
||||||
@ -23,7 +33,11 @@ from jsonschema import FormatChecker
|
|||||||
from synapse.api.constants import EventContentFields
|
from synapse.api.constants import EventContentFields
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.api.presence import UserPresenceState
|
from synapse.api.presence import UserPresenceState
|
||||||
from synapse.types import RoomID, UserID
|
from synapse.events import EventBase
|
||||||
|
from synapse.types import JsonDict, RoomID, UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
FILTER_SCHEMA = {
|
FILTER_SCHEMA = {
|
||||||
"additionalProperties": False,
|
"additionalProperties": False,
|
||||||
@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
|
|||||||
|
|
||||||
|
|
||||||
@FormatChecker.cls_checks("matrix_room_id")
|
@FormatChecker.cls_checks("matrix_room_id")
|
||||||
def matrix_room_id_validator(room_id_str):
|
def matrix_room_id_validator(room_id_str: str) -> RoomID:
|
||||||
return RoomID.from_string(room_id_str)
|
return RoomID.from_string(room_id_str)
|
||||||
|
|
||||||
|
|
||||||
@FormatChecker.cls_checks("matrix_user_id")
|
@FormatChecker.cls_checks("matrix_user_id")
|
||||||
def matrix_user_id_validator(user_id_str):
|
def matrix_user_id_validator(user_id_str: str) -> UserID:
|
||||||
return UserID.from_string(user_id_str)
|
return UserID.from_string(user_id_str)
|
||||||
|
|
||||||
|
|
||||||
class Filtering:
|
class Filtering:
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
async def get_user_filter(self, user_localpart, filter_id):
|
async def get_user_filter(
|
||||||
|
self, user_localpart: str, filter_id: Union[int, str]
|
||||||
|
) -> "FilterCollection":
|
||||||
result = await self.store.get_user_filter(user_localpart, filter_id)
|
result = await self.store.get_user_filter(user_localpart, filter_id)
|
||||||
return FilterCollection(result)
|
return FilterCollection(result)
|
||||||
|
|
||||||
def add_user_filter(self, user_localpart, user_filter):
|
def add_user_filter(
|
||||||
|
self, user_localpart: str, user_filter: JsonDict
|
||||||
|
) -> Awaitable[int]:
|
||||||
self.check_valid_filter(user_filter)
|
self.check_valid_filter(user_filter)
|
||||||
return self.store.add_user_filter(user_localpart, user_filter)
|
return self.store.add_user_filter(user_localpart, user_filter)
|
||||||
|
|
||||||
@ -146,13 +164,13 @@ class Filtering:
|
|||||||
# replace_user_filter at some point? There's no REST API specified for
|
# replace_user_filter at some point? There's no REST API specified for
|
||||||
# them however
|
# them however
|
||||||
|
|
||||||
def check_valid_filter(self, user_filter_json):
|
def check_valid_filter(self, user_filter_json: JsonDict) -> None:
|
||||||
"""Check if the provided filter is valid.
|
"""Check if the provided filter is valid.
|
||||||
|
|
||||||
This inspects all definitions contained within the filter.
|
This inspects all definitions contained within the filter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_filter_json(dict): The filter
|
user_filter_json: The filter
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError: If the filter is not valid.
|
SynapseError: If the filter is not valid.
|
||||||
"""
|
"""
|
||||||
@ -167,8 +185,12 @@ class Filtering:
|
|||||||
raise SynapseError(400, str(e))
|
raise SynapseError(400, str(e))
|
||||||
|
|
||||||
|
|
||||||
|
# Filters work across events, presence EDUs, and account data.
|
||||||
|
FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
|
||||||
|
|
||||||
|
|
||||||
class FilterCollection:
|
class FilterCollection:
|
||||||
def __init__(self, filter_json):
|
def __init__(self, filter_json: JsonDict):
|
||||||
self._filter_json = filter_json
|
self._filter_json = filter_json
|
||||||
|
|
||||||
room_filter_json = self._filter_json.get("room", {})
|
room_filter_json = self._filter_json.get("room", {})
|
||||||
@ -188,25 +210,25 @@ class FilterCollection:
|
|||||||
self.event_fields = filter_json.get("event_fields", [])
|
self.event_fields = filter_json.get("event_fields", [])
|
||||||
self.event_format = filter_json.get("event_format", "client")
|
self.event_format = filter_json.get("event_format", "client")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
||||||
|
|
||||||
def get_filter_json(self):
|
def get_filter_json(self) -> JsonDict:
|
||||||
return self._filter_json
|
return self._filter_json
|
||||||
|
|
||||||
def timeline_limit(self):
|
def timeline_limit(self) -> int:
|
||||||
return self._room_timeline_filter.limit()
|
return self._room_timeline_filter.limit()
|
||||||
|
|
||||||
def presence_limit(self):
|
def presence_limit(self) -> int:
|
||||||
return self._presence_filter.limit()
|
return self._presence_filter.limit()
|
||||||
|
|
||||||
def ephemeral_limit(self):
|
def ephemeral_limit(self) -> int:
|
||||||
return self._room_ephemeral_filter.limit()
|
return self._room_ephemeral_filter.limit()
|
||||||
|
|
||||||
def lazy_load_members(self):
|
def lazy_load_members(self) -> bool:
|
||||||
return self._room_state_filter.lazy_load_members()
|
return self._room_state_filter.lazy_load_members()
|
||||||
|
|
||||||
def include_redundant_members(self):
|
def include_redundant_members(self) -> bool:
|
||||||
return self._room_state_filter.include_redundant_members()
|
return self._room_state_filter.include_redundant_members()
|
||||||
|
|
||||||
def filter_presence(self, events):
|
def filter_presence(self, events):
|
||||||
@ -218,29 +240,31 @@ class FilterCollection:
|
|||||||
def filter_room_state(self, events):
|
def filter_room_state(self, events):
|
||||||
return self._room_state_filter.filter(self._room_filter.filter(events))
|
return self._room_state_filter.filter(self._room_filter.filter(events))
|
||||||
|
|
||||||
def filter_room_timeline(self, events):
|
def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||||
return self._room_timeline_filter.filter(self._room_filter.filter(events))
|
return self._room_timeline_filter.filter(self._room_filter.filter(events))
|
||||||
|
|
||||||
def filter_room_ephemeral(self, events):
|
def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||||
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
|
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
|
||||||
|
|
||||||
def filter_room_account_data(self, events):
|
def filter_room_account_data(
|
||||||
|
self, events: Iterable[FilterEvent]
|
||||||
|
) -> List[FilterEvent]:
|
||||||
return self._room_account_data.filter(self._room_filter.filter(events))
|
return self._room_account_data.filter(self._room_filter.filter(events))
|
||||||
|
|
||||||
def blocks_all_presence(self):
|
def blocks_all_presence(self) -> bool:
|
||||||
return (
|
return (
|
||||||
self._presence_filter.filters_all_types()
|
self._presence_filter.filters_all_types()
|
||||||
or self._presence_filter.filters_all_senders()
|
or self._presence_filter.filters_all_senders()
|
||||||
)
|
)
|
||||||
|
|
||||||
def blocks_all_room_ephemeral(self):
|
def blocks_all_room_ephemeral(self) -> bool:
|
||||||
return (
|
return (
|
||||||
self._room_ephemeral_filter.filters_all_types()
|
self._room_ephemeral_filter.filters_all_types()
|
||||||
or self._room_ephemeral_filter.filters_all_senders()
|
or self._room_ephemeral_filter.filters_all_senders()
|
||||||
or self._room_ephemeral_filter.filters_all_rooms()
|
or self._room_ephemeral_filter.filters_all_rooms()
|
||||||
)
|
)
|
||||||
|
|
||||||
def blocks_all_room_timeline(self):
|
def blocks_all_room_timeline(self) -> bool:
|
||||||
return (
|
return (
|
||||||
self._room_timeline_filter.filters_all_types()
|
self._room_timeline_filter.filters_all_types()
|
||||||
or self._room_timeline_filter.filters_all_senders()
|
or self._room_timeline_filter.filters_all_senders()
|
||||||
@ -249,7 +273,7 @@ class FilterCollection:
|
|||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
def __init__(self, filter_json):
|
def __init__(self, filter_json: JsonDict):
|
||||||
self.filter_json = filter_json
|
self.filter_json = filter_json
|
||||||
|
|
||||||
self.types = self.filter_json.get("types", None)
|
self.types = self.filter_json.get("types", None)
|
||||||
@ -266,20 +290,20 @@ class Filter:
|
|||||||
self.labels = self.filter_json.get("org.matrix.labels", None)
|
self.labels = self.filter_json.get("org.matrix.labels", None)
|
||||||
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
||||||
|
|
||||||
def filters_all_types(self):
|
def filters_all_types(self) -> bool:
|
||||||
return "*" in self.not_types
|
return "*" in self.not_types
|
||||||
|
|
||||||
def filters_all_senders(self):
|
def filters_all_senders(self) -> bool:
|
||||||
return "*" in self.not_senders
|
return "*" in self.not_senders
|
||||||
|
|
||||||
def filters_all_rooms(self):
|
def filters_all_rooms(self) -> bool:
|
||||||
return "*" in self.not_rooms
|
return "*" in self.not_rooms
|
||||||
|
|
||||||
def check(self, event):
|
def check(self, event: FilterEvent) -> bool:
|
||||||
"""Checks whether the filter matches the given event.
|
"""Checks whether the filter matches the given event.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the event matches
|
True if the event matches
|
||||||
"""
|
"""
|
||||||
# We usually get the full "events" as dictionaries coming through,
|
# We usually get the full "events" as dictionaries coming through,
|
||||||
# except for presence which actually gets passed around as its own
|
# except for presence which actually gets passed around as its own
|
||||||
@ -305,18 +329,25 @@ class Filter:
|
|||||||
room_id = event.get("room_id", None)
|
room_id = event.get("room_id", None)
|
||||||
ev_type = event.get("type", None)
|
ev_type = event.get("type", None)
|
||||||
|
|
||||||
content = event.get("content", {})
|
content = event.get("content") or {}
|
||||||
# check if there is a string url field in the content for filtering purposes
|
# check if there is a string url field in the content for filtering purposes
|
||||||
contains_url = isinstance(content.get("url"), str)
|
contains_url = isinstance(content.get("url"), str)
|
||||||
labels = content.get(EventContentFields.LABELS, [])
|
labels = content.get(EventContentFields.LABELS, [])
|
||||||
|
|
||||||
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
||||||
|
|
||||||
def check_fields(self, room_id, sender, event_type, labels, contains_url):
|
def check_fields(
|
||||||
|
self,
|
||||||
|
room_id: Optional[str],
|
||||||
|
sender: Optional[str],
|
||||||
|
event_type: Optional[str],
|
||||||
|
labels: Container[str],
|
||||||
|
contains_url: bool,
|
||||||
|
) -> bool:
|
||||||
"""Checks whether the filter matches the given event fields.
|
"""Checks whether the filter matches the given event fields.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the event fields match
|
True if the event fields match
|
||||||
"""
|
"""
|
||||||
literal_keys = {
|
literal_keys = {
|
||||||
"rooms": lambda v: room_id == v,
|
"rooms": lambda v: room_id == v,
|
||||||
@ -343,14 +374,14 @@ class Filter:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def filter_rooms(self, room_ids):
|
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
|
||||||
"""Apply the 'rooms' filter to a given list of rooms.
|
"""Apply the 'rooms' filter to a given list of rooms.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_ids (list): A list of room_ids.
|
room_ids: A list of room_ids.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: A list of room_ids that match the filter
|
A list of room_ids that match the filter
|
||||||
"""
|
"""
|
||||||
room_ids = set(room_ids)
|
room_ids = set(room_ids)
|
||||||
|
|
||||||
@ -363,23 +394,23 @@ class Filter:
|
|||||||
|
|
||||||
return room_ids
|
return room_ids
|
||||||
|
|
||||||
def filter(self, events):
|
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
||||||
return list(filter(self.check, events))
|
return list(filter(self.check, events))
|
||||||
|
|
||||||
def limit(self):
|
def limit(self) -> int:
|
||||||
return self.filter_json.get("limit", 10)
|
return self.filter_json.get("limit", 10)
|
||||||
|
|
||||||
def lazy_load_members(self):
|
def lazy_load_members(self) -> bool:
|
||||||
return self.filter_json.get("lazy_load_members", False)
|
return self.filter_json.get("lazy_load_members", False)
|
||||||
|
|
||||||
def include_redundant_members(self):
|
def include_redundant_members(self) -> bool:
|
||||||
return self.filter_json.get("include_redundant_members", False)
|
return self.filter_json.get("include_redundant_members", False)
|
||||||
|
|
||||||
def with_room_ids(self, room_ids):
|
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
|
||||||
"""Returns a new filter with the given room IDs appended.
|
"""Returns a new filter with the given room IDs appended.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_ids (iterable[unicode]): The room_ids to add
|
room_ids: The room_ids to add
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
filter: A new filter including the given rooms and the old
|
filter: A new filter including the given rooms and the old
|
||||||
@ -390,8 +421,8 @@ class Filter:
|
|||||||
return newFilter
|
return newFilter
|
||||||
|
|
||||||
|
|
||||||
def _matches_wildcard(actual_value, filter_value):
|
def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
|
||||||
if filter_value.endswith("*"):
|
if filter_value.endswith("*") and isinstance(actual_value, str):
|
||||||
type_prefix = filter_value[:-1]
|
type_prefix = filter_value[:-1]
|
||||||
return actual_value.startswith(type_prefix)
|
return actual_value.startswith(type_prefix)
|
||||||
else:
|
else:
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
|
||||||
from synapse.api.errors import Codes, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
|
|||||||
|
|
||||||
class FilteringStore(SQLBaseStore):
|
class FilteringStore(SQLBaseStore):
|
||||||
@cached(num_args=2)
|
@cached(num_args=2)
|
||||||
async def get_user_filter(self, user_localpart, filter_id):
|
async def get_user_filter(
|
||||||
|
self, user_localpart: str, filter_id: Union[int, str]
|
||||||
|
) -> JsonDict:
|
||||||
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
|
||||||
# with a coherent error message rather than 500 M_UNKNOWN.
|
# with a coherent error message rather than 500 M_UNKNOWN.
|
||||||
try:
|
try:
|
||||||
@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
|
|||||||
|
|
||||||
return db_to_json(def_json)
|
return db_to_json(def_json)
|
||||||
|
|
||||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
|
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
|
||||||
def_json = encode_canonical_json(user_filter)
|
def_json = encode_canonical_json(user_filter)
|
||||||
|
|
||||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||||
|
Loading…
Reference in New Issue
Block a user