Inform the client of new room tags using v1 /events

This commit is contained in:
Mark Haines 2015-10-29 15:20:52 +00:00
parent a89b86dc47
commit f40b0ed5e1
6 changed files with 91 additions and 14 deletions

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class PrivateUserDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
user_id = user.to_string()
last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = []
for room_id, room_tags in tags.items():
results.append({
"type": "m.tag",
"content": {"tags": room_tags},
"room_id": room_id,
})
defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
defer.returnValue(([], config.to_id))

View File

@ -270,7 +270,7 @@ class Notifier(object):
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0", "0")):
from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""

View File

@ -62,6 +62,7 @@ class TagServlet(RestServlet):
super(TagServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
@ -69,9 +70,12 @@ class TagServlet(RestServlet):
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
yield self.store.add_tag_to_room(user_id, room_id, tag)
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
# TODO: poke the notifier.
defer.returnValue((200, {}))
@defer.inlineCallbacks
@ -80,7 +84,11 @@ class TagServlet(RestServlet):
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
yield self.store.remove_tag_from_room(user_id, room_id, tag)
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
# TODO: poke the notifier.
defer.returnValue((200, {}))

View File

@ -31,6 +31,14 @@ class TagsStore(SQLBaseStore):
"private_user_data_max_stream_id", "stream_id"
)
def get_max_private_user_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._private_user_data_id_gen.get_max_token(self)
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
@ -83,7 +91,7 @@ class TagsStore(SQLBaseStore):
results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(self, user_id)
tags_by_room = yield self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room[room_id]
@ -129,6 +137,9 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
"""Remove a tag from a room for a user.
@ -148,6 +159,9 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.

View File

@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.private_user_data import PrivateUserDataEventSource
class EventSources(object):
@ -29,6 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource,
"typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource,
"private_user_data": PrivateUserDataEventSource,
}
def __init__(self, hs):
@ -52,5 +54,8 @@ class EventSources(object):
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
private_user_data_key=(
yield self.sources["private_user_data"].get_current_key()
),
)
defer.returnValue(token)

View File

@ -98,10 +98,13 @@ class EventID(DomainSpecificString):
class StreamToken(
namedtuple(
"Token",
("room_key", "presence_key", "typing_key", "receipt_key")
)
namedtuple("Token", (
"room_key",
"presence_key",
"typing_key",
"receipt_key",
"private_user_data_key",
))
):
_SEPARATOR = "_"
@ -128,13 +131,14 @@ class StreamToken(
else:
return int(self.room_key[1:].split("-")[-1])
def is_after(self, other_token):
def is_after(self, other):
"""Does this token contain events that the other doesn't?"""
return (
(other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key))
(other.room_stream_id < self.room_stream_id)
or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.private_user_data_key) < int(self.private_user_data_key))
)
def copy_and_advance(self, key, new_value):