Load push rules in storage layer, so that they get cached

This commit is contained in:
Erik Johnston 2016-06-01 14:27:07 +01:00
parent 59f2d73522
commit 6a0afa582a
5 changed files with 63 additions and 44 deletions

View File

@ -198,9 +198,8 @@ class SyncHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def push_rules_for_user(self, user): def push_rules_for_user(self, user):
user_id = user.to_string() user_id = user.to_string()
rawrules = yield self.store.get_push_rules_for_user(user_id) rules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) rules = format_push_rules_for_user(user, rules)
rules = format_push_rules_for_user(user, rawrules, enabled_map)
defer.returnValue(rules) defer.returnValue(rules)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -18,7 +18,6 @@ import ujson as json
from twisted.internet import defer from twisted.internet import defer
from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -38,36 +37,9 @@ def decode_rule_json(rule):
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_rules(room_id, user_ids, store): def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None} rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
rules_by_user = {
uid: list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
# fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too.
for uid in user_ids:
user_enabled_map = rules_enabled_by_user.get(uid)
if not user_enabled_map:
continue
for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id']
if rule_id in user_enabled_map:
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(user_enabled_map[rule_id])
rules_by_user[uid][i] = rule
defer.returnValue(rules_by_user) defer.returnValue(rules_by_user)

View File

@ -23,10 +23,7 @@ import copy
import simplejson as json import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map): def load_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
ruleslist = [] ruleslist = []
for rawrule in rawrules: for rawrule in rawrules:
rule = dict(rawrule) rule = dict(rawrule)
@ -35,7 +32,26 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
ruleslist.append(rule) ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule['rule_id']
if rule_id in enabled_map:
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
def format_push_rules_for_user(user, ruleslist):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}
@ -60,9 +76,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
if r['rule_id'] in enabled_map: if 'enabled' in r:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled'] template_rule['enabled'] = r['enabled']
else: else:
template_rule['enabled'] = True template_rule['enabled'] = True

View File

@ -128,11 +128,9 @@ class PushRuleRestServlet(ClientV1RestServlet):
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.store.get_push_rules_for_user(user_id) rules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id) rules = format_push_rules_for_user(requester.user, rules)
rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
path = request.postpath[1:] path = request.postpath[1:]

View File

@ -15,6 +15,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.push.baserules import list_with_base_rules
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,6 +24,29 @@ import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
rule_id = rule['rule_id']
if rule_id in enabled_map:
if rule.get('enabled', True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks(lru=True)
def get_push_rules_for_user(self, user_id): def get_push_rules_for_user(self, user_id):
@ -42,7 +66,11 @@ class PushRuleStore(SQLBaseStore):
key=lambda row: (-int(row["priority_class"]), -int(row["priority"])) key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
) )
defer.returnValue(rows) enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
rules = _load_rules(rows, enabled_map)
defer.returnValue(rules)
@cachedInlineCallbacks(lru=True) @cachedInlineCallbacks(lru=True)
def get_push_rules_enabled_for_user(self, user_id): def get_push_rules_enabled_for_user(self, user_id):
@ -85,6 +113,14 @@ class PushRuleStore(SQLBaseStore):
for row in rows: for row in rows:
results.setdefault(row['user_name'], []).append(row) results.setdefault(row['user_name'], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {})
)
defer.returnValue(results) defer.returnValue(results)
@cachedList(cached_method_name="get_push_rules_enabled_for_user", @cachedList(cached_method_name="get_push_rules_enabled_for_user",