mirror of
				https://git.anonymousland.org/anonymousland/synapse.git
				synced 2025-11-03 23:24:17 -05:00 
			
		
		
		
	Add type hints to the push module. (#8901)
This commit is contained in:
		
							parent
							
								
									a8eceb01e5
								
							
						
					
					
						commit
						5d34f40d49
					
				
					 9 changed files with 159 additions and 86 deletions
				
			
		
							
								
								
									
										1
									
								
								changelog.d/8901.misc
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								changelog.d/8901.misc
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1 @@
 | 
			
		|||
Add type hints to push module.
 | 
			
		||||
							
								
								
									
										7
									
								
								mypy.ini
									
										
									
									
									
								
							
							
						
						
									
										7
									
								
								mypy.ini
									
										
									
									
									
								
							| 
						 | 
				
			
			@ -56,12 +56,7 @@ files =
 | 
			
		|||
  synapse/metrics,
 | 
			
		||||
  synapse/module_api,
 | 
			
		||||
  synapse/notifier.py,
 | 
			
		||||
  synapse/push/emailpusher.py,
 | 
			
		||||
  synapse/push/httppusher.py,
 | 
			
		||||
  synapse/push/mailer.py,
 | 
			
		||||
  synapse/push/pusher.py,
 | 
			
		||||
  synapse/push/pusherpool.py,
 | 
			
		||||
  synapse/push/push_rule_evaluator.py,
 | 
			
		||||
  synapse/push,
 | 
			
		||||
  synapse/replication,
 | 
			
		||||
  synapse/rest,
 | 
			
		||||
  synapse/server.py,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,8 @@ class SynapsePlugin(Plugin):
 | 
			
		|||
    ) -> Optional[Callable[[MethodSigContext], CallableType]]:
 | 
			
		||||
        if fullname.startswith(
 | 
			
		||||
            "synapse.util.caches.descriptors._CachedFunction.__call__"
 | 
			
		||||
        ) or fullname.startswith(
 | 
			
		||||
            "synapse.util.caches.descriptors._LruCachedFunction.__call__"
 | 
			
		||||
        ):
 | 
			
		||||
            return cached_function_method_signature
 | 
			
		||||
        return None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,19 +14,22 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from typing import TYPE_CHECKING
 | 
			
		||||
 | 
			
		||||
from synapse.events import EventBase
 | 
			
		||||
from synapse.events.snapshot import EventContext
 | 
			
		||||
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
 | 
			
		||||
from synapse.util.metrics import Measure
 | 
			
		||||
 | 
			
		||||
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from synapse.app.homeserver import HomeServer
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActionGenerator:
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
        self.hs = hs
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        self.clock = hs.get_clock()
 | 
			
		||||
        self.store = hs.get_datastore()
 | 
			
		||||
        self.bulk_evaluator = BulkPushRuleEvaluator(hs)
 | 
			
		||||
        # really we want to get all user ids and all profile tags too,
 | 
			
		||||
        # since we want the actions for each profile tag for every user and
 | 
			
		||||
| 
						 | 
				
			
			@ -35,6 +38,8 @@ class ActionGenerator:
 | 
			
		|||
        # event stream, so we just run the rules for a client with no profile
 | 
			
		||||
        # tag (ie. we just need all the users).
 | 
			
		||||
 | 
			
		||||
    async def handle_push_actions_for_event(self, event, context):
 | 
			
		||||
    async def handle_push_actions_for_event(
 | 
			
		||||
        self, event: EventBase, context: EventContext
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        with Measure(self.clock, "action_for_event_by_user"):
 | 
			
		||||
            await self.bulk_evaluator.action_for_event_by_user(event, context)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,16 +15,19 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
from typing import Any, Dict, List
 | 
			
		||||
 | 
			
		||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def list_with_base_rules(rawrules, use_new_defaults=False):
 | 
			
		||||
def list_with_base_rules(
 | 
			
		||||
    rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
 | 
			
		||||
) -> List[Dict[str, Any]]:
 | 
			
		||||
    """Combine the list of rules set by the user with the default push rules
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        rawrules(list): The rules the user has modified or set.
 | 
			
		||||
        use_new_defaults(bool): Whether to use the new experimental default rules when
 | 
			
		||||
        rawrules: The rules the user has modified or set.
 | 
			
		||||
        use_new_defaults: Whether to use the new experimental default rules when
 | 
			
		||||
            appending or prepending default rules.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
| 
						 | 
				
			
			@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
 | 
			
		|||
    return ruleslist
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
 | 
			
		||||
def make_base_append_rules(
 | 
			
		||||
    kind: str,
 | 
			
		||||
    modified_base_rules: Dict[str, Dict[str, Any]],
 | 
			
		||||
    use_new_defaults: bool = False,
 | 
			
		||||
) -> List[Dict[str, Any]]:
 | 
			
		||||
    rules = []
 | 
			
		||||
 | 
			
		||||
    if kind == "override":
 | 
			
		||||
| 
						 | 
				
			
			@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
 | 
			
		|||
    rules = copy.deepcopy(rules)
 | 
			
		||||
    for r in rules:
 | 
			
		||||
        # Only modify the actions, keep the conditions the same.
 | 
			
		||||
        assert isinstance(r["rule_id"], str)
 | 
			
		||||
        modified = modified_base_rules.get(r["rule_id"])
 | 
			
		||||
        if modified:
 | 
			
		||||
            r["actions"] = modified["actions"]
 | 
			
		||||
| 
						 | 
				
			
			@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
 | 
			
		|||
    return rules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
 | 
			
		||||
def make_base_prepend_rules(
 | 
			
		||||
    kind: str,
 | 
			
		||||
    modified_base_rules: Dict[str, Dict[str, Any]],
 | 
			
		||||
    use_new_defaults: bool = False,
 | 
			
		||||
) -> List[Dict[str, Any]]:
 | 
			
		||||
    rules = []
 | 
			
		||||
 | 
			
		||||
    if kind == "override":
 | 
			
		||||
| 
						 | 
				
			
			@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
 | 
			
		|||
    rules = copy.deepcopy(rules)
 | 
			
		||||
    for r in rules:
 | 
			
		||||
        # Only modify the actions, keep the conditions the same.
 | 
			
		||||
        assert isinstance(r["rule_id"], str)
 | 
			
		||||
        modified = modified_base_rules.get(r["rule_id"])
 | 
			
		||||
        if modified:
 | 
			
		||||
            r["actions"] = modified["actions"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,6 +15,7 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import attr
 | 
			
		||||
from prometheus_client import Counter
 | 
			
		||||
| 
						 | 
				
			
			@ -25,18 +26,18 @@ from synapse.events import EventBase
 | 
			
		|||
from synapse.events.snapshot import EventContext
 | 
			
		||||
from synapse.state import POWER_KEY
 | 
			
		||||
from synapse.util.async_helpers import Linearizer
 | 
			
		||||
from synapse.util.caches import register_cache
 | 
			
		||||
from synapse.util.caches import CacheMetric, register_cache
 | 
			
		||||
from synapse.util.caches.descriptors import lru_cache
 | 
			
		||||
from synapse.util.caches.lrucache import LruCache
 | 
			
		||||
 | 
			
		||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from synapse.app.homeserver import HomeServer
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
rules_by_room = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
push_rules_invalidation_counter = Counter(
 | 
			
		||||
    "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
 | 
			
		||||
)
 | 
			
		||||
| 
						 | 
				
			
			@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
    room at once.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hs):
 | 
			
		||||
    def __init__(self, hs: "HomeServer"):
 | 
			
		||||
        self.hs = hs
 | 
			
		||||
        self.store = hs.get_datastore()
 | 
			
		||||
        self.auth = hs.get_auth()
 | 
			
		||||
| 
						 | 
				
			
			@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
            resizable=False,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def _get_rules_for_event(self, event, context):
 | 
			
		||||
    async def _get_rules_for_event(
 | 
			
		||||
        self, event: EventBase, context: EventContext
 | 
			
		||||
    ) -> Dict[str, List[Dict[str, Any]]]:
 | 
			
		||||
        """This gets the rules for all users in the room at the time of the event,
 | 
			
		||||
        as well as the push rules for the invitee if the event is an invite.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
        return rules_by_user
 | 
			
		||||
 | 
			
		||||
    @lru_cache()
 | 
			
		||||
    def _get_rules_for_room(self, room_id):
 | 
			
		||||
    def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
 | 
			
		||||
        """Get the current RulesForRoom object for the given room id
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            RulesForRoom
 | 
			
		||||
        """
 | 
			
		||||
        # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
 | 
			
		||||
        # before any lookup methods get called on it as otherwise there may be
 | 
			
		||||
| 
						 | 
				
			
			@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
            self.room_push_rule_cache_metrics,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def _get_power_levels_and_sender_level(self, event, context):
 | 
			
		||||
    async def _get_power_levels_and_sender_level(
 | 
			
		||||
        self, event: EventBase, context: EventContext
 | 
			
		||||
    ) -> Tuple[dict, int]:
 | 
			
		||||
        prev_state_ids = await context.get_prev_state_ids()
 | 
			
		||||
        pl_event_id = prev_state_ids.get(POWER_KEY)
 | 
			
		||||
        if pl_event_id:
 | 
			
		||||
            # fastpath: if there's a power level event, that's all we need, and
 | 
			
		||||
            # not having a power level event is an extreme edge case
 | 
			
		||||
            pl_event = await self.store.get_event(pl_event_id)
 | 
			
		||||
            auth_events = {POWER_KEY: pl_event}
 | 
			
		||||
            auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
 | 
			
		||||
        else:
 | 
			
		||||
            auth_events_ids = self.auth.compute_auth_events(
 | 
			
		||||
                event, prev_state_ids, for_verification=False
 | 
			
		||||
            )
 | 
			
		||||
            auth_events = await self.store.get_events(auth_events_ids)
 | 
			
		||||
            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
 | 
			
		||||
            auth_events_dict = await self.store.get_events(auth_events_ids)
 | 
			
		||||
            auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
 | 
			
		||||
 | 
			
		||||
        sender_level = get_user_power_level(event.sender, auth_events)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
 | 
			
		||||
        return pl_event.content if pl_event else {}, sender_level
 | 
			
		||||
 | 
			
		||||
    async def action_for_event_by_user(self, event, context) -> None:
 | 
			
		||||
    async def action_for_event_by_user(
 | 
			
		||||
        self, event: EventBase, context: EventContext
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Given an event and context, evaluate the push rules, check if the message
 | 
			
		||||
        should increment the unread count, and insert the results into the
 | 
			
		||||
        event_push_actions_staging table.
 | 
			
		||||
| 
						 | 
				
			
			@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
        count_as_unread = _should_count_as_unread(event, context)
 | 
			
		||||
 | 
			
		||||
        rules_by_user = await self._get_rules_for_event(event, context)
 | 
			
		||||
        actions_by_user = {}
 | 
			
		||||
        actions_by_user = {}  # type: Dict[str, List[Union[dict, str]]]
 | 
			
		||||
 | 
			
		||||
        room_members = await self.store.get_joined_users_from_context(event, context)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
            event, len(room_members), sender_power_level, power_levels
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        condition_cache = {}
 | 
			
		||||
        condition_cache = {}  # type: Dict[str, bool]
 | 
			
		||||
 | 
			
		||||
        for uid, rules in rules_by_user.items():
 | 
			
		||||
            if event.sender == uid:
 | 
			
		||||
| 
						 | 
				
			
			@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
 | 
			
		||||
def _condition_checker(
 | 
			
		||||
    evaluator: PushRuleEvaluatorForEvent,
 | 
			
		||||
    conditions: List[dict],
 | 
			
		||||
    uid: str,
 | 
			
		||||
    display_name: str,
 | 
			
		||||
    cache: Dict[str, bool],
 | 
			
		||||
) -> bool:
 | 
			
		||||
    for cond in conditions:
 | 
			
		||||
        _id = cond.get("_id", None)
 | 
			
		||||
        if _id:
 | 
			
		||||
| 
						 | 
				
			
			@ -277,15 +286,19 @@ class RulesForRoom:
 | 
			
		|||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
 | 
			
		||||
        self,
 | 
			
		||||
        hs: "HomeServer",
 | 
			
		||||
        room_id: str,
 | 
			
		||||
        rules_for_room_cache: LruCache,
 | 
			
		||||
        room_push_rule_cache_metrics: CacheMetric,
 | 
			
		||||
    ):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            hs (HomeServer)
 | 
			
		||||
            room_id (str)
 | 
			
		||||
            hs: The HomeServer object.
 | 
			
		||||
            room_id: The room ID.
 | 
			
		||||
            rules_for_room_cache: The cache object that caches these
 | 
			
		||||
                RoomsForUser objects.
 | 
			
		||||
            room_push_rule_cache_metrics (CacheMetric)
 | 
			
		||||
            room_push_rule_cache_metrics: The metrics object
 | 
			
		||||
        """
 | 
			
		||||
        self.room_id = room_id
 | 
			
		||||
        self.is_mine_id = hs.is_mine_id
 | 
			
		||||
| 
						 | 
				
			
			@ -294,8 +307,10 @@ class RulesForRoom:
 | 
			
		|||
 | 
			
		||||
        self.linearizer = Linearizer(name="rules_for_room")
 | 
			
		||||
 | 
			
		||||
        self.member_map = {}  # event_id -> (user_id, state)
 | 
			
		||||
        self.rules_by_user = {}  # user_id -> rules
 | 
			
		||||
        # event_id -> (user_id, state)
 | 
			
		||||
        self.member_map = {}  # type: Dict[str, Tuple[str, str]]
 | 
			
		||||
        # user_id -> rules
 | 
			
		||||
        self.rules_by_user = {}  # type: Dict[str, List[Dict[str, dict]]]
 | 
			
		||||
 | 
			
		||||
        # The last state group we updated the caches for. If the state_group of
 | 
			
		||||
        # a new event comes along, we know that we can just return the cached
 | 
			
		||||
| 
						 | 
				
			
			@ -315,7 +330,7 @@ class RulesForRoom:
 | 
			
		|||
        # calculate push for)
 | 
			
		||||
        # These never need to be invalidated as we will never set up push for
 | 
			
		||||
        # them.
 | 
			
		||||
        self.uninteresting_user_set = set()
 | 
			
		||||
        self.uninteresting_user_set = set()  # type: Set[str]
 | 
			
		||||
 | 
			
		||||
        # We need to be clever on the invalidating caches callbacks, as
 | 
			
		||||
        # otherwise the invalidation callback holds a reference to the object,
 | 
			
		||||
| 
						 | 
				
			
			@ -325,7 +340,9 @@ class RulesForRoom:
 | 
			
		|||
        # to self around in the callback.
 | 
			
		||||
        self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
 | 
			
		||||
 | 
			
		||||
    async def get_rules(self, event, context):
 | 
			
		||||
    async def get_rules(
 | 
			
		||||
        self, event: EventBase, context: EventContext
 | 
			
		||||
    ) -> Dict[str, List[Dict[str, dict]]]:
 | 
			
		||||
        """Given an event context return the rules for all users who are
 | 
			
		||||
        currently in the room.
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -356,6 +373,8 @@ class RulesForRoom:
 | 
			
		|||
            else:
 | 
			
		||||
                current_state_ids = await context.get_current_state_ids()
 | 
			
		||||
                push_rules_delta_state_cache_metric.inc_misses()
 | 
			
		||||
            # Ensure the state IDs exist.
 | 
			
		||||
            assert current_state_ids is not None
 | 
			
		||||
 | 
			
		||||
            push_rules_state_size_counter.inc(len(current_state_ids))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -420,18 +439,23 @@ class RulesForRoom:
 | 
			
		|||
        return ret_rules_by_user
 | 
			
		||||
 | 
			
		||||
    async def _update_rules_with_member_event_ids(
 | 
			
		||||
        self, ret_rules_by_user, member_event_ids, state_group, event
 | 
			
		||||
    ):
 | 
			
		||||
        self,
 | 
			
		||||
        ret_rules_by_user: Dict[str, list],
 | 
			
		||||
        member_event_ids: Dict[str, str],
 | 
			
		||||
        state_group: Optional[int],
 | 
			
		||||
        event: EventBase,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """Update the partially filled rules_by_user dict by fetching rules for
 | 
			
		||||
        any newly joined users in the `member_event_ids` list.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
 | 
			
		||||
            ret_rules_by_user: Partially filled dict of push rules. Gets
 | 
			
		||||
                updated with any new rules.
 | 
			
		||||
            member_event_ids (dict): Dict of user id to event id for membership events
 | 
			
		||||
            member_event_ids: Dict of user id to event id for membership events
 | 
			
		||||
                that have happened since the last time we filled rules_by_user
 | 
			
		||||
            state_group: The state group we are currently computing push rules
 | 
			
		||||
                for. Used when updating the cache.
 | 
			
		||||
            event: The event we are currently computing push rules for.
 | 
			
		||||
        """
 | 
			
		||||
        sequence = self.sequence
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -449,19 +473,19 @@ class RulesForRoom:
 | 
			
		|||
        if logger.isEnabledFor(logging.DEBUG):
 | 
			
		||||
            logger.debug("Found members %r: %r", self.room_id, members.values())
 | 
			
		||||
 | 
			
		||||
        user_ids = {
 | 
			
		||||
        joined_user_ids = {
 | 
			
		||||
            user_id
 | 
			
		||||
            for user_id, membership in members.values()
 | 
			
		||||
            if membership == Membership.JOIN
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        logger.debug("Joined: %r", user_ids)
 | 
			
		||||
        logger.debug("Joined: %r", joined_user_ids)
 | 
			
		||||
 | 
			
		||||
        # Previously we only considered users with pushers or read receipts in that
 | 
			
		||||
        # room. We can't do this anymore because we use push actions to calculate unread
 | 
			
		||||
        # counts, which don't rely on the user having pushers or sent a read receipt into
 | 
			
		||||
        # the room. Therefore we just need to filter for local users here.
 | 
			
		||||
        user_ids = list(filter(self.is_mine_id, user_ids))
 | 
			
		||||
        user_ids = list(filter(self.is_mine_id, joined_user_ids))
 | 
			
		||||
 | 
			
		||||
        rules_by_user = await self.store.bulk_get_push_rules(
 | 
			
		||||
            user_ids, on_invalidate=self.invalidate_all_cb
 | 
			
		||||
| 
						 | 
				
			
			@ -473,7 +497,7 @@ class RulesForRoom:
 | 
			
		|||
 | 
			
		||||
        self.update_cache(sequence, members, ret_rules_by_user, state_group)
 | 
			
		||||
 | 
			
		||||
    def invalidate_all(self):
 | 
			
		||||
    def invalidate_all(self) -> None:
 | 
			
		||||
        # Note: Don't hand this function directly to an invalidation callback
 | 
			
		||||
        # as it keeps a reference to self and will stop this instance from being
 | 
			
		||||
        # GC'd if it gets dropped from the rules_to_user cache. Instead use
 | 
			
		||||
| 
						 | 
				
			
			@ -485,7 +509,7 @@ class RulesForRoom:
 | 
			
		|||
        self.rules_by_user = {}
 | 
			
		||||
        push_rules_invalidation_counter.inc()
 | 
			
		||||
 | 
			
		||||
    def update_cache(self, sequence, members, rules_by_user, state_group):
 | 
			
		||||
    def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
 | 
			
		||||
        if sequence == self.sequence:
 | 
			
		||||
            self.member_map.update(members)
 | 
			
		||||
            self.rules_by_user = rules_by_user
 | 
			
		||||
| 
						 | 
				
			
			@ -506,7 +530,7 @@ class _Invalidation:
 | 
			
		|||
    cache = attr.ib(type=LruCache)
 | 
			
		||||
    room_id = attr.ib(type=str)
 | 
			
		||||
 | 
			
		||||
    def __call__(self):
 | 
			
		||||
    def __call__(self) -> None:
 | 
			
		||||
        rules = self.cache.get(self.room_id, None, update_metrics=False)
 | 
			
		||||
        if rules:
 | 
			
		||||
            rules.invalidate_all()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,24 +14,27 @@
 | 
			
		|||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
 | 
			
		||||
from synapse.types import UserID
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def format_push_rules_for_user(user, ruleslist):
 | 
			
		||||
def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
 | 
			
		||||
    """Converts a list of rawrules and a enabled map into nested dictionaries
 | 
			
		||||
    to match the Matrix client-server format for push rules"""
 | 
			
		||||
 | 
			
		||||
    # We're going to be mutating this a lot, so do a deep copy
 | 
			
		||||
    ruleslist = copy.deepcopy(ruleslist)
 | 
			
		||||
 | 
			
		||||
    rules = {"global": {}, "device": {}}
 | 
			
		||||
    rules = {
 | 
			
		||||
        "global": {},
 | 
			
		||||
        "device": {},
 | 
			
		||||
    }  # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
 | 
			
		||||
 | 
			
		||||
    rules["global"] = _add_empty_priority_class_arrays(rules["global"])
 | 
			
		||||
 | 
			
		||||
    for r in ruleslist:
 | 
			
		||||
        rulearray = None
 | 
			
		||||
 | 
			
		||||
        template_name = _priority_class_to_template_name(r["priority_class"])
 | 
			
		||||
 | 
			
		||||
        # Remove internal stuff.
 | 
			
		||||
| 
						 | 
				
			
			@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
 | 
			
		|||
    return rules
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _add_empty_priority_class_arrays(d):
 | 
			
		||||
def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
 | 
			
		||||
    for pc in PRIORITY_CLASS_MAP.keys():
 | 
			
		||||
        d[pc] = []
 | 
			
		||||
    return d
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _rule_to_template(rule):
 | 
			
		||||
def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
 | 
			
		||||
    unscoped_rule_id = None
 | 
			
		||||
    if "rule_id" in rule:
 | 
			
		||||
        unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
 | 
			
		||||
| 
						 | 
				
			
			@ -82,6 +85,10 @@ def _rule_to_template(rule):
 | 
			
		|||
            return None
 | 
			
		||||
        templaterule = {"actions": rule["actions"]}
 | 
			
		||||
        templaterule["pattern"] = thecond["pattern"]
 | 
			
		||||
    else:
 | 
			
		||||
        # This should not be reached unless this function is not kept in sync
 | 
			
		||||
        # with PRIORITY_CLASS_INVERSE_MAP.
 | 
			
		||||
        raise ValueError("Unexpected template_name: %s" % (template_name,))
 | 
			
		||||
 | 
			
		||||
    if unscoped_rule_id:
 | 
			
		||||
        templaterule["rule_id"] = unscoped_rule_id
 | 
			
		||||
| 
						 | 
				
			
			@ -90,9 +97,9 @@ def _rule_to_template(rule):
 | 
			
		|||
    return templaterule
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _rule_id_from_namespaced(in_rule_id):
 | 
			
		||||
def _rule_id_from_namespaced(in_rule_id: str) -> str:
 | 
			
		||||
    return in_rule_id.split("/")[-1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _priority_class_to_template_name(pc):
 | 
			
		||||
def _priority_class_to_template_name(pc: int) -> str:
 | 
			
		||||
    return PRIORITY_CLASS_INVERSE_MAP[pc]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -15,8 +15,14 @@
 | 
			
		|||
 | 
			
		||||
import logging
 | 
			
		||||
import re
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, Iterable, Optional
 | 
			
		||||
 | 
			
		||||
from synapse.api.constants import EventTypes
 | 
			
		||||
from synapse.events import EventBase
 | 
			
		||||
from synapse.types import StateMap
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from synapse.storage.databases.main import DataStore
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def calculate_room_name(
 | 
			
		||||
    store,
 | 
			
		||||
    room_state_ids,
 | 
			
		||||
    user_id,
 | 
			
		||||
    fallback_to_members=True,
 | 
			
		||||
    fallback_to_single_member=True,
 | 
			
		||||
):
 | 
			
		||||
    store: "DataStore",
 | 
			
		||||
    room_state_ids: StateMap[str],
 | 
			
		||||
    user_id: str,
 | 
			
		||||
    fallback_to_members: bool = True,
 | 
			
		||||
    fallback_to_single_member: bool = True,
 | 
			
		||||
) -> Optional[str]:
 | 
			
		||||
    """
 | 
			
		||||
    Works out a user-facing name for the given room as per Matrix
 | 
			
		||||
    spec recommendations.
 | 
			
		||||
    Does not yet support internationalisation.
 | 
			
		||||
    Args:
 | 
			
		||||
        room_state: Dictionary of the room's state
 | 
			
		||||
        store: The data store to query.
 | 
			
		||||
        room_state_ids: Dictionary of the room's state IDs.
 | 
			
		||||
        user_id: The ID of the user to whom the room name is being presented
 | 
			
		||||
        fallback_to_members: If False, return None instead of generating a name
 | 
			
		||||
                             based on the room's members if the room has no
 | 
			
		||||
                             title or aliases.
 | 
			
		||||
        fallback_to_single_member: If False, return None instead of generating a
 | 
			
		||||
            name based on the user who invited this user to the room if the room
 | 
			
		||||
            has no title or aliases.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        (string or None) A human readable name for the room.
 | 
			
		||||
        A human readable name for the room, if possible.
 | 
			
		||||
    """
 | 
			
		||||
    # does it have a name?
 | 
			
		||||
    if (EventTypes.Name, "") in room_state_ids:
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +107,7 @@ async def calculate_room_name(
 | 
			
		|||
                        name_from_member_event(inviter_member_event),
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    return
 | 
			
		||||
                    return None
 | 
			
		||||
        else:
 | 
			
		||||
            return "Room Invite"
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -150,19 +160,19 @@ async def calculate_room_name(
 | 
			
		|||
        else:
 | 
			
		||||
            return ALL_ALONE
 | 
			
		||||
    elif len(other_members) == 1 and not fallback_to_single_member:
 | 
			
		||||
        return
 | 
			
		||||
    else:
 | 
			
		||||
        return descriptor_from_member_events(other_members)
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    return descriptor_from_member_events(other_members)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def descriptor_from_member_events(member_events):
 | 
			
		||||
def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
 | 
			
		||||
    """Get a description of the room based on the member events.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        member_events (Iterable[FrozenEvent])
 | 
			
		||||
        member_events: The events of a room.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        str
 | 
			
		||||
        The room description
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    member_events = list(member_events)
 | 
			
		||||
| 
						 | 
				
			
			@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def name_from_member_event(member_event):
 | 
			
		||||
def name_from_member_event(member_event: EventBase) -> str:
 | 
			
		||||
    if (
 | 
			
		||||
        member_event.content
 | 
			
		||||
        and "displayname" in member_event.content
 | 
			
		||||
| 
						 | 
				
			
			@ -193,12 +203,12 @@ def name_from_member_event(member_event):
 | 
			
		|||
    return member_event.state_key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _state_as_two_level_dict(state):
 | 
			
		||||
    ret = {}
 | 
			
		||||
def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
 | 
			
		||||
    ret = {}  # type: Dict[str, Dict[str, str]]
 | 
			
		||||
    for k, v in state.items():
 | 
			
		||||
        ret.setdefault(k[0], {})[k[1]] = v
 | 
			
		||||
    return ret
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _looks_like_an_alias(string):
 | 
			
		||||
def _looks_like_an_alias(string: str) -> bool:
 | 
			
		||||
    return ALIAS_RE.match(string) is not None
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
 | 
			
		|||
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _room_member_count(ev, condition, room_member_count):
 | 
			
		||||
def _room_member_count(
 | 
			
		||||
    ev: EventBase, condition: Dict[str, Any], room_member_count: int
 | 
			
		||||
) -> bool:
 | 
			
		||||
    return _test_ineq_condition(condition, room_member_count)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
 | 
			
		||||
def _sender_notification_permission(
 | 
			
		||||
    ev: EventBase,
 | 
			
		||||
    condition: Dict[str, Any],
 | 
			
		||||
    sender_power_level: int,
 | 
			
		||||
    power_levels: Dict[str, Union[int, Dict[str, int]]],
 | 
			
		||||
) -> bool:
 | 
			
		||||
    notif_level_key = condition.get("key")
 | 
			
		||||
    if notif_level_key is None:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    notif_levels = power_levels.get("notifications", {})
 | 
			
		||||
    assert isinstance(notif_levels, dict)
 | 
			
		||||
    room_notif_level = notif_levels.get(notif_level_key, 50)
 | 
			
		||||
 | 
			
		||||
    return sender_power_level >= room_notif_level
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _test_ineq_condition(condition, number):
 | 
			
		||||
def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
 | 
			
		||||
    if "is" not in condition:
 | 
			
		||||
        return False
 | 
			
		||||
    m = INEQUALITY_EXPR.match(condition["is"])
 | 
			
		||||
| 
						 | 
				
			
			@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
 | 
			
		|||
        event: EventBase,
 | 
			
		||||
        room_member_count: int,
 | 
			
		||||
        sender_power_level: int,
 | 
			
		||||
        power_levels: dict,
 | 
			
		||||
        power_levels: Dict[str, Union[int, Dict[str, int]]],
 | 
			
		||||
    ):
 | 
			
		||||
        self._event = event
 | 
			
		||||
        self._room_member_count = room_member_count
 | 
			
		||||
| 
						 | 
				
			
			@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
 | 
			
		|||
        # Maps strings of e.g. 'content.body' -> event["content"]["body"]
 | 
			
		||||
        self._value_cache = _flatten_dict(event)
 | 
			
		||||
 | 
			
		||||
    def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
 | 
			
		||||
    def matches(
 | 
			
		||||
        self, condition: Dict[str, Any], user_id: str, display_name: str
 | 
			
		||||
    ) -> bool:
 | 
			
		||||
        if condition["kind"] == "event_match":
 | 
			
		||||
            return self._event_match(condition, user_id)
 | 
			
		||||
        elif condition["kind"] == "contains_display_name":
 | 
			
		||||
| 
						 | 
				
			
			@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
 | 
			
		|||
    return r"(^|\W)%s(\W|$)" % (r,)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _flatten_dict(d, prefix=[], result=None):
 | 
			
		||||
def _flatten_dict(
 | 
			
		||||
    d: Union[EventBase, dict],
 | 
			
		||||
    prefix: Optional[List[str]] = None,
 | 
			
		||||
    result: Optional[Dict[str, str]] = None,
 | 
			
		||||
) -> Dict[str, str]:
 | 
			
		||||
    if prefix is None:
 | 
			
		||||
        prefix = []
 | 
			
		||||
    if result is None:
 | 
			
		||||
        result = {}
 | 
			
		||||
    for key, value in d.items():
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue