Merge pull request #146 from matrix-org/erikj/push_rules_fixes

Fix 500 on push rule updates.
This commit is contained in:
Erik Johnston 2015-05-11 11:33:47 +01:00
commit 79b7154454
2 changed files with 53 additions and 56 deletions

View File

@ -308,6 +308,7 @@ class SQLBaseStore(object):
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()

View File

@ -19,7 +19,6 @@ from ._base import SQLBaseStore, Table
from twisted.internet import defer from twisted.internet import defer
import logging import logging
import copy
import simplejson as json import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,46 +27,45 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
sql = ( rows = yield self._simple_select_list(
"SELECT "+",".join(PushRuleTable.fields)+" " table=PushRuleTable.table_name,
"FROM "+PushRuleTable.table_name+" " keyvalues={
"WHERE user_name = ? " "user_name": user_name,
"ORDER BY priority_class DESC, priority DESC" },
retcols=PushRuleTable.fields,
) )
rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
dicts = [] rows.sort(
for r in rows: key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
d = {} )
for i, f in enumerate(PushRuleTable.fields):
d[f] = r[i]
dicts.append(d)
defer.returnValue(dicts) defer.returnValue(rows)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
PushRuleEnableTable.table_name, table=PushRuleEnableTable.table_name,
{'user_name': user_name}, keyvalues={
PushRuleEnableTable.fields, 'user_name': user_name
},
retcols=PushRuleEnableTable.fields,
desc="get_push_rules_enabled_for_user", desc="get_push_rules_enabled_for_user",
) )
defer.returnValue( defer.returnValue({
{r['rule_id']: False if r['enabled'] == 0 else True for r in results} r['rule_id']: False if r['enabled'] == 0 else True for r in results
) })
@defer.inlineCallbacks @defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs): def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs) vals = kwargs
if 'conditions' in vals: if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions']) vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals: if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions']) vals['actions'] = json.dumps(vals['actions'])
# we could check the rest of the keys are valid column names # we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless. # but sqlite will do that anyway so I think it's just pointless.
if 'id' in vals: vals.pop("id", None)
del vals['id']
if before or after: if before or after:
ret = yield self.runInteraction( ret = yield self.runInteraction(
@ -87,39 +85,39 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(ret) defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
after = None after = kwargs.pop("after", None)
relative_to_rule = None relative_to_rule = kwargs.pop("before", after)
if 'after' in kwargs and kwargs['after']:
after = kwargs['after']
relative_to_rule = after
if 'before' in kwargs and kwargs['before']:
relative_to_rule = kwargs['before']
# get the priority of the rule we're inserting after/before res = self._simple_select_one_txn(
sql = ( txn,
"SELECT priority_class, priority FROM ? " table=PushRuleTable.table_name,
"WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,) keyvalues={
"user_name": user_name,
"rule_id": relative_to_rule,
},
retcols=["priority_class", "priority"],
allow_none=True,
) )
txn.execute(sql, (user_name, relative_to_rule))
res = txn.fetchall()
if not res: if not res:
raise RuleNotFoundException( raise RuleNotFoundException(
"before/after rule not found: %s" % (relative_to_rule,) "before/after rule not found: %s" % (relative_to_rule,)
) )
priority_class, base_rule_priority = res[0]
priority_class = res["priority_class"]
base_rule_priority = res["priority"]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException( raise InconsistentRuleException(
"Given priority class does not match class of relative rule" "Given priority class does not match class of relative rule"
) )
new_rule = copy.copy(kwargs) new_rule = kwargs
if 'before' in new_rule: new_rule.pop("before", None)
del new_rule['before'] new_rule.pop("after", None)
if 'after' in new_rule:
del new_rule['after']
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name new_rule['user_name'] = user_name
new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
# check if the priority before/after is free # check if the priority before/after is free
new_rule_priority = base_rule_priority new_rule_priority = base_rule_priority
@ -153,12 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule self._simple_insert_txn(
sql = "INSERT INTO "+PushRuleTable.table_name+" (" txn,
sql += ",".join(new_rule.keys())+") VALUES (" table=PushRuleTable.table_name,
sql += ", ".join(["?" for _ in new_rule.keys()])+")" values=new_rule,
)
txn.execute(sql, new_rule.values())
def _add_push_rule_highest_priority_txn(self, txn, user_name, def _add_push_rule_highest_priority_txn(self, txn, user_name,
priority_class, **kwargs): priority_class, **kwargs):
@ -176,18 +173,17 @@ class PushRuleStore(SQLBaseStore):
new_prio = highest_prio + 1 new_prio = highest_prio + 1
# and insert the new rule # and insert the new rule
new_rule = copy.copy(kwargs) new_rule = kwargs
if 'id' in new_rule: new_rule['id'] = self._push_rule_id_gen.get_next_txn(txn)
del new_rule['id']
new_rule['user_name'] = user_name new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
sql = "INSERT INTO "+PushRuleTable.table_name+" (" self._simple_insert_txn(
sql += ",".join(new_rule.keys())+") VALUES (" txn,
sql += ", ".join(["?" for _ in new_rule.keys()])+")" table=PushRuleTable.table_name,
values=new_rule,
txn.execute(sql, new_rule.values()) )
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id): def delete_push_rule(self, user_name, rule_id):