Merge branch 'rejections_storage' of github.com:matrix-org/synapse into replication_split

This commit is contained in:
Erik Johnston 2015-01-30 14:17:47 +00:00
commit 84b78c3b5f
52 changed files with 2372 additions and 315 deletions

View File

@ -33,7 +33,7 @@ setup(
install_requires=[ install_requires=[
"syutil==0.0.2", "syutil==0.0.2",
"matrix_angular_sdk==0.6.0", "matrix_angular_sdk==0.6.0",
"Twisted>=14.0.0", "Twisted==14.0.2",
"service_identity>=1.0.0", "service_identity>=1.0.0",
"pyopenssl>=0.14", "pyopenssl>=0.14",
"pyyaml", "pyyaml",

View File

@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import UserID from synapse.types import UserID, ClientInfo
import logging import logging
@ -290,7 +290,9 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
UserID : User ID object of the user making the request tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -299,6 +301,8 @@ class Auth(object):
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
@ -314,7 +318,7 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
defer.returnValue(user) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(403, "Missing access token.")
@ -339,6 +343,7 @@ class Auth(object):
"admin": bool(ret.get("admin", False)), "admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"), "device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
} }
defer.returnValue(user_info) defer.returnValue(user_info)

View File

@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
class Codes(object): class Codes(object):
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED" UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN" FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON" BAD_JSON = "M_BAD_JSON"
@ -34,6 +35,7 @@ class Codes(object):
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
@ -81,6 +83,35 @@ class RegistrationError(SynapseError):
pass pass
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
message = None
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super(UnrecognizedRequestError, self).__init__(
400,
message,
**kwargs
)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.NOT_FOUND
super(NotFoundError, self).__init__(
404,
"Not found",
**kwargs
)
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""

View File

@ -16,6 +16,7 @@
"""Contains the URL paths to prefix various aspects of the server with. """ """Contains the URL paths to prefix various aspects of the server with. """
CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1" FEDERATION_PREFIX = "/_matrix/federation/v1"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"

View File

@ -32,12 +32,13 @@ from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.rest.client.v1 import ClientV1RestResource from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from daemonize import Daemonize from daemonize import Daemonize
import twisted.manhole.telnet import twisted.manhole.telnet
@ -62,6 +63,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_client(self): def build_resource_for_client(self):
return ClientV1RestResource(self) return ClientV1RestResource(self)
def build_resource_for_client_v2_alpha(self):
return ClientV2AlphaRestResource(self)
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource() return JsonResource()
@ -105,6 +109,7 @@ class SynapseHomeServer(HomeServer):
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ] # [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [ desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()), (CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()), (FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
@ -267,6 +272,8 @@ def setup():
bind_port = None bind_port = None
hs.start_listening(bind_port, config.unsecure_port) hs.start_listening(bind_port, config.unsecure_port)
hs.get_pusherpool().start()
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(

View File

@ -89,31 +89,31 @@ def prune_event(event):
return type(event)(allowed_fields) return type(event)(allowed_fields)
def serialize_event(hs, e, client_event=True): def serialize_event(e, time_now_ms, client_event=True):
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase): if not isinstance(e, EventBase):
return e return e
time_now_ms = int(time_now_ms)
# Should this strip out None's? # Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()} d = {k: v for k, v in e.get_dict().items()}
if not client_event: if not client_event:
# set the age and keep all other keys # set the age and keep all other keys
if "age_ts" in d["unsigned"]: if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec()) d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
return d return d
if "age_ts" in d["unsigned"]: if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec()) d["age"] = time_now_ms - d["unsigned"]["age_ts"]
d["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None) d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned: if "redacted_because" in e.unsigned:
d["redacted_because"] = serialize_event( d["redacted_because"] = serialize_event(
hs, e.unsigned["redacted_because"] e.unsigned["redacted_because"], time_now_ms
) )
del d["unsigned"]["redacted_because"] del d["unsigned"]["redacted_because"]

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import serialize_event
from ._base import BaseHandler from ._base import BaseHandler
@ -48,10 +49,11 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True): as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
try: try:
if affect_presence:
if auth_user not in self._streams_per_user: if auth_user not in self._streams_per_user:
self._streams_per_user[auth_user] = 0 self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user: if auth_user in self._stop_timer_per_user:
@ -78,8 +80,10 @@ class EventStreamHandler(BaseHandler):
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout
) )
time_now = self.clock.time_msec()
chunks = [ chunks = [
self.hs.serialize_event(e, as_client_event) for e in events serialize_event(e, time_now, as_client_event) for e in events
] ]
chunk = { chunk = {
@ -91,6 +95,7 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
finally: finally:
if affect_presence:
self._streams_per_user[auth_user] -= 1 self._streams_per_user[auth_user] -= 1
if not self._streams_per_user[auth_user]: if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user] del self._streams_per_user[auth_user]
@ -104,7 +109,7 @@ class EventStreamHandler(BaseHandler):
self._stop_timer_per_user.pop(auth_user, None) self._stop_timer_per_user.pop(auth_user, None)
yield self.distributor.fire( return self.distributor.fire(
"stopped_user_eventstream", auth_user "stopped_user_eventstream", auth_user
) )

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID from synapse.types import UserID
@ -100,9 +101,11 @@ class MessageHandler(BaseHandler):
"room_key", next_key "room_key", next_key
) )
time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": [ "chunk": [
self.hs.serialize_event(e, as_client_event) for e in events serialize_event(e, time_now, as_client_event) for e in events
], ],
"start": pagin_config.from_token.to_string(), "start": pagin_config.from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
@ -111,7 +114,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True): def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
@ -145,6 +149,15 @@ class MessageHandler(BaseHandler):
builder.content builder.content
) )
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
) )
@ -211,7 +224,8 @@ class MessageHandler(BaseHandler):
# TODO: This is duplicating logic from snapshot_all_rooms # TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id) current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state]) now = self.clock.time_msec()
defer.returnValue([serialize_event(c, now) for c in current_state])
@defer.inlineCallbacks @defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, def snapshot_all_rooms(self, user_id=None, pagin_config=None,
@ -283,10 +297,11 @@ class MessageHandler(BaseHandler):
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = { d["messages"] = {
"chunk": [ "chunk": [
self.hs.serialize_event(m, as_client_event) serialize_event(m, time_now, as_client_event)
for m in messages for m in messages
], ],
"start": start_token.to_string(), "start": start_token.to_string(),
@ -297,7 +312,8 @@ class MessageHandler(BaseHandler):
event.room_id event.room_id
) )
d["state"] = [ d["state"] = [
self.hs.serialize_event(c) for c in current_state serialize_event(c, time_now, as_client_event)
for c in current_state
] ]
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -320,8 +336,9 @@ class MessageHandler(BaseHandler):
auth_user = UserID.from_string(user_id) auth_user = UserID.from_string(user_id)
# TODO: These concurrently # TODO: These concurrently
time_now = self.clock.time_msec()
state_tuples = yield self.state_handler.get_current_state(room_id) state_tuples = yield self.state_handler.get_current_state(room_id)
state = [self.hs.serialize_event(x) for x in state_tuples] state = [serialize_event(x, time_now) for x in state_tuples]
member_event = (yield self.store.get_room_member( member_event = (yield self.store.get_room_member(
user_id=user_id, user_id=user_id,
@ -360,11 +377,13 @@ class MessageHandler(BaseHandler):
"Failed to get member presence of %r", m.user_id "Failed to get member presence of %r", m.user_id
) )
time_now = self.clock.time_msec()
defer.returnValue({ defer.returnValue({
"membership": member_event.membership, "membership": member_event.membership,
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": [self.hs.serialize_event(m) for m in messages], "chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(), "start": start_token.to_string(),
"end": end_token.to_string(), "end": end_token.to_string(),
}, },

View File

@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler):
"changed_presencelike_data", self.changed_presencelike_data "changed_presencelike_data", self.changed_presencelike_data
) )
# outbound signal from the presence module to advertise when a user's
# presence has changed
distributor.declare("user_presence_changed")
self.distributor = distributor self.distributor = distributor
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids, room_ids=room_ids,
statuscache=statuscache, statuscache=statuscache,
) )
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None): def _push_presence_remote(self, user, destination, state=None):

View File

@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler):
# each request # each request
httpCli = SimpleHttpClient(self.hs) httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
trustedIdServers = ['matrix.org:8090'] trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers: if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer']) 'credentials', creds['idServer'])

View File

@ -16,12 +16,14 @@
"""Contains functions for performing events on rooms.""" """Contains functions for performing events on rooms."""
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import BaseHandler from synapse.events.utils import serialize_event
import logging import logging
@ -293,8 +295,9 @@ class RoomMemberHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id) yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id) member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [ event_list = [
self.hs.serialize_event(entry) serialize_event(entry, time_now)
for entry in member_list for entry in member_list
] ]
chunk_data = { chunk_data = {

View File

@ -62,6 +62,25 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json):
json_str = json.dumps(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri)
response = yield self.agent.request(
"POST",
uri.encode("ascii"),
headers=Headers({
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}): def get_json(self, uri, args={}):
""" Get's some json from the given host and path """ Get's some json from the given host and path

View File

@ -16,7 +16,7 @@
from synapse.http.agent_name import AGENT_NAME from synapse.http.agent_name import AGENT_NAME
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource):
return return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
self._send_response( raise UnrecognizedRequestError()
request,
400,
{"error": "Unrecognized request"}
)
except CodeMessageException as e: except CodeMessageException as e:
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)

View File

@ -15,6 +15,8 @@
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
from synapse.api.errors import SynapseError
import logging import logging
@ -54,3 +56,58 @@ class RestServlet(object):
http_server.register_path(method, pattern, method_handler) http_server.register_path(method, pattern, method_handler)
else: else:
raise NotImplementedError("RestServlet must register something.") raise NotImplementedError("RestServlet must register something.")
@staticmethod
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default

364
synapse/push/__init__.py Normal file
View File

@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
import synapse.util.async
import baserules
import logging
import fnmatch
import json
logger = logging.getLogger(__name__)
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['notify']
def __init__(self, _hs, instance_handle, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.instance_handle = instance_handle
self.user_name = user_name
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
self.pushkey = pushkey
self.pushkey_ts = pushkey_ts
self.data = data
self.last_token = last_token
self.last_success = last_success # not actually used
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since
self.alive = True
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@defer.inlineCallbacks
def _actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify'])
if ev['type'] == 'm.room.member':
if ev['state_key'] != self.user_name:
defer.returnValue(['dont_notify'])
rules = yield self.store.get_push_rules_for_user_name(self.user_name)
for r in rules:
r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions'])
user_name_localpart = UserID.from_string(self.user_name).localpart
rules.extend(baserules.make_base_rules(user_name_localpart))
# get *our* member event for display name matching
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=self.user_name
)
my_display_name = None
if len(member_events_for_room) > 0:
my_display_name = member_events_for_room[0].content['displayname']
for r in rules:
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name
)
# ignore rules with no actions (we have an explict 'dont_notify'
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s" %
(r['rule_id'], r['user_name'])
)
continue
if matches:
defer.returnValue(actions)
defer.returnValue(Pusher.DEFAULT_ACTIONS)
def _event_fulfills_condition(self, ev, condition, display_name):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
pat = condition['pattern']
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return fnmatch.fnmatch(val.upper(), pat.upper())
elif condition['kind'] == 'device':
if 'instance_handle' not in condition:
return True
return condition['instance_handle'] == self.instance_handle
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
return fnmatch.fnmatch(
ev['content']['body'].upper(), "*%s*" % (display_name.upper(),)
)
else:
return True
@defer.inlineCallbacks
def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases(
ev['room_id']
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=ev['user_id']
)
if len(their_member_events_for_room) > 0:
dn = their_member_events_for_room[0].content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)
@defer.inlineCallbacks
def start(self):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
# because we need the result to be reproduceable in case
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.user_name, self.pushkey, self.last_token)
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
while self.alive:
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk
)
if not self.alive:
continue
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success(
self.user_name,
self.pushkey,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.user_name,
self.pushkey,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
def dispatch_push(self, p, tweaks):
"""
Overridden by implementing classes to actually deliver the notification
Args:
p: The event to notify for as a single event from the event stream
Returns: If the notification was delivered, an array containing any
pushkeys that were rejected by the push gateway.
False if the notification could not be delivered (ie.
should be retried).
"""
pass
def reset_badge_count(self):
pass
def presence_changed(self, state):
"""
We clear badge counts whenever a user's last_active time is bumped
This is by no means perfect but I think it's the best we can do
without read receipts.
"""
if 'last_active' in state.state:
last_active = state.state['last_active']
if last_active > self.last_last_active_time:
self.last_last_active_time = last_active
if self.has_unread:
logger.info("Resetting badge count for %s", self.user_name)
self.reset_badge_count()
self.has_unread = False
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val
def _tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_sound' in a:
tweaks['sound'] = a['set_sound']
return tweaks
class PusherConfigException(Exception):
def __init__(self, msg):
super(PusherConfigException, self).__init__(msg)

35
synapse/push/baserules.py Normal file
View File

@ -0,0 +1,35 @@
def make_base_rules(user_name):
"""
Nominally we reserve priority class 0 for these rules, although
in practice we just append them to the end so we don't actually need it.
"""
return [
{
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern': '*%s*' % (user_name,), # Matrix ID match
}
],
'actions': [
'notify',
{
'set_sound': 'default'
}
]
},
{
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_sound': 'default'
}
]
},
]

146
synapse/push/httppusher.py Normal file
View File

@ -0,0 +1,146 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push import Pusher, PusherConfigException
from synapse.http.client import SimpleHttpClient
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
def __init__(self, _hs, instance_handle, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
instance_handle,
user_name,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
data,
last_token,
last_success,
failing_since
)
if 'url' not in data:
raise PusherConfigException(
"'url' required in data for HTTP pusher"
)
self.url = data['url']
self.httpCli = SimpleHttpClient(self.hs)
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks):
# we probably do not want to push for every presence update
# (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones)
# Actually, presence events will not get this far now because we
# need to filter them out in the main Pusher code.
if 'event_id' not in event:
defer.returnValue(None)
ctx = yield self.get_context_for_event(event)
d = {
'notification': {
'id': event['event_id'],
'type': event['type'],
'sender': event['user_id'],
'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing
# Just set the badge to 1 until we have read receipts
'unread': 1,
# 'missed_calls': 2
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
'tweaks': tweaks
}
]
}
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
if 'content' in event:
d['notification']['content'] = event['content']
if len(ctx['aliases']):
d['notification']['room_alias'] = ctx['aliases'][0]
if 'sender_display_name' in ctx:
d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx:
d['notification']['room_name'] = ctx['name']
defer.returnValue(d)
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks):
notification_dict = yield self._build_notification_dict(event, tweaks)
if not notification_dict:
defer.returnValue([])
try:
resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)
@defer.inlineCallbacks
def reset_badge_count(self):
d = {
'notification': {
'id': '',
'type': None,
'sender': '',
'counts': {
'unread': 0,
'missed_calls': 0
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
}
]
}
}
try:
resp = yield self.httpCli.post_json_get_json(self.url, d)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)

152
synapse/push/pusherpool.py Normal file
View File

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
import logging
import json
logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
self.store = self.hs.get_datastore()
self.pushers = {}
self.last_pusher_started = -1
distributor = self.hs.get_distributor()
distributor.observe(
"user_presence_changed", self.user_presence_changed
)
@defer.inlineCallbacks
def user_presence_changed(self, user, state):
user_name = user.to_string()
# until we have read receipts, pushers use this to reset a user's
# badge counters to zero
for p in self.pushers.values():
if p.user_name == user_name:
yield p.presence_changed(state)
@defer.inlineCallbacks
def start(self):
pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers)
@defer.inlineCallbacks
def add_pusher(self, user_name, instance_handle, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
"user_name": user_name,
"kind": kind,
"instance_handle": instance_handle,
"app_id": app_id,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(),
"lang": lang,
"data": data,
"last_token": None,
"last_success": None,
"failing_since": None
})
yield self._add_pusher_to_store(
user_name, instance_handle, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_name=user_name,
instance_handle=instance_handle,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=json.dumps(data)
)
self._refresh_pusher((app_id, pushkey))
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
instance_handle=pusherdict['instance_handle'],
user_name=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
failing_since=pusherdict['failing_since']
)
else:
raise PusherConfigException(
"Unknown pusher type '%s' for user %s" %
(pusherdict['kind'], pusherdict['user_name'])
)
@defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey):
p = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey
)
p['data'] = json.loads(p['data'])
self._start_pushers([p])
def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
p = self._create_pusher(pusherdict)
if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
p.start()
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey):
fullid = "%s:%s" % (app_id, pushkey)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)

View File

@ -5,8 +5,8 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil==0.0.2": ["syutil"], "syutil==0.0.2": ["syutil"],
"matrix_angular_sdk==0.6.0": ["syweb==0.6.0"], "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"],
"Twisted>=14.0.0": ["twisted>=14.0.0"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],

View File

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import ( from . import (
room, events, register, login, profile, presence, initial_sync, directory, room, events, register, login, profile, presence, initial_sync, directory,
voip, admin, voip, admin, pusher, push_rule
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -41,3 +40,5 @@ class ClientV1RestResource(JsonResource):
directory.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)

View File

@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:

View File

@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
if not "room_id" in content: if not "room_id" in content:
@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, room_alias): def on_DELETE(self, request, room_alias):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_pattern
from synapse.events.utils import serialize_event
import logging import logging
@ -33,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -64,14 +65,19 @@ class EventStreamRestServlet(ClientV1RestServlet):
class EventRestServlet(ClientV1RestServlet): class EventRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$")
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)
time_now = self.clock.time_msec()
if event: if event:
defer.returnValue((200, self.hs.serialize_event(event))) defer.returnValue((200, serialize_event(event, time_now)))
else: else:
defer.returnValue((404, "Event not found.")) defer.returnValue((404, "Event not found."))

View File

@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)

View File

@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} state = {}
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):

View File

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:

View File

@ -0,0 +1,401 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \
StoreError
from .base import ClientV1RestServlet, client_path_pattern
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
import json
class PushRuleRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$")
PRIORITY_CLASS_MAP = {
'underride': 1,
'sender': 2,
'room': 3,
'content': 4,
'override': 5,
}
PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
def rule_spec_from_path(self, path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
if scope not in ['global', 'device']:
raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
template = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
rule_id = path[0]
spec = {
'scope': scope,
'template': template,
'rule_id': rule_id
}
if device:
spec['device'] = device
return spec
def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None):
if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
conditions = req_obj['conditions']
for c in conditions:
if 'kind' not in c:
raise InvalidRuleException("Condition without 'kind'")
elif rule_template == 'room':
conditions = [{
'kind': 'event_match',
'key': 'room_id',
'pattern': rule_id
}]
elif rule_template == 'sender':
conditions = [{
'kind': 'event_match',
'key': 'user_id',
'pattern': rule_id
}]
elif rule_template == 'content':
if 'pattern' not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
pat = req_obj['pattern']
if pat.strip("*?[]") == pat:
# no special glob characters so we assume the user means
# 'contains this string' rather than 'is this string'
pat = "*%s*" % (pat,)
conditions = [{
'kind': 'event_match',
'key': 'content.body',
'pattern': pat
}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'instance_handle': device
})
if 'actions' not in req_obj:
raise InvalidRuleException("No actions found")
actions = req_obj['actions']
for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']:
pass
elif isinstance(a, dict) and 'set_sound' in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
return conditions, actions
@defer.inlineCallbacks
def on_PUT(self, request):
spec = self.rule_spec_from_path(request.postpath)
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
try:
(conditions, actions) = self.rule_tuple_from_request_object(
spec['template'],
spec['rule_id'],
content,
device=spec['device'] if 'device' in spec else None
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
if before and len(before):
before = before[0]
after = request.args.get("after", None)
if after and len(after):
after = after[0]
try:
yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(),
rule_id=spec['rule_id'],
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after
)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
except RuleNotFoundException as e:
raise SynapseError(400, e.message)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request):
spec = self.rule_spec_from_path(request.postpath)
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request)
if 'device' in spec:
rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
user.to_string()
)
for r in rules:
conditions = json.loads(r['conditions'])
ih = _instance_handle_from_conditions(conditions)
if ih == spec['device'] and r['priority_class'] == priority_class:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), spec['rule_id']
)
defer.returnValue((200, {}))
raise NotFoundError()
else:
try:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), spec['rule_id'],
priority_class=priority_class
)
defer.returnValue((200, {}))
except StoreError as e:
if e.code == 404:
raise NotFoundError()
else:
raise
@defer.inlineCallbacks
def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string())
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in rawrules:
rulearray = None
r["conditions"] = json.loads(r["conditions"])
r["actions"] = json.loads(r["actions"])
template_name = _priority_class_to_template_name(r['priority_class'])
if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
# per-device rule
instance_handle = _instance_handle_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not instance_handle:
continue
if instance_handle not in rules['device']:
rules['device'][instance_handle] = {}
rules['device'][instance_handle] = (
_add_empty_priority_class_arrays(
rules['device'][instance_handle]
)
)
rulearray = rules['device'][instance_handle][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
rulearray.append(template_rule)
path = request.postpath[1:]
if path == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules))
elif path[0] == 'global':
path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
instance_handle = path[0]
path = path[1:]
if instance_handle not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][instance_handle]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else:
raise UnrecognizedRequestError()
def on_OPTIONS(self, _):
return 200, {}
def _add_empty_priority_class_arrays(d):
for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _instance_handle_from_conditions(conditions):
"""
Given a list of conditions, return the instance handle of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['instance_handle']
return None
def _filter_ruleset_with_path(ruleset, path):
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset
template_kind = path[0]
if template_kind not in ruleset:
raise UnrecognizedRequestError()
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset[template_kind]
rule_id = path[0]
for r in ruleset[template_kind]:
if r['rule_id'] == rule_id:
return r
raise NotFoundError
def _priority_class_from_spec(spec):
if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
return pc
def _priority_class_to_template_name(pc):
if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
# per-device
prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
return {k: rule[k] for k in ["rule_id", "conditions", "actions"]}
elif template_name in ["sender", "room"]:
return {k: rule[k] for k in ["rule_id", "actions"]}
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
ret = {k: rule[k] for k in ["rule_id", "actions"]}
ret["pattern"] = thecond["pattern"]
return ret
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
class InvalidRuleException(Exception):
pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server)

View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_pattern
import json
class PusherRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushers/set$")
@defer.inlineCallbacks
def on_POST(self, request):
user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
pusher_pool = self.hs.get_pusherpool()
if ('pushkey' in content and 'app_id' in content
and 'kind' in content and
content['kind'] is None):
yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey']
)
defer.returnValue((200, {}))
reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
missing = []
for i in reqd:
if i not in content:
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM)
try:
yield pusher_pool.add_pusher(
user_name=user.to_string(),
instance_handle=content['instance_handle'],
kind=content['kind'],
app_id=content['app_id'],
app_display_name=content['app_display_name'],
device_display_name=content['device_display_name'],
pushkey=content['pushkey'],
lang=content['lang'],
data=content['data']
)
except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message,
errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server)

View File

@ -21,6 +21,7 @@ from synapse.api.errors import SynapseError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.types import UserID, RoomID, RoomAlias from synapse.types import UserID, RoomID, RoomAlias
from synapse.events.utils import serialize_event
import json import json
import logging import logging
@ -61,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request) room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None) info = yield self.make_room(room_config, auth_user, None)
@ -124,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -141,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, data.get_dict()["content"])) defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -157,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(event_dict) yield msg_handler.create_and_send_event(
event_dict, client=client, txn_id=txn_id,
)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -171,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server, with_get=True) register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type): def on_POST(self, request, room_id, event_type, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -182,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -199,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_id, event_type) response = yield self.on_POST(request, room_id, event_type, txn_id)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -214,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier): def on_POST(self, request, room_identifier, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid # other if it fails to parse, without swallowing other valid
@ -244,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
"room_id": identifier.to_string(), "room_id": identifier.to_string(),
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -258,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_identifier) response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -282,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk( members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id, room_id=room_id,
@ -310,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
@ -334,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
@ -350,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
@ -363,6 +370,10 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
class RoomTriggerBackfill(ClientV1RestServlet): class RoomTriggerBackfill(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$") PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$")
def __init__(self, hs):
super(RoomTriggerBackfill, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
remote_server = urllib.unquote( remote_server = urllib.unquote(
@ -374,7 +385,9 @@ class RoomTriggerBackfill(ClientV1RestServlet):
handler = self.handlers.federation_handler handler = self.handlers.federation_handler
events = yield handler.backfill(remote_server, room_id, limit) events = yield handler.backfill(remote_server, room_id, limit)
res = [self.hs.serialize_event(event) for event in events] time_now = self.clock.time_msec()
res = [serialize_event(event, time_now) for event in events]
defer.returnValue((200, res)) defer.returnValue((200, res))
@ -388,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action): def on_POST(self, request, room_id, membership_action, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -411,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -425,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_id, membership_action) response = yield self.on_POST(
request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -437,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -449,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"redacts": event_id, "redacts": event_id,
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -463,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_id, event_id) response = yield self.on_POST(request, room_id, event_id, txn_id)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -476,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))

View File

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import JsonResource
class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
pass

View File

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def client_v2_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)

View File

@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def map_request_to_name(self, request): def map_request_to_name(self, request):
# auth the user # auth the user
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user # namespace all file uploads on the user
prefix = base64.urlsafe_b64encode( prefix = base64.urlsafe_b64encode(

View File

@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
try: try:
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")

View File

@ -20,7 +20,6 @@
# Imports required for the default HomeServer() implementation # Imports required for the default HomeServer() implementation
from synapse.federation import initialize_http_replication from synapse.federation import initialize_http_replication
from synapse.events.utils import serialize_event
from synapse.notifier import Notifier from synapse.notifier import Notifier
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.handlers import Handlers from synapse.handlers import Handlers
@ -32,6 +31,7 @@ from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
@ -70,6 +70,7 @@ class BaseHomeServer(object):
'notifier', 'notifier',
'distributor', 'distributor',
'resource_for_client', 'resource_for_client',
'resource_for_client_v2_alpha',
'resource_for_federation', 'resource_for_federation',
'resource_for_web_client', 'resource_for_web_client',
'resource_for_content_repo', 'resource_for_content_repo',
@ -78,6 +79,7 @@ class BaseHomeServer(object):
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',
'keyring', 'keyring',
'pusherpool',
'event_builder_factory', 'event_builder_factory',
] ]
@ -123,9 +125,6 @@ class BaseHomeServer(object):
setattr(BaseHomeServer, "get_%s" % (depname), _get) setattr(BaseHomeServer, "get_%s" % (depname), _get)
def serialize_event(self, e, as_client_event=True):
return serialize_event(self, e, as_client_event)
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# May be an X-Forwarding-For header depending on config # May be an X-Forwarding-For header depending on config
ip_addr = request.getClientIP() ip_addr = request.getClientIP()
@ -200,3 +199,6 @@ class HomeServer(BaseHomeServer):
clock=self.get_clock(), clock=self.get_clock(),
hostname=self.hostname, hostname=self.hostname,
) )
def build_pusherpool(self):
return PusherPool(self)

View File

@ -37,13 +37,15 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
""" """
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
# self.auth = hs.get_auth()
self.hs = hs self.hs = hs
@defer.inlineCallbacks @defer.inlineCallbacks
@ -215,7 +217,7 @@ class StateHandler(object):
auth_events = { auth_events = {
k: e for k, e in unconflicted_state.items() k: e for k, e in unconflicted_state.items()
if k[0] in (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) if k[0] in AuthEventTypes
} }
try: try:
@ -240,10 +242,6 @@ class StateHandler(object):
1. power levels 1. power levels
2. memberships 2. memberships
3. other events. 3. other events.
:param conflicted_state:
:param auth_events:
:return:
""" """
resolved_state = {} resolved_state = {}
power_key = (EventTypes.PowerLevels, "") power_key = (EventTypes.PowerLevels, "")

View File

@ -29,6 +29,8 @@ from .stream import StreamStore
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
from .pusher import PusherStore
from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore from .rejections import RejectionsStore
@ -61,6 +63,7 @@ SCHEMAS = [
"state", "state",
"event_edges", "event_edges",
"event_signatures", "event_signatures",
"pusher",
"media_repository", "media_repository",
] ]
@ -84,6 +87,8 @@ class DataStore(RoomMemberStore, RoomStore,
EventFederationStore, EventFederationStore,
MediaRepositoryStore, MediaRepositoryStore,
RejectionsStore, RejectionsStore,
PusherStore,
PushRuleStore
): ):
def __init__(self, hs): def __init__(self, hs):
@ -386,6 +391,41 @@ class DataStore(RoomMemberStore, RoomStore,
events = yield self._parse_events(results) events = yield self._parse_events(results)
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
sql += " OR s.type = 'm.room.aliases')"
args = (room_id,)
results = yield self._execute_and_decode(sql, *args)
events = yield self._parse_events(results)
name = None
aliases = []
for e in events:
if e.type == 'm.room.name':
name = e.content['name']
elif e.type == 'm.room.aliases':
aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
row = yield self._execute( row = yield self._execute(

View File

@ -193,6 +193,50 @@ class SQLBaseStore(object):
txn.execute(sql, values.values()) txn.execute(sql, values.values())
return txn.lastrowid return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
Returns: A deferred
"""
return self.runInteraction(
"_simple_upsert",
self._simple_upsert_txn, table, keyvalues, values
)
def _simple_upsert_txn(self, txn, table, keyvalues, values):
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
logger.debug(
"[SQL] %s Args=%s",
sql, sqlargs,
)
txn.execute(sql, sqlargs)
if txn.rowcount == 0:
# We didn't update and rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
logger.debug(
"[SQL] %s Args=%s",
sql, keyvalues.values(),
)
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False): allow_none=False):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
@ -344,8 +388,8 @@ class SQLBaseStore(object):
if updatevalues: if updatevalues:
update_sql = "UPDATE %s SET %s WHERE %s" % ( update_sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k) for k in updatevalues), ", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
def func(txn): def func(txn):

View File

@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
import logging
import copy
import json
logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks
def get_push_rules_for_user_name(self, user_name):
sql = (
"SELECT "+",".join(PushRuleTable.fields)+" "
"FROM "+PushRuleTable.table_name+" "
"WHERE user_name = ? "
"ORDER BY priority_class DESC, priority DESC"
)
rows = yield self._execute(None, sql, user_name)
dicts = []
for r in rows:
d = {}
for i, f in enumerate(PushRuleTable.fields):
d[f] = r[i]
dicts.append(d)
defer.returnValue(dicts)
@defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs)
if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions'])
# we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless.
if 'id' in vals:
del vals['id']
if before or after:
ret = yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
before=before,
after=after,
**vals
)
defer.returnValue(ret)
else:
ret = yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
**vals
)
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
after = None
relative_to_rule = None
if 'after' in kwargs and kwargs['after']:
after = kwargs['after']
relative_to_rule = after
if 'before' in kwargs and kwargs['before']:
relative_to_rule = kwargs['before']
# get the priority of the rule we're inserting after/before
sql = (
"SELECT priority_class, priority FROM ? "
"WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
)
txn.execute(sql, (user_name, relative_to_rule))
res = txn.fetchall()
if not res:
raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule))
priority_class, base_rule_priority = res[0]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
new_rule = copy.copy(kwargs)
if 'before' in new_rule:
del new_rule['before']
if 'after' in new_rule:
del new_rule['after']
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
# check if the priority before/after is free
new_rule_priority = base_rule_priority
if after:
new_rule_priority -= 1
else:
new_rule_priority += 1
new_rule['priority'] = new_rule_priority
sql = (
"SELECT COUNT(*) FROM " + PushRuleTable.table_name +
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
)
txn.execute(sql, (user_name, priority_class, new_rule_priority))
res = txn.fetchall()
num_conflicting = res[0][0]
# if there are conflicting rules, bump everything
if num_conflicting:
sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
if after:
sql += "-1"
else:
sql += "+1"
sql += " WHERE user_name = ? AND priority_class = ? AND priority "
if after:
sql += "<= ?"
else:
sql += ">= ?"
txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
def _add_push_rule_highest_priority_txn(self, txn, user_name,
priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name +
" WHERE user_name = ? and priority_class = ?"
)
txn.execute(sql, (user_name, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
new_prio = 0
if how_many > 0:
new_prio = highest_prio + 1
# and insert the new rule
new_rule = copy.copy(kwargs)
if 'id' in new_rule:
del new_rule['id']
new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id, **kwargs):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_name (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(PushRuleTable.table_name, kwargs)
class RuleNotFoundException(Exception):
pass
class InconsistentRuleException(Exception):
pass
class PushRuleTable(Table):
table_name = "push_rules"
fields = [
"id",
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
]
EntryType = collections.namedtuple("PushRuleEntry", fields)

173
synapse/storage/pusher.py Normal file
View File

@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
from synapse.api.errors import StoreError
import logging
logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
sql = (
"SELECT id, user_name, kind, instance_handle, app_id,"
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers "
"WHERE app_id = ? AND pushkey = ?"
)
rows = yield self._execute(
None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1]
)
ret = [
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"instance_handle": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": r[9],
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret[0])
@defer.inlineCallbacks
def get_all_pushers(self):
sql = (
"SELECT id, user_name, kind, instance_handle, app_id,"
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers"
)
rows = yield self._execute(None, sql)
ret = [
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"instance_handle": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": r[9],
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret)
@defer.inlineCallbacks
def add_pusher(self, user_name, instance_handle, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data):
try:
yield self._simple_upsert(
PushersTable.table_name,
dict(
app_id=app_id,
pushkey=pushkey,
),
dict(
user_name=user_name,
kind=kind,
instance_handle=instance_handle,
app_display_name=app_display_name,
device_display_name=device_display_name,
ts=pushkey_ts,
lang=lang,
data=data
))
except Exception as e:
logger.error("create_pusher with failed: %s", e)
raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
yield self._simple_delete_one(
PushersTable.table_name,
dict(app_id=app_id, pushkey=pushkey)
)
@defer.inlineCallbacks
def update_pusher_last_token(self, user_name, pushkey, last_token):
yield self._simple_update_one(
PushersTable.table_name,
{'user_name': user_name, 'pushkey': pushkey},
{'last_token': last_token}
)
@defer.inlineCallbacks
def update_pusher_last_token_and_success(self, user_name, pushkey,
last_token, last_success):
yield self._simple_update_one(
PushersTable.table_name,
{'user_name': user_name, 'pushkey': pushkey},
{'last_token': last_token, 'last_success': last_success}
)
@defer.inlineCallbacks
def update_pusher_failing_since(self, user_name, pushkey, failing_since):
yield self._simple_update_one(
PushersTable.table_name,
{'user_name': user_name, 'pushkey': pushkey},
{'failing_since': failing_since}
)
class PushersTable(Table):
table_name = "pushers"
fields = [
"id",
"user_name",
"kind",
"instance_handle",
"app_id",
"app_display_name",
"device_display_name",
"pushkey",
"pushkey_ts",
"data",
"last_token",
"last_success",
"failing_since"
]
EntryType = collections.namedtuple("PusherEntry", fields)

View File

@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.admin, access_tokens.device_id" "SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id" " INNER JOIN access_tokens on users.id = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"

View File

@ -19,3 +19,36 @@ CREATE TABLE IF NOT EXISTS rejections(
last_check TEXT NOT NULL, last_check TEXT NOT NULL,
CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
); );
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
instance_handle varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class TINYINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);

View File

@ -0,0 +1,46 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
instance_handle varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class TINYINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);

View File

@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse(cls, string): def parse(cls, string):
try: try:
if string[0] == 's': if string[0] == 's':
return cls(None, int(string[1:])) return cls(topological=None, stream=int(string[1:]))
if string[0] == 't': if string[0] == 't':
parts = string[1:].split('-', 1) parts = string[1:].split('-', 1)
return cls(int(parts[1]), int(parts[0])) return cls(topological=int(parts[0]), stream=int(parts[1]))
except: except:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))
@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse_stream_token(cls, string): def parse_stream_token(cls, string):
try: try:
if string[0] == 's': if string[0] == 's':
return cls(None, int(string[1:])) return cls(topological=None, stream=int(string[1:]))
except: except:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))

View File

@ -119,3 +119,6 @@ class StreamToken(
d = self._asdict() d = self._asdict()
d[key] = new_value d[key] = new_value
return StreamToken(**d) return StreamToken(**d)
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View File

@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.handlers.room_member_handler = Mock( hs.handlers.room_member_handler = Mock(
@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000 hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None): def _get_user_by_req(req=None):
return UserID.from_string(myid) return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View File

@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase):
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None):
return UserID.from_string(myid) return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View File

@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token

View File

@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token

View File

@ -0,0 +1,60 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests import unittest
from mock import Mock
from ....utils import MockHttpResource, MockKey
from synapse.server import HomeServer
from synapse.types import UserID
PATH_PREFIX = "/_matrix/client/v2_alpha"
class V2AlphaRestTestCase(unittest.TestCase):
# Consumer must define
# USER_ID = <some string>
# TO_REGISTER = [<list of REST servlets to register>]
def setUp(self):
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
mock_config = Mock()
mock_config.signing_key = [MockKey()]
hs = HomeServer("test",
db_pool=None,
datastore=Mock(spec=[
"insert_client_ip",
]),
http_client=None,
resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource,
config=mock_config,
)
def _get_user_by_token(token=None):
return {
"user": UserID.from_string(self.USER_ID),
"admin": False,
"device_id": None,
}
hs.get_auth().get_user_by_token = _get_user_by_token
for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource)

View File

@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
{"admin": 0, "device_id": None, "name": self.user_id}, {"admin": 0,
"device_id": None,
"name": self.user_id,
"token_id": 1},
(yield self.store.get_user_by_token(self.tokens[0])) (yield self.store.get_user_by_token(self.tokens[0]))
) )
@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals( self.assertEquals(
{"admin": 0, "device_id": None, "name": self.user_id}, {"admin": 0,
"device_id": None,
"name": self.user_id,
"token_id": 2},
(yield self.store.get_user_by_token(self.tokens[1])) (yield self.store.get_user_by_token(self.tokens[1]))
) )