diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py index e0d039518..15b7898a4 100644 --- a/synapse/replication/resource.py +++ b/synapse/replication/resource.py @@ -36,6 +36,7 @@ STREAM_NAMES = ( ("receipts",), ("user_account_data", "room_account_data", "tag_account_data",), ("backfill",), + ("push_rules",), ) @@ -63,6 +64,7 @@ class ReplicationResource(Resource): * "room_account_data: Per room per user account data. * "tag_account_data": Per room per user tags. * "backfill": Old events that have been backfilled from other servers. + * "push_rules": Per user changes to push rules. The API takes two additional query parameters: @@ -117,14 +119,16 @@ class ReplicationResource(Resource): def current_replication_token(self): stream_token = yield self.sources.get_current_token() backfill_token = yield self.store.get_current_backfill_token() + push_rules_token, room_stream_token = self.store.get_push_rules_stream_token() defer.returnValue(_ReplicationToken( - stream_token.room_stream_id, + room_stream_token, int(stream_token.presence_key), int(stream_token.typing_key), int(stream_token.receipt_key), int(stream_token.account_data_key), backfill_token, + push_rules_token, )) @request_handler @@ -146,6 +150,7 @@ class ReplicationResource(Resource): yield self.presence(writer, current_token) # TODO: implement limit yield self.typing(writer, current_token) # TODO: implement limit yield self.receipts(writer, current_token, limit) + yield self.push_rules(writer, current_token, limit) self.streams(writer, current_token) logger.info("Replicated %d rows", writer.total) @@ -277,6 +282,21 @@ class ReplicationResource(Resource): "position", "user_id", "room_id", "tags" )) + @defer.inlineCallbacks + def push_rules(self, writer, current_token, limit): + current_position = current_token.push_rules + + push_rules = parse_integer(writer.request, "push_rules") + + if push_rules is not None: + rows = yield self.store.get_all_push_rule_updates( + push_rules, current_position, limit + ) + writer.write_header_and_rows("push_rules", rows, ( + "position", "stream_ordering", "user_id", "rule_id", "op", + "priority_class", "priority", "conditions", "actions" + )) + class _Writer(object): """Writes the streams as a JSON object as the response to the request""" @@ -307,12 +327,16 @@ class _Writer(object): class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( "events", "presence", "typing", "receipts", "account_data", "backfill", + "push_rules" ))): __slots__ = [] def __new__(cls, *args): if len(args) == 1: - return cls(*(int(value) for value in args[0].split("_"))) + streams = [int(value) for value in args[0].split("_")] + if len(streams) < len(cls._fields): + streams.extend([0] * (len(cls._fields) - len(streams))) + return cls(*streams) else: return super(_ReplicationToken, cls).__new__(cls, *args) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index f3ebd4949..e03402410 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -412,6 +412,12 @@ class PushRuleStore(SQLBaseStore): "get_all_push_rule_updates", get_all_push_rule_updates_txn ) + def get_push_rules_stream_token(self): + """Get the position of the push rules stream. + Returns a pair of a stream id for the push_rules stream and the + room stream ordering it corresponds to.""" + return self._push_rules_stream_id_gen.get_max_token() + class RuleNotFoundException(Exception): pass diff --git a/tests/replication/test_resource.py b/tests/replication/test_resource.py index 38daaf87e..a30d59a86 100644 --- a/tests/replication/test_resource.py +++ b/tests/replication/test_resource.py @@ -35,7 +35,8 @@ class ReplicationResourceCase(unittest.TestCase): "send_message", ]), ) - self.user = UserID.from_string("@seeing:red") + self.user_id = "@seeing:red" + self.user = UserID.from_string(self.user_id) self.hs.get_ratelimiter().send_message.return_value = (True, 0) @@ -101,7 +102,7 @@ class ReplicationResourceCase(unittest.TestCase): event_id = yield self.send_text_message(room_id, "Hello, World") get = self.get(receipts="-1") yield self.hs.get_handlers().receipts_handler.received_client_receipt( - room_id, "m.read", self.user.to_string(), event_id + room_id, "m.read", self.user_id, event_id ) code, body = yield get self.assertEquals(code, 200) @@ -129,6 +130,7 @@ class ReplicationResourceCase(unittest.TestCase): test_timeout_room_account_data = _test_timeout("room_account_data") test_timeout_tag_account_data = _test_timeout("tag_account_data") test_timeout_backfill = _test_timeout("backfill") + test_timeout_push_rules = _test_timeout("push_rules") @defer.inlineCallbacks def send_text_message(self, room_id, message):