diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py index 83e880fdd..bb2c40b6e 100644 --- a/synapse/replication/slave/storage/push_rule.py +++ b/synapse/replication/slave/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,29 +16,15 @@ from .events import SlavedEventStore from ._slaved_id_tracker import SlavedIdTracker -from synapse.storage import DataStore -from synapse.storage.push_rule import PushRuleStore -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.storage.push_rule import PushRulesWorkerStore -class SlavedPushRuleStore(SlavedEventStore): +class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore): def __init__(self, db_conn, hs): - super(SlavedPushRuleStore, self).__init__(db_conn, hs) self._push_rules_stream_id_gen = SlavedIdTracker( db_conn, "push_rules_stream", "stream_id", ) - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", - self._push_rules_stream_id_gen.get_current_token(), - ) - - get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"] - get_push_rules_enabled_for_user = ( - PushRuleStore.__dict__["get_push_rules_enabled_for_user"] - ) - have_push_rules_changed_for_user = ( - DataStore.have_push_rules_changed_for_user.__func__ - ) + super(SlavedPushRuleStore, self).__init__(db_conn, hs) def get_push_rules_stream_token(self): return ( @@ -45,6 +32,9 @@ class SlavedPushRuleStore(SlavedEventStore): self._stream_id_gen.get_current_token(), ) + def get_max_push_rules_stream_id(self): + return self._push_rules_stream_id_gen.get_current_token() + def stream_positions(self): result = super(SlavedPushRuleStore, self).stream_positions() result["push_rules"] = self._push_rules_stream_id_gen.get_current_token() diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e221284ee..0f136f8a0 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -169,18 +170,6 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_current_token()[0], - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, - prefilled_cache=push_rules_prefill, - ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( db_conn, "device_inbox", diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 8758b1c0c..583efb7bd 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,10 +16,12 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.push.baserules import list_with_base_rules from synapse.api.constants import EventTypes from twisted.internet import defer +import abc import logging import simplejson as json @@ -48,7 +51,39 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): +class PushRulesWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( @@ -89,6 +124,24 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + + +class PushRuleStore(PushRulesWorkerStore): @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -526,21 +579,8 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception):