mirror of
https://git.anonymousland.org/anonymousland/synapse.git
synced 2025-05-04 23:15:02 -04:00
Add type hints to the push module. (#8901)
This commit is contained in:
parent
a8eceb01e5
commit
5d34f40d49
9 changed files with 159 additions and 86 deletions
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
@ -25,18 +26,18 @@ from synapse.events import EventBase
|
|||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import POWER_KEY
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches import CacheMetric, register_cache
|
||||
from synapse.util.caches.descriptors import lru_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
rules_by_room = {}
|
||||
|
||||
|
||||
push_rules_invalidation_counter = Counter(
|
||||
"synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
|
||||
)
|
||||
|
@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
|
|||
room at once.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
|
|||
resizable=False,
|
||||
)
|
||||
|
||||
async def _get_rules_for_event(self, event, context):
|
||||
async def _get_rules_for_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""This gets the rules for all users in the room at the time of the event,
|
||||
as well as the push rules for the invitee if the event is an invite.
|
||||
|
||||
|
@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
|
|||
return rules_by_user
|
||||
|
||||
@lru_cache()
|
||||
def _get_rules_for_room(self, room_id):
|
||||
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
|
||||
"""Get the current RulesForRoom object for the given room id
|
||||
|
||||
Returns:
|
||||
RulesForRoom
|
||||
"""
|
||||
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
|
||||
# before any lookup methods get called on it as otherwise there may be
|
||||
|
@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
|
|||
self.room_push_rule_cache_metrics,
|
||||
)
|
||||
|
||||
async def _get_power_levels_and_sender_level(self, event, context):
|
||||
async def _get_power_levels_and_sender_level(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Tuple[dict, int]:
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
pl_event_id = prev_state_ids.get(POWER_KEY)
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
pl_event = await self.store.get_event(pl_event_id)
|
||||
auth_events = {POWER_KEY: pl_event}
|
||||
auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
|
||||
else:
|
||||
auth_events_ids = self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=False
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_dict = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
|
||||
|
||||
sender_level = get_user_power_level(event.sender, auth_events)
|
||||
|
||||
|
@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
|
|||
|
||||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
async def action_for_event_by_user(self, event, context) -> None:
|
||||
async def action_for_event_by_user(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
"""Given an event and context, evaluate the push rules, check if the message
|
||||
should increment the unread count, and insert the results into the
|
||||
event_push_actions_staging table.
|
||||
|
@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
|
|||
count_as_unread = _should_count_as_unread(event, context)
|
||||
|
||||
rules_by_user = await self._get_rules_for_event(event, context)
|
||||
actions_by_user = {}
|
||||
actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
|
||||
|
||||
room_members = await self.store.get_joined_users_from_context(event, context)
|
||||
|
||||
|
@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
|
|||
event, len(room_members), sender_power_level, power_levels
|
||||
)
|
||||
|
||||
condition_cache = {}
|
||||
condition_cache = {} # type: Dict[str, bool]
|
||||
|
||||
for uid, rules in rules_by_user.items():
|
||||
if event.sender == uid:
|
||||
|
@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
|
|||
)
|
||||
|
||||
|
||||
def _condition_checker(evaluator, conditions, uid, display_name, cache):
|
||||
def _condition_checker(
|
||||
evaluator: PushRuleEvaluatorForEvent,
|
||||
conditions: List[dict],
|
||||
uid: str,
|
||||
display_name: str,
|
||||
cache: Dict[str, bool],
|
||||
) -> bool:
|
||||
for cond in conditions:
|
||||
_id = cond.get("_id", None)
|
||||
if _id:
|
||||
|
@ -277,15 +286,19 @@ class RulesForRoom:
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
room_id: str,
|
||||
rules_for_room_cache: LruCache,
|
||||
room_push_rule_cache_metrics: CacheMetric,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
room_id (str)
|
||||
hs: The HomeServer object.
|
||||
room_id: The room ID.
|
||||
rules_for_room_cache: The cache object that caches these
|
||||
RoomsForUser objects.
|
||||
room_push_rule_cache_metrics (CacheMetric)
|
||||
room_push_rule_cache_metrics: The metrics object
|
||||
"""
|
||||
self.room_id = room_id
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
@ -294,8 +307,10 @@ class RulesForRoom:
|
|||
|
||||
self.linearizer = Linearizer(name="rules_for_room")
|
||||
|
||||
self.member_map = {} # event_id -> (user_id, state)
|
||||
self.rules_by_user = {} # user_id -> rules
|
||||
# event_id -> (user_id, state)
|
||||
self.member_map = {} # type: Dict[str, Tuple[str, str]]
|
||||
# user_id -> rules
|
||||
self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
|
||||
|
||||
# The last state group we updated the caches for. If the state_group of
|
||||
# a new event comes along, we know that we can just return the cached
|
||||
|
@ -315,7 +330,7 @@ class RulesForRoom:
|
|||
# calculate push for)
|
||||
# These never need to be invalidated as we will never set up push for
|
||||
# them.
|
||||
self.uninteresting_user_set = set()
|
||||
self.uninteresting_user_set = set() # type: Set[str]
|
||||
|
||||
# We need to be clever on the invalidating caches callbacks, as
|
||||
# otherwise the invalidation callback holds a reference to the object,
|
||||
|
@ -325,7 +340,9 @@ class RulesForRoom:
|
|||
# to self around in the callback.
|
||||
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
|
||||
|
||||
async def get_rules(self, event, context):
|
||||
async def get_rules(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> Dict[str, List[Dict[str, dict]]]:
|
||||
"""Given an event context return the rules for all users who are
|
||||
currently in the room.
|
||||
"""
|
||||
|
@ -356,6 +373,8 @@ class RulesForRoom:
|
|||
else:
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
push_rules_delta_state_cache_metric.inc_misses()
|
||||
# Ensure the state IDs exist.
|
||||
assert current_state_ids is not None
|
||||
|
||||
push_rules_state_size_counter.inc(len(current_state_ids))
|
||||
|
||||
|
@ -420,18 +439,23 @@ class RulesForRoom:
|
|||
return ret_rules_by_user
|
||||
|
||||
async def _update_rules_with_member_event_ids(
|
||||
self, ret_rules_by_user, member_event_ids, state_group, event
|
||||
):
|
||||
self,
|
||||
ret_rules_by_user: Dict[str, list],
|
||||
member_event_ids: Dict[str, str],
|
||||
state_group: Optional[int],
|
||||
event: EventBase,
|
||||
) -> None:
|
||||
"""Update the partially filled rules_by_user dict by fetching rules for
|
||||
any newly joined users in the `member_event_ids` list.
|
||||
|
||||
Args:
|
||||
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
|
||||
ret_rules_by_user: Partially filled dict of push rules. Gets
|
||||
updated with any new rules.
|
||||
member_event_ids (dict): Dict of user id to event id for membership events
|
||||
member_event_ids: Dict of user id to event id for membership events
|
||||
that have happened since the last time we filled rules_by_user
|
||||
state_group: The state group we are currently computing push rules
|
||||
for. Used when updating the cache.
|
||||
event: The event we are currently computing push rules for.
|
||||
"""
|
||||
sequence = self.sequence
|
||||
|
||||
|
@ -449,19 +473,19 @@ class RulesForRoom:
|
|||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.debug("Found members %r: %r", self.room_id, members.values())
|
||||
|
||||
user_ids = {
|
||||
joined_user_ids = {
|
||||
user_id
|
||||
for user_id, membership in members.values()
|
||||
if membership == Membership.JOIN
|
||||
}
|
||||
|
||||
logger.debug("Joined: %r", user_ids)
|
||||
logger.debug("Joined: %r", joined_user_ids)
|
||||
|
||||
# Previously we only considered users with pushers or read receipts in that
|
||||
# room. We can't do this anymore because we use push actions to calculate unread
|
||||
# counts, which don't rely on the user having pushers or sent a read receipt into
|
||||
# the room. Therefore we just need to filter for local users here.
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
user_ids = list(filter(self.is_mine_id, joined_user_ids))
|
||||
|
||||
rules_by_user = await self.store.bulk_get_push_rules(
|
||||
user_ids, on_invalidate=self.invalidate_all_cb
|
||||
|
@ -473,7 +497,7 @@ class RulesForRoom:
|
|||
|
||||
self.update_cache(sequence, members, ret_rules_by_user, state_group)
|
||||
|
||||
def invalidate_all(self):
|
||||
def invalidate_all(self) -> None:
|
||||
# Note: Don't hand this function directly to an invalidation callback
|
||||
# as it keeps a reference to self and will stop this instance from being
|
||||
# GC'd if it gets dropped from the rules_to_user cache. Instead use
|
||||
|
@ -485,7 +509,7 @@ class RulesForRoom:
|
|||
self.rules_by_user = {}
|
||||
push_rules_invalidation_counter.inc()
|
||||
|
||||
def update_cache(self, sequence, members, rules_by_user, state_group):
|
||||
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
|
||||
if sequence == self.sequence:
|
||||
self.member_map.update(members)
|
||||
self.rules_by_user = rules_by_user
|
||||
|
@ -506,7 +530,7 @@ class _Invalidation:
|
|||
cache = attr.ib(type=LruCache)
|
||||
room_id = attr.ib(type=str)
|
||||
|
||||
def __call__(self):
|
||||
def __call__(self) -> None:
|
||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
||||
if rules:
|
||||
rules.invalidate_all()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue