Add type hints to the push module. (#8901)

This commit is contained in:
Patrick Cloke 2020-12-11 11:43:53 -05:00 committed by GitHub
parent a8eceb01e5
commit 5d34f40d49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 159 additions and 86 deletions

1
changelog.d/8901.misc Normal file
View File

@ -0,0 +1 @@
Add type hints to push module.

View File

@ -56,12 +56,7 @@ files =
synapse/metrics, synapse/metrics,
synapse/module_api, synapse/module_api,
synapse/notifier.py, synapse/notifier.py,
synapse/push/emailpusher.py, synapse/push,
synapse/push/httppusher.py,
synapse/push/mailer.py,
synapse/push/pusher.py,
synapse/push/pusherpool.py,
synapse/push/push_rule_evaluator.py,
synapse/replication, synapse/replication,
synapse/rest, synapse/rest,
synapse/server.py, synapse/server.py,

View File

@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
) -> Optional[Callable[[MethodSigContext], CallableType]]: ) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith( if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__" "synapse.util.caches.descriptors._CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
): ):
return cached_function_method_signature return cached_function_method_signature
return None return None

View File

@ -14,19 +14,22 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ActionGenerator: class ActionGenerator:
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs) self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too, # really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and # since we want the actions for each profile tag for every user and
@ -35,6 +38,8 @@ class ActionGenerator:
# event stream, so we just run the rules for a client with no profile # event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
async def handle_push_actions_for_event(self, event, context): async def handle_push_actions_for_event(
self, event: EventBase, context: EventContext
) -> None:
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context) await self.bulk_evaluator.action_for_event_by_user(event, context)

View File

@ -15,16 +15,19 @@
# limitations under the License. # limitations under the License.
import copy import copy
from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
def list_with_base_rules(rawrules, use_new_defaults=False): def list_with_base_rules(
rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules """Combine the list of rules set by the user with the default push rules
Args: Args:
rawrules(list): The rules the user has modified or set. rawrules: The rules the user has modified or set.
use_new_defaults(bool): Whether to use the new experimental default rules when use_new_defaults: Whether to use the new experimental default rules when
appending or prepending default rules. appending or prepending default rules.
Returns: Returns:
@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
return ruleslist return ruleslist
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False): def make_base_append_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = [] rules = []
if kind == "override": if kind == "override":
@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules) rules = copy.deepcopy(rules)
for r in rules: for r in rules:
# Only modify the actions, keep the conditions the same. # Only modify the actions, keep the conditions the same.
assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"]) modified = modified_base_rules.get(r["rule_id"])
if modified: if modified:
r["actions"] = modified["actions"] r["actions"] = modified["actions"]
@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
return rules return rules
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False): def make_base_prepend_rules(
kind: str,
modified_base_rules: Dict[str, Dict[str, Any]],
use_new_defaults: bool = False,
) -> List[Dict[str, Any]]:
rules = [] rules = []
if kind == "override": if kind == "override":
@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules) rules = copy.deepcopy(rules)
for r in rules: for r in rules:
# Only modify the actions, keep the conditions the same. # Only modify the actions, keep the conditions the same.
assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"]) modified = modified_base_rules.get(r["rule_id"])
if modified: if modified:
r["actions"] = modified["actions"] r["actions"] = modified["actions"]

View File

@ -15,6 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -25,18 +26,18 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
rules_by_room = {}
push_rules_invalidation_counter = Counter( push_rules_invalidation_counter = Counter(
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
) )
@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once. room at once.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False, resizable=False,
) )
async def _get_rules_for_event(self, event, context): async def _get_rules_for_event(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event, """This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite. as well as the push rules for the invitee if the event is an invite.
@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user return rules_by_user
@lru_cache() @lru_cache()
def _get_rules_for_room(self, room_id): def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id """Get the current RulesForRoom object for the given room id
Returns:
RulesForRoom
""" """
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be # before any lookup methods get called on it as otherwise there may be
@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics, self.room_push_rule_cache_metrics,
) )
async def _get_power_levels_and_sender_level(self, event, context): async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY) pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id: if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and # fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case # not having a power level event is an extreme edge case
pl_event = await self.store.get_event(pl_event_id) auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
auth_events = {POWER_KEY: pl_event}
else: else:
auth_events_ids = self.auth.compute_auth_events( auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False event, prev_state_ids, for_verification=False
) )
auth_events = await self.store.get_events(auth_events_ids) auth_events_dict = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events) sender_level = get_user_power_level(event.sender, auth_events)
@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None: async def action_for_event_by_user(
self, event: EventBase, context: EventContext
) -> None:
"""Given an event and context, evaluate the push rules, check if the message """Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the should increment the unread count, and insert the results into the
event_push_actions_staging table. event_push_actions_staging table.
@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context) count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context) rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context) room_members = await self.store.get_joined_users_from_context(event, context)
@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels event, len(room_members), sender_power_level, power_levels
) )
condition_cache = {} condition_cache = {} # type: Dict[str, bool]
for uid, rules in rules_by_user.items(): for uid, rules in rules_by_user.items():
if event.sender == uid: if event.sender == uid:
@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
) )
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(
evaluator: PushRuleEvaluatorForEvent,
conditions: List[dict],
uid: str,
display_name: str,
cache: Dict[str, bool],
) -> bool:
for cond in conditions: for cond in conditions:
_id = cond.get("_id", None) _id = cond.get("_id", None)
if _id: if _id:
@ -277,15 +286,19 @@ class RulesForRoom:
""" """
def __init__( def __init__(
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics self,
hs: "HomeServer",
room_id: str,
rules_for_room_cache: LruCache,
room_push_rule_cache_metrics: CacheMetric,
): ):
""" """
Args: Args:
hs (HomeServer) hs: The HomeServer object.
room_id (str) room_id: The room ID.
rules_for_room_cache: The cache object that caches these rules_for_room_cache: The cache object that caches these
RoomsForUser objects. RoomsForUser objects.
room_push_rule_cache_metrics (CacheMetric) room_push_rule_cache_metrics: The metrics object
""" """
self.room_id = room_id self.room_id = room_id
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
@ -294,8 +307,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room") self.linearizer = Linearizer(name="rules_for_room")
self.member_map = {} # event_id -> (user_id, state) # event_id -> (user_id, state)
self.rules_by_user = {} # user_id -> rules self.member_map = {} # type: Dict[str, Tuple[str, str]]
# user_id -> rules
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of # The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached # a new event comes along, we know that we can just return the cached
@ -315,7 +330,7 @@ class RulesForRoom:
# calculate push for) # calculate push for)
# These never need to be invalidated as we will never set up push for # These never need to be invalidated as we will never set up push for
# them. # them.
self.uninteresting_user_set = set() self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as # We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object, # otherwise the invalidation callback holds a reference to the object,
@ -325,7 +340,9 @@ class RulesForRoom:
# to self around in the callback. # to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
async def get_rules(self, event, context): async def get_rules(
self, event: EventBase, context: EventContext
) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are """Given an event context return the rules for all users who are
currently in the room. currently in the room.
""" """
@ -356,6 +373,8 @@ class RulesForRoom:
else: else:
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses() push_rules_delta_state_cache_metric.inc_misses()
# Ensure the state IDs exist.
assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids)) push_rules_state_size_counter.inc(len(current_state_ids))
@ -420,18 +439,23 @@ class RulesForRoom:
return ret_rules_by_user return ret_rules_by_user
async def _update_rules_with_member_event_ids( async def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event self,
): ret_rules_by_user: Dict[str, list],
member_event_ids: Dict[str, str],
state_group: Optional[int],
event: EventBase,
) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for """Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list. any newly joined users in the `member_event_ids` list.
Args: Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules. updated with any new rules.
member_event_ids (dict): Dict of user id to event id for membership events member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules state_group: The state group we are currently computing push rules
for. Used when updating the cache. for. Used when updating the cache.
event: The event we are currently computing push rules for.
""" """
sequence = self.sequence sequence = self.sequence
@ -449,19 +473,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values()) logger.debug("Found members %r: %r", self.room_id, members.values())
user_ids = { joined_user_ids = {
user_id user_id
for user_id, membership in members.values() for user_id, membership in members.values()
if membership == Membership.JOIN if membership == Membership.JOIN
} }
logger.debug("Joined: %r", user_ids) logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that # Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread # room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into # counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here. # the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, user_ids)) user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules( rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb user_ids, on_invalidate=self.invalidate_all_cb
@ -473,7 +497,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group) self.update_cache(sequence, members, ret_rules_by_user, state_group)
def invalidate_all(self): def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback # Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being # as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use # GC'd if it gets dropped from the rules_to_user cache. Instead use
@ -485,7 +509,7 @@ class RulesForRoom:
self.rules_by_user = {} self.rules_by_user = {}
push_rules_invalidation_counter.inc() push_rules_invalidation_counter.inc()
def update_cache(self, sequence, members, rules_by_user, state_group): def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence: if sequence == self.sequence:
self.member_map.update(members) self.member_map.update(members)
self.rules_by_user = rules_by_user self.rules_by_user = rules_by_user
@ -506,7 +530,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache) cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str) room_id = attr.ib(type=str)
def __call__(self): def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False) rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules: if rules:
rules.invalidate_all() rules.invalidate_all()

View File

@ -14,24 +14,27 @@
# limitations under the License. # limitations under the License.
import copy import copy
from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
from synapse.types import UserID
def format_push_rules_for_user(user, ruleslist): def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries """Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules""" to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist) ruleslist = copy.deepcopy(ruleslist)
rules = {"global": {}, "device": {}} rules = {
"global": {},
"device": {},
} # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
rules["global"] = _add_empty_priority_class_arrays(rules["global"]) rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist: for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r["priority_class"]) template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff. # Remove internal stuff.
@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
return rules return rules
def _add_empty_priority_class_arrays(d): def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
for pc in PRIORITY_CLASS_MAP.keys(): for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = [] d[pc] = []
return d return d
def _rule_to_template(rule): def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unscoped_rule_id = None unscoped_rule_id = None
if "rule_id" in rule: if "rule_id" in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"]) unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@ -82,6 +85,10 @@ def _rule_to_template(rule):
return None return None
templaterule = {"actions": rule["actions"]} templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"] templaterule["pattern"] = thecond["pattern"]
else:
# This should not be reached unless this function is not kept in sync
# with PRIORITY_CLASS_INVERSE_MAP.
raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id: if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id templaterule["rule_id"] = unscoped_rule_id
@ -90,9 +97,9 @@ def _rule_to_template(rule):
return templaterule return templaterule
def _rule_id_from_namespaced(in_rule_id): def _rule_id_from_namespaced(in_rule_id: str) -> str:
return in_rule_id.split("/")[-1] return in_rule_id.split("/")[-1]
def _priority_class_to_template_name(pc): def _priority_class_to_template_name(pc: int) -> str:
return PRIORITY_CLASS_INVERSE_MAP[pc] return PRIORITY_CLASS_INVERSE_MAP[pc]

View File

@ -15,8 +15,14 @@
import logging import logging
import re import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
async def calculate_room_name( async def calculate_room_name(
store, store: "DataStore",
room_state_ids, room_state_ids: StateMap[str],
user_id, user_id: str,
fallback_to_members=True, fallback_to_members: bool = True,
fallback_to_single_member=True, fallback_to_single_member: bool = True,
): ) -> Optional[str]:
""" """
Works out a user-facing name for the given room as per Matrix Works out a user-facing name for the given room as per Matrix
spec recommendations. spec recommendations.
Does not yet support internationalisation. Does not yet support internationalisation.
Args: Args:
room_state: Dictionary of the room's state store: The data store to query.
room_state_ids: Dictionary of the room's state IDs.
user_id: The ID of the user to whom the room name is being presented user_id: The ID of the user to whom the room name is being presented
fallback_to_members: If False, return None instead of generating a name fallback_to_members: If False, return None instead of generating a name
based on the room's members if the room has no based on the room's members if the room has no
title or aliases. title or aliases.
fallback_to_single_member: If False, return None instead of generating a
name based on the user who invited this user to the room if the room
has no title or aliases.
Returns: Returns:
(string or None) A human readable name for the room. A human readable name for the room, if possible.
""" """
# does it have a name? # does it have a name?
if (EventTypes.Name, "") in room_state_ids: if (EventTypes.Name, "") in room_state_ids:
@ -97,7 +107,7 @@ async def calculate_room_name(
name_from_member_event(inviter_member_event), name_from_member_event(inviter_member_event),
) )
else: else:
return return None
else: else:
return "Room Invite" return "Room Invite"
@ -150,19 +160,19 @@ async def calculate_room_name(
else: else:
return ALL_ALONE return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member: elif len(other_members) == 1 and not fallback_to_single_member:
return return None
else:
return descriptor_from_member_events(other_members) return descriptor_from_member_events(other_members)
def descriptor_from_member_events(member_events): def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
"""Get a description of the room based on the member events. """Get a description of the room based on the member events.
Args: Args:
member_events (Iterable[FrozenEvent]) member_events: The events of a room.
Returns: Returns:
str The room description
""" """
member_events = list(member_events) member_events = list(member_events)
@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
) )
def name_from_member_event(member_event): def name_from_member_event(member_event: EventBase) -> str:
if ( if (
member_event.content member_event.content
and "displayname" in member_event.content and "displayname" in member_event.content
@ -193,12 +203,12 @@ def name_from_member_event(member_event):
return member_event.state_key return member_event.state_key
def _state_as_two_level_dict(state): def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
ret = {} ret = {} # type: Dict[str, Dict[str, str]]
for k, v in state.items(): for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v ret.setdefault(k[0], {})[k[1]] = v
return ret return ret
def _looks_like_an_alias(string): def _looks_like_an_alias(string: str) -> bool:
return ALIAS_RE.match(string) is not None return ALIAS_RE.match(string) is not None

View File

@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def _room_member_count(ev, condition, room_member_count): def _room_member_count(
ev: EventBase, condition: Dict[str, Any], room_member_count: int
) -> bool:
return _test_ineq_condition(condition, room_member_count) return _test_ineq_condition(condition, room_member_count)
def _sender_notification_permission(ev, condition, sender_power_level, power_levels): def _sender_notification_permission(
ev: EventBase,
condition: Dict[str, Any],
sender_power_level: int,
power_levels: Dict[str, Union[int, Dict[str, int]]],
) -> bool:
notif_level_key = condition.get("key") notif_level_key = condition.get("key")
if notif_level_key is None: if notif_level_key is None:
return False return False
notif_levels = power_levels.get("notifications", {}) notif_levels = power_levels.get("notifications", {})
assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50) room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level return sender_power_level >= room_notif_level
def _test_ineq_condition(condition, number): def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
if "is" not in condition: if "is" not in condition:
return False return False
m = INEQUALITY_EXPR.match(condition["is"]) m = INEQUALITY_EXPR.match(condition["is"])
@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
event: EventBase, event: EventBase,
room_member_count: int, room_member_count: int,
sender_power_level: int, sender_power_level: int,
power_levels: dict, power_levels: Dict[str, Union[int, Dict[str, int]]],
): ):
self._event = event self._event = event
self._room_member_count = room_member_count self._room_member_count = room_member_count
@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
# Maps strings of e.g. 'content.body' -> event["content"]["body"] # Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event) self._value_cache = _flatten_dict(event)
def matches(self, condition: dict, user_id: str, display_name: str) -> bool: def matches(
self, condition: Dict[str, Any], user_id: str, display_name: str
) -> bool:
if condition["kind"] == "event_match": if condition["kind"] == "event_match":
return self._event_match(condition, user_id) return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name": elif condition["kind"] == "contains_display_name":
@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
return r"(^|\W)%s(\W|$)" % (r,) return r"(^|\W)%s(\W|$)" % (r,)
def _flatten_dict(d, prefix=[], result=None): def _flatten_dict(
d: Union[EventBase, dict],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, str]] = None,
) -> Dict[str, str]:
if prefix is None:
prefix = []
if result is None: if result is None:
result = {} result = {}
for key, value in d.items(): for key, value in d.items():