mirror of
https://mau.dev/maunium/synapse.git
synced 2024-10-01 01:36:05 -04:00
Add type hints to the push module. (#8901)
This commit is contained in:
parent
a8eceb01e5
commit
5d34f40d49
1
changelog.d/8901.misc
Normal file
1
changelog.d/8901.misc
Normal file
@ -0,0 +1 @@
|
||||
Add type hints to push module.
|
7
mypy.ini
7
mypy.ini
@ -56,12 +56,7 @@ files =
|
||||
synapse/metrics,
|
||||
synapse/module_api,
|
||||
synapse/notifier.py,
|
||||
synapse/push/emailpusher.py,
|
||||
synapse/push/httppusher.py,
|
||||
synapse/push/mailer.py,
|
||||
synapse/push/pusher.py,
|
||||
synapse/push/pusherpool.py,
|
||||
synapse/push/push_rule_evaluator.py,
|
||||
synapse/push,
|
||||
synapse/replication,
|
||||
synapse/rest,
|
||||
synapse/server.py,
|
||||
|
@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
|
||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||
if fullname.startswith(
|
||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||
) or fullname.startswith(
|
||||
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
||||
):
|
||||
return cached_function_method_signature
|
||||
return None
|
||||
|
@ -14,19 +14,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionGenerator:
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
|
||||
# really we want to get all user ids and all profile tags too,
|
||||
# since we want the actions for each profile tag for every user and
|
||||
@ -35,6 +38,8 @@ class ActionGenerator:
|
||||
# event stream, so we just run the rules for a client with no profile
|
||||
# tag (ie. we just need all the users).
|
||||
|
||||
async def handle_push_actions_for_event(self, event, context):
|
||||
async def handle_push_actions_for_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
await self.bulk_evaluator.action_for_event_by_user(event, context)
|
||||
|
@ -15,16 +15,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
||||
|
||||
|
||||
def list_with_base_rules(rawrules, use_new_defaults=False):
|
||||
def list_with_base_rules(
|
||||
rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Combine the list of rules set by the user with the default push rules
|
||||
|
||||
Args:
|
||||
rawrules(list): The rules the user has modified or set.
|
||||
use_new_defaults(bool): Whether to use the new experimental default rules when
|
||||
rawrules: The rules the user has modified or set.
|
||||
use_new_defaults: Whether to use the new experimental default rules when
|
||||
appending or prepending default rules.
|
||||
|
||||
Returns:
|
||||
@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
|
||||
return ruleslist
|
||||
|
||||
|
||||
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
def make_base_append_rules(
|
||||
kind: str,
|
||||
modified_base_rules: Dict[str, Dict[str, Any]],
|
||||
use_new_defaults: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rules = []
|
||||
|
||||
if kind == "override":
|
||||
@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
rules = copy.deepcopy(rules)
|
||||
for r in rules:
|
||||
# Only modify the actions, keep the conditions the same.
|
||||
assert isinstance(r["rule_id"], str)
|
||||
modified = modified_base_rules.get(r["rule_id"])
|
||||
if modified:
|
||||
r["actions"] = modified["actions"]
|
||||
@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
return rules
|
||||
|
||||
|
||||
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
def make_base_prepend_rules(
|
||||
kind: str,
|
||||
modified_base_rules: Dict[str, Dict[str, Any]],
|
||||
use_new_defaults: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
rules = []
|
||||
|
||||
if kind == "override":
|
||||
@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
|
||||
rules = copy.deepcopy(rules)
|
||||
for r in rules:
|
||||
# Only modify the actions, keep the conditions the same.
|
||||
assert isinstance(r["rule_id"], str)
|
||||
modified = modified_base_rules.get(r["rule_id"])
|
||||
if modified:
|
||||
r["actions"] = modified["actions"]
|
||||
|
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
@ -25,18 +26,18 @@ from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import POWER_KEY
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches import CacheMetric, register_cache
|
||||
from synapse.util.caches.descriptors import lru_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
rules_by_room = {}
|
||||
|
||||
|
||||
push_rules_invalidation_counter = Counter(
|
||||
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
|
||||
)
|
||||
@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
|
||||
room at once.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
|
||||
resizable=False,
|
||||
)
|
||||
|
||||
async def _get_rules_for_event(self, event, context):
|
||||
async def _get_rules_for_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""This gets the rules for all users in the room at the time of the event,
|
||||
as well as the push rules for the invitee if the event is an invite.
|
||||
|
||||
@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
|
||||
return rules_by_user
|
||||
|
||||
@lru_cache()
|
||||
def _get_rules_for_room(self, room_id):
|
||||
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
||||
"""Get the current RulesForRoom object for the given room id
|
||||
|
||||
Returns:
|
||||
RulesForRoom
|
||||
"""
|
||||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
||||
# before any lookup methods get called on it as otherwise there may be
|
||||
@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
|
||||
self.room_push_rule_cache_metrics,
|
||||
)
|
||||
|
||||
async def _get_power_levels_and_sender_level(self, event, context):
|
||||
async def _get_power_levels_and_sender_level(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Tuple[dict, int]:
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
pl_event = await self.store.get_event(pl_event_id)
|
||||
auth_events = {POWER_KEY: pl_event}
|
||||
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
|
||||
else:
|
||||
auth_events_ids = self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=False
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_dict = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
|
||||
|
||||
sender_level = get_user_power_level(event.sender, auth_events)
|
||||
|
||||
@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
|
||||
|
||||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
async def action_for_event_by_user(self, event, context) -> None:
|
||||
async def action_for_event_by_user(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
"""Given an event and context, evaluate the push rules, check if the message
|
||||
should increment the unread count, and insert the results into the
|
||||
event_push_actions_staging table.
|
||||
@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
|
||||
count_as_unread = _should_count_as_unread(event, context)
|
||||
|
||||
rules_by_user = await self._get_rules_for_event(event, context)
|
||||
actions_by_user = {}
|
||||
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
|
||||
|
||||
room_members = await self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
|
||||
event, len(room_members), sender_power_level, power_levels
|
||||
)
|
||||
|
||||
condition_cache = {}
|
||||
condition_cache = {} # type: Dict[str, bool]
|
||||
|
||||
for uid, rules in rules_by_user.items():
|
||||
if event.sender == uid:
|
||||
@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||
def _condition_checker(
|
||||
evaluator: PushRuleEvaluatorForEvent,
|
||||
conditions: List[dict],
|
||||
uid: str,
|
||||
display_name: str,
|
||||
cache: Dict[str, bool],
|
||||
) -> bool:
|
||||
for cond in conditions:
|
||||
_id = cond.get("_id", None)
|
||||
if _id:
|
||||
@ -277,15 +286,19 @@ class RulesForRoom:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
room_id: str,
|
||||
rules_for_room_cache: LruCache,
|
||||
room_push_rule_cache_metrics: CacheMetric,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
room_id (str)
|
||||
hs: The HomeServer object.
|
||||
room_id: The room ID.
|
||||
rules_for_room_cache: The cache object that caches these
|
||||
RoomsForUser objects.
|
||||
room_push_rule_cache_metrics (CacheMetric)
|
||||
room_push_rule_cache_metrics: The metrics object
|
||||
"""
|
||||
self.room_id = room_id
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@ -294,8 +307,10 @@ class RulesForRoom:
|
||||
|
||||
self.linearizer = Linearizer(name="rules_for_room")
|
||||
|
||||
self.member_map = {} # event_id -> (user_id, state)
|
||||
self.rules_by_user = {} # user_id -> rules
|
||||
# event_id -> (user_id, state)
|
||||
self.member_map = {} # type: Dict[str, Tuple[str, str]]
|
||||
# user_id -> rules
|
||||
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
|
||||
|
||||
# The last state group we updated the caches for. If the state_group of
|
||||
# a new event comes along, we know that we can just return the cached
|
||||
@ -315,7 +330,7 @@ class RulesForRoom:
|
||||
# calculate push for)
|
||||
# These never need to be invalidated as we will never set up push for
|
||||
# them.
|
||||
self.uninteresting_user_set = set()
|
||||
self.uninteresting_user_set = set() # type: Set[str]
|
||||
|
||||
# We need to be clever on the invalidating caches callbacks, as
|
||||
# otherwise the invalidation callback holds a reference to the object,
|
||||
@ -325,7 +340,9 @@ class RulesForRoom:
|
||||
# to self around in the callback.
|
||||
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
||||
|
||||
async def get_rules(self, event, context):
|
||||
async def get_rules(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, List[Dict[str, dict]]]:
|
||||
"""Given an event context return the rules for all users who are
|
||||
currently in the room.
|
||||
"""
|
||||
@ -356,6 +373,8 @@ class RulesForRoom:
|
||||
else:
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
push_rules_delta_state_cache_metric.inc_misses()
|
||||
# Ensure the state IDs exist.
|
||||
assert current_state_ids is not None
|
||||
|
||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||
|
||||
@ -420,18 +439,23 @@ class RulesForRoom:
|
||||
return ret_rules_by_user
|
||||
|
||||
async def _update_rules_with_member_event_ids(
|
||||
self, ret_rules_by_user, member_event_ids, state_group, event
|
||||
):
|
||||
self,
|
||||
ret_rules_by_user: Dict[str, list],
|
||||
member_event_ids: Dict[str, str],
|
||||
state_group: Optional[int],
|
||||
event: EventBase,
|
||||
) -> None:
|
||||
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||
any newly joined users in the `member_event_ids` list.
|
||||
|
||||
Args:
|
||||
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
|
||||
ret_rules_by_user: Partially filled dict of push rules. Gets
|
||||
updated with any new rules.
|
||||
member_event_ids (dict): Dict of user id to event id for membership events
|
||||
member_event_ids: Dict of user id to event id for membership events
|
||||
that have happened since the last time we filled rules_by_user
|
||||
state_group: The state group we are currently computing push rules
|
||||
for. Used when updating the cache.
|
||||
event: The event we are currently computing push rules for.
|
||||
"""
|
||||
sequence = self.sequence
|
||||
|
||||
@ -449,19 +473,19 @@ class RulesForRoom:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||
|
||||
user_ids = {
|
||||
joined_user_ids = {
|
||||
user_id
|
||||
for user_id, membership in members.values()
|
||||
if membership == Membership.JOIN
|
||||
}
|
||||
|
||||
logger.debug("Joined: %r", user_ids)
|
||||
logger.debug("Joined: %r", joined_user_ids)
|
||||
|
||||
# Previously we only considered users with pushers or read receipts in that
|
||||
# room. We can't do this anymore because we use push actions to calculate unread
|
||||
# counts, which don't rely on the user having pushers or sent a read receipt into
|
||||
# the room. Therefore we just need to filter for local users here.
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
user_ids = list(filter(self.is_mine_id, joined_user_ids))
|
||||
|
||||
rules_by_user = await self.store.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=self.invalidate_all_cb
|
||||
@ -473,7 +497,7 @@ class RulesForRoom:
|
||||
|
||||
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||
|
||||
def invalidate_all(self):
|
||||
def invalidate_all(self) -> None:
|
||||
# Note: Don't hand this function directly to an invalidation callback
|
||||
# as it keeps a reference to self and will stop this instance from being
|
||||
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
||||
@ -485,7 +509,7 @@ class RulesForRoom:
|
||||
self.rules_by_user = {}
|
||||
push_rules_invalidation_counter.inc()
|
||||
|
||||
def update_cache(self, sequence, members, rules_by_user, state_group):
|
||||
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
|
||||
if sequence == self.sequence:
|
||||
self.member_map.update(members)
|
||||
self.rules_by_user = rules_by_user
|
||||
@ -506,7 +530,7 @@ class _Invalidation:
|
||||
cache = attr.ib(type=LruCache)
|
||||
room_id = attr.ib(type=str)
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> None:
|
||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||
if rules:
|
||||
rules.invalidate_all()
|
||||
|
@ -14,24 +14,27 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
def format_push_rules_for_user(user, ruleslist):
|
||||
def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
|
||||
"""Converts a list of rawrules and a enabled map into nested dictionaries
|
||||
to match the Matrix client-server format for push rules"""
|
||||
|
||||
# We're going to be mutating this a lot, so do a deep copy
|
||||
ruleslist = copy.deepcopy(ruleslist)
|
||||
|
||||
rules = {"global": {}, "device": {}}
|
||||
rules = {
|
||||
"global": {},
|
||||
"device": {},
|
||||
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
|
||||
|
||||
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
|
||||
|
||||
for r in ruleslist:
|
||||
rulearray = None
|
||||
|
||||
template_name = _priority_class_to_template_name(r["priority_class"])
|
||||
|
||||
# Remove internal stuff.
|
||||
@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
|
||||
return rules
|
||||
|
||||
|
||||
def _add_empty_priority_class_arrays(d):
|
||||
def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
|
||||
for pc in PRIORITY_CLASS_MAP.keys():
|
||||
d[pc] = []
|
||||
return d
|
||||
|
||||
|
||||
def _rule_to_template(rule):
|
||||
def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
unscoped_rule_id = None
|
||||
if "rule_id" in rule:
|
||||
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
|
||||
@ -82,6 +85,10 @@ def _rule_to_template(rule):
|
||||
return None
|
||||
templaterule = {"actions": rule["actions"]}
|
||||
templaterule["pattern"] = thecond["pattern"]
|
||||
else:
|
||||
# This should not be reached unless this function is not kept in sync
|
||||
# with PRIORITY_CLASS_INVERSE_MAP.
|
||||
raise ValueError("Unexpected template_name: %s" % (template_name,))
|
||||
|
||||
if unscoped_rule_id:
|
||||
templaterule["rule_id"] = unscoped_rule_id
|
||||
@ -90,9 +97,9 @@ def _rule_to_template(rule):
|
||||
return templaterule
|
||||
|
||||
|
||||
def _rule_id_from_namespaced(in_rule_id):
|
||||
def _rule_id_from_namespaced(in_rule_id: str) -> str:
|
||||
return in_rule_id.split("/")[-1]
|
||||
|
||||
|
||||
def _priority_class_to_template_name(pc):
|
||||
def _priority_class_to_template_name(pc: int) -> str:
|
||||
return PRIORITY_CLASS_INVERSE_MAP[pc]
|
||||
|
@ -15,8 +15,14 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
|
||||
|
||||
|
||||
async def calculate_room_name(
|
||||
store,
|
||||
room_state_ids,
|
||||
user_id,
|
||||
fallback_to_members=True,
|
||||
fallback_to_single_member=True,
|
||||
):
|
||||
store: "DataStore",
|
||||
room_state_ids: StateMap[str],
|
||||
user_id: str,
|
||||
fallback_to_members: bool = True,
|
||||
fallback_to_single_member: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Works out a user-facing name for the given room as per Matrix
|
||||
spec recommendations.
|
||||
Does not yet support internationalisation.
|
||||
Args:
|
||||
room_state: Dictionary of the room's state
|
||||
store: The data store to query.
|
||||
room_state_ids: Dictionary of the room's state IDs.
|
||||
user_id: The ID of the user to whom the room name is being presented
|
||||
fallback_to_members: If False, return None instead of generating a name
|
||||
based on the room's members if the room has no
|
||||
title or aliases.
|
||||
fallback_to_single_member: If False, return None instead of generating a
|
||||
name based on the user who invited this user to the room if the room
|
||||
has no title or aliases.
|
||||
|
||||
Returns:
|
||||
(string or None) A human readable name for the room.
|
||||
A human readable name for the room, if possible.
|
||||
"""
|
||||
# does it have a name?
|
||||
if (EventTypes.Name, "") in room_state_ids:
|
||||
@ -97,7 +107,7 @@ async def calculate_room_name(
|
||||
name_from_member_event(inviter_member_event),
|
||||
)
|
||||
else:
|
||||
return
|
||||
return None
|
||||
else:
|
||||
return "Room Invite"
|
||||
|
||||
@ -150,19 +160,19 @@ async def calculate_room_name(
|
||||
else:
|
||||
return ALL_ALONE
|
||||
elif len(other_members) == 1 and not fallback_to_single_member:
|
||||
return
|
||||
else:
|
||||
return descriptor_from_member_events(other_members)
|
||||
return None
|
||||
|
||||
return descriptor_from_member_events(other_members)
|
||||
|
||||
|
||||
def descriptor_from_member_events(member_events):
|
||||
def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
|
||||
"""Get a description of the room based on the member events.
|
||||
|
||||
Args:
|
||||
member_events (Iterable[FrozenEvent])
|
||||
member_events: The events of a room.
|
||||
|
||||
Returns:
|
||||
str
|
||||
The room description
|
||||
"""
|
||||
|
||||
member_events = list(member_events)
|
||||
@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
|
||||
)
|
||||
|
||||
|
||||
def name_from_member_event(member_event):
|
||||
def name_from_member_event(member_event: EventBase) -> str:
|
||||
if (
|
||||
member_event.content
|
||||
and "displayname" in member_event.content
|
||||
@ -193,12 +203,12 @@ def name_from_member_event(member_event):
|
||||
return member_event.state_key
|
||||
|
||||
|
||||
def _state_as_two_level_dict(state):
|
||||
ret = {}
|
||||
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
|
||||
ret = {} # type: Dict[str, Dict[str, str]]
|
||||
for k, v in state.items():
|
||||
ret.setdefault(k[0], {})[k[1]] = v
|
||||
return ret
|
||||
|
||||
|
||||
def _looks_like_an_alias(string):
|
||||
def _looks_like_an_alias(string: str) -> bool:
|
||||
return ALIAS_RE.match(string) is not None
|
||||
|
@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
|
||||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
|
||||
def _room_member_count(ev, condition, room_member_count):
|
||||
def _room_member_count(
|
||||
ev: EventBase, condition: Dict[str, Any], room_member_count: int
|
||||
) -> bool:
|
||||
return _test_ineq_condition(condition, room_member_count)
|
||||
|
||||
|
||||
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
|
||||
def _sender_notification_permission(
|
||||
ev: EventBase,
|
||||
condition: Dict[str, Any],
|
||||
sender_power_level: int,
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||
) -> bool:
|
||||
notif_level_key = condition.get("key")
|
||||
if notif_level_key is None:
|
||||
return False
|
||||
|
||||
notif_levels = power_levels.get("notifications", {})
|
||||
assert isinstance(notif_levels, dict)
|
||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||
|
||||
return sender_power_level >= room_notif_level
|
||||
|
||||
|
||||
def _test_ineq_condition(condition, number):
|
||||
def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
|
||||
if "is" not in condition:
|
||||
return False
|
||||
m = INEQUALITY_EXPR.match(condition["is"])
|
||||
@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
|
||||
event: EventBase,
|
||||
room_member_count: int,
|
||||
sender_power_level: int,
|
||||
power_levels: dict,
|
||||
power_levels: Dict[str, Union[int, Dict[str, int]]],
|
||||
):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
|
||||
def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
|
||||
def matches(
|
||||
self, condition: Dict[str, Any], user_id: str, display_name: str
|
||||
) -> bool:
|
||||
if condition["kind"] == "event_match":
|
||||
return self._event_match(condition, user_id)
|
||||
elif condition["kind"] == "contains_display_name":
|
||||
@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
|
||||
return r"(^|\W)%s(\W|$)" % (r,)
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result=None):
|
||||
def _flatten_dict(
|
||||
d: Union[EventBase, dict],
|
||||
prefix: Optional[List[str]] = None,
|
||||
result: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, str]:
|
||||
if prefix is None:
|
||||
prefix = []
|
||||
if result is None:
|
||||
result = {}
|
||||
for key, value in d.items():
|
||||
|
Loading…
Reference in New Issue
Block a user