Merge branch 'develop' into markjh/edu_frequency

This commit is contained in:
Mark Haines 2015-12-01 19:15:27 +00:00
commit f73ea0bda2
45 changed files with 854 additions and 252 deletions

View File

@ -115,8 +115,8 @@ Installing prerequisites on CentOS 7::
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \ sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
lcms2-devel libwebp-devel tcl-devel tk-devel \ lcms2-devel libwebp-devel tcl-devel tk-devel \
python-virtualenv python-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools" sudo yum groupinstall "Development Tools"
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
@ -152,7 +152,8 @@ To set up your homeserver, run (in your virtualenv, as before)::
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config \
--report-stats=[yes|no]
Substituting your host and domain name as appropriate. Substituting your host and domain name as appropriate.

View File

@ -50,7 +50,7 @@ class Filtering(object):
# many definitions. # many definitions.
top_level_definitions = [ top_level_definitions = [
"presence" "presence", "account_data"
] ]
room_level_definitions = [ room_level_definitions = [
@ -139,6 +139,10 @@ class FilterCollection(object):
self.filter_json.get("presence", {}) self.filter_json.get("presence", {})
) )
self.account_data = Filter(
self.filter_json.get("account_data", {})
)
def timeline_limit(self): def timeline_limit(self):
return self.room_timeline_filter.limit() return self.room_timeline_filter.limit()
@ -151,6 +155,9 @@ class FilterCollection(object):
def filter_presence(self, events): def filter_presence(self, events):
return self.presence_filter.filter(events) return self.presence_filter.filter(events)
def filter_account_data(self, events):
return self.account_data.filter(events)
def filter_room_state(self, events): def filter_room_state(self, events):
return self.room_state_filter.filter(events) return self.room_state_filter.filter(events)

View File

@ -100,22 +100,20 @@ def format_event_raw(d):
def format_event_for_client_v1(d): def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None) d = format_event_for_client_v2(d)
move_keys = ( sender = d.get("sender")
if sender is not None:
d["user_id"] = sender
copy_keys = (
"age", "redacted_because", "replaces_state", "prev_content", "age", "redacted_because", "replaces_state", "prev_content",
"invite_room_state", "invite_room_state",
) )
for key in move_keys: for key in copy_keys:
if key in d["unsigned"]: if key in d["unsigned"]:
d[key] = d["unsigned"][key] d[key] = d["unsigned"][key]
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"unsigned", "origin", "prev_state"
)
for key in drop_keys:
d.pop(key, None)
return d return d

View File

@ -165,7 +165,7 @@ class BaseFederationServlet(object):
if code is None: if code is None:
continue continue
server.register_path(method, pattern, self._wrap(code)) server.register_paths(method, (pattern,), self._wrap(code))
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):

View File

@ -29,9 +29,10 @@ class AccountDataEventSource(object):
last_stream_id = from_key last_stream_id = from_key
current_stream_id = yield self.store.get_max_account_data_stream_id() current_stream_id = yield self.store.get_max_account_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = [] results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items(): for room_id, room_tags in tags.items():
results.append({ results.append({
"type": "m.tag", "type": "m.tag",
@ -39,6 +40,24 @@ class AccountDataEventSource(object):
"room_id": room_id, "room_id": room_id,
}) })
account_data, room_account_data = (
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
)
for account_data_type, content in account_data.items():
results.append({
"type": account_data_type,
"content": content,
})
for room_id, account_data in room_account_data.items():
for account_data_type, content in account_data.items():
results.append({
"type": account_data_type,
"content": content,
"room_id": room_id,
})
defer.returnValue((results, current_stream_id)) defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -359,6 +359,10 @@ class MessageHandler(BaseHandler):
tags_by_room = yield self.store.get_tags_for_user(user_id) tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -436,14 +440,22 @@ class MessageHandler(BaseHandler):
for c in current_state.values() for c in current_state.values()
] ]
account_data = [] account_data_events = []
tags = tags_by_room.get(event.room_id) tags = tags_by_room.get(event.room_id)
if tags: if tags:
account_data.append({ account_data_events.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
d["account_data"] = account_data
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -456,9 +468,17 @@ class MessageHandler(BaseHandler):
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,
"account_data": account_data_events,
"receipts": receipt, "receipts": receipt,
"end": now_token.to_string(), "end": now_token.to_string(),
} }
@ -498,14 +518,22 @@ class MessageHandler(BaseHandler):
user_id, room_id, pagin_config, membership, member_event_id, is_guest user_id, room_id, pagin_config, membership, member_event_id, is_guest
) )
account_data = [] account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags: if tags:
account_data.append({ account_data_events.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
result["account_data"] = account_data
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result) defer.returnValue(result)

View File

@ -467,7 +467,7 @@ class PresenceHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, observer_user, observed_user): def send_presence_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user""" """Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")

View File

@ -131,6 +131,17 @@ class SearchHandler(BaseHandler):
if batch_group == "room_id": if batch_group == "room_id":
room_ids.intersection_update({batch_group_key}) room_ids.intersection_update({batch_group_key})
if not room_ids:
defer.returnValue({
"search_categories": {
"room_events": {
"results": {},
"count": 0,
"highlights": [],
}
}
})
rank_map = {} # event_id -> rank of event rank_map = {} # event_id -> rank of event
allowed_events = [] allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable room_groups = {} # Holds result of grouping by room, if applicable
@ -178,85 +189,67 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id) s["results"].append(e.event_id)
elif order_by == "recent": elif order_by == "recent":
# In this case we specifically loop through each room as the given room_events = []
# limit applies to each room, rather than a global list. i = 0
# This is not necessarilly a good idea.
for room_id in room_ids: pagination_token = batch_token
room_events = []
if batch_group == "room_id" and batch_group_key == room_id: # We keep looping and we keep filtering until we reach the limit
pagination_token = batch_token # or we run out of things.
else: # But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
room_ids, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token,
)
if search_result["highlights"]:
highlights.update(search_result["highlights"])
results = search_result["results"]
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([
r["event"] for r in results
])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[:search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None pagination_token = None
i = 0 break
# We keep looping and we keep filtering until we reach the limit
# or we run out of things.
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_room(
room_id, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token,
)
if search_result["highlights"]:
highlights.update(search_result["highlights"])
results = search_result["results"]
results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results})
filtered_events = search_filter.filter([
r["event"] for r in results
])
events = yield self._filter_events_for_client(
user.to_string(), filtered_events
)
room_events.extend(events)
room_events = room_events[:search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
break
else:
pagination_token = results[-1]["pagination_token"]
if room_events:
res = results_map[room_events[-1].event_id]
pagination_token = res["pagination_token"]
group = room_groups.setdefault(room_id, {})
if pagination_token:
next_batch = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
group["next_batch"] = next_batch
if batch_token:
global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
e.origin_server_ts/1000 for e in room_events
if hasattr(e, "origin_server_ts")
)
allowed_events.extend(room_events)
# Normalize the group orders
if room_groups:
if len(room_groups) > 1:
mx = max(g["order"] for g in room_groups.values())
mn = min(g["order"] for g in room_groups.values())
for g in room_groups.values():
g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
else: else:
room_groups.values()[0]["order"] = 1 pagination_token = results[-1]["pagination_token"]
for event in room_events:
group = room_groups.setdefault(event.room_id, {
"results": [],
})
group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit():
last_event_id = room_events[-1].event_id
pagination_token = results_map[last_event_id]["pagination_token"]
global_next_batch = encode_base64("%s\n%s\n%s" % (
"all", "", pagination_token
))
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
allowed_events.extend(room_events)
else: else:
# We should never get here due to the guard earlier. # We should never get here due to the guard earlier.

View File

@ -100,6 +100,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
class SyncResult(collections.namedtuple("SyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync "next_batch", # Token for the next sync
"presence", # List of presence events for the user. "presence", # List of presence events for the user.
"account_data", # List of account_data events for the user.
"joined", # JoinedSyncResult for each joined room. "joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
@ -195,6 +196,12 @@ class SyncHandler(BaseHandler):
) )
) )
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(
sync_config.user.to_string()
)
)
tags_by_room = yield self.store.get_tags_for_user( tags_by_room = yield self.store.get_tags_for_user(
sync_config.user.to_string() sync_config.user.to_string()
) )
@ -211,6 +218,7 @@ class SyncHandler(BaseHandler):
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room, ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
) )
joined.append(room_sync) joined.append(room_sync)
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
@ -230,11 +238,13 @@ class SyncHandler(BaseHandler):
leave_token=leave_token, leave_token=leave_token,
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
) )
archived.append(room_sync) archived.append(room_sync)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data),
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -244,7 +254,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_joined_room(self, room_id, sync_config, def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token, now_token, timeline_since_token,
ephemeral_by_room, tags_by_room): ephemeral_by_room, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -262,19 +273,38 @@ class SyncHandler(BaseHandler):
state=current_state, state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
account_data=self.account_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
)) ))
def account_data_for_room(self, room_id, tags_by_room): def account_data_for_user(self, account_data):
account_data = [] account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
return account_data_events
def account_data_for_room(self, room_id, tags_by_room, account_data_by_room):
account_data_events = []
tags = tags_by_room.get(room_id) tags = tags_by_room.get(room_id)
if tags is not None: if tags is not None:
account_data.append({ account_data_events.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
return account_data
account_data = account_data_by_room.get(room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
return account_data_events
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_config, now_token, since_token=None):
@ -341,7 +371,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token, tags_by_room): timeline_since_token, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -358,7 +389,7 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=leave_state, state=leave_state,
account_data=self.account_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
)) ))
@ -415,6 +446,13 @@ class SyncHandler(BaseHandler):
since_token.account_data_key, since_token.account_data_key,
) )
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(),
since_token.account_data_key,
)
)
joined = [] joined = []
archived = [] archived = []
if len(room_events) <= timeline_limit: if len(room_events) <= timeline_limit:
@ -469,7 +507,7 @@ class SyncHandler(BaseHandler):
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
account_data=self.account_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
) )
logger.debug("Result for room %s: %r", room_id, room_sync) logger.debug("Result for room %s: %r", room_id, room_sync)
@ -492,14 +530,15 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room ephemeral_by_room, tags_by_room, account_data_by_room
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events: for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room sync_config, leave_event, since_token, tags_by_room,
account_data_by_room
) )
archived.append(room_sync) archived.append(room_sync)
@ -510,6 +549,7 @@ class SyncHandler(BaseHandler):
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data),
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -566,7 +606,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config, def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room): ephemeral_by_room, tags_by_room,
account_data_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to the room. Gives the client the most recent events and the changes to
state. state.
@ -606,7 +647,7 @@ class SyncHandler(BaseHandler):
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
account_data=self.account_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
) )
@ -616,7 +657,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token, tags_by_room): since_token, tags_by_room,
account_data_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
@ -654,7 +696,7 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
account_data=self.account_data_for_room( account_data=self.account_data_for_room(
leave_event.room_id, tags_by_room leave_event.room_id, tags_by_room, account_data_by_room
), ),
) )

View File

@ -120,7 +120,7 @@ class HttpServer(object):
""" Interface for registering callbacks on a HTTP server """ Interface for registering callbacks on a HTTP server
""" """
def register_path(self, method, path_pattern, callback): def register_paths(self, method, path_patterns, callback):
""" Register a callback that gets fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
@ -129,7 +129,7 @@ class HttpServer(object):
Args: Args:
method (str): The method to listen to. method (str): The method to listen to.
path_pattern (str): The regex used to match requests. path_patterns (list<SRE_Pattern>): The regex used to match requests.
callback (function): The function to fire if we receive a matched callback (function): The function to fire if we receive a matched
request. The first argument will be the request object and request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex. subsequent arguments will be any matched groups from the regex.
@ -165,10 +165,11 @@ class JsonResource(HttpServer, resource.Resource):
self.version_string = hs.version_string self.version_string = hs.version_string
self.hs = hs self.hs = hs
def register_path(self, method, path_pattern, callback): def register_paths(self, method, path_patterns, callback):
self.path_regexs.setdefault(method, []).append( for path_pattern in path_patterns:
self._PathEntry(path_pattern, callback) self.path_regexs.setdefault(method, []).append(
) self._PathEntry(path_pattern, callback)
)
def render(self, request): def render(self, request):
""" This gets called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.

View File

@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,12 +101,13 @@ class RestServlet(object):
def register(self, http_server): def register(self, http_server):
""" Register this servlet with the given HTTP server. """ """ Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"): if hasattr(self, "PATTERNS"):
pattern = self.PATTERN patterns = self.PATTERNS
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method,)): if hasattr(self, "on_%s" % (method,)):
method_handler = getattr(self, "on_%s" % (method,)) method_handler = getattr(self, "on_%s" % (method,))
http_server.register_path(method, pattern, method_handler) http_server.register_paths(method, patterns, method_handler)
else: else:
raise NotImplementedError("RestServlet must register something.") raise NotImplementedError("RestServlet must register something.")

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import logging import logging
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)", releases=())
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_path_pattern(path_regex): def client_path_patterns(path_regex, releases=(0,)):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -37,7 +37,13 @@ def client_path_pattern(path_regex):
Returns: Returns:
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_PREFIX + path_regex) patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
class ClientV1RestServlet(RestServlet): class ClientV1RestServlet(RestServlet):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import RoomAlias from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -32,7 +32,7 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
import logging import logging
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet): class EventStreamRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events$") PATTERNS = client_path_patterns("/events$")
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
@ -72,7 +72,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events. # TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet): class EventRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__(hs)

View File

@ -16,12 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/initialSync$") PATTERNS = client_path_patterns("/initialSync$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import urllib import urllib
@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERNS = client_path_patterns("/login$", releases=())
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
@ -238,7 +238,7 @@ class LoginRestServlet(ClientV1RestServlet):
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/saml2") PATTERNS = client_path_patterns("/login/saml2", releases=())
def __init__(self, hs): def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs) super(SAML2RestServlet, self).__init__(hs)
@ -282,7 +282,7 @@ class SAML2RestServlet(ClientV1RestServlet):
# TODO Delete this after all CAS clients switch to token login instead # TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet): class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas") PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs): def __init__(self, hs):
super(CasRestServlet, self).__init__(hs) super(CasRestServlet, self).__init__(hs)
@ -293,7 +293,7 @@ class CasRestServlet(ClientV1RestServlet):
class CasRedirectServlet(ClientV1RestServlet): class CasRedirectServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/redirect") PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
def __init__(self, hs): def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs) super(CasRedirectServlet, self).__init__(hs)
@ -316,7 +316,7 @@ class CasRedirectServlet(ClientV1RestServlet):
class CasTicketServlet(ClientV1RestServlet): class CasTicketServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas/ticket") PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
def __init__(self, hs): def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs) super(CasTicketServlet, self).__init__(hs)

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceStatusRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -73,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
class PresenceListRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -120,7 +120,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if len(u) == 0: if len(u) == 0:
continue continue
invited_user = UserID.from_string(u) invited_user = UserID.from_string(u)
yield self.handlers.presence_handler.send_invite( yield self.handlers.presence_handler.send_presence_invite(
observer_user=user, observed_user=invited_user observer_user=user, observed_user=invited_user
) )

View File

@ -16,14 +16,14 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
from synapse.types import UserID from synapse.types import UserID
import simplejson as json import simplejson as json
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -56,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -89,7 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
) )
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException InconsistentRuleException, RuleNotFoundException
) )
@ -31,7 +31,7 @@ import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet): class PushRuleRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$") PATTERNS = client_path_patterns("/pushrules/.*$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")

View File

@ -17,13 +17,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
class PusherRestServlet(ClientV1RestServlet): class PusherRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushers/set$") PATTERNS = client_path_patterns("/pushers/set$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -48,7 +48,7 @@ class RegisterRestServlet(ClientV1RestServlet):
handler doesn't have a concept of multi-stages or sessions. handler doesn't have a concept of multi-stages or sessions.
""" """
PATTERN = client_path_pattern("/register$") PATTERNS = client_path_patterns("/register$", releases=())
def __init__(self, hs): def __init__(self, hs):
super(RegisterRestServlet, self).__init__(hs) super(RegisterRestServlet, self).__init__(hs)

View File

@ -16,7 +16,7 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -34,16 +34,16 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def register(self, http_server): def register(self, http_server):
PATTERN = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_path("OPTIONS", http_server.register_paths("OPTIONS",
client_path_pattern("/rooms(?:/.*)?$"), client_path_patterns("/rooms(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
# define CORS for /createRoom[/txnid] # define CORS for /createRoom[/txnid]
http_server.register_path("OPTIONS", http_server.register_paths("OPTIONS",
client_path_pattern("/createRoom(?:/.*)?$"), client_path_patterns("/createRoom(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, txn_id): def on_PUT(self, request, txn_id):
@ -103,18 +103,18 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
state_key = ("/rooms/(?P<room_id>[^/]*)/state/" state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_path("GET", http_server.register_paths("GET",
client_path_pattern(state_key), client_path_patterns(state_key),
self.on_GET) self.on_GET)
http_server.register_path("PUT", http_server.register_paths("PUT",
client_path_pattern(state_key), client_path_patterns(state_key),
self.on_PUT) self.on_PUT)
http_server.register_path("GET", http_server.register_paths("GET",
client_path_pattern(no_state_key), client_path_patterns(no_state_key, releases=()),
self.on_GET_no_state_key) self.on_GET_no_state_key)
http_server.register_path("PUT", http_server.register_paths("PUT",
client_path_pattern(no_state_key), client_path_patterns(no_state_key, releases=()),
self.on_PUT_no_state_key) self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(self, request, room_id, event_type):
return self.on_GET(request, room_id, event_type, "") return self.on_GET(request, room_id, event_type, "")
@ -170,8 +170,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
register_txn_path(self, PATTERN, http_server, with_get=True) register_txn_path(self, PATTERNS, http_server, with_get=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
@ -215,8 +215,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
PATTERN = ("/join/(?P<room_identifier>[^/]*)") PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
@ -280,7 +280,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet): class PublicRoomListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/publicRooms$") PATTERNS = client_path_patterns("/publicRooms$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -291,7 +291,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -328,7 +328,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -351,7 +351,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -368,7 +368,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -384,7 +384,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
class RoomTriggerBackfill(ClientV1RestServlet): class RoomTriggerBackfill(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/backfill$", releases=())
def __init__(self, hs): def __init__(self, hs):
super(RoomTriggerBackfill, self).__init__(hs) super(RoomTriggerBackfill, self).__init__(hs)
@ -408,7 +408,7 @@ class RoomTriggerBackfill(ClientV1RestServlet):
class RoomEventContext(ClientV1RestServlet): class RoomEventContext(ClientV1RestServlet):
PATTERN = client_path_pattern( PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
) )
@ -447,9 +447,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|kick|forget)") "(?P<membership_action>join|invite|leave|ban|kick|forget)")
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
@ -543,8 +543,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
@ -582,7 +582,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern( PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
) )
@ -615,7 +615,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
class SearchRestServlet(ClientV1RestServlet): class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern( PATTERNS = client_path_patterns(
"/search$" "/search$"
) )
@ -655,20 +655,20 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
http_server : The http_server to register paths with. http_server : The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs. with_get: True to also register respective GET paths for the PUTs.
""" """
http_server.register_path( http_server.register_paths(
"POST", "POST",
client_path_pattern(regex_string + "$"), client_path_patterns(regex_string + "$"),
servlet.on_POST servlet.on_POST
) )
http_server.register_path( http_server.register_paths(
"PUT", "PUT",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_PUT servlet.on_PUT
) )
if with_get: if with_get:
http_server.register_path( http_server.register_paths(
"GET", "GET",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_GET servlet.on_GET
) )

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import hmac import hmac
@ -24,7 +24,7 @@ import base64
class VoipRestServlet(ClientV1RestServlet): class VoipRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/voip/turnServer$") PATTERNS = client_path_patterns("/voip/turnServer$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -23,6 +23,7 @@ from . import (
keys, keys,
tokenrefresh, tokenrefresh,
tags, tags,
account_data,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -46,3 +47,4 @@ class ClientV2AlphaRestResource(JsonResource):
keys.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource)

View File

@ -27,7 +27,7 @@ import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_pattern(path_regex): def client_v2_patterns(path_regex, releases=(0,)):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -37,7 +37,13 @@ def client_v2_pattern(path_regex):
Returns: Returns:
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
def parse_request_allow_empty(request): def parse_request_allow_empty(request):

View File

@ -20,7 +20,7 @@ from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_patterns, parse_json_dict_from_request
import logging import logging
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/password") PATTERNS = client_v2_patterns("/account/password", releases=())
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@ -89,7 +89,7 @@ class PasswordRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/3pid") PATTERNS = client_v2_patterns("/account/3pid", releases=())
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()

View File

@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class AccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
def __init__(self, hs):
super(AccountDataServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
max_id = yield self.store.add_account_data_for_user(
user_id, account_data_type, body
)
yield self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
class RoomAccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"
)
def __init__(self, hs):
super(RoomAccountDataServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
if not isinstance(body, dict):
raise ValueError("Expected a JSON object")
max_id = yield self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
yield self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server)

View File

@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern from ._base import client_v2_patterns
import logging import logging
@ -97,7 +97,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint). cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth. Current use is for web fallback auth.
""" """
PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()

View File

@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_pattern from ._base import client_v2_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
@ -65,7 +65,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs): def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()

View File

@ -21,7 +21,7 @@ from synapse.types import UserID
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from ._base import client_v2_pattern from ._base import client_v2_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -54,7 +54,7 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)") PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
@ -154,12 +154,13 @@ class KeyQueryServlet(RestServlet):
} } } } } } } } } } } }
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/keys/query(?:" "/keys/query(?:"
"/(?P<user_id>[^/]*)(?:" "/(?P<user_id>[^/]*)(?:"
"/(?P<device_id>[^/]*)" "/(?P<device_id>[^/]*)"
")?" ")?"
")?" ")?",
releases=()
) )
def __init__(self, hs): def __init__(self, hs):
@ -245,10 +246,11 @@ class OneTimeKeyServlet(RestServlet):
} } } } } } } }
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/keys/claim(?:/?|(?:/" "/keys/claim(?:/?|(?:/"
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)" "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
")?)" ")?)",
releases=()
) )
def __init__(self, hs): def __init__(self, hs):

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern from ._base import client_v2_patterns
import logging import logging
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet): class ReceiptRestServlet(RestServlet):
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"

View File

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_patterns, parse_json_dict_from_request
import logging import logging
import hmac import hmac
@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register") PATTERNS = client_v2_patterns("/register")
def __init__(self, hs): def __init__(self, hs):
super(RegisterRestServlet, self).__init__() super(RegisterRestServlet, self).__init__()

View File

@ -25,7 +25,7 @@ from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_room_id, serialize_event, format_event_for_client_v2_without_room_id,
) )
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from ._base import client_v2_pattern from ._base import client_v2_patterns
import copy import copy
import logging import logging
@ -69,7 +69,7 @@ class SyncRestServlet(RestServlet):
} }
""" """
PATTERN = client_v2_pattern("/sync$") PATTERNS = client_v2_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline"]) ALLOWED_PRESENCE = set(["online", "offline"])
def __init__(self, hs): def __init__(self, hs):
@ -144,6 +144,9 @@ class SyncRestServlet(RestServlet):
) )
response_content = { response_content = {
"account_data": self.encode_account_data(
sync_result.account_data, filter, time_now
),
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, filter, time_now
), ),
@ -165,6 +168,9 @@ class SyncRestServlet(RestServlet):
formatted.append(event) formatted.append(event)
return {"events": filter.filter_presence(formatted)} return {"events": filter.filter_presence(formatted)}
def encode_account_data(self, events, filter, time_now):
return {"events": filter.filter_account_data(events)}
def encode_joined(self, rooms, filter, time_now, token_id): def encode_joined(self, rooms, filter, time_now, token_id):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import client_v2_pattern from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
@ -31,7 +31,7 @@ class TagListServlet(RestServlet):
""" """
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
) )
@ -56,7 +56,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_patterns, parse_json_dict_from_request
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh Exchanges refresh tokens for a pair of an access token and a new refresh
token. token.
""" """
PATTERN = client_v2_pattern("/tokenrefresh") PATTERNS = client_v2_patterns("/tokenrefresh")
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()

View File

@ -42,6 +42,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore
import logging import logging
@ -73,6 +74,7 @@ class DataStore(RoomMemberStore, RoomStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,
TagsStore, TagsStore,
AccountDataStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View File

@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
import ujson as json
import logging
logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
Args:
user_id(str): The user to get the account_data for.
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
rows = self._simple_select_list_txn(
txn, "account_data", {"user_id": user_id},
["account_data_type", "content"]
)
global_account_data = {
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id},
["room_id", "account_data_type", "content"]
)
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = json.loads(row["content"])
return (global_account_data, by_room)
return self.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
Returns:
A deferred dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
["account_data_type", "content"]
)
return {
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
return self.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
"""Get all the client account_data for a that's changed.
Args:
user_id(str): The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
"""
def get_updated_account_data_for_user_txn(txn):
sql = (
"SELECT account_data_type, content FROM account_data"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall()
}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
account_data_by_room = {}
for row in txn.fetchall():
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
return (global_account_data, account_data_by_room)
return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the account_data has been added.
"""
content_json = json.dumps(content)
def add_account_data_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="room_account_data",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
}
)
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id
)
result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
"""Add some account_data to a room for a user.
Args:
user_id(str): The user to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the account_data has been added.
"""
content_json = json.dumps(content)
def add_account_data_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="account_data",
keyvalues={
"user_id": user_id,
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
}
)
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
"""Update the max stream_id
Args:
txn: The database cursor
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))

View File

@ -51,6 +51,14 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True): is_new_state=True):
@ -365,6 +373,7 @@ class EventsStore(SQLBaseStore):
"processed": True, "processed": True,
"outlier": event.internal_metadata.is_outlier(), "outlier": event.internal_metadata.is_outlier(),
"content": encode_json(event.content).decode("UTF-8"), "content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],
@ -964,3 +973,71 @@ class EventsStore(SQLBaseStore):
ret = yield self.runInteraction("count_messages", _count_messages) ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
rows = []
for event in events:
try:
event_id = event.event_id
origin_server_ts = event.origin_server_ts
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
rows.append((origin_server_ts, event_id))
sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
)
for index in range(0, len(rows), INSERT_CLUMP_SIZE):
clump = rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
}
self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result)

View File

@ -15,3 +15,26 @@
ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id; ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id;
CREATE TABLE IF NOT EXISTS account_data(
user_id TEXT NOT NULL,
account_data_type TEXT NOT NULL, -- The type of the account_data.
stream_id BIGINT NOT NULL, -- The version of the account_data.
content TEXT NOT NULL, -- The JSON content of the account_data
CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type)
);
CREATE TABLE IF NOT EXISTS room_account_data(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
account_data_type TEXT NOT NULL, -- The type of the account_data.
stream_id BIGINT NOT NULL, -- The version of the account_data.
content TEXT NOT NULL, -- The JSON content of the account_data
CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type)
);
CREATE INDEX account_data_stream_id on account_data(user_id, stream_id);
CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id);

View File

@ -0,0 +1,57 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
import ujson
logger = logging.getLogger(__name__)
ALTER_TABLE = (
"ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;"
"CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
cur.execute("SELECT MIN(stream_ordering) FROM events")
rows = cur.fetchall()
min_stream_id = rows[0][0]
cur.execute("SELECT MAX(stream_ordering) FROM events")
rows = cur.fetchall()
max_stream_id = rows[0][0]
if min_stream_id is not None and max_stream_id is not None:
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
" VALUES (?, ?)"
)
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json))

View File

@ -212,11 +212,11 @@ class SearchStore(BackgroundUpdateStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def search_room(self, room_id, search_term, keys, limit, pagination_token=None): def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:
room_id (str): The room_id to search in room_id (list): The room_ids to search in
search_term (str): Search term to search for search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
@ -226,7 +226,15 @@ class SearchStore(BackgroundUpdateStore):
list of dicts list of dicts
""" """
clauses = [] clauses = []
args = [search_term, room_id] args = [search_term]
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = [] local_clauses = []
for key in keys: for key in keys:
@ -239,25 +247,25 @@ class SearchStore(BackgroundUpdateStore):
if pagination_token: if pagination_token:
try: try:
topo, stream = pagination_token.split(",") origin_server_ts, stream = pagination_token.split(",")
topo = int(topo) origin_server_ts = int(origin_server_ts)
stream = int(stream) stream = int(stream)
except: except:
raise SynapseError(400, "Invalid pagination token") raise SynapseError(400, "Invalid pagination token")
clauses.append( clauses.append(
"(topological_ordering < ?" "(origin_server_ts < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))" " OR (origin_server_ts = ? AND stream_ordering < ?))"
) )
args.extend([topo, topo, stream]) args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"SELECT ts_rank_cd(vector, query) as rank," "SELECT ts_rank_cd(vector, query) as rank,"
" topological_ordering, stream_ordering, room_id, event_id" " origin_server_ts, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search" " FROM plainto_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events" " NATURAL JOIN events"
" WHERE vector @@ query AND room_id = ?" " WHERE vector @@ query AND "
) )
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes. # We use CROSS JOIN here to ensure we use the right indexes.
@ -270,24 +278,23 @@ class SearchStore(BackgroundUpdateStore):
# MATCH unless it uses the full text search index # MATCH unless it uses the full text search index
sql = ( sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id," "SELECT rank(matchinfo) as rank, room_id, event_id,"
" topological_ordering, stream_ordering" " origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search" " FROM event_search"
" WHERE value MATCH ?" " WHERE value MATCH ?"
" )" " )"
" CROSS JOIN events USING (event_id)" " CROSS JOIN events USING (event_id)"
" WHERE room_id = ?" " WHERE "
) )
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
for clause in clauses: sql += " AND ".join(clauses)
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the # We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database. # entire table from the database.
sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
args.append(limit) args.append(limit)
@ -295,6 +302,8 @@ class SearchStore(BackgroundUpdateStore):
"search_rooms", self.cursor_to_dict, sql, *args "search_rooms", self.cursor_to_dict, sql, *args
) )
results = filter(lambda row: row["room_id"] in room_ids, results)
events = yield self._get_events([r["event_id"] for r in results]) events = yield self._get_events([r["event_id"] for r in results])
event_map = { event_map = {
@ -312,7 +321,7 @@ class SearchStore(BackgroundUpdateStore):
"event": event_map[r["event_id"]], "event": event_map[r["event_id"]],
"rank": r["rank"], "rank": r["rank"],
"pagination_token": "%s,%s" % ( "pagination_token": "%s,%s" % (
r["topological_ordering"], r["stream_ordering"] r["origin_server_ts"], r["stream_ordering"]
), ),
} }
for r in results for r in results

View File

@ -48,8 +48,8 @@ class TagsStore(SQLBaseStore):
Args: Args:
user_id(str): The user to get the tags for. user_id(str): The user to get the tags for.
Returns: Returns:
A deferred dict mapping from room_id strings to lists of tag A deferred dict mapping from room_id strings to dicts mapping from
strings. tag strings to tag content.
""" """
deferred = self._simple_select_list( deferred = self._simple_select_list(

View File

@ -365,7 +365,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
# TODO(paul): This test will likely break if/when real auth permissions # TODO(paul): This test will likely break if/when real auth permissions
# are added; for now the HS will always accept any invite # are added; for now the HS will always accept any invite
yield self.handler.send_invite( yield self.handler.send_presence_invite(
observer_user=self.u_apple, observed_user=self.u_banana) observer_user=self.u_apple, observed_user=self.u_banana)
self.assertEquals( self.assertEquals(
@ -384,7 +384,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite_local_nonexistant(self): def test_invite_local_nonexistant(self):
yield self.handler.send_invite( yield self.handler.send_presence_invite(
observer_user=self.u_apple, observed_user=self.u_durian) observer_user=self.u_apple, observed_user=self.u_durian)
self.assertEquals( self.assertEquals(
@ -414,7 +414,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
yield self.handler.send_invite( yield self.handler.send_presence_invite(
observer_user=self.u_apple, observed_user=u_rocket) observer_user=self.u_apple, observed_user=u_rocket)
self.assertEquals( self.assertEquals(

View File

@ -168,8 +168,9 @@ class MockHttpResource(HttpServer):
raise KeyError("No event can handle %s" % path) raise KeyError("No event can handle %s" % path)
def register_path(self, method, path_pattern, callback): def register_paths(self, method, path_patterns, callback):
self.callbacks.append((method, path_pattern, callback)) for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))
class MockKey(object): class MockKey(object):