Merge pull request #335 from matrix-org/markjh/room_tags

Add APIs for adding and removing tags from rooms
This commit is contained in:
Mark Haines 2015-11-03 16:45:53 +00:00
commit 6797fcd9ab
14 changed files with 525 additions and 26 deletions

View File

@ -147,6 +147,10 @@ class FilterCollection(object):
self.filter_json.get("room", {}).get("ephemeral", {}) self.filter_json.get("room", {}).get("ephemeral", {})
) )
self.room_private_user_data = Filter(
self.filter_json.get("room", {}).get("private_user_data", {})
)
self.presence_filter = Filter( self.presence_filter = Filter(
self.filter_json.get("presence", {}) self.filter_json.get("presence", {})
) )
@ -172,6 +176,9 @@ class FilterCollection(object):
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events) return self.room_ephemeral_filter.filter(events)
def filter_room_private_user_data(self, events):
return self.room_private_user_data.filter(events)
class Filter(object): class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):

View File

@ -322,6 +322,8 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("receipt"), None user, pagination_config.get_source_config("receipt"), None
) )
tags_by_room = yield self.store.get_tags_for_user(user_id)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -398,6 +400,15 @@ class MessageHandler(BaseHandler):
serialize_event(c, time_now, as_client_event) serialize_event(c, time_now, as_client_event)
for c in current_state.values() for c in current_state.values()
] ]
private_user_data = []
tags = tags_by_room.get(event.room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
d["private_user_data"] = private_user_data
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -447,6 +458,16 @@ class MessageHandler(BaseHandler):
result = yield self._room_initial_sync_parted( result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event user_id, room_id, pagin_config, member_event
) )
private_user_data = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
result["private_user_data"] = private_user_data
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -476,8 +497,8 @@ class MessageHandler(BaseHandler):
user_id, messages user_id, messages
) )
start_token = StreamToken(token[0], 0, 0, 0) start_token = StreamToken(token[0], 0, 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0) end_token = StreamToken(token[1], 0, 0, 0, 0)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class PrivateUserDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
user_id = user.to_string()
last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = []
for room_id, room_tags in tags.items():
results.append({
"type": "m.tag",
"content": {"tags": room_tags},
"room_id": room_id,
})
defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
defer.returnValue(([], config.to_id))

View File

@ -51,6 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"timeline", "timeline",
"state", "state",
"ephemeral", "ephemeral",
"private_user_data",
])): ])):
__slots__ = [] __slots__ = []
@ -58,13 +59,19 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result. to tell if room needs to be part of the sync result.
""" """
return bool(self.timeline or self.state or self.ephemeral) return bool(
self.timeline
or self.state
or self.ephemeral
or self.private_user_data
)
class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", "room_id",
"timeline", "timeline",
"state", "state",
"private_user_data",
])): ])):
__slots__ = [] __slots__ = []
@ -72,7 +79,11 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"""Make the result appear empty if there are no updates. This is used """Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result. to tell if room needs to be part of the sync result.
""" """
return bool(self.timeline or self.state) return bool(
self.timeline
or self.state
or self.private_user_data
)
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@ -197,6 +208,10 @@ class SyncHandler(BaseHandler):
) )
) )
tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string()
)
joined = [] joined = []
invited = [] invited = []
archived = [] archived = []
@ -207,7 +222,8 @@ class SyncHandler(BaseHandler):
sync_config=sync_config, sync_config=sync_config,
now_token=now_token, now_token=now_token,
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
typing_by_room=typing_by_room typing_by_room=typing_by_room,
tags_by_room=tags_by_room,
) )
joined.append(room_sync) joined.append(room_sync)
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
@ -226,6 +242,7 @@ class SyncHandler(BaseHandler):
leave_event_id=event.event_id, leave_event_id=event.event_id,
leave_token=leave_token, leave_token=leave_token,
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
) )
archived.append(room_sync) archived.append(room_sync)
@ -240,7 +257,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_joined_room(self, room_id, sync_config, def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token, now_token, timeline_since_token,
typing_by_room): typing_by_room, tags_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -260,8 +277,21 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=current_state_events, state=current_state_events,
ephemeral=typing_by_room.get(room_id, []), ephemeral=typing_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)) ))
def private_user_data_for_room(self, room_id, tags_by_room):
private_user_data = []
tags = tags_by_room.get(room_id)
if tags:
private_user_data.append({
"type": "m.tag",
"content": {"tags": tags},
})
return private_user_data
@defer.inlineCallbacks @defer.inlineCallbacks
def typing_by_room(self, sync_config, now_token, since_token=None): def typing_by_room(self, sync_config, now_token, since_token=None):
"""Get the typing events for each room the user is in """Get the typing events for each room the user is in
@ -296,7 +326,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token): timeline_since_token, tags_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -314,6 +344,9 @@ class SyncHandler(BaseHandler):
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state[leave_event_id].values(), state=leave_state[leave_event_id].values(),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
)) ))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -359,6 +392,11 @@ class SyncHandler(BaseHandler):
limit=timeline_limit + 1, limit=timeline_limit + 1,
) )
tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(),
since_token.private_user_data_key,
)
joined = [] joined = []
archived = [] archived = []
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
@ -399,7 +437,10 @@ class SyncHandler(BaseHandler):
limited=limited, limited=limited,
), ),
state=state, state=state,
ephemeral=typing_by_room.get(room_id, []) ephemeral=typing_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
@ -416,14 +457,14 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
typing_by_room typing_by_room, tags_by_room
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events: for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token sync_config, leave_event, since_token, tags_by_room
) )
archived.append(room_sync) archived.append(room_sync)
@ -487,7 +528,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config, def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token, since_token, now_token,
typing_by_room): typing_by_room, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to the room. Gives the client the most recent events and the changes to
state. state.
@ -528,7 +569,10 @@ class SyncHandler(BaseHandler):
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
ephemeral=typing_by_room.get(room_id, []) ephemeral=typing_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room(
room_id, tags_by_room
),
) )
logging.debug("Room sync: %r", room_sync) logging.debug("Room sync: %r", room_sync)
@ -537,7 +581,7 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token): since_token, tags_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
@ -578,6 +622,9 @@ class SyncHandler(BaseHandler):
room_id=leave_event.room_id, room_id=leave_event.room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
private_user_data=self.private_user_data_for_room(
leave_event.room_id, tags_by_room
),
) )
logging.debug("Room sync: %r", room_sync) logging.debug("Room sync: %r", room_sync)

View File

@ -270,7 +270,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """

View File

@ -22,6 +22,7 @@ from . import (
receipts, receipts,
keys, keys,
tokenrefresh, tokenrefresh,
tags,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -44,3 +45,4 @@ class ClientV2AlphaRestResource(JsonResource):
receipts.register_servlets(hs, client_resource) receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)

View File

@ -220,6 +220,10 @@ class SyncRestServlet(RestServlet):
) )
timeline_event_ids.append(event.event_id) timeline_event_ids.append(event.event_id)
private_user_data = filter.filter_room_private_user_data(
room.private_user_data
)
result = { result = {
"event_map": event_map, "event_map": event_map,
"timeline": { "timeline": {
@ -228,6 +232,7 @@ class SyncRestServlet(RestServlet):
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": state_event_ids}, "state": {"events": state_event_ids},
"private_user_data": {"events": private_user_data},
} }
if joined: if joined:

View File

@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_pattern
from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
def __init__(self, hs):
super(TagListServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id):
auth_user, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
tags = yield self.store.get_tags_for_room(user_id, room_id)
defer.returnValue((200, {"tags": tags}))
class TagServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
PATTERN = client_v2_pattern(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
def __init__(self, hs):
super(TagServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid tag JSON")
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
auth_user, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
TagListServlet(hs).register(http_server)
TagServlet(hs).register(http_server)

View File

@ -41,6 +41,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore from .search import SearchStore
from .tags import TagsStore
import logging import logging
@ -71,6 +72,7 @@ class DataStore(RoomMemberStore, RoomStore,
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,
TagsStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View File

@ -0,0 +1,38 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS room_tags(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
tag TEXT NOT NULL, -- The name of the tag.
content TEXT NOT NULL, -- The JSON content of the tag.
CONSTRAINT room_tag_uniqueness UNIQUE (user_id, room_id, tag)
);
CREATE TABLE IF NOT EXISTS room_tags_revisions (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
stream_id BIGINT NOT NULL, -- The current version of the room tags.
CONSTRAINT room_tag_revisions_uniqueness UNIQUE (user_id, room_id)
);
CREATE TABLE IF NOT EXISTS private_user_data_max_stream_id(
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_id BIGINT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO private_user_data_max_stream_id (stream_id) VALUES (0);

216
synapse/storage/tags.py Normal file
View File

@ -0,0 +1,216 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from .util.id_generators import StreamIdGenerator
import ujson as json
import logging
logger = logging.getLogger(__name__)
class TagsStore(SQLBaseStore):
def __init__(self, hs):
super(TagsStore, self).__init__(hs)
self._private_user_data_id_gen = StreamIdGenerator(
"private_user_data_max_stream_id", "stream_id"
)
def get_max_private_user_data_stream_id(self):
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
"""
return self._private_user_data_id_gen.get_max_token(self)
@cached()
def get_tags_for_user(self, user_id):
"""Get all the tags for a user.
Args:
user_id(str): The user to get the tags for.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings.
"""
deferred = self._simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@deferred.addCallback
def tags_by_room(rows):
tags_by_room = {}
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = json.loads(row["content"])
return tags_by_room
return deferred
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
"""Get all the tags for the rooms where the tags have changed since the
given version
Args:
user_id(str): The user to get the tags for.
stream_id(int): The earliest update to get for the user.
Returns:
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
sql = (
"SELECT room_id from room_tags_revisions"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
room_ids = [row[0] for row in txn.fetchall()]
return room_ids
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
results = {}
if room_ids:
tags_by_room = yield self.get_tags_for_user(user_id)
for room_id in room_ids:
results[room_id] = tags_by_room[room_id]
defer.returnValue(results)
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
Args:
user_id(str): The user to get tags for
room_id(str): The room to get tags for
Returns:
A deferred list of string tags.
"""
return self._simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(lambda rows: {
row["tag"]: json.loads(row["content"]) for row in rows
})
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
tag(str): The tag name to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the tag has been added.
"""
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="room_tags",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"tag": tag,
},
values={
"content": content_json,
}
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
"""Remove a tag from a room for a user.
Returns:
A deferred that completes once the tag has been removed
"""
def remove_tag_txn(txn, next_id):
sql = (
"DELETE FROM room_tags "
" WHERE user_id = ? AND room_id = ? AND tag = ?"
)
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
Args:
txn: The database cursor
user_id(str): The ID of the user.
room_id(str): The ID of the room.
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE private_user_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
update_sql = (
"UPDATE room_tags_revisions"
" SET stream_id = ?"
" WHERE user_id = ?"
" AND room_id = ?"
)
txn.execute(update_sql, (next_id, user_id, room_id))
if txn.rowcount == 0:
insert_sql = (
"INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
" VALUES (?, ?, ?)"
)
try:
txn.execute(insert_sql, (user_id, room_id, next_id))
except self.database_engine.module.IntegrityError:
# Ignore insertion errors. It doesn't matter if the row wasn't
# inserted because if two updates happend concurrently the one
# with the higher stream_id will not be reported to a client
# unless the previous update has completed. It doesn't matter
# which stream_id ends up in the table, as long as it is higher
# than the id that the client has.
pass

View File

@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.private_user_data import PrivateUserDataEventSource
class EventSources(object): class EventSources(object):
@ -29,6 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource, "receipt": ReceiptEventSource,
"private_user_data": PrivateUserDataEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -52,5 +54,8 @@ class EventSources(object):
receipt_key=( receipt_key=(
yield self.sources["receipt"].get_current_key() yield self.sources["receipt"].get_current_key()
), ),
private_user_data_key=(
yield self.sources["private_user_data"].get_current_key()
),
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -98,10 +98,13 @@ class EventID(DomainSpecificString):
class StreamToken( class StreamToken(
namedtuple( namedtuple("Token", (
"Token", "room_key",
("room_key", "presence_key", "typing_key", "receipt_key") "presence_key",
) "typing_key",
"receipt_key",
"private_user_data_key",
))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -109,7 +112,7 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
if len(keys) == len(cls._fields) - 1: while len(keys) < len(cls._fields):
# i.e. old token from before receipt_key # i.e. old token from before receipt_key
keys.append("0") keys.append("0")
return cls(*keys) return cls(*keys)
@ -128,13 +131,14 @@ class StreamToken(
else: else:
return int(self.room_key[1:].split("-")[-1]) return int(self.room_key[1:].split("-")[-1])
def is_after(self, other_token): def is_after(self, other):
"""Does this token contain events that the other doesn't?""" """Does this token contain events that the other doesn't?"""
return ( return (
(other_token.room_stream_id < self.room_stream_id) (other.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key)) or (int(other.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.private_user_data_key) < int(self.private_user_data_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View File

@ -369,7 +369,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# all be ours # all be ours
# I'll already get my own presence state change # I'll already get my own presence state change
self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []}, self.assertEquals({"start": "0_1_0_0_0", "end": "0_1_0_0_0", "chunk": []},
response response
) )
@ -388,7 +388,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"/events?from=s0_1_0&timeout=0", None) "/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [ self.assertEquals({"start": "s0_1_0_0_0", "end": "s0_2_0_0_0", "chunk": [
{"type": "m.presence", {"type": "m.presence",
"content": { "content": {
"user_id": "@banana:test", "user_id": "@banana:test",