Merge pull request #618 from matrix-org/markjh/pushrule_stream

Add a stream for push rule updates
This commit is contained in:
Mark Haines 2016-03-04 16:35:08 +00:00
commit b7a3be693b
15 changed files with 484 additions and 194 deletions

View File

@ -662,8 +662,8 @@ class MessageHandler(BaseHandler):
user_id, messages, is_peeking=is_peeking user_id, messages, is_peeking=is_peeking
) )
start_token = StreamToken(token[0], 0, 0, 0, 0) start_token = StreamToken.START.copy_and_replace("room_key", token[0])
end_token = StreamToken(token[1], 0, 0, 0, 0) end_token = StreamToken.START.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer from twisted.internet import defer
@ -224,6 +225,10 @@ class SyncHandler(BaseHandler):
) )
) )
account_data['m.push_rules'] = yield self.push_rules_for_user(
sync_config.user
)
tags_by_room = yield self.store.get_tags_for_user( tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -328,6 +333,14 @@ class SyncHandler(BaseHandler):
defer.returnValue(room_sync) defer.returnValue(room_sync)
@defer.inlineCallbacks
def push_rules_for_user(self, user):
user_id = user.to_string()
rawrules = yield self.store.get_push_rules_for_user(user_id)
enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
rules = format_push_rules_for_user(user, rawrules, enabled_map)
defer.returnValue(rules)
def account_data_for_user(self, account_data): def account_data_for_user(self, account_data):
account_data_events = [] account_data_events = []
@ -487,6 +500,15 @@ class SyncHandler(BaseHandler):
) )
) )
push_rules_changed = yield self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key)
)
if push_rules_changed:
account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
)
# Get a list of membership change events that have happened. # Get a list of membership change events that have happened.
rooms_changed = yield self.store.get_membership_changes_for_user( rooms_changed = yield self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key user_id, since_token.room_key, now_token.room_key

View File

@ -284,7 +284,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user_id, timeout, callback, room_ids=None, def wait_for_events(self, user_id, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken.START):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """

View File

@ -0,0 +1,112 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.baserules import list_with_base_rules
from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map):
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
return rules
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
def _priority_class_to_template_name(pc):
return PRIORITY_CLASS_INVERSE_MAP[pc]

View File

@ -36,6 +36,7 @@ STREAM_NAMES = (
("receipts",), ("receipts",),
("user_account_data", "room_account_data", "tag_account_data",), ("user_account_data", "room_account_data", "tag_account_data",),
("backfill",), ("backfill",),
("push_rules",),
) )
@ -63,6 +64,7 @@ class ReplicationResource(Resource):
* "room_account_data: Per room per user account data. * "room_account_data: Per room per user account data.
* "tag_account_data": Per room per user tags. * "tag_account_data": Per room per user tags.
* "backfill": Old events that have been backfilled from other servers. * "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
The API takes two additional query parameters: The API takes two additional query parameters:
@ -117,14 +119,16 @@ class ReplicationResource(Resource):
def current_replication_token(self): def current_replication_token(self):
stream_token = yield self.sources.get_current_token() stream_token = yield self.sources.get_current_token()
backfill_token = yield self.store.get_current_backfill_token() backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
stream_token.room_stream_id, room_stream_token,
int(stream_token.presence_key), int(stream_token.presence_key),
int(stream_token.typing_key), int(stream_token.typing_key),
int(stream_token.receipt_key), int(stream_token.receipt_key),
int(stream_token.account_data_key), int(stream_token.account_data_key),
backfill_token, backfill_token,
push_rules_token,
)) ))
@request_handler @request_handler
@ -146,6 +150,7 @@ class ReplicationResource(Resource):
yield self.presence(writer, current_token) # TODO: implement limit yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit) yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
self.streams(writer, current_token) self.streams(writer, current_token)
logger.info("Replicated %d rows", writer.total) logger.info("Replicated %d rows", writer.total)
@ -277,6 +282,21 @@ class ReplicationResource(Resource):
"position", "user_id", "room_id", "tags" "position", "user_id", "room_id", "tags"
)) ))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
push_rules, current_position, limit
)
writer.write_header_and_rows("push_rules", rows, (
"position", "event_stream_ordering", "user_id", "rule_id", "op",
"priority_class", "priority", "conditions", "actions"
))
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -307,12 +327,16 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules"
))): ))):
__slots__ = [] __slots__ = []
def __new__(cls, *args): def __new__(cls, *args):
if len(args) == 1: if len(args) == 1:
return cls(*(int(value) for value in args[0].split("_"))) streams = [int(value) for value in args[0].split("_")]
if len(streams) < len(cls._fields):
streams.extend([0] * (len(cls._fields) - len(streams)))
return cls(*streams)
else: else:
return super(_ReplicationToken, cls).__new__(cls, *args) return super(_ReplicationToken, cls).__new__(cls, *args)

View File

@ -22,12 +22,10 @@ from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException InconsistentRuleException, RuleNotFoundException
) )
from synapse.push.baserules import list_with_base_rules, BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import ( from synapse.push.baserules import BASE_RULE_IDS
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP
)
import copy
import simplejson as json import simplejson as json
@ -36,6 +34,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request): def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
@ -51,8 +54,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
content = _parse_json(request) content = _parse_json(request)
user_id = requester.user.to_string()
if 'attr' in spec: if 'attr' in spec:
yield self.set_rule_attr(requester.user.to_string(), spec, content) yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
if spec['rule_id'].startswith('.'): if spec['rule_id'].startswith('.'):
@ -77,8 +83,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
after = _namespaced_rule_id(spec, after[0]) after = _namespaced_rule_id(spec, after[0])
try: try:
yield self.hs.get_datastore().add_push_rule( yield self.store.add_push_rule(
user_id=requester.user.to_string(), user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec), rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class, priority_class=priority_class,
conditions=conditions, conditions=conditions,
@ -86,6 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
before=before, before=before,
after=after after=after
) )
self.notify_user(user_id)
except InconsistentRuleException as e: except InconsistentRuleException as e:
raise SynapseError(400, e.message) raise SynapseError(400, e.message)
except RuleNotFoundException as e: except RuleNotFoundException as e:
@ -98,13 +105,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
spec = _rule_spec_from_path(request.postpath) spec = _rule_spec_from_path(request.postpath)
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try: try:
yield self.hs.get_datastore().delete_push_rule( yield self.store.delete_push_rule(
requester.user.to_string(), namespaced_rule_id user_id, namespaced_rule_id
) )
self.notify_user(user_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
@ -115,58 +124,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = requester.user user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it # we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is # to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference # is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user( rawrules = yield self.store.get_push_rules_for_user(user_id)
user.to_string()
)
ruleslist = [] enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
# Remove internal stuff.
for c in r["conditions"]:
c.pop("_id", None)
pattern_type = c.pop("pattern_type", None)
if pattern_type == "user_id":
c["pattern"] = user.to_string()
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled']
else:
template_rule['enabled'] = True
rulearray.append(template_rule)
path = request.postpath[1:] path = request.postpath[1:]
@ -188,6 +155,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event(
"push_rules_key", stream_id, users=[user_id]
)
def set_rule_attr(self, user_id, spec, val): def set_rule_attr(self, user_id, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val: if isinstance(val, dict) and "enabled" in val:
@ -198,7 +171,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
# bools directly, so let's not break them. # bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return self.hs.get_datastore().set_push_rule_enabled( return self.store.set_push_rule_enabled(
user_id, namespaced_rule_id, val user_id, namespaced_rule_id, val
) )
elif spec['attr'] == 'actions': elif spec['attr'] == 'actions':
@ -210,7 +183,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
if is_default_rule: if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS: if namespaced_rule_id not in BASE_RULE_IDS:
raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
return self.hs.get_datastore().set_push_rule_actions( return self.store.set_push_rule_actions(
user_id, namespaced_rule_id, actions, is_default_rule user_id, namespaced_rule_id, actions, is_default_rule
) )
else: else:
@ -308,12 +281,6 @@ def _check_actions(actions):
raise InvalidRuleException("Unrecognised action") raise InvalidRuleException("Unrecognised action")
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _filter_ruleset_with_path(ruleset, path): def _filter_ruleset_with_path(ruleset, path):
if path == []: if path == []:
raise UnrecognizedRequestError( raise UnrecognizedRequestError(
@ -362,37 +329,6 @@ def _priority_class_from_spec(spec):
return pc return pc
def _priority_class_to_template_name(pc):
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _namespaced_rule_id_from_spec(spec): def _namespaced_rule_id_from_spec(spec):
return _namespaced_rule_id(spec, spec['rule_id']) return _namespaced_rule_id(spec, spec['rule_id'])
@ -401,10 +337,6 @@ def _namespaced_rule_id(spec, rule_id):
return "global/%s/%s" % (spec['template'], rule_id) return "global/%s/%s" % (spec['template'], rule_id)
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
class InvalidRuleException(Exception): class InvalidRuleException(Exception):
pass pass

View File

@ -45,7 +45,7 @@ from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from util.id_generators import IdGenerator, StreamIdGenerator from util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -122,6 +122,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id") self._pushers_id_gen = IdGenerator(db_conn, "pushers", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
events_max = self._stream_id_gen.get_max_token() events_max = self._stream_id_gen.get_max_token()
event_cache_prefill, min_event_val = self._get_cache_dict( event_cache_prefill, min_event_val = self._get_cache_dict(
@ -157,6 +160,18 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=presence_cache_prefill prefilled_cache=presence_cache_prefill
) )
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
prefilled_cache=push_rules_prefill,
)
super(DataStore, self).__init__(hs) super(DataStore, self).__init__(hs)
def take_presence_startup_info(self): def take_presence_startup_info(self):

View File

@ -766,6 +766,19 @@ class SQLBaseStore(object):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(
desc, self._simple_delete_one_txn, table, keyvalues
)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
Args: Args:
table : string giving the table name table : string giving the table name
keyvalues : dict of column names and values to select the row with keyvalues : dict of column names and values to select the row with
@ -775,13 +788,11 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues) " AND ".join("%s = ?" % (k, ) for k in keyvalues)
) )
def func(txn): txn.execute(sql, keyvalues.values())
txn.execute(sql, keyvalues.values()) if txn.rowcount == 0:
if txn.rowcount == 0: raise StoreError(404, "No row found")
raise StoreError(404, "No row found") if txn.rowcount > 1:
if txn.rowcount > 1: raise StoreError(500, "more than one row matched")
raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func)
@staticmethod @staticmethod
def _simple_delete_txn(txn, table, keyvalues): def _simple_delete_txn(txn, table, keyvalues):

View File

@ -99,30 +99,32 @@ class PushRuleStore(SQLBaseStore):
results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled'] results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results) defer.returnValue(results)
@defer.inlineCallbacks
def add_push_rule( def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions, self, user_id, rule_id, priority_class, conditions, actions,
before=None, after=None before=None, after=None
): ):
conditions_json = json.dumps(conditions) conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions) actions_json = json.dumps(actions)
with self._push_rules_stream_id_gen.get_next() as ids:
if before or after: stream_id, event_stream_ordering = ids
return self.runInteraction( if before or after:
"_add_push_rule_relative_txn", yield self.runInteraction(
self._add_push_rule_relative_txn, "_add_push_rule_relative_txn",
user_id, rule_id, priority_class, self._add_push_rule_relative_txn,
conditions_json, actions_json, before, after, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
) conditions_json, actions_json, before, after,
else: )
return self.runInteraction( else:
"_add_push_rule_highest_priority_txn", yield self.runInteraction(
self._add_push_rule_highest_priority_txn, "_add_push_rule_highest_priority_txn",
user_id, rule_id, priority_class, self._add_push_rule_highest_priority_txn,
conditions_json, actions_json, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
) conditions_json, actions_json,
)
def _add_push_rule_relative_txn( def _add_push_rule_relative_txn(
self, txn, user_id, rule_id, priority_class, self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after conditions_json, actions_json, before, after
): ):
# Lock the table since otherwise we'll have annoying races between the # Lock the table since otherwise we'll have annoying races between the
@ -174,12 +176,12 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority)) txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, user_id, rule_id, priority_class, new_rule_priority, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, new_rule_priority, conditions_json, actions_json,
) )
def _add_push_rule_highest_priority_txn( def _add_push_rule_highest_priority_txn(
self, txn, user_id, rule_id, priority_class, self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json conditions_json, actions_json
): ):
# Lock the table since otherwise we'll have annoying races between the # Lock the table since otherwise we'll have annoying races between the
@ -201,13 +203,13 @@ class PushRuleStore(SQLBaseStore):
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, txn,
user_id, rule_id, priority_class, new_prio, stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
conditions_json, actions_json, conditions_json, actions_json,
) )
def _upsert_push_rule_txn( def _upsert_push_rule_txn(
self, txn, user_id, rule_id, priority_class, self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
priority, conditions_json, actions_json priority, conditions_json, actions_json, update_stream=True
): ):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id """Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes using the _push_rule_id_gen if it needs to insert the rule. It assumes
@ -242,12 +244,17 @@ class PushRuleStore(SQLBaseStore):
}, },
) )
txn.call_after( if update_stream:
self.get_push_rules_for_user.invalidate, (user_id,) self._insert_push_rules_update_txn(
) txn, stream_id, event_stream_ordering, user_id, rule_id,
txn.call_after( op="ADD",
self.get_push_rules_enabled_for_user.invalidate, (user_id,) data={
) "priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_push_rule(self, user_id, rule_id): def delete_push_rule(self, user_id, rule_id):
@ -260,25 +267,37 @@ class PushRuleStore(SQLBaseStore):
user_id (str): The matrix ID of the push rule owner user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted rule_id (str): The rule_id of the rule to be deleted
""" """
yield self._simple_delete_one( def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
"push_rules", self._simple_delete_one_txn(
{'user_name': user_id, 'rule_id': rule_id}, txn,
desc="delete_push_rule", "push_rules",
) {'user_name': user_id, 'rule_id': rule_id},
)
self.get_push_rules_for_user.invalidate((user_id,)) self._insert_push_rules_update_txn(
self.get_push_rules_enabled_for_user.invalidate((user_id,)) txn, stream_id, event_stream_ordering, user_id, rule_id,
op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_id, rule_id, enabled): def set_push_rule_enabled(self, user_id, rule_id, enabled):
ret = yield self.runInteraction( with self._push_rules_stream_id_gen.get_next() as ids:
"_set_push_rule_enabled_txn", stream_id, event_stream_ordering = ids
self._set_push_rule_enabled_txn, yield self.runInteraction(
user_id, rule_id, enabled "_set_push_rule_enabled_txn",
) self._set_push_rule_enabled_txn,
defer.returnValue(ret) stream_id, event_stream_ordering, user_id, rule_id, enabled
)
def _set_push_rule_enabled_txn(self, txn, user_id, rule_id, enabled): def _set_push_rule_enabled_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next() new_id = self._push_rules_enable_id_gen.get_next()
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
@ -287,25 +306,26 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0}, {'enabled': 1 if enabled else 0},
{'id': new_id}, {'id': new_id},
) )
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,) self._insert_push_rules_update_txn(
) txn, stream_id, event_stream_ordering, user_id, rule_id,
txn.call_after( op="ENABLE" if enabled else "DISABLE"
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
) )
@defer.inlineCallbacks
def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule): def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
actions_json = json.dumps(actions) actions_json = json.dumps(actions)
def set_push_rule_actions_txn(txn): def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
if is_default_rule: if is_default_rule:
# Add a dummy rule to the rules table with the user specified # Add a dummy rule to the rules table with the user specified
# actions. # actions.
priority_class = -1 priority_class = -1
priority = 1 priority = 1
self._upsert_push_rule_txn( self._upsert_push_rule_txn(
txn, user_id, rule_id, priority_class, priority, txn, stream_id, event_stream_ordering, user_id, rule_id,
"[]", actions_json priority_class, priority, "[]", actions_json,
update_stream=False
) )
else: else:
self._simple_update_one_txn( self._simple_update_one_txn(
@ -315,9 +335,80 @@ class PushRuleStore(SQLBaseStore):
{'actions': actions_json}, {'actions': actions_json},
) )
return self.runInteraction( self._insert_push_rules_update_txn(
"set_push_rule_actions", set_push_rule_actions_txn, txn, stream_id, event_stream_ordering, user_id, rule_id,
op="ACTIONS", data={"actions": actions_json}
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"set_push_rule_actions", set_push_rule_actions_txn,
stream_id, event_stream_ordering
)
def _insert_push_rules_update_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
):
values = {
"stream_id": stream_id,
"event_stream_ordering": event_stream_ordering,
"user_id": user_id,
"rule_id": rule_id,
"op": op,
}
if data is not None:
values.update(data)
self._simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
) )
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
def get_all_push_rule_updates(self, last_id, current_id, limit):
"""Get all the push rules changes that have happend on the server"""
def get_all_push_rule_updates_txn(txn):
sql = (
"SELECT stream_id, event_stream_ordering, user_id, rule_id,"
" op, priority_class, priority, conditions, actions"
" FROM push_rules_stream"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token()
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
" WHERE user_id = ? AND ? < stream_id"
)
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
class RuleNotFoundException(Exception): class RuleNotFoundException(Exception):

View File

@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE push_rules_stream(
stream_id BIGINT NOT NULL,
event_stream_ordering BIGINT NOT NULL,
user_id TEXT NOT NULL,
rule_id TEXT NOT NULL,
op TEXT NOT NULL, -- One of "ENABLE", "DISABLE", "ACTIONS", "ADD", "DELETE"
priority_class SMALLINT,
priority INTEGER,
conditions TEXT,
actions TEXT
);
-- The extra data for each operation is:
-- * ENABLE, DISABLE, DELETE: []
-- * ACTIONS: ["actions"]
-- * ADD: ["priority_class", "priority", "actions", "conditions"]
-- Index for replication queries.
CREATE INDEX push_rules_stream_id ON push_rules_stream(stream_id);
-- Index for /sync queries.
CREATE INDEX push_rules_stream_user_stream_id on push_rules_stream(user_id, stream_id);

View File

@ -20,23 +20,21 @@ import threading
class IdGenerator(object): class IdGenerator(object):
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
cur = db_conn.cursor() self._next_id = _load_max_id(db_conn, table, column)
self._next_id = self._load_next_id(cur)
cur.close()
def _load_next_id(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table,))
val, = txn.fetchone()
return val + 1 if val else 1
def get_next(self): def get_next(self):
with self._lock: with self._lock:
i = self._next_id
self._next_id += 1 self._next_id += 1
return i return self._next_id
def _load_max_id(db_conn, table, column):
cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
return val if val else 1
class StreamIdGenerator(object): class StreamIdGenerator(object):
@ -52,23 +50,10 @@ class StreamIdGenerator(object):
# ... persist event ... # ... persist event ...
""" """
def __init__(self, db_conn, table, column): def __init__(self, db_conn, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
cur = db_conn.cursor()
self._current_max = self._load_current_max(cur)
cur.close()
self._unfinished_ids = deque() self._unfinished_ids = deque()
def _load_current_max(self, txn):
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall()
val, = rows[0]
return int(val) if val else 1
def get_next(self): def get_next(self):
""" """
Usage: Usage:
@ -124,3 +109,50 @@ class StreamIdGenerator(object):
return self._unfinished_ids[0] - 1 return self._unfinished_ids[0] - 1
return self._current_max return self._current_max
class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync
with another stream. It generates pairs of IDs, the first element is an
integer ID for this stream, the second element is the ID for the stream
that this stream needs to be kept in sync with."""
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
"""
Usage:
with stream_id_gen.get_next() as (stream_id, chained_id):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_max_token()
self._unfinished_ids.append((next_id, chained_id))
@contextlib.contextmanager
def manager():
try:
yield (next_id, chained_id)
finally:
with self._lock:
self._unfinished_ids.remove((next_id, chained_id))
return manager()
def get_max_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token())

View File

@ -38,9 +38,12 @@ class EventSources(object):
name: cls(hs) name: cls(hs)
for name, cls in EventSources.SOURCE_TYPES.items() for name, cls in EventSources.SOURCE_TYPES.items()
} }
self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_token(self, direction='f'): def get_current_token(self, direction='f'):
push_rules_key, _ = self.store.get_push_rules_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
yield self.sources["room"].get_current_key(direction) yield self.sources["room"].get_current_key(direction)
@ -57,5 +60,6 @@ class EventSources(object):
account_data_key=( account_data_key=(
yield self.sources["account_data"].get_current_key() yield self.sources["account_data"].get_current_key()
), ),
push_rules_key=push_rules_key,
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -115,6 +115,7 @@ class StreamToken(
"typing_key", "typing_key",
"receipt_key", "receipt_key",
"account_data_key", "account_data_key",
"push_rules_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -150,6 +151,7 @@ class StreamToken(
or (int(other.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):
@ -174,6 +176,11 @@ class StreamToken(
return StreamToken(**d) return StreamToken(**d)
StreamToken.START = StreamToken(
*(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
)
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")): class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"""Tokens are positions between events. The token "s1" comes after event 1. """Tokens are positions between events. The token "s1" comes after event 1.

View File

@ -35,7 +35,8 @@ class ReplicationResourceCase(unittest.TestCase):
"send_message", "send_message",
]), ]),
) )
self.user = UserID.from_string("@seeing:red") self.user_id = "@seeing:red"
self.user = UserID.from_string(self.user_id)
self.hs.get_ratelimiter().send_message.return_value = (True, 0) self.hs.get_ratelimiter().send_message.return_value = (True, 0)
@ -101,7 +102,7 @@ class ReplicationResourceCase(unittest.TestCase):
event_id = yield self.send_text_message(room_id, "Hello, World") event_id = yield self.send_text_message(room_id, "Hello, World")
get = self.get(receipts="-1") get = self.get(receipts="-1")
yield self.hs.get_handlers().receipts_handler.received_client_receipt( yield self.hs.get_handlers().receipts_handler.received_client_receipt(
room_id, "m.read", self.user.to_string(), event_id room_id, "m.read", self.user_id, event_id
) )
code, body = yield get code, body = yield get
self.assertEquals(code, 200) self.assertEquals(code, 200)
@ -129,6 +130,7 @@ class ReplicationResourceCase(unittest.TestCase):
test_timeout_room_account_data = _test_timeout("room_account_data") test_timeout_room_account_data = _test_timeout("room_account_data")
test_timeout_tag_account_data = _test_timeout("tag_account_data") test_timeout_tag_account_data = _test_timeout("tag_account_data")
test_timeout_backfill = _test_timeout("backfill") test_timeout_backfill = _test_timeout("backfill")
test_timeout_push_rules = _test_timeout("push_rules")
@defer.inlineCallbacks @defer.inlineCallbacks
def send_text_message(self, room_id, message): def send_text_message(self, room_id, message):

View File

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0" token = "t1-0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0" token = "s0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))