From 4dcaa42b6d1fde87e29eb4f3c0080ea92fcc7fa2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 30 Nov 2015 17:45:31 +0000 Subject: [PATCH 01/10] Allow paginating search ordered by recents --- synapse/handlers/search.py | 146 ++++++++++++-------------- synapse/storage/events.py | 77 ++++++++++++++ synapse/storage/schema/delta/26/ts.py | 57 ++++++++++ synapse/storage/search.py | 41 +++++--- 4 files changed, 228 insertions(+), 93 deletions(-) create mode 100644 synapse/storage/schema/delta/26/ts.py diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6d2197339..671dbb61b 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -131,6 +131,17 @@ class SearchHandler(BaseHandler): if batch_group == "room_id": room_ids.intersection_update({batch_group_key}) + if not room_ids: + defer.returnValue({ + "search_categories": { + "room_events": { + "results": {}, + "count": 0, + "highlights": [], + } + } + }) + rank_map = {} # event_id -> rank of event allowed_events = [] room_groups = {} # Holds result of grouping by room, if applicable @@ -178,85 +189,66 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - # In this case we specifically loop through each room as the given - # limit applies to each room, rather than a global list. - # This is not necessarilly a good idea. - for room_id in room_ids: - room_events = [] - if batch_group == "room_id" and batch_group_key == room_id: - pagination_token = batch_token - else: + room_events = [] + i = 0 + + pagination_token = batch_token + + # We keep looping and we keep filtering until we reach the limit + # or we run out of things. + # But only go around 5 times since otherwise synapse will be sad. + while len(room_events) < search_filter.limit() and i < 5: + i += 1 + search_result = yield self.store.search_rooms( + room_ids, search_term, keys, search_filter.limit() * 2, + pagination_token=pagination_token, + ) + + if search_result["highlights"]: + highlights.update(search_result["highlights"]) + + results = search_result["results"] + + results_map = {r["event"].event_id: r for r in results} + + rank_map.update({r["event"].event_id: r["rank"] for r in results}) + + filtered_events = search_filter.filter([ + r["event"] for r in results + ]) + + events = yield self._filter_events_for_client( + user.to_string(), filtered_events + ) + + room_events.extend(events) + room_events = room_events[:search_filter.limit()] + + if len(results) < search_filter.limit() * 2: pagination_token = None - i = 0 - - # We keep looping and we keep filtering until we reach the limit - # or we run out of things. - # But only go around 5 times since otherwise synapse will be sad. - while len(room_events) < search_filter.limit() and i < 5: - i += 1 - search_result = yield self.store.search_room( - room_id, search_term, keys, search_filter.limit() * 2, - pagination_token=pagination_token, - ) - - if search_result["highlights"]: - highlights.update(search_result["highlights"]) - - results = search_result["results"] - - results_map = {r["event"].event_id: r for r in results} - - rank_map.update({r["event"].event_id: r["rank"] for r in results}) - - filtered_events = search_filter.filter([ - r["event"] for r in results - ]) - - events = yield self._filter_events_for_client( - user.to_string(), filtered_events - ) - - room_events.extend(events) - room_events = room_events[:search_filter.limit()] - - if len(results) < search_filter.limit() * 2: - pagination_token = None - break - else: - pagination_token = results[-1]["pagination_token"] - - if room_events: - res = results_map[room_events[-1].event_id] - pagination_token = res["pagination_token"] - - group = room_groups.setdefault(room_id, {}) - if pagination_token: - next_batch = encode_base64("%s\n%s\n%s" % ( - "room_id", room_id, pagination_token - )) - group["next_batch"] = next_batch - - if batch_token: - global_next_batch = next_batch - - group["results"] = [e.event_id for e in room_events] - group["order"] = max( - e.origin_server_ts/1000 for e in room_events - if hasattr(e, "origin_server_ts") - ) - - allowed_events.extend(room_events) - - # Normalize the group orders - if room_groups: - if len(room_groups) > 1: - mx = max(g["order"] for g in room_groups.values()) - mn = min(g["order"] for g in room_groups.values()) - - for g in room_groups.values(): - g["order"] = (g["order"] - mn) * 1.0 / (mx - mn) else: - room_groups.values()[0]["order"] = 1 + pagination_token = results[-1]["pagination_token"] + + if room_events: + for event in room_events: + group = room_groups.setdefault(event.room_id, { + "results": [], + }) + group["results"].append(event.event_id) + + pagination_token = results_map[room_events[-1].event_id]["pagination_token"] + + if pagination_token: + global_next_batch = encode_base64("%s\n%s\n%s" % ( + "all", "", pagination_token + )) + + for room_id, group in room_groups.items(): + group["next_batch"] = encode_base64("%s\n%s\n%s" % ( + "room_id", room_id, pagination_token + )) + + allowed_events.extend(room_events) else: # We should never get here due to the guard earlier. diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 5d35ca90b..7088f2709 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -51,6 +51,14 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events class EventsStore(SQLBaseStore): + EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" + + def __init__(self, hs): + super(EventsStore, self).__init__(hs) + self.register_background_update_handler( + self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts + ) + @defer.inlineCallbacks def persist_events(self, events_and_contexts, backfilled=False, is_new_state=True): @@ -365,6 +373,7 @@ class EventsStore(SQLBaseStore): "processed": True, "outlier": event.internal_metadata.is_outlier(), "content": encode_json(event.content).decode("UTF-8"), + "origin_server_ts": int(event.origin_server_ts), } for event, _ in events_and_contexts ], @@ -964,3 +973,71 @@ class EventsStore(SQLBaseStore): ret = yield self.runInteraction("count_messages", _count_messages) defer.returnValue(ret) + + @defer.inlineCallbacks + def _background_reindex_origin_server_ts(self, progress, batch_size): + target_min_stream_id = progress["target_min_stream_id_inclusive"] + max_stream_id = progress["max_stream_id_exclusive"] + rows_inserted = progress.get("rows_inserted", 0) + + INSERT_CLUMP_SIZE = 1000 + + def reindex_search_txn(txn): + sql = ( + "SELECT stream_ordering, event_id FROM events" + " WHERE ? <= stream_ordering AND stream_ordering < ?" + " ORDER BY stream_ordering DESC" + " LIMIT ?" + ) + + txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + + rows = txn.fetchall() + if not rows: + return 0 + + min_stream_id = rows[-1][0] + event_ids = [row[1] for row in rows] + + events = self._get_events_txn(txn, event_ids) + + rows = [] + for event in events: + try: + event_id = event.event_id + origin_server_ts = event.origin_server_ts + except (KeyError, AttributeError): + # If the event is missing a necessary field then + # skip over it. + continue + + rows.append((origin_server_ts, event_id)) + + sql = ( + "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" + ) + + for index in range(0, len(rows), INSERT_CLUMP_SIZE): + clump = rows[index:index + INSERT_CLUMP_SIZE] + txn.executemany(sql, clump) + + progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + "rows_inserted": rows_inserted + len(rows) + } + + self._background_update_progress_txn( + txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress + ) + + return len(rows) + + result = yield self.runInteraction( + self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn + ) + + if not result: + yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME) + + defer.returnValue(result) diff --git a/synapse/storage/schema/delta/26/ts.py b/synapse/storage/schema/delta/26/ts.py new file mode 100644 index 000000000..8d4a98197 --- /dev/null +++ b/synapse/storage/schema/delta/26/ts.py @@ -0,0 +1,57 @@ +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements + +import ujson + +logger = logging.getLogger(__name__) + + +ALTER_TABLE = ( + "ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;" + "CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);" +) + + +def run_upgrade(cur, database_engine, *args, **kwargs): + for statement in get_statements(ALTER_TABLE.splitlines()): + cur.execute(statement) + + cur.execute("SELECT MIN(stream_ordering) FROM events") + rows = cur.fetchall() + min_stream_id = rows[0][0] + + cur.execute("SELECT MAX(stream_ordering) FROM events") + rows = cur.fetchall() + max_stream_id = rows[0][0] + + if min_stream_id is not None and max_stream_id is not None: + progress = { + "target_min_stream_id_inclusive": min_stream_id, + "max_stream_id_exclusive": max_stream_id + 1, + "rows_inserted": 0, + } + progress_json = ujson.dumps(progress) + + sql = ( + "INSERT into background_updates (update_name, progress_json)" + " VALUES (?, ?)" + ) + + sql = database_engine.convert_param_style(sql) + + cur.execute(sql, ("event_origin_server_ts", progress_json)) diff --git a/synapse/storage/search.py b/synapse/storage/search.py index c6386642d..20a62d07f 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -212,11 +212,11 @@ class SearchStore(BackgroundUpdateStore): }) @defer.inlineCallbacks - def search_room(self, room_id, search_term, keys, limit, pagination_token=None): + def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None): """Performs a full text search over events with given keys. Args: - room_id (str): The room_id to search in + room_id (list): The room_ids to search in search_term (str): Search term to search for keys (list): List of keys to search in, currently supports "content.body", "content.name", "content.topic" @@ -226,7 +226,15 @@ class SearchStore(BackgroundUpdateStore): list of dicts """ clauses = [] - args = [search_term, room_id] + args = [search_term] + + # Make sure we don't explode because the person is in too many rooms. + # We filter the results below regardless. + if len(room_ids) < 500: + clauses.append( + "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),) + ) + args.extend(room_ids) local_clauses = [] for key in keys: @@ -239,25 +247,25 @@ class SearchStore(BackgroundUpdateStore): if pagination_token: try: - topo, stream = pagination_token.split(",") - topo = int(topo) + origin_server_ts, stream = pagination_token.split(",") + origin_server_ts = int(origin_server_ts) stream = int(stream) except: raise SynapseError(400, "Invalid pagination token") clauses.append( - "(topological_ordering < ?" - " OR (topological_ordering = ? AND stream_ordering < ?))" + "(origin_server_ts < ?" + " OR (origin_server_ts = ? AND stream_ordering < ?))" ) - args.extend([topo, topo, stream]) + args.extend([origin_server_ts, origin_server_ts, stream]) if isinstance(self.database_engine, PostgresEngine): sql = ( "SELECT ts_rank_cd(vector, query) as rank," - " topological_ordering, stream_ordering, room_id, event_id" + " origin_server_ts, stream_ordering, room_id, event_id" " FROM plainto_tsquery('english', ?) as query, event_search" " NATURAL JOIN events" - " WHERE vector @@ query AND room_id = ?" + " WHERE vector @@ query AND " ) elif isinstance(self.database_engine, Sqlite3Engine): # We use CROSS JOIN here to ensure we use the right indexes. @@ -270,24 +278,23 @@ class SearchStore(BackgroundUpdateStore): # MATCH unless it uses the full text search index sql = ( "SELECT rank(matchinfo) as rank, room_id, event_id," - " topological_ordering, stream_ordering" + " origin_server_ts, stream_ordering" " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" " FROM event_search" " WHERE value MATCH ?" " )" " CROSS JOIN events USING (event_id)" - " WHERE room_id = ?" + " WHERE " ) else: # This should be unreachable. raise Exception("Unrecognized database engine") - for clause in clauses: - sql += " AND " + clause + sql += " AND ".join(clauses) # We add an arbitrary limit here to ensure we don't try to pull the # entire table from the database. - sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" + sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?" args.append(limit) @@ -295,6 +302,8 @@ class SearchStore(BackgroundUpdateStore): "search_rooms", self.cursor_to_dict, sql, *args ) + results = filter(lambda row: row["room_id"] in room_ids, results) + events = yield self._get_events([r["event_id"] for r in results]) event_map = { @@ -312,7 +321,7 @@ class SearchStore(BackgroundUpdateStore): "event": event_map[r["event_id"]], "rank": r["rank"], "pagination_token": "%s,%s" % ( - r["topological_ordering"], r["stream_ordering"] + r["origin_server_ts"], r["stream_ordering"] ), } for r in results From bde8d78b8a19c941cef926d3e480c81555dbe993 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Mon, 30 Nov 2015 17:46:35 +0000 Subject: [PATCH 02/10] Copy rather than move the fields to shuffle between a v1 and a v2 event. This should make all v1 APIs compatible with v2 clients. While still allowing v1 clients to access the fields. This makes the documentation easier since we can just document the v2 format and explain that some of the fields, in some of the APIs are duplicated for backwards compatibility, rather than having to document two separate event formats. --- synapse/events/utils.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 44cc1ef13..666df5411 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -100,22 +100,18 @@ def format_event_raw(d): def format_event_for_client_v1(d): - d["user_id"] = d.pop("sender", None) + d = format_event_for_client_v2(d) - move_keys = ( + d["user_id"] = d.get("sender", None) + + copy_keys = ( "age", "redacted_because", "replaces_state", "prev_content", "invite_room_state", ) - for key in move_keys: + for key in copy_keys: if key in d["unsigned"]: d[key] = d["unsigned"][key] - drop_keys = ( - "auth_events", "prev_events", "hashes", "signatures", "depth", - "unsigned", "origin", "prev_state" - ) - for key in drop_keys: - d.pop(key, None) return d From da7dd586414653a3d7d3ae4225600cb5126059f5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Dec 2015 11:06:40 +0000 Subject: [PATCH 03/10] Tidy up a bit --- synapse/handlers/search.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 671dbb61b..df6390cf0 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -226,19 +226,20 @@ class SearchHandler(BaseHandler): if len(results) < search_filter.limit() * 2: pagination_token = None + break else: pagination_token = results[-1]["pagination_token"] - if room_events: - for event in room_events: - group = room_groups.setdefault(event.room_id, { - "results": [], - }) - group["results"].append(event.event_id) + for event in room_events: + group = room_groups.setdefault(event.room_id, { + "results": [], + }) + group["results"].append(event.event_id) - pagination_token = results_map[room_events[-1].event_id]["pagination_token"] + if room_events and len(room_events) >= search_filter.limit(): + last_event_id = room_events[-1].event_id + pagination_token = results_map[last_event_id]["pagination_token"] - if pagination_token: global_next_batch = encode_base64("%s\n%s\n%s" % ( "all", "", pagination_token )) From 306415391dfe1a304738f307440d7bd79fe93972 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Dec 2015 11:14:48 +0000 Subject: [PATCH 04/10] Only add the user_id if the sender is present --- synapse/events/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 666df5411..e634b149b 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -102,7 +102,9 @@ def format_event_raw(d): def format_event_for_client_v1(d): d = format_event_for_client_v2(d) - d["user_id"] = d.get("sender", None) + sender = d.get("sender") + if sender is not None: + d["user_id"] = sender copy_keys = ( "age", "redacted_because", "replaces_state", "prev_content", From 31069ecf6a9c91e62fcecab9059385e33a19a629 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Dec 2015 15:59:45 +0000 Subject: [PATCH 05/10] Rename presence_handler.send_invite to presence_handler.send_presence_invite to distinguish it from normal invites --- synapse/handlers/presence.py | 2 +- synapse/rest/client/v1/presence.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index aca65096f..e95e821c9 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -467,7 +467,7 @@ class PresenceHandler(BaseHandler): ) @defer.inlineCallbacks - def send_invite(self, observer_user, observed_user): + def send_presence_invite(self, observer_user, observed_user): """Request the presence of a local or remote user for a local user""" if not self.hs.is_mine(observer_user): raise SynapseError(400, "User is not hosted on this Home Server") diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 6fe5d19a2..48533f9d6 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -120,7 +120,7 @@ class PresenceListRestServlet(ClientV1RestServlet): if len(u) == 0: continue invited_user = UserID.from_string(u) - yield self.handlers.presence_handler.send_invite( + yield self.handlers.presence_handler.send_presence_invite( observer_user=user, observed_user=invited_user ) From 7b593af7e16ddad7ae61173306649424e1078814 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Dec 2015 16:06:17 +0000 Subject: [PATCH 06/10] rename the method in the tests as well --- tests/handlers/test_presence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 1172ceae8..c42b5b80d 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -365,7 +365,7 @@ class PresenceInvitesTestCase(PresenceTestCase): # TODO(paul): This test will likely break if/when real auth permissions # are added; for now the HS will always accept any invite - yield self.handler.send_invite( + yield self.handler.send_presence_invite( observer_user=self.u_apple, observed_user=self.u_banana) self.assertEquals( @@ -384,7 +384,7 @@ class PresenceInvitesTestCase(PresenceTestCase): @defer.inlineCallbacks def test_invite_local_nonexistant(self): - yield self.handler.send_invite( + yield self.handler.send_presence_invite( observer_user=self.u_apple, observed_user=self.u_durian) self.assertEquals( @@ -414,7 +414,7 @@ class PresenceInvitesTestCase(PresenceTestCase): defer.succeed((200, "OK")) ) - yield self.handler.send_invite( + yield self.handler.send_presence_invite( observer_user=self.u_apple, observed_user=u_rocket) self.assertEquals( From 14d7acfad48ea7807b032b3fd99649b500e651f7 Mon Sep 17 00:00:00 2001 From: Daniel Wagner-Hall Date: Tue, 1 Dec 2015 17:34:32 +0000 Subject: [PATCH 07/10] Host /unstable and /r0 versions of r0 APIs --- synapse/federation/transport/server.py | 2 +- synapse/http/server.py | 13 +-- synapse/http/servlet.py | 8 +- synapse/rest/client/v1/admin.py | 4 +- synapse/rest/client/v1/base.py | 10 ++- synapse/rest/client/v1/directory.py | 4 +- synapse/rest/client/v1/events.py | 6 +- synapse/rest/client/v1/initial_sync.py | 4 +- synapse/rest/client/v1/login.py | 12 +-- synapse/rest/client/v1/presence.py | 6 +- synapse/rest/client/v1/profile.py | 8 +- synapse/rest/client/v1/push_rule.py | 4 +- synapse/rest/client/v1/pusher.py | 4 +- synapse/rest/client/v1/register.py | 4 +- synapse/rest/client/v1/room.py | 90 ++++++++++---------- synapse/rest/client/v1/voip.py | 4 +- synapse/rest/client/v2_alpha/_base.py | 10 ++- synapse/rest/client/v2_alpha/account.py | 6 +- synapse/rest/client/v2_alpha/auth.py | 4 +- synapse/rest/client/v2_alpha/filter.py | 6 +- synapse/rest/client/v2_alpha/keys.py | 14 +-- synapse/rest/client/v2_alpha/receipts.py | 4 +- synapse/rest/client/v2_alpha/register.py | 4 +- synapse/rest/client/v2_alpha/sync.py | 4 +- synapse/rest/client/v2_alpha/tags.py | 6 +- synapse/rest/client/v2_alpha/tokenrefresh.py | 4 +- tests/utils.py | 5 +- 27 files changed, 133 insertions(+), 117 deletions(-) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 127b4da4f..6b164fd2d 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -165,7 +165,7 @@ class BaseFederationServlet(object): if code is None: continue - server.register_path(method, pattern, self._wrap(code)) + server.register_paths(method, (pattern,), self._wrap(code)) class FederationSendServlet(BaseFederationServlet): diff --git a/synapse/http/server.py b/synapse/http/server.py index 50feea6f1..ef75be742 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -120,7 +120,7 @@ class HttpServer(object): """ Interface for registering callbacks on a HTTP server """ - def register_path(self, method, path_pattern, callback): + def register_paths(self, method, path_patterns, callback): """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -129,7 +129,7 @@ class HttpServer(object): Args: method (str): The method to listen to. - path_pattern (str): The regex used to match requests. + path_patterns (list): The regex used to match requests. callback (function): The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. @@ -165,10 +165,11 @@ class JsonResource(HttpServer, resource.Resource): self.version_string = hs.version_string self.hs = hs - def register_path(self, method, path_pattern, callback): - self.path_regexs.setdefault(method, []).append( - self._PathEntry(path_pattern, callback) - ) + def register_paths(self, method, path_patterns, callback): + for path_pattern in path_patterns: + self.path_regexs.setdefault(method, []).append( + self._PathEntry(path_pattern, callback) + ) def render(self, request): """ This gets called by twisted every time someone sends us a request. diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index 9cda17fcf..32b6d6cd7 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError import logging - logger = logging.getLogger(__name__) @@ -102,12 +101,13 @@ class RestServlet(object): def register(self, http_server): """ Register this servlet with the given HTTP server. """ - if hasattr(self, "PATTERN"): - pattern = self.PATTERN + if hasattr(self, "PATTERNS"): + patterns = self.PATTERNS for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): if hasattr(self, "on_%s" % (method,)): method_handler = getattr(self, "on_%s" % (method,)) - http_server.register_path(method, pattern, method_handler) + http_server.register_paths(method, patterns, method_handler) + else: raise NotImplementedError("RestServlet must register something.") diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index bdde43864..010369788 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.types import UserID -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class WhoisRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/admin/whois/(?P[^/]*)") + PATTERNS = client_path_patterns("/admin/whois/(?P[^/]*)", releases=()) @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 504a5e432..7ae3839a1 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -27,7 +27,7 @@ import logging logger = logging.getLogger(__name__) -def client_path_pattern(path_regex): +def client_path_patterns(path_regex, releases=(0,)): """Creates a regex compiled client path with the correct client path prefix. @@ -37,7 +37,13 @@ def client_path_pattern(path_regex): Returns: SRE_Pattern """ - return re.compile("^" + CLIENT_PREFIX + path_regex) + patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)] + unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release) + patterns.append(re.compile("^" + new_prefix + path_regex)) + return patterns class ClientV1RestServlet(RestServlet): diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 240eedac7..f488e2dd4 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError, Codes from synapse.types import RoomAlias -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import logging @@ -32,7 +32,7 @@ def register_servlets(hs, http_server): class ClientDirectoryServer(ClientV1RestServlet): - PATTERN = client_path_pattern("/directory/room/(?P[^/]*)$") + PATTERNS = client_path_patterns("/directory/room/(?P[^/]*)$") @defer.inlineCallbacks def on_GET(self, request, room_alias): diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index 3e1750d1a..41b97e7d1 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.events.utils import serialize_event import logging @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class EventStreamRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/events$") + PATTERNS = client_path_patterns("/events$") DEFAULT_LONGPOLL_TIME_MS = 30000 @@ -72,7 +72,7 @@ class EventStreamRestServlet(ClientV1RestServlet): # TODO: Unit test gets, with and without auth, with different kinds of events. class EventRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/events/(?P[^/]*)$") + PATTERNS = client_path_patterns("/events/(?P[^/]*)$") def __init__(self, hs): super(EventRestServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 856a70f29..9ad3df8a9 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -16,12 +16,12 @@ from twisted.internet import defer from synapse.streams.config import PaginationConfig -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns # TODO: Needs unit testing class InitialSyncRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/initialSync$") + PATTERNS = client_path_patterns("/initialSync$") @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 720d6358e..b0b641e43 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, LoginError, Codes from synapse.http.client import SimpleHttpClient from synapse.types import UserID -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import simplejson as json import urllib @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) class LoginRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login$") + PATTERNS = client_path_patterns("/login$", releases=()) PASS_TYPE = "m.login.password" SAML2_TYPE = "m.login.saml2" CAS_TYPE = "m.login.cas" @@ -238,7 +238,7 @@ class LoginRestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/saml2") + PATTERNS = client_path_patterns("/login/saml2", releases=()) def __init__(self, hs): super(SAML2RestServlet, self).__init__(hs) @@ -282,7 +282,7 @@ class SAML2RestServlet(ClientV1RestServlet): # TODO Delete this after all CAS clients switch to token login instead class CasRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas") + PATTERNS = client_path_patterns("/login/cas", releases=()) def __init__(self, hs): super(CasRestServlet, self).__init__(hs) @@ -293,7 +293,7 @@ class CasRestServlet(ClientV1RestServlet): class CasRedirectServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas/redirect") + PATTERNS = client_path_patterns("/login/cas/redirect", releases=()) def __init__(self, hs): super(CasRedirectServlet, self).__init__(hs) @@ -316,7 +316,7 @@ class CasRedirectServlet(ClientV1RestServlet): class CasTicketServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/login/cas/ticket") + PATTERNS = client_path_patterns("/login/cas/ticket", releases=()) def __init__(self, hs): super(CasTicketServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 48533f9d6..e0949fe4b 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -19,7 +19,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json import logging @@ -28,7 +28,7 @@ logger = logging.getLogger(__name__) class PresenceStatusRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/presence/(?P[^/]*)/status") + PATTERNS = client_path_patterns("/presence/(?P[^/]*)/status") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -73,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/presence/list/(?P[^/]*)") + PATTERNS = client_path_patterns("/presence/list/(?P[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 3218e4702..e6c6e5d02 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,14 +16,14 @@ """ This module contains REST servlets to do with profile: /profile/ """ from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.types import UserID import simplejson as json class ProfileDisplaynameRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P[^/]*)/displayname") + PATTERNS = client_path_patterns("/profile/(?P[^/]*)/displayname") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -56,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P[^/]*)/avatar_url") + PATTERNS = client_path_patterns("/profile/(?P[^/]*)/avatar_url") @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -89,7 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/profile/(?P[^/]*)") + PATTERNS = client_path_patterns("/profile/(?P[^/]*)") @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index b0870db1a..edf5b0ca4 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import ( SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError ) -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns from synapse.storage.push_rule import ( InconsistentRuleException, RuleNotFoundException ) @@ -31,7 +31,7 @@ import simplejson as json class PushRuleRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/pushrules/.*$") + PATTERNS = client_path_patterns("/pushrules/.*$") SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index a110c0a4f..6f465035b 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -17,13 +17,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.push import PusherConfigException -from .base import ClientV1RestServlet, client_path_pattern +from .base import ClientV1RestServlet, client_path_patterns import simplejson as json class PusherRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/pushers/set$") + PATTERNS = client_path_patterns("/pushers/set$") @defer.inlineCallbacks def on_POST(self, request): diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index a56834e36..5b95d63e2 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, Codes from synapse.api.constants import LoginType -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import synapse.util.stringutils as stringutils from synapse.util.async import run_on_reactor @@ -48,7 +48,7 @@ class RegisterRestServlet(ClientV1RestServlet): handler doesn't have a concept of multi-stages or sessions. """ - PATTERN = client_path_pattern("/register$") + PATTERNS = client_path_patterns("/register$", releases=()) def __init__(self, hs): super(RegisterRestServlet, self).__init__(hs) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 6952d269e..d86d26646 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -16,7 +16,7 @@ """ This module contains REST servlets to do with rooms: /rooms/ """ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership @@ -34,16 +34,16 @@ class RoomCreateRestServlet(ClientV1RestServlet): # No PATTERN; we have custom dispatch rules here def register(self, http_server): - PATTERN = "/createRoom" - register_txn_path(self, PATTERN, http_server) + PATTERNS = "/createRoom" + register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity - http_server.register_path("OPTIONS", - client_path_pattern("/rooms(?:/.*)?$"), - self.on_OPTIONS) + http_server.register_paths("OPTIONS", + client_path_patterns("/rooms(?:/.*)?$"), + self.on_OPTIONS) # define CORS for /createRoom[/txnid] - http_server.register_path("OPTIONS", - client_path_pattern("/createRoom(?:/.*)?$"), - self.on_OPTIONS) + http_server.register_paths("OPTIONS", + client_path_patterns("/createRoom(?:/.*)?$"), + self.on_OPTIONS) @defer.inlineCallbacks def on_PUT(self, request, txn_id): @@ -103,18 +103,18 @@ class RoomStateEventRestServlet(ClientV1RestServlet): state_key = ("/rooms/(?P[^/]*)/state/" "(?P[^/]*)/(?P[^/]*)$") - http_server.register_path("GET", - client_path_pattern(state_key), - self.on_GET) - http_server.register_path("PUT", - client_path_pattern(state_key), - self.on_PUT) - http_server.register_path("GET", - client_path_pattern(no_state_key), - self.on_GET_no_state_key) - http_server.register_path("PUT", - client_path_pattern(no_state_key), - self.on_PUT_no_state_key) + http_server.register_paths("GET", + client_path_patterns(state_key), + self.on_GET) + http_server.register_paths("PUT", + client_path_patterns(state_key), + self.on_PUT) + http_server.register_paths("GET", + client_path_patterns(no_state_key, releases=()), + self.on_GET_no_state_key) + http_server.register_paths("PUT", + client_path_patterns(no_state_key, releases=()), + self.on_PUT_no_state_key) def on_GET_no_state_key(self, request, room_id, event_type): return self.on_GET(request, room_id, event_type, "") @@ -170,8 +170,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] - PATTERN = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") - register_txn_path(self, PATTERN, http_server, with_get=True) + PATTERNS = ("/rooms/(?P[^/]*)/send/(?P[^/]*)") + register_txn_path(self, PATTERNS, http_server, with_get=True) @defer.inlineCallbacks def on_POST(self, request, room_id, event_type, txn_id=None): @@ -215,8 +215,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet): def register(self, http_server): # /join/$room_identifier[/$txn_id] - PATTERN = ("/join/(?P[^/]*)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/join/(?P[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_identifier, txn_id=None): @@ -280,7 +280,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): # TODO: Needs unit testing class PublicRoomListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/publicRooms$") + PATTERNS = client_path_patterns("/publicRooms$") @defer.inlineCallbacks def on_GET(self, request): @@ -291,7 +291,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomMemberListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P[^/]*)/members$") + PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/members$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -328,7 +328,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): # TODO: Needs better unit testing class RoomMessageListRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P[^/]*)/messages$") + PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/messages$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -351,7 +351,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomStateRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P[^/]*)/state$") + PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/state$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -368,7 +368,7 @@ class RoomStateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing class RoomInitialSyncRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P[^/]*)/initialSync$") + PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/initialSync$") @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -384,7 +384,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomTriggerBackfill(ClientV1RestServlet): - PATTERN = client_path_pattern("/rooms/(?P[^/]*)/backfill$") + PATTERNS = client_path_patterns("/rooms/(?P[^/]*)/backfill$", releases=()) def __init__(self, hs): super(RoomTriggerBackfill, self).__init__(hs) @@ -408,7 +408,7 @@ class RoomTriggerBackfill(ClientV1RestServlet): class RoomEventContext(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/rooms/(?P[^/]*)/context/(?P[^/]*)$" ) @@ -447,9 +447,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): def register(self, http_server): # /rooms/$roomid/[invite|join|leave] - PATTERN = ("/rooms/(?P[^/]*)/" - "(?Pjoin|invite|leave|ban|kick|forget)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/rooms/(?P[^/]*)/" + "(?Pjoin|invite|leave|ban|kick|forget)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, membership_action, txn_id=None): @@ -543,8 +543,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet): def register(self, http_server): - PATTERN = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") - register_txn_path(self, PATTERN, http_server) + PATTERNS = ("/rooms/(?P[^/]*)/redact/(?P[^/]*)") + register_txn_path(self, PATTERNS, http_server) @defer.inlineCallbacks def on_POST(self, request, room_id, event_id, txn_id=None): @@ -582,7 +582,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/rooms/(?P[^/]*)/typing/(?P[^/]*)$" ) @@ -615,7 +615,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): class SearchRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern( + PATTERNS = client_path_patterns( "/search$" ) @@ -655,20 +655,20 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): http_server : The http_server to register paths with. with_get: True to also register respective GET paths for the PUTs. """ - http_server.register_path( + http_server.register_paths( "POST", - client_path_pattern(regex_string + "$"), + client_path_patterns(regex_string + "$"), servlet.on_POST ) - http_server.register_path( + http_server.register_paths( "PUT", - client_path_pattern(regex_string + "/(?P[^/]*)$"), + client_path_patterns(regex_string + "/(?P[^/]*)$"), servlet.on_PUT ) if with_get: - http_server.register_path( + http_server.register_paths( "GET", - client_path_pattern(regex_string + "/(?P[^/]*)$"), + client_path_patterns(regex_string + "/(?P[^/]*)$"), servlet.on_GET ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index eb7c57cad..1567a03c8 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from base import ClientV1RestServlet, client_path_pattern +from base import ClientV1RestServlet, client_path_patterns import hmac @@ -24,7 +24,7 @@ import base64 class VoipRestServlet(ClientV1RestServlet): - PATTERN = client_path_pattern("/voip/turnServer$") + PATTERNS = client_path_patterns("/voip/turnServer$") @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 4540e8dcf..7b8b879c0 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -27,7 +27,7 @@ import simplejson logger = logging.getLogger(__name__) -def client_v2_pattern(path_regex): +def client_v2_patterns(path_regex, releases=(0,)): """Creates a regex compiled client path with the correct client path prefix. @@ -37,7 +37,13 @@ def client_v2_pattern(path_regex): Returns: SRE_Pattern """ - return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) + patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) + for release in releases: + new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) + patterns.append(re.compile("^" + new_prefix + path_regex)) + return patterns def parse_request_allow_empty(request): diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 1970ad345..6f1c33f75 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -20,7 +20,7 @@ from synapse.api.errors import LoginError, SynapseError, Codes from synapse.http.servlet import RestServlet from synapse.util.async import run_on_reactor -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request import logging @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class PasswordRestServlet(RestServlet): - PATTERN = client_v2_pattern("/account/password") + PATTERNS = client_v2_patterns("/account/password", releases=()) def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -89,7 +89,7 @@ class PasswordRestServlet(RestServlet): class ThreepidRestServlet(RestServlet): - PATTERN = client_v2_pattern("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid", releases=()) def __init__(self, hs): super(ThreepidRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 4c726f05f..fb5947a14 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern +from ._base import client_v2_patterns import logging @@ -97,7 +97,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERN = client_v2_pattern("/auth/(?P[\w\.]*)/fallback/web") + PATTERNS = client_v2_patterns("/auth/(?P[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index 97956a4b9..3cd0364b5 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, SynapseError from synapse.http.servlet import RestServlet from synapse.types import UserID -from ._base import client_v2_pattern +from ._base import client_v2_patterns import simplejson as json import logging @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/user/(?P[^/]*)/filter/(?P[^/]*)") + PATTERNS = client_v2_patterns("/user/(?P[^/]*)/filter/(?P[^/]*)") def __init__(self, hs): super(GetFilterRestServlet, self).__init__() @@ -65,7 +65,7 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/user/(?P[^/]*)/filter") + PATTERNS = client_v2_patterns("/user/(?P[^/]*)/filter") def __init__(self, hs): super(CreateFilterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 820d33336..c55e85920 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -21,7 +21,7 @@ from synapse.types import UserID from canonicaljson import encode_canonical_json -from ._base import client_v2_pattern +from ._base import client_v2_patterns import simplejson as json import logging @@ -54,7 +54,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERN = client_v2_pattern("/keys/upload/(?P[^/]*)") + PATTERNS = client_v2_patterns("/keys/upload/(?P[^/]*)") def __init__(self, hs): super(KeyUploadServlet, self).__init__() @@ -154,12 +154,13 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/keys/query(?:" "/(?P[^/]*)(?:" "/(?P[^/]*)" ")?" - ")?" + ")?", + releases=() ) def __init__(self, hs): @@ -245,10 +246,11 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/keys/claim(?:/?|(?:/" "(?P[^/]*)/(?P[^/]*)/(?P[^/]*)" - ")?)" + ")?)", + releases=() ) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index 788acd4ad..aa214e13b 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -17,7 +17,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern +from ._base import client_v2_patterns import logging @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class ReceiptRestServlet(RestServlet): - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/rooms/(?P[^/]*)" "/receipt/(?P[^/]*)" "/(?P[^/]*)$" diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f89937631..b2b89652c 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -19,7 +19,7 @@ from synapse.api.constants import LoginType from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request import logging import hmac @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) class RegisterRestServlet(RestServlet): - PATTERN = client_v2_pattern("/register") + PATTERNS = client_v2_patterns("/register") def __init__(self, hs): super(RegisterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 775f49885..09693bb43 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -25,7 +25,7 @@ from synapse.events.utils import ( serialize_event, format_event_for_client_v2_without_room_id, ) from synapse.api.filtering import FilterCollection -from ._base import client_v2_pattern +from ._base import client_v2_patterns import copy import logging @@ -69,7 +69,7 @@ class SyncRestServlet(RestServlet): } """ - PATTERN = client_v2_pattern("/sync$") + PATTERNS = client_v2_patterns("/sync$") ALLOWED_PRESENCE = set(["online", "offline"]) def __init__(self, hs): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index ba7223be1..b5d0db556 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import client_v2_pattern +from ._base import client_v2_patterns from synapse.http.servlet import RestServlet from synapse.api.errors import AuthError, SynapseError @@ -31,7 +31,7 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags" ) @@ -56,7 +56,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ - PATTERN = client_v2_pattern( + PATTERNS = client_v2_patterns( "/user/(?P[^/]*)/rooms/(?P[^/]*)/tags/(?P[^/]*)" ) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 901e77798..5a63afd51 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_pattern, parse_json_dict_from_request +from ._base import client_v2_patterns, parse_json_dict_from_request class TokenRefreshRestServlet(RestServlet): @@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ - PATTERN = client_v2_pattern("/tokenrefresh") + PATTERNS = client_v2_patterns("/tokenrefresh") def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() diff --git a/tests/utils.py b/tests/utils.py index 91040c2ef..aee69b1ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -168,8 +168,9 @@ class MockHttpResource(HttpServer): raise KeyError("No event can handle %s" % path) - def register_path(self, method, path_pattern, callback): - self.callbacks.append((method, path_pattern, callback)) + def register_paths(self, method, path_patterns, callback): + for path_pattern in path_patterns: + self.callbacks.append((method, path_pattern, callback)) class MockKey(object): From c533f69d380e1e2643e5653825a98d82fad1bb2a Mon Sep 17 00:00:00 2001 From: "Mads R. Christensen" Date: Tue, 1 Dec 2015 20:00:41 +0100 Subject: [PATCH 08/10] Added libffi-devel in CentOS 7 installation requirements and fixed indentation of yum groupinstall. Signed-off-by: Mads Robin Christensen --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 1761d3398..3d2a8ae34 100644 --- a/README.rst +++ b/README.rst @@ -115,8 +115,8 @@ Installing prerequisites on CentOS 7:: sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ lcms2-devel libwebp-devel tcl-devel tk-devel \ - python-virtualenv - sudo yum groupinstall "Development Tools" + python-virtualenv libffi-devel + sudo yum groupinstall "Development Tools" Installing prerequisites on Mac OS X:: From 98ee629d00c78fcf378ec9b9c1e64def85bbed5c Mon Sep 17 00:00:00 2001 From: "Mads R. Christensen" Date: Tue, 1 Dec 2015 20:20:53 +0100 Subject: [PATCH 09/10] Added --report-status=yes|no as Synapse won't generate the config without it --- README.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 3d2a8ae34..9149f2fee 100644 --- a/README.rst +++ b/README.rst @@ -115,7 +115,7 @@ Installing prerequisites on CentOS 7:: sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ lcms2-devel libwebp-devel tcl-devel tk-devel \ - python-virtualenv libffi-devel + python-virtualenv libffi-devel openssl-devel sudo yum groupinstall "Development Tools" @@ -152,7 +152,8 @@ To set up your homeserver, run (in your virtualenv, as before):: python -m synapse.app.homeserver \ --server-name machine.my.domain.name \ --config-path homeserver.yaml \ - --generate-config + --generate-config \ + --report-stats=[yes|no] Substituting your host and domain name as appropriate. From 95f30ecd1f90cd143c908589b600742148491c15 Mon Sep 17 00:00:00 2001 From: Mark Haines Date: Tue, 1 Dec 2015 18:41:32 +0000 Subject: [PATCH 10/10] Add API for setting account_data globaly or on a per room basis --- synapse/api/filtering.py | 9 +- synapse/handlers/account_data.py | 21 +- synapse/handlers/message.py | 40 +++- synapse/handlers/sync.py | 72 ++++-- synapse/rest/client/v2_alpha/__init__.py | 2 + synapse/rest/client/v2_alpha/account_data.py | 111 +++++++++ synapse/rest/client/v2_alpha/sync.py | 6 + synapse/storage/__init__.py | 2 + synapse/storage/account_data.py | 211 ++++++++++++++++++ .../storage/schema/delta/26/account_data.sql | 23 ++ synapse/storage/tags.py | 4 +- 11 files changed, 476 insertions(+), 25 deletions(-) create mode 100644 synapse/rest/client/v2_alpha/account_data.py create mode 100644 synapse/storage/account_data.py diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index 18f2ec3ae..19f30c273 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -50,7 +50,7 @@ class Filtering(object): # many definitions. top_level_definitions = [ - "presence" + "presence", "account_data" ] room_level_definitions = [ @@ -139,6 +139,10 @@ class FilterCollection(object): self.filter_json.get("presence", {}) ) + self.account_data = Filter( + self.filter_json.get("account_data", {}) + ) + def timeline_limit(self): return self.room_timeline_filter.limit() @@ -151,6 +155,9 @@ class FilterCollection(object): def filter_presence(self, events): return self.presence_filter.filter(events) + def filter_account_data(self, events): + return self.account_data.filter(events) + def filter_room_state(self, events): return self.room_state_filter.filter(events) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 1d35d3b7d..fe773bee9 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -29,9 +29,10 @@ class AccountDataEventSource(object): last_stream_id = from_key current_stream_id = yield self.store.get_max_account_data_stream_id() - tags = yield self.store.get_updated_tags(user_id, last_stream_id) results = [] + tags = yield self.store.get_updated_tags(user_id, last_stream_id) + for room_id, room_tags in tags.items(): results.append({ "type": "m.tag", @@ -39,6 +40,24 @@ class AccountDataEventSource(object): "room_id": room_id, }) + account_data, room_account_data = ( + yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) + + for account_data_type, content in account_data.items(): + results.append({ + "type": account_data_type, + "content": content, + }) + + for room_id, account_data in room_account_data.items(): + for account_data_type, content in account_data.items(): + results.append({ + "type": account_data_type, + "content": content, + "room_id": room_id, + }) + defer.returnValue((results, current_stream_id)) @defer.inlineCallbacks diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 64c57375f..e959ce50b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -359,6 +359,10 @@ class MessageHandler(BaseHandler): tags_by_room = yield self.store.get_tags_for_user(user_id) + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user(user_id) + ) + public_room_ids = yield self.store.get_public_room_ids() limit = pagin_config.limit @@ -436,14 +440,22 @@ class MessageHandler(BaseHandler): for c in current_state.values() ] - account_data = [] + account_data_events = [] tags = tags_by_room.get(event.room_id) if tags: - account_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - d["account_data"] = account_data + + account_data = account_data_by_room.get(event.room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + d["account_data"] = account_data_events except: logger.exception("Failed to get snapshot") @@ -456,9 +468,17 @@ class MessageHandler(BaseHandler): consumeErrors=True ).addErrback(unwrapFirstError) + account_data_events = [] + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + ret = { "rooms": rooms_ret, "presence": presence, + "account_data": account_data_events, "receipts": receipt, "end": now_token.to_string(), } @@ -498,14 +518,22 @@ class MessageHandler(BaseHandler): user_id, room_id, pagin_config, membership, member_event_id, is_guest ) - account_data = [] + account_data_events = [] tags = yield self.store.get_tags_for_room(user_id, room_id) if tags: - account_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - result["account_data"] = account_data + + account_data = yield self.store.get_account_data_for_room(user_id, room_id) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + result["account_data"] = account_data_events defer.returnValue(result) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 877328b29..943ce368e 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -100,6 +100,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [ "next_batch", # Token for the next sync "presence", # List of presence events for the user. + "account_data", # List of account_data events for the user. "joined", # JoinedSyncResult for each joined room. "invited", # InvitedSyncResult for each invited room. "archived", # ArchivedSyncResult for each archived room. @@ -195,6 +196,12 @@ class SyncHandler(BaseHandler): ) ) + account_data, account_data_by_room = ( + yield self.store.get_account_data_for_user( + sync_config.user.to_string() + ) + ) + tags_by_room = yield self.store.get_tags_for_user( sync_config.user.to_string() ) @@ -211,6 +218,7 @@ class SyncHandler(BaseHandler): timeline_since_token=timeline_since_token, ephemeral_by_room=ephemeral_by_room, tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, ) joined.append(room_sync) elif event.membership == Membership.INVITE: @@ -230,11 +238,13 @@ class SyncHandler(BaseHandler): leave_token=leave_token, timeline_since_token=timeline_since_token, tags_by_room=tags_by_room, + account_data_by_room=account_data_by_room, ) archived.append(room_sync) defer.returnValue(SyncResult( presence=presence, + account_data=self.account_data_for_user(account_data), joined=joined, invited=invited, archived=archived, @@ -244,7 +254,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def full_state_sync_for_joined_room(self, room_id, sync_config, now_token, timeline_since_token, - ephemeral_by_room, tags_by_room): + ephemeral_by_room, tags_by_room, + account_data_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. @@ -262,19 +273,38 @@ class SyncHandler(BaseHandler): state=current_state, ephemeral=ephemeral_by_room.get(room_id, []), account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), )) - def account_data_for_room(self, room_id, tags_by_room): - account_data = [] + def account_data_for_user(self, account_data): + account_data_events = [] + + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + return account_data_events + + def account_data_for_room(self, room_id, tags_by_room, account_data_by_room): + account_data_events = [] tags = tags_by_room.get(room_id) if tags is not None: - account_data.append({ + account_data_events.append({ "type": "m.tag", "content": {"tags": tags}, }) - return account_data + + account_data = account_data_by_room.get(room_id, {}) + for account_data_type, content in account_data.items(): + account_data_events.append({ + "type": account_data_type, + "content": content, + }) + + return account_data_events @defer.inlineCallbacks def ephemeral_by_room(self, sync_config, now_token, since_token=None): @@ -341,7 +371,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def full_state_sync_for_archived_room(self, room_id, sync_config, leave_event_id, leave_token, - timeline_since_token, tags_by_room): + timeline_since_token, tags_by_room, + account_data_by_room): """Sync a room for a client which is starting without any state Returns: A Deferred JoinedSyncResult. @@ -358,7 +389,7 @@ class SyncHandler(BaseHandler): timeline=batch, state=leave_state, account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), )) @@ -415,6 +446,13 @@ class SyncHandler(BaseHandler): since_token.account_data_key, ) + account_data, account_data_by_room = ( + yield self.store.get_updated_account_data_for_user( + sync_config.user.to_string(), + since_token.account_data_key, + ) + ) + joined = [] archived = [] if len(room_events) <= timeline_limit: @@ -469,7 +507,7 @@ class SyncHandler(BaseHandler): state=state, ephemeral=ephemeral_by_room.get(room_id, []), account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), ) logger.debug("Result for room %s: %r", room_id, room_sync) @@ -492,14 +530,15 @@ class SyncHandler(BaseHandler): for room_id in joined_room_ids: room_sync = yield self.incremental_sync_with_gap_for_room( room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room + ephemeral_by_room, tags_by_room, account_data_by_room ) if room_sync: joined.append(room_sync) for leave_event in leave_events: room_sync = yield self.incremental_sync_for_archived_room( - sync_config, leave_event, since_token, tags_by_room + sync_config, leave_event, since_token, tags_by_room, + account_data_by_room ) archived.append(room_sync) @@ -510,6 +549,7 @@ class SyncHandler(BaseHandler): defer.returnValue(SyncResult( presence=presence, + account_data=self.account_data_for_user(account_data), joined=joined, invited=invited, archived=archived, @@ -566,7 +606,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_with_gap_for_room(self, room_id, sync_config, since_token, now_token, - ephemeral_by_room, tags_by_room): + ephemeral_by_room, tags_by_room, + account_data_by_room): """ Get the incremental delta needed to bring the client up to date for the room. Gives the client the most recent events and the changes to state. @@ -606,7 +647,7 @@ class SyncHandler(BaseHandler): state=state, ephemeral=ephemeral_by_room.get(room_id, []), account_data=self.account_data_for_room( - room_id, tags_by_room + room_id, tags_by_room, account_data_by_room ), ) @@ -616,7 +657,8 @@ class SyncHandler(BaseHandler): @defer.inlineCallbacks def incremental_sync_for_archived_room(self, sync_config, leave_event, - since_token, tags_by_room): + since_token, tags_by_room, + account_data_by_room): """ Get the incremental delta needed to bring the client up to date for the archived room. Returns: @@ -654,7 +696,7 @@ class SyncHandler(BaseHandler): timeline=batch, state=state_events_delta, account_data=self.account_data_for_room( - leave_event.room_id, tags_by_room + leave_event.room_id, tags_by_room, account_data_by_room ), ) diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index a10813234..d7b59c84d 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -23,6 +23,7 @@ from . import ( keys, tokenrefresh, tags, + account_data, ) from synapse.http.server import JsonResource @@ -46,3 +47,4 @@ class ClientV2AlphaRestResource(JsonResource): keys.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource) + account_data.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py new file mode 100644 index 000000000..5b8f454bf --- /dev/null +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import client_v2_patterns + +from synapse.http.servlet import RestServlet +from synapse.api.errors import AuthError, SynapseError + +from twisted.internet import defer + +import logging + +import simplejson as json + +logger = logging.getLogger(__name__) + + +class AccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 + """ + PATTERNS = client_v2_patterns( + "/user/(?P[^/]*)/account_data/(?P[^/]*)" + ) + + def __init__(self, hs): + super(AccountDataServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, account_data_type): + auth_user, _, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid JSON") + + max_id = yield self.store.add_account_data_for_user( + user_id, account_data_type, body + ) + + yield self.notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +class RoomAccountDataServlet(RestServlet): + """ + PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 + """ + PATTERNS = client_v2_patterns( + "/user/(?P[^/]*)" + "/rooms/(?P[^/]*)" + "/account_data/(?P[^/]*)" + ) + + def __init__(self, hs): + super(RoomAccountDataServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() + + @defer.inlineCallbacks + def on_PUT(self, request, user_id, room_id, account_data_type): + auth_user, _, _ = yield self.auth.get_user_by_req(request) + if user_id != auth_user.to_string(): + raise AuthError(403, "Cannot add account data for other users.") + + try: + content_bytes = request.content.read() + body = json.loads(content_bytes) + except: + raise SynapseError(400, "Invalid JSON") + + if not isinstance(body, dict): + raise ValueError("Expected a JSON object") + + max_id = yield self.store.add_account_data_to_room( + user_id, room_id, account_data_type, body + ) + + yield self.notifier.on_new_event( + "account_data_key", max_id, users=[user_id] + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + AccountDataServlet(hs).register(http_server) + RoomAccountDataServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 09693bb43..4efe80248 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -144,6 +144,9 @@ class SyncRestServlet(RestServlet): ) response_content = { + "account_data": self.encode_account_data( + sync_result.account_data, filter, time_now + ), "presence": self.encode_presence( sync_result.presence, filter, time_now ), @@ -165,6 +168,9 @@ class SyncRestServlet(RestServlet): formatted.append(event) return {"events": filter.filter_presence(formatted)} + def encode_account_data(self, events, filter, time_now): + return {"events": filter.filter_account_data(events)} + def encode_joined(self, rooms, filter, time_now, token_id): """ Encode the joined rooms in a sync result diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e7443f283..c46b653f1 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -42,6 +42,7 @@ from .end_to_end_keys import EndToEndKeyStore from .receipts import ReceiptsStore from .search import SearchStore from .tags import TagsStore +from .account_data import AccountDataStore import logging @@ -73,6 +74,7 @@ class DataStore(RoomMemberStore, RoomStore, EndToEndKeyStore, SearchStore, TagsStore, + AccountDataStore, ): def __init__(self, hs): diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py new file mode 100644 index 000000000..d1829f84e --- /dev/null +++ b/synapse/storage/account_data.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from twisted.internet import defer + +import ujson as json +import logging + +logger = logging.getLogger(__name__) + + +class AccountDataStore(SQLBaseStore): + + def get_account_data_for_user(self, user_id): + """Get all the client account_data for a user. + + Args: + user_id(str): The user to get the account_data for. + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_account_data_for_user_txn(txn): + rows = self._simple_select_list_txn( + txn, "account_data", {"user_id": user_id}, + ["account_data_type", "content"] + ) + + global_account_data = { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + rows = self._simple_select_list_txn( + txn, "room_account_data", {"user_id": user_id}, + ["room_id", "account_data_type", "content"] + ) + + by_room = {} + for row in rows: + room_data = by_room.setdefault(row["room_id"], {}) + room_data[row["account_data_type"]] = json.loads(row["content"]) + + return (global_account_data, by_room) + + return self.runInteraction( + "get_account_data_for_user", get_account_data_for_user_txn + ) + + def get_account_data_for_room(self, user_id, room_id): + """Get all the client account_data for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + Returns: + A deferred dict of the room account_data + """ + def get_account_data_for_room_txn(txn): + rows = self._simple_select_list_txn( + txn, "room_account_data", {"user_id": user_id, "room_id": room_id}, + ["account_data_type", "content"] + ) + + return { + row["account_data_type"]: json.loads(row["content"]) for row in rows + } + + return self.runInteraction( + "get_account_data_for_room", get_account_data_for_room_txn + ) + + def get_updated_account_data_for_user(self, user_id, stream_id): + """Get all the client account_data for a that's changed. + + Args: + user_id(str): The user to get the account_data for. + stream_id(int): The point in the stream since which to get updates + Returns: + A deferred pair of a dict of global account_data and a dict + mapping from room_id string to per room account_data dicts. + """ + + def get_updated_account_data_for_user_txn(txn): + sql = ( + "SELECT account_data_type, content FROM account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + global_account_data = { + row[0]: json.loads(row[1]) for row in txn.fetchall() + } + + sql = ( + "SELECT room_id, account_data_type, content FROM room_account_data" + " WHERE user_id = ? AND stream_id > ?" + ) + + txn.execute(sql, (user_id, stream_id)) + + account_data_by_room = {} + for row in txn.fetchall(): + room_account_data = account_data_by_room.setdefault(row[0], {}) + room_account_data[row[1]] = json.loads(row[2]) + + return (global_account_data, account_data_by_room) + + return self.runInteraction( + "get_updated_account_data_for_user", get_updated_account_data_for_user_txn + ) + + @defer.inlineCallbacks + def add_account_data_to_room(self, user_id, room_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + room_id(str): The room to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + def add_account_data_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + values={ + "stream_id": next_id, + "content": content_json, + } + ) + self._update_max_stream_id(txn, next_id) + + with (yield self._account_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction( + "add_room_account_data", add_account_data_txn, next_id + ) + + result = yield self._account_data_id_gen.get_max_token(self) + defer.returnValue(result) + + @defer.inlineCallbacks + def add_account_data_for_user(self, user_id, account_data_type, content): + """Add some account_data to a room for a user. + Args: + user_id(str): The user to add a tag for. + account_data_type(str): The type of account_data to add. + content(dict): A json object to associate with the tag. + Returns: + A deferred that completes once the account_data has been added. + """ + content_json = json.dumps(content) + + def add_account_data_txn(txn, next_id): + self._simple_upsert_txn( + txn, + table="account_data", + keyvalues={ + "user_id": user_id, + "account_data_type": account_data_type, + }, + values={ + "stream_id": next_id, + "content": content_json, + } + ) + self._update_max_stream_id(txn, next_id) + + with (yield self._account_data_id_gen.get_next(self)) as next_id: + yield self.runInteraction( + "add_user_account_data", add_account_data_txn, next_id + ) + + result = yield self._account_data_id_gen.get_max_token(self) + defer.returnValue(result) + + def _update_max_stream_id(self, txn, next_id): + """Update the max stream_id + + Args: + txn: The database cursor + next_id(int): The the revision to advance to. + """ + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/schema/delta/26/account_data.sql index 3198a0d29..48ad9cc6b 100644 --- a/synapse/storage/schema/delta/26/account_data.sql +++ b/synapse/storage/schema/delta/26/account_data.sql @@ -15,3 +15,26 @@ ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; + + +CREATE TABLE IF NOT EXISTS account_data( + user_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type) +); + + +CREATE TABLE IF NOT EXISTS room_account_data( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + account_data_type TEXT NOT NULL, -- The type of the account_data. + stream_id BIGINT NOT NULL, -- The version of the account_data. + content TEXT NOT NULL, -- The JSON content of the account_data + CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type) +); + + +CREATE INDEX account_data_stream_id on account_data(user_id, stream_id); +CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id); diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index f6d826cc5..f520f60c6 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -48,8 +48,8 @@ class TagsStore(SQLBaseStore): Args: user_id(str): The user to get the tags for. Returns: - A deferred dict mapping from room_id strings to lists of tag - strings. + A deferred dict mapping from room_id strings to dicts mapping from + tag strings to tag content. """ deferred = self._simple_select_list(