Convert push to async/await. (#7948)

This commit is contained in:
Patrick Cloke 2020-07-27 12:21:34 -04:00 committed by GitHub
parent 7c2e2c2077
commit 8144bc26a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 106 additions and 145 deletions

1
changelog.d/7948.misc Normal file
View File

@ -0,0 +1 @@
Convert push to async/await.

View File

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
@ -37,7 +35,6 @@ class ActionGenerator(object):
# event stream, so we just run the rules for a client with no profile # event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users). # tag (ie. we just need all the users).
@defer.inlineCallbacks async def handle_push_actions_for_event(self, event, context):
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"): with Measure(self.clock, "action_for_event_by_user"):
yield self.bulk_evaluator.action_for_event_by_user(event, context) await self.bulk_evaluator.action_for_event_by_user(event, context)

View File

@ -19,8 +19,6 @@ from collections import namedtuple
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY from synapse.state import POWER_KEY
@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object):
resizable=False, resizable=False,
) )
@defer.inlineCallbacks async def _get_rules_for_event(self, event, context):
def _get_rules_for_event(self, event, context):
"""This gets the rules for all users in the room at the time of the event, """This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite. as well as the push rules for the invitee if the event is an invite.
@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules dict of user_id -> push_rules
""" """
room_id = event.room_id room_id = event.room_id
rules_for_room = yield self._get_rules_for_room(room_id) rules_for_room = await self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context) rules_by_user = await rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
if event.type == "m.room.member" and event.content["membership"] == "invite": if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key invited = event.state_key
if invited and self.hs.is_mine_id(invited): if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited) has_pusher = await self.store.user_has_pusher(invited)
if has_pusher: if has_pusher:
rules_by_user = dict(rules_by_user) rules_by_user = dict(rules_by_user)
rules_by_user[invited] = yield self.store.get_push_rules_for_user( rules_by_user[invited] = await self.store.get_push_rules_for_user(
invited invited
) )
@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics, self.room_push_rule_cache_metrics,
) )
@defer.inlineCallbacks async def _get_power_levels_and_sender_level(self, event, context):
def _get_power_levels_and_sender_level(self, event, context): prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = yield context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY) pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id: if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and # fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case # not having a power level event is an extreme edge case
pl_event = yield self.store.get_event(pl_event_id) pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event} auth_events = {POWER_KEY: pl_event}
else: else:
auth_events_ids = yield self.auth.compute_auth_events( auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False event, prev_state_ids, for_verification=False
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events) sender_level = get_user_power_level(event.sender, auth_events)
@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
@defer.inlineCallbacks async def action_for_event_by_user(self, event, context) -> None:
def action_for_event_by_user(self, event, context):
"""Given an event and context, evaluate the push rules and insert the """Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table. results into the event_push_actions_staging table.
Returns:
Deferred
""" """
rules_by_user = yield self._get_rules_for_event(event, context) rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {} actions_by_user = {}
room_members = yield self.store.get_joined_users_from_context(event, context) room_members = await self.store.get_joined_users_from_context(event, context)
( (
power_levels, power_levels,
sender_power_level, sender_power_level,
) = yield self._get_power_levels_and_sender_level(event, context) ) = await self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent( evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels event, len(room_members), sender_power_level, power_levels
@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object):
continue continue
if not event.is_state(): if not event.is_state():
is_ignored = yield self.store.is_ignored_by(event.sender, uid) is_ignored = await self.store.is_ignored_by(event.sender, uid)
if is_ignored: if is_ignored:
continue continue
@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be # Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist # notified for this event. (This will then get handled when we persist
# the event) # the event)
yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user) await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
def _condition_checker(evaluator, conditions, uid, display_name, cache): def _condition_checker(evaluator, conditions, uid, display_name, cache):
@ -274,8 +266,7 @@ class RulesForRoom(object):
# to self around in the callback. # to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id) self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
@defer.inlineCallbacks async def get_rules(self, event, context):
def get_rules(self, event, context):
"""Given an event context return the rules for all users who are """Given an event context return the rules for all users who are
currently in the room. currently in the room.
""" """
@ -286,7 +277,7 @@ class RulesForRoom(object):
self.room_push_rule_cache_metrics.inc_hits() self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user return self.rules_by_user
with (yield self.linearizer.queue(())): with (await self.linearizer.queue(())):
if state_group and self.state_group == state_group: if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id) logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits() self.room_push_rule_cache_metrics.inc_hits()
@ -304,9 +295,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits() push_rules_delta_state_cache_metric.inc_hits()
else: else:
current_state_ids = yield defer.ensureDeferred( current_state_ids = await context.get_current_state_ids()
context.get_current_state_ids()
)
push_rules_delta_state_cache_metric.inc_misses() push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids)) push_rules_state_size_counter.inc(len(current_state_ids))
@ -353,7 +342,7 @@ class RulesForRoom(object):
# If we have some memebr events we haven't seen, look them up # If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate. # and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids) logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids( await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event ret_rules_by_user, missing_member_event_ids, state_group, event
) )
else: else:
@ -371,8 +360,7 @@ class RulesForRoom(object):
) )
return ret_rules_by_user return ret_rules_by_user
@defer.inlineCallbacks async def _update_rules_with_member_event_ids(
def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event self, ret_rules_by_user, member_event_ids, state_group, event
): ):
"""Update the partially filled rules_by_user dict by fetching rules for """Update the partially filled rules_by_user dict by fetching rules for
@ -388,7 +376,7 @@ class RulesForRoom(object):
""" """
sequence = self.sequence sequence = self.sequence
rows = yield self.store.get_membership_from_event_ids(member_event_ids.values()) rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
@ -410,7 +398,7 @@ class RulesForRoom(object):
logger.debug("Joined: %r", interested_in_user_ids) logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers( if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb interested_in_user_ids, on_invalidate=self.invalidate_all_cb
) )
@ -420,7 +408,7 @@ class RulesForRoom(object):
logger.debug("With pushers: %r", user_ids) logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb self.room_id, on_invalidate=self.invalidate_all_cb
) )
@ -431,7 +419,7 @@ class RulesForRoom(object):
if uid in interested_in_user_ids: if uid in interested_in_user_ids:
user_ids.add(uid) user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules( rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb user_ids, on_invalidate=self.invalidate_all_cb
) )

View File

@ -17,7 +17,6 @@ import logging
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -128,12 +127,11 @@ class HttpPusher(object):
# but currently that's the only type of receipt anyway... # but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge) run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
@defer.inlineCallbacks async def _update_badge(self):
def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it. # to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge) await self._send_badge(badge)
def on_timer(self): def on_timer(self):
self._start_processing() self._start_processing()
@ -152,8 +150,7 @@ class HttpPusher(object):
run_as_background_process("httppush.process", self._process) run_as_background_process("httppush.process", self._process)
@defer.inlineCallbacks async def _process(self):
def _process(self):
# we should never get here if we are already processing # we should never get here if we are already processing
assert not self._is_processing assert not self._is_processing
@ -164,7 +161,7 @@ class HttpPusher(object):
while True: while True:
starting_max_ordering = self.max_stream_ordering starting_max_ordering = self.max_stream_ordering
try: try:
yield self._unsafe_process() await self._unsafe_process()
except Exception: except Exception:
logger.exception("Exception processing notifs") logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering: if self.max_stream_ordering == starting_max_ordering:
@ -172,8 +169,7 @@ class HttpPusher(object):
finally: finally:
self._is_processing = False self._is_processing = False
@defer.inlineCallbacks async def _unsafe_process(self):
def _unsafe_process(self):
""" """
Looks for unset notifications and dispatch them, in order Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to Never call this directly: use _process which will only allow this to
@ -181,7 +177,7 @@ class HttpPusher(object):
""" """
fn = self.store.get_unread_push_actions_for_user_in_range_for_http fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn( unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -203,13 +199,13 @@ class HttpPusher(object):
"app_display_name": self.app_display_name, "app_display_name": self.app_display_name,
}, },
): ):
processed = yield self._process_one(push_action) processed = await self._process_one(push_action)
if processed: if processed:
http_push_processed_counter.inc() http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success( pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
@ -224,14 +220,14 @@ class HttpPusher(object):
if self.failing_since: if self.failing_since:
self.failing_since = None self.failing_since = None
yield self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since self.app_id, self.pushkey, self.user_id, self.failing_since
) )
else: else:
http_push_failed_counter.inc() http_push_failed_counter.inc()
if not self.failing_since: if not self.failing_since:
self.failing_since = self.clock.time_msec() self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since self.app_id, self.pushkey, self.user_id, self.failing_since
) )
@ -250,7 +246,7 @@ class HttpPusher(object):
) )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"] self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering( pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_id, self.user_id,
@ -263,7 +259,7 @@ class HttpPusher(object):
return return
self.failing_since = None self.failing_since = None
yield self.store.update_pusher_failing_since( await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since self.app_id, self.pushkey, self.user_id, self.failing_since
) )
else: else:
@ -276,18 +272,17 @@ class HttpPusher(object):
) )
break break
@defer.inlineCallbacks async def _process_one(self, push_action):
def _process_one(self, push_action):
if "notify" not in push_action["actions"]: if "notify" not in push_action["actions"]:
return True return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"]) tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id) badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action["event_id"], allow_none=True) event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None: if event is None:
return True # It's been redacted return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge) rejected = await self.dispatch_push(event, tweaks, badge)
if rejected is False: if rejected is False:
return False return False
@ -301,11 +296,10 @@ class HttpPusher(object):
) )
else: else:
logger.info("Pushkey %s was rejected: removing", pk) logger.info("Pushkey %s was rejected: removing", pk)
yield self.hs.remove_pusher(self.app_id, pk, self.user_id) await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True return True
@defer.inlineCallbacks async def _build_notification_dict(self, event, tweaks, badge):
def _build_notification_dict(self, event, tweaks, badge):
priority = "low" priority = "low"
if ( if (
event.type == EventTypes.Encrypted event.type == EventTypes.Encrypted
@ -335,7 +329,7 @@ class HttpPusher(object):
} }
return d return d
ctx = yield push_tools.get_context_for_event( ctx = await push_tools.get_context_for_event(
self.storage, self.state_handler, event, self.user_id self.storage, self.state_handler, event, self.user_id
) )
@ -377,13 +371,12 @@ class HttpPusher(object):
return d return d
@defer.inlineCallbacks async def dispatch_push(self, event, tweaks, badge):
def dispatch_push(self, event, tweaks, badge): notification_dict = await self._build_notification_dict(event, tweaks, badge)
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
if not notification_dict: if not notification_dict:
return [] return []
try: try:
resp = yield self.http_client.post_json_get_json( resp = await self.http_client.post_json_get_json(
self.url, notification_dict self.url, notification_dict
) )
except Exception as e: except Exception as e:
@ -400,8 +393,7 @@ class HttpPusher(object):
rejected = resp["rejected"] rejected = resp["rejected"]
return rejected return rejected
@defer.inlineCallbacks async def _send_badge(self, badge):
def _send_badge(self, badge):
""" """
Args: Args:
badge (int): number of unread messages badge (int): number of unread messages
@ -424,7 +416,7 @@ class HttpPusher(object):
} }
} }
try: try:
yield self.http_client.post_json_get_json(self.url, d) await self.http_client.post_json_get_json(self.url, d)
http_badges_processed_counter.inc() http_badges_processed_counter.inc()
except Exception as e: except Exception as e:
logger.warning( logger.warning(

View File

@ -16,8 +16,6 @@
import logging import logging
import re import re
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room" ALL_ALONE = "Empty Room"
@defer.inlineCallbacks async def calculate_room_name(
def calculate_room_name(
store, store,
room_state_ids, room_state_ids,
user_id, user_id,
@ -53,7 +50,7 @@ def calculate_room_name(
""" """
# does it have a name? # does it have a name?
if (EventTypes.Name, "") in room_state_ids: if (EventTypes.Name, "") in room_state_ids:
m_room_name = yield store.get_event( m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True room_state_ids[(EventTypes.Name, "")], allow_none=True
) )
if m_room_name and m_room_name.content and m_room_name.content["name"]: if m_room_name and m_room_name.content and m_room_name.content["name"]:
@ -61,7 +58,7 @@ def calculate_room_name(
# does it have a canonical alias? # does it have a canonical alias?
if (EventTypes.CanonicalAlias, "") in room_state_ids: if (EventTypes.CanonicalAlias, "") in room_state_ids:
canon_alias = yield store.get_event( canon_alias = await store.get_event(
room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
) )
if ( if (
@ -81,7 +78,7 @@ def calculate_room_name(
my_member_event = None my_member_event = None
if (EventTypes.Member, user_id) in room_state_ids: if (EventTypes.Member, user_id) in room_state_ids:
my_member_event = yield store.get_event( my_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, user_id)], allow_none=True room_state_ids[(EventTypes.Member, user_id)], allow_none=True
) )
@ -90,7 +87,7 @@ def calculate_room_name(
and my_member_event.content["membership"] == "invite" and my_member_event.content["membership"] == "invite"
): ):
if (EventTypes.Member, my_member_event.sender) in room_state_ids: if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event( inviter_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, my_member_event.sender)], room_state_ids[(EventTypes.Member, my_member_event.sender)],
allow_none=True, allow_none=True,
) )
@ -107,7 +104,7 @@ def calculate_room_name(
# we're going to have to generate a name based on who's in the room, # we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user. # so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids: if EventTypes.Member in room_state_bytype_ids:
member_events = yield store.get_events( member_events = await store.get_events(
list(room_state_bytype_ids[EventTypes.Member].values()) list(room_state_bytype_ids[EventTypes.Member].values())
) )
all_members = [ all_members = [

View File

@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage from synapse.storage import Storage
@defer.inlineCallbacks async def get_badge_count(store, user_id):
def get_badge_count(store, user_id): invites = await store.get_invited_rooms_for_local_user(user_id)
invites = yield store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read") my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
badge = len(invites) badge = len(invites)
@ -32,7 +29,7 @@ def get_badge_count(store, user_id):
if room_id in my_receipts_by_room: if room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[room_id] last_unread_event_id = my_receipts_by_room[room_id]
notifs = yield ( notifs = await (
store.get_unread_event_push_actions_by_room_for_user( store.get_unread_event_push_actions_by_room_for_user(
room_id, user_id, last_unread_event_id room_id, user_id, last_unread_event_id
) )
@ -43,23 +40,22 @@ def get_badge_count(store, user_id):
return badge return badge
@defer.inlineCallbacks async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {} ctx = {}
room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id) room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the # we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or # human-readable name instead, be that m.room.name, an alias or
# a list of people in the room # a list of people in the room
name = yield calculate_room_name( name = await calculate_room_name(
storage.main, room_state_ids, user_id, fallback_to_single_member=False storage.main, room_state_ids, user_id, fallback_to_single_member=False
) )
if name: if name:
ctx["name"] = name ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)] sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield storage.main.get_event(sender_state_event_id) sender_state_event = await storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event) ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx return ctx

View File

@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union
from prometheus_client import Gauge from prometheus_client import Gauge
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher from synapse.push.emailpusher import EmailPusher
@ -52,7 +50,7 @@ class PusherPool:
Note that it is expected that each pusher will have its own 'processing' loop which Note that it is expected that each pusher will have its own 'processing' loop which
will send out the notifications in the background, rather than blocking until the will send out the notifications in the background, rather than blocking until the
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds. Pusher.on_new_receipts are not expected to return awaitables.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@ -77,8 +75,7 @@ class PusherPool:
return return
run_as_background_process("start_pushers", self._start_pushers) run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks async def add_pusher(
def add_pusher(
self, self,
user_id, user_id,
access_token, access_token,
@ -94,7 +91,7 @@ class PusherPool:
"""Creates a new pusher and adds it to the pool """Creates a new pusher and adds it to the pool
Returns: Returns:
Deferred[EmailPusher|HttpPusher] EmailPusher|HttpPusher
""" """
time_now_msec = self.clock.time_msec() time_now_msec = self.clock.time_msec()
@ -124,9 +121,9 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum # create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process # stream ordering in event_push_actions, so it will process
# pushes from this point onwards. # pushes from this point onwards.
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering() last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher( await self.store.add_pusher(
user_id=user_id, user_id=user_id,
access_token=access_token, access_token=access_token,
kind=kind, kind=kind,
@ -140,15 +137,14 @@ class PusherPool:
last_stream_ordering=last_stream_ordering, last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag, profile_tag=profile_tag,
) )
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id) pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
return pusher return pusher
@defer.inlineCallbacks async def remove_pushers_by_app_id_and_pushkey_not_user(
def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id self, app_id, pushkey, not_user_id
): ):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove: for p in to_remove:
if p["user_name"] != not_user_id: if p["user_name"] != not_user_id:
logger.info( logger.info(
@ -157,10 +153,9 @@ class PusherPool:
pushkey, pushkey,
p["user_name"], p["user_name"],
) )
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks async def remove_pushers_by_access_token(self, user_id, access_tokens):
def remove_pushers_by_access_token(self, user_id, access_tokens):
"""Remove the pushers for a given user corresponding to a set of """Remove the pushers for a given user corresponding to a set of
access_tokens. access_tokens.
@ -173,7 +168,7 @@ class PusherPool:
return return
tokens = set(access_tokens) tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)): for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens: if p["access_token"] in tokens:
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
@ -181,16 +176,15 @@ class PusherPool:
p["pushkey"], p["pushkey"],
p["user_name"], p["user_name"],
) )
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks async def on_new_notifications(self, min_stream_id, max_stream_id):
def on_new_notifications(self, min_stream_id, max_stream_id):
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
try: try:
users_affected = yield self.store.get_push_action_users_in_range( users_affected = await self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id min_stream_id, max_stream_id
) )
@ -202,8 +196,7 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_notifications") logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
if not self.pushers: if not self.pushers:
# nothing to do here. # nothing to do here.
return return
@ -211,7 +204,7 @@ class PusherPool:
try: try:
# Need to subtract 1 from the minimum because the lower bound here # Need to subtract 1 from the minimum because the lower bound here
# is not inclusive # is not inclusive
users_affected = yield self.store.get_users_sent_receipts_between( users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id min_stream_id - 1, max_stream_id
) )
@ -223,12 +216,11 @@ class PusherPool:
except Exception: except Exception:
logger.exception("Exception in pusher on_new_receipts") logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks async def start_pusher_by_id(self, app_id, pushkey, user_id):
def start_pusher_by_id(self, app_id, pushkey, user_id):
"""Look up the details for the given pusher, and start it """Look up the details for the given pusher, and start it
Returns: Returns:
Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any EmailPusher|HttpPusher|None: The pusher started, if any
""" """
if not self._should_start_pushers: if not self._should_start_pushers:
return return
@ -236,7 +228,7 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id): if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None pusher_dict = None
for r in resultlist: for r in resultlist:
@ -245,34 +237,29 @@ class PusherPool:
pusher = None pusher = None
if pusher_dict: if pusher_dict:
pusher = yield self._start_pusher(pusher_dict) pusher = await self._start_pusher(pusher_dict)
return pusher return pusher
@defer.inlineCallbacks async def _start_pushers(self) -> None:
def _start_pushers(self):
"""Start all the pushers """Start all the pushers
Returns:
Deferred
""" """
pushers = yield self.store.get_all_pushers() pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the # Stagger starting up the pushers so we don't completely drown the
# process on start up. # process on start up.
yield concurrently_execute(self._start_pusher, pushers, 10) await concurrently_execute(self._start_pusher, pushers, 10)
logger.info("Started pushers") logger.info("Started pushers")
@defer.inlineCallbacks async def _start_pusher(self, pusherdict):
def _start_pusher(self, pusherdict):
"""Start the given pusher """Start the given pusher
Args: Args:
pusherdict (dict): dict with the values pulled from the db table pusherdict (dict): dict with the values pulled from the db table
Returns: Returns:
Deferred[EmailPusher|HttpPusher] EmailPusher|HttpPusher
""" """
if not self._pusher_shard_config.should_handle( if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"] self._instance_name, pusherdict["user_name"]
@ -315,7 +302,7 @@ class PusherPool:
user_id = pusherdict["user_name"] user_id = pusherdict["user_name"]
last_stream_ordering = pusherdict["last_stream_ordering"] last_stream_ordering = pusherdict["last_stream_ordering"]
if last_stream_ordering: if last_stream_ordering:
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user( have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering user_id, last_stream_ordering
) )
else: else:
@ -327,8 +314,7 @@ class PusherPool:
return p return p
@defer.inlineCallbacks async def remove_pusher(self, app_id, pushkey, user_id):
def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey) appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {}) byuser = self.pushers.get(user_id, {})
@ -340,6 +326,6 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id( await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id app_id, pushkey, user_id
) )

View File

@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn, _get_if_maybe_push_in_range_for_user_txn,
) )
def add_push_actions_to_staging(self, event_id, user_id_actions): async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area. """Add the push actions for the event to the push action staging area.
Args: Args:
@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
), ),
) )
return self.db.runInteraction( return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn "add_push_actions_to_staging", _add_push_actions_to_staging_txn
) )

View File

@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler() state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event)) context = self.get_success(state_handler.compute_event_context(event))
self.get_success(
self.master_store.add_push_actions_to_staging( self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions} event.event_id, {user_id: actions for user_id, actions in push_actions}
) )
)
return event, context return event, context

View File

@ -72,9 +72,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
event.depth = stream event.depth = stream
yield self.store.add_push_actions_to_staging( yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
event.event_id, {user_id: action} event.event_id, {user_id: action}
) )
)
yield self.store.db.runInteraction( yield self.store.db.runInteraction(
"", "",
self.persist_events_store._set_push_actions_for_event_and_users_txn, self.persist_events_store._set_push_actions_for_event_and_users_txn,