Merge branch 'develop' into application-services

Conflicts:
	synapse/handlers/__init__.py
	synapse/storage/__init__.py
This commit is contained in:
Kegan Dougal 2015-02-02 15:57:59 +00:00
commit c059c9fea5
77 changed files with 6367 additions and 1536 deletions

View File

@ -95,27 +95,30 @@ Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \ $ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
$ xcode-select --install $ xcode-select --install
$ sudo pip install virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
$ pip install --user --process-dependency-links https://github.com/matrix-org/synapse/tarball/master $ virtualenv ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into This installs synapse, along with the libraries it uses, into a virtual
``$HOME/.local/lib/`` on Linux or ``$HOME/Library/Python/2.7/lib/`` on OSX. environment under ``~/.synapse``.
Your python may not give priority to locally installed libraries over system To set up your homeserver, run (in your virtualenv, as before)::
libraries, in which case you must add your local packages to your python path::
$ # on Linux: $ python -m synapse.app.homeserver \
$ export PYTHONPATH=$HOME/.local/lib/python2.7/site-packages:$PYTHONPATH --server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
$ # on OSX: Substituting your host and domain name as appropriate.
$ export PYTHONPATH=$HOME/Library/Python/2.7/lib/python/site-packages:$PYTHONPATH
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details. a TURN server. See docs/turn-howto.rst for details.
@ -140,7 +143,7 @@ host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.:: failing, e.g.::
$ pip install --user twisted $ pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments. will need to export CFLAGS=-Qunused-arguments.
@ -182,23 +185,13 @@ Running Your Homeserver
To actually run your new homeserver, pick a working directory for Synapse to run To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and:: (e.g. ``~/.synapse``), and::
$ mkdir ~/.synapse
$ cd ~/.synapse $ cd ~/.synapse
$ source ./bin/activate
$ # on Linux $ synctl start
$ ~/.local/bin/synctl start
$ # on OSX
$ ~/Library/Python/2.7/bin/synctl start
Troubleshooting Running Troubleshooting Running
----------------------- -----------------------
If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may
need a newer version of setuptools than that provided by your OS.::
$ sudo pip install setuptools --upgrade
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
encryption and digital signatures. encryption and digital signatures.
@ -225,13 +218,15 @@ directory of your choice::
$ cd synapse $ cd synapse
The homeserver has a number of external dependencies, that are easiest The homeserver has a number of external dependencies, that are easiest
to install by making setup.py do so, in --user mode:: to install using pip and a virtualenv::
$ python setup.py develop --user $ virtualenv env
$ source env/bin/activate
$ python synapse/dependencies | xargs -i pip install
$ pip install setuptools_trial mock
This will run a process of downloading and installing into your This will run a process of downloading and installing all the needed
user's .local/lib directory all of the required dependencies that are dependencies into a virtual env.
missing.
Once this is done, you may wish to run the homeserver's unit tests, to Once this is done, you may wish to run the homeserver's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
@ -252,7 +247,7 @@ IMPORTANT: Before upgrading an existing homeserver to a new version, please
refer to UPGRADE.rst for any additional instructions. refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g. Otherwise, simply re-install the new codebase over the current one - e.g.
by ``pip install --user --process-dependency-links by ``pip install --process-dependency-links
https://github.com/matrix-org/synapse/tarball/master`` https://github.com/matrix-org/synapse/tarball/master``
if using pip, or by ``git pull`` if running off a git working copy. if using pip, or by ``git pull`` if running off a git working copy.
@ -279,9 +274,9 @@ For the first form, simply pass the required hostname (of the machine) as the
$ python -m synapse.app.homeserver \ $ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.config \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.config $ python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process. Alternatively, you can run ``synctl start`` to guide you through the process.
@ -301,9 +296,9 @@ SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \ $ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \ --server-name YOURDOMAIN \
--bind-port 8448 \ --bind-port 8448 \
--config-path homeserver.config \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.config $ python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to

View File

@ -32,8 +32,8 @@ setup(
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=[ install_requires=[
"syutil==0.0.2", "syutil==0.0.2",
"matrix_angular_sdk==0.6.0", "matrix_angular_sdk>=0.6.0",
"Twisted>=14.0.0", "Twisted==14.0.2",
"service_identity>=1.0.0", "service_identity>=1.0.0",
"pyopenssl>=0.14", "pyopenssl>=0.14",
"pyyaml", "pyyaml",
@ -50,6 +50,7 @@ setup(
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0", "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0",
], ],
setup_requires=[ setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial", "setuptools_trial",
"setuptools>=1.0.0", # Needs setuptools that supports git+ssh. "setuptools>=1.0.0", # Needs setuptools that supports git+ssh.
# TODO: Do we need this now? we don't use git+ssh. # TODO: Do we need this now? we don't use git+ssh.

View File

@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import UserID from synapse.types import UserID, ClientInfo
import logging import logging
@ -102,6 +102,8 @@ class Auth(object):
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
logger.debug("Got curr_state %s", curr_state)
for event in curr_state: for event in curr_state:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
@ -290,7 +292,9 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
UserID : User ID object of the user making the request tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -299,6 +303,8 @@ class Auth(object):
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
@ -314,7 +320,7 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
defer.returnValue(user) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(403, "Missing access token.")
@ -339,6 +345,7 @@ class Auth(object):
"admin": bool(ret.get("admin", False)), "admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"), "device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@ -353,9 +360,23 @@ class Auth(object):
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
yield run_on_reactor() yield run_on_reactor()
if builder.type == EventTypes.Create: auth_ids = self.compute_auth_events(builder, context)
builder.auth_events = []
return auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
def compute_auth_events(self, event, context):
if event.type == EventTypes.Create:
return []
auth_ids = [] auth_ids = []
@ -368,7 +389,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key) join_rule_event = context.current_state.get(key)
key = (EventTypes.Member, builder.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key) member_event = context.current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
@ -382,8 +403,8 @@ class Auth(object):
else: else:
is_public = False is_public = False
if builder.type == EventTypes.Member: if event.type == EventTypes.Member:
e_type = builder.content["membership"] e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event:
auth_ids.append(join_rule_event.event_id) auth_ids.append(join_rule_event.event_id)
@ -398,17 +419,7 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
auth_events_entries = yield self.store.add_event_hashes( return auth_ids
auth_ids
)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
@log_function @log_function
def _can_send_event(self, event, auth_events): def _can_send_event(self, event, auth_events):

View File

@ -74,3 +74,9 @@ class EventTypes(object):
Message = "m.room.message" Message = "m.room.message"
Topic = "m.room.topic" Topic = "m.room.topic"
Name = "m.room.name" Name = "m.room.name"
class RejectedReason(object):
AUTH_ERROR = "auth_error"
REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor"

View File

@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
class Codes(object): class Codes(object):
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED" UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN" FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON" BAD_JSON = "M_BAD_JSON"
@ -34,6 +35,7 @@ class Codes(object):
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
@ -81,6 +83,35 @@ class RegistrationError(SynapseError):
pass pass
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
message = None
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super(UnrecognizedRequestError, self).__init__(
400,
message,
**kwargs
)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.NOT_FOUND
super(NotFoundError, self).__init__(
404,
"Not found",
**kwargs
)
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""

229
synapse/api/filtering.py Normal file
View File

@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter)
return result
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
def _check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter_json(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
def filter_public_user_data(self, events):
return self._filter_on_key(events, ["public_user_data"])
def filter_private_user_data(self, events):
return self._filter_on_key(events, ["private_user_data"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_events(self, events):
return self._filter_on_key(events, ["room", "events"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type

View File

@ -277,6 +277,8 @@ def setup():
bind_port = None bind_port = None
hs.start_listening(bind_port, config.unsecure_port) hs.start_listening(bind_port, config.unsecure_port)
hs.get_pusherpool().start()
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(

View File

@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None
def connectionMade(self): def connectionMade(self):
logger.debug("Connected to %s", self.transport.getHost()) self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", b"/_matrix/key/v1/") self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", logger.debug("Timeout waiting for response from %s", self.host)
self.transport.getHost())
self.remote_key.errback(IOError("Timeout waiting for response")) self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()

View File

@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze
class _EventInternalMetadata(object): class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict): def __init__(self, internal_metadata_dict):
self.__dict__ = internal_metadata_dict self.__dict__ = dict(internal_metadata_dict)
def get_dict(self): def get_dict(self):
return dict(self.__dict__) return dict(self.__dict__)

View File

@ -23,14 +23,15 @@ import copy
class EventBuilder(EventBase): class EventBuilder(EventBase):
def __init__(self, key_values={}): def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {})) signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {})) unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__( super(EventBuilder, self).__init__(
key_values, key_values,
signatures=signatures, signatures=signatures,
unsigned=unsigned unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
) )
def build(self): def build(self):

View File

@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state self.current_state = current_state
self.auth_events = auth_events self.auth_events = auth_events
self.state_group = None self.state_group = None
self.rejected = False

View File

@ -45,12 +45,14 @@ def prune_event(event):
"membership", "membership",
] ]
event_dict = event.get_dict()
new_content = {} new_content = {}
def add_fields(*fields): def add_fields(*fields):
for field in fields: for field in fields:
if field in event.content: if field in event.content:
new_content[field] = event.content[field] new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
add_fields("membership") add_fields("membership")
@ -75,7 +77,7 @@ def prune_event(event):
allowed_fields = { allowed_fields = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event_dict.items()
if k in allowed_keys if k in allowed_keys
} }
@ -86,10 +88,53 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
return type(event)(allowed_fields) return type(event)(
allowed_fields,
internal_metadata_dict=event.internal_metadata.get_dict()
)
def serialize_event(e, time_now_ms, client_event=True): def format_event_raw(d):
return d
def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None)
move_keys = ("age", "redacted_because", "replaces_state", "prev_content")
for key in move_keys:
if key in d["unsigned"]:
d[key] = d["unsigned"][key]
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"unsigned", "origin", "prev_state"
)
for key in drop_keys:
d.pop(key, None)
return d
def format_event_for_client_v2(d):
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"origin", "prev_state",
)
for key in drop_keys:
d.pop(key, None)
return d
def format_event_for_client_v2_without_event_id(d):
d = format_event_for_client_v2(d)
d.pop("room_id", None)
d.pop("event_id", None)
return d
def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None):
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase): if not isinstance(e, EventBase):
return e return e
@ -99,43 +144,22 @@ def serialize_event(e, time_now_ms, client_event=True):
# Should this strip out None's? # Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()} d = {k: v for k, v in e.get_dict().items()}
if not client_event:
# set the age and keep all other keys
if "age_ts" in d["unsigned"]: if "age_ts" in d["unsigned"]:
d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
return d
if "age_ts" in d["unsigned"]:
d["age"] = time_now_ms - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned: if "redacted_because" in e.unsigned:
d["redacted_because"] = serialize_event( d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms e.unsigned["redacted_because"], time_now_ms
) )
del d["unsigned"]["redacted_because"] if token_id is not None:
if token_id == getattr(e.internal_metadata, "token_id", None):
if "redacted_by" in e.unsigned: txn_id = getattr(e.internal_metadata, "txn_id", None)
d["redacted_by"] = e.unsigned["redacted_by"] if txn_id is not None:
del d["unsigned"]["redacted_by"] d["unsigned"]["transaction_id"] = txn_id
if "replaces_state" in e.unsigned:
d["replaces_state"] = e.unsigned["replaces_state"]
del d["unsigned"]["replaces_state"]
if "prev_content" in e.unsigned:
d["prev_content"] = e.unsigned["prev_content"]
del d["unsigned"]["prev_content"]
del d["auth_events"]
del d["prev_events"]
del d["hashes"]
del d["signatures"]
d.pop("depth", None)
d.pop("unsigned", None)
d.pop("origin", None)
if as_client_event:
return event_format(d)
else:
return d return d

View File

@ -0,0 +1,409 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .units import Edu
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class FederationClient(object):
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
self.event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
servers.
Will attempt to get the PDU from each destination in the list until
one succeeds.
This will persist the PDU locally upon receipt.
Args:
destinations (list): Which home servers to query
pdu_origin (str): The home server that originally sent the pdu.
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
# TODO: Rate limit the number of times we try and get the same event.
pdu = None
for destination in destinations:
try:
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
if pdu_list:
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
break
except Exception as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the `current` state PDUs for a given room from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
for i, pdu in enumerate(auth_chain):
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
for i, pdu in enumerate(auth_chain):
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue(auth_chain)
@defer.inlineCallbacks
def make_join(self, destination, room_id, user_id):
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@defer.inlineCallbacks
def send_join(self, destination, pdu):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
for i, pdu in enumerate(state):
state[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
for i, pdu in enumerate(auth_chain):
auth_chain[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = self.event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdu)
@defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""
Params:
destination (str)
event_it (str)
local_auth (list)
"""
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
}
code, content = yield self.transport_layer.send_query_auth(
destination=destination,
room_id=room_id,
event_id=event_id,
content=send_content,
)
auth_chain = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
for e in content["auth_chain"]
]
ret = {
"auth_chain": auth_chain,
"rejects": content.get("rejects", []),
"missing": content.get("missing", []),
}
defer.returnValue(ret)
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@ -0,0 +1,462 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import FederationError, SynapseError
import logging
logger = logging.getLogger(__name__)
class FederationServer(object):
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug(
"[%s] We've already responed to this request",
transaction.transaction_id
)
defer.returnValue(response)
return
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
results = yield defer.DeferredList(dl)
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1])})
logger.debug("Returning: %s", str(ret))
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}))
@defer.inlineCallbacks
def on_query_auth_request(self, origin, content, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
missing (list): A list of event_ids indicating what the other
side (`origin`) think we're missing.
rejects (dict): A mapping from event_id to a 2-tuple of reason
string and a proof (or None) of why the event was rejected.
The keys of this dict give the list of events the `origin` has
rejected.
Args:
origin (str)
content (dict)
event_id (str)
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
auth_chain = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
for e in content["auth_chain"]
]
missing = [
(yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
for e in content.get("missing", [])
]
ret = yield self.handler.on_query_auth(
origin, event_id, auth_chain, content.get("rejects", []), missing
)
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [
e.get_pdu_json(time_now)
for e in ret["auth_chain"]
],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
defer.returnValue(
(200, send_content)
)
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, max_recursion=10):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
except SynapseError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=pdu.event_id,
)
state = None
auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth and max_recursion > 0:
for event_id, hashes in pdu.prev_events:
if event_id not in have_seen:
logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try:
new_pdu = yield self.federation_client.get_pdu(
[origin, pdu.origin],
event_id=event_id,
)
if new_pdu:
yield self._handle_new_pdu(
origin,
new_pdu,
max_recursion=max_recursion-1
)
logger.debug("Processed pdu %s", event_id)
else:
logger.warn("Failed to get PDU %s", event_id)
fetch_state = True
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
fetch_state = True
else:
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
fetch_state = True
else:
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
ret = yield self.handler.on_receive_pdu(
origin,
pdu,
backfilled=False,
state=state,
auth_chain=auth_chain,
)
defer.returnValue(ret)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@ -17,23 +17,20 @@
a given transport. a given transport.
""" """
from twisted.internet import defer from .federation_client import FederationClient
from .federation_server import FederationServer
from .units import Transaction, Edu from .transaction_queue import TransactionQueue
from .persistence import TransactionActions from .persistence import TransactionActions
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReplicationLayer(object): class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over """This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to the given transport. I.e., does the sending and receiving of PDUs to
remote home servers. remote home servers.
@ -54,898 +51,26 @@ class ReplicationLayer(object):
def __init__(self, hs, transport_layer): def __init__(self, hs, transport_layer):
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self) self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self) self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore() self.federation_client = self
# self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue( self.store = hs.get_datastore()
hs, self.transaction_actions, transport_layer
)
self.handler = None self.handler = None
self.edu_handlers = {} self.edu_handlers = {}
self.query_handlers = {} self.query_handlers = {}
self._order = 0
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.event_builder_factory = hs.get_event_builder_factory() self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = TransactionQueue(hs, transport_layer)
def set_handler(self, handler): self._order = 0
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
transaction = Transaction(**transaction_data)
pdus = [
self.event_from_pdu_json(p, outlier=False)
for p in transaction.pdus
]
for pdu in pdus:
yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
This will persist the PDU locally upon receipt.
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
transaction = Transaction(**transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction.pdus
]
pdu = None
if pdu_list:
pdu = pdu_list[0]
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the `current` state PDUs for a given room from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue(auth_chain)
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug("[%s] We've already responed to this request",
transaction.transaction_id)
defer.returnValue(response)
return
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
results = yield defer.DeferredList(dl)
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1])})
logger.debug("Returning: %s", str(ret))
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}))
@defer.inlineCallbacks
def make_join(self, destination, room_id, user_id):
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@defer.inlineCallbacks
def send_join(self, destination, pdu):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
state = None
auth_chain = []
# We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu(
# origin,
# e_id,
# do_auth=False
# )
#
# if not exists:
# try:
# logger.debug(
# "_handle_new_pdu fetch missing auth event %s from %s",
# e_id,
# origin,
# )
#
# yield self.get_pdu(
# origin,
# event_id=e_id,
# outlier=True,
# )
#
# logger.debug("Processed pdu %s", e_id)
# except:
# logger.warn(
# "Failed to get auth event %s from %s",
# e_id,
# origin
# )
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(
origin,
event_id,
do_auth=False
)
if not exists:
logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try:
yield self.get_pdu(
origin,
event_id=event_id,
)
logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
else:
# We need to get the state at this event, since we have reached
# a backward extremity edge.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
if not backfilled:
ret = yield self.handler.on_receive_pdu(
origin,
pdu,
backfilled=backfilled,
state=state,
auth_chain=auth_chain,
)
else:
ret = None
# yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transaction_actions, transport_layer):
self.server_name = hs.hostname
self.transaction_actions = transaction_actions
self.transport_layer = transport_layer
self._clock = hs.get_clock()
self.store = hs.get_datastore()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set(destinations)
destinations.discard(self.server_name)
destinations.discard("localhost")
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
deferreds.append(deferred)
yield defer.DeferredList(deferreds)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if destination == self.server_name:
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send edu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
deferred = defer.Deferred()
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
return
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None)
else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@ -0,0 +1,317 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.transaction_actions = TransactionActions(self.store)
self.transport_layer = transport_layer
self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set(destinations)
destinations.discard(self.server_name)
destinations.discard("localhost")
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
deferreds.append(deferred)
yield defer.DeferredList(deferreds)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if destination == self.server_name:
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send edu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
deferred = defer.Deferred()
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
return
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None)
else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@ -213,3 +213,19 @@ class TransportLayerClient(object):
) )
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
code, content = yield self.client.post_json(
destination=destination,
path=path,
data=content,
)
if not 200 <= code < 300:
raise RuntimeError("Got %d from send_invite", code)
defer.returnValue(json.loads(content))

View File

@ -42,7 +42,7 @@ class TransportLayerServer(object):
content = None content = None
origin = None origin = None
if request.method == "PUT": if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types? # TODO: Handle other method types? other content types?
try: try:
content_bytes = request.content.read() content_bytes = request.content.read()
@ -234,6 +234,16 @@ class TransportLayerServer(object):
) )
) )
) )
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_query_auth_request(
origin, content, event_id,
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -325,3 +335,12 @@ class TransportLayerServer(object):
) )
defer.returnValue((200, content)) defer.returnValue((200, content))
@defer.inlineCallbacks
@log_function
def _on_query_auth_request(self, origin, content, event_id):
new_content = yield self.request_handler.on_query_auth_request(
origin, content, event_id
)
defer.returnValue((200, new_content))

View File

@ -27,6 +27,7 @@ from .directory import DirectoryHandler
from .typing import TypingNotificationHandler from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
class Handlers(object): class Handlers(object):
@ -53,3 +54,4 @@ class Handlers(object):
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.appservice_handler = ApplicationServicesHandler(hs) self.appservice_handler = ApplicationServicesHandler(hs)
self.sync_handler = SyncHandler(hs)

View File

@ -49,10 +49,11 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True): as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
try: try:
if affect_presence:
if auth_user not in self._streams_per_user: if auth_user not in self._streams_per_user:
self._streams_per_user[auth_user] = 0 self._streams_per_user[auth_user] = 0
if auth_user in self._stop_timer_per_user: if auth_user in self._stop_timer_per_user:
@ -94,6 +95,7 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
finally: finally:
if affect_presence:
self._streams_per_user[auth_user] -= 1 self._streams_per_user[auth_user] -= 1
if not self._streams_per_user[auth_user]: if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user] del self._streams_per_user[auth_user]
@ -107,7 +109,7 @@ class EventStreamHandler(BaseHandler):
self._stop_timer_per_user.pop(auth_user, None) self._stop_timer_per_user.pop(auth_user, None)
yield self.distributor.fire( return self.distributor.fire(
"stopped_user_eventstream", auth_user "stopped_user_eventstream", auth_user
) )

View File

@ -17,19 +17,16 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, StoreError,
) )
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, check_event_content_hash, compute_event_signature, add_hashes_and_signatures,
add_hashes_and_signatures,
) )
from synapse.types import UserID from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer from twisted.internet import defer
@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id) logger.debug("Processing event: %s", event.event_id)
redacted_event = prune_event(event)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
event.origin, redacted_pdu_json
)
except SynapseError as e:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
if not check_event_content_hash(event):
logger.warn(
"Event content has been tampered, redacting %s, %s",
event.event_id, encode_canonical_json(event.get_dict())
)
event = redacted_event
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
event.room_id, event.room_id,
self.server_name self.server_name
) )
if not is_in_room and not event.internal_metadata.outlier: if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
replication = self.replication_layer
if not state:
state, auth_chain = yield replication.get_state_for_room(
origin, context=event.room_id, event_id=event.event_id,
)
if not auth_chain:
auth_chain = yield replication.get_event_auth(
origin,
context=event.room_id,
event_id=event.event_id,
)
for e in auth_chain:
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e, fetch_auth_from=origin)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
current_state = state current_state = state
if state: if state and auth_chain is not None:
for e in state: for e in state:
logging.info("A :) %r", e)
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e) auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(origin, e, auth_events=auth)
except: except:
logger.exception( logger.exception(
"Failed to handle state event %s", "Failed to handle state event %s",
@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
try: try:
yield self._handle_new_event( yield self._handle_new_event(
origin,
event, event,
state=state, state=state,
backfilled=backfilled, backfilled=backfilled,
@ -393,8 +343,19 @@ class FederationHandler(BaseHandler):
for e in auth_chain: for e in auth_chain:
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
if e.event_id == event.event_id:
continue
try: try:
yield self._handle_new_event(e) auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
target_host, e, auth_events=auth
)
except: except:
logger.exception( logger.exception(
"Failed to handle auth event %s", "Failed to handle auth event %s",
@ -402,11 +363,18 @@ class FederationHandler(BaseHandler):
) )
for e in state: for e in state:
# FIXME: Auth these. if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event( yield self._handle_new_event(
e, fetch_auth_from=target_host target_host, e, auth_events=auth
) )
except: except:
logger.exception( logger.exception(
@ -414,10 +382,18 @@ class FederationHandler(BaseHandler):
e.event_id, e.event_id,
) )
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event( yield self._handle_new_event(
target_host,
new_event, new_event,
state=state, state=state,
current_state=state, current_state=state,
auth_events=auth_events,
) )
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
@ -481,7 +457,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context = yield self._handle_new_event(event) context = yield self._handle_new_event(origin, event)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@ -682,11 +658,12 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None) waiters.pop().callback(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, @log_function
current_state=None, fetch_auth_from=None): def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
logger.debug( logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s", "_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures, event.event_id, event.signatures,
) )
@ -694,65 +671,44 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
if not auth_events:
auth_events = context.auth_events
logger.debug( logger.debug(
"_handle_new_event: Before auth fetch: %s, sigs: %s", "_handle_new_event: %s, auth_events: %s",
event.event_id, event.signatures, event.event_id, auth_events,
) )
is_new_state = not event.internal_metadata.is_outlier() is_new_state = not event.internal_metadata.is_outlier()
known_ids = set( # This is a hack to fix some old rooms where the initial join event
[s.event_id for s in context.auth_events.values()] # didn't reference the create event in its auth events.
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(e_id, allow_none=True)
if not e and fetch_auth_from is not None:
# Grab the auth_chain over federation if we are missing
# auth events.
auth_chain = yield self.replication_layer.get_event_auth(
fetch_auth_from, event.event_id, event.room_id
)
for auth_event in auth_chain:
yield self._handle_new_event(auth_event)
e = yield self.store.get_event(e_id, allow_none=True)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in db or %s",
event.event_id, e_id, known_ids,
)
# FIXME: How does raising AuthError work with federation?
raise AuthError(403, "Cannot find auth event")
context.auth_events[(e.type, e.state_key)] = e
logger.debug(
"_handle_new_event: Before hack: %s, sigs: %s",
event.event_id, event.signatures,
)
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1: if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0]) c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create: if c.type == EventTypes.Create:
context.auth_events[(c.type, c.state_key)] = c auth_events[(c.type, c.state_key)] = c
logger.debug( try:
"_handle_new_event: Before auth check: %s, sigs: %s", yield self.do_auth(
event.event_id, event.signatures, origin, event, context, auth_events=auth_events
)
except AuthError as e:
logger.warn(
"Rejecting %s because %s",
event.event_id, e.msg
) )
self.auth.check(event, auth_events=context.auth_events) context.rejected = RejectedReason.AUTH_ERROR
logger.debug( yield self.store.persist_event(
"_handle_new_event: Before persist_event: %s, sigs: %s", event,
event.event_id, event.signatures, context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
) )
raise
yield self.store.persist_event( yield self.store.persist_event(
event, event,
@ -762,9 +718,294 @@ class FederationHandler(BaseHandler):
current_state=current_state, current_state=current_state,
) )
logger.debug( defer.returnValue(context)
"_handle_new_event: After persist_event: %s, sigs: %s",
event.event_id, event.signatures, @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
missing):
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
yield self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain([event_id])
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = yield self.construct_auth_difference(
local_auth_chain, remote_auth_chain
) )
defer.returnValue(context) for event in ret["auth_chain"]:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
logger.debug("on_query_auth reutrning: %s", ret)
defer.returnValue(ret)
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
res = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events)
seen_events = set(res.keys())
missing_auth = event_auth_events - seen_events
if missing_auth:
logger.debug("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
remote_auth_chain = yield self.replication_layer.get_event_auth(
origin, event.room_id, event.event_id
)
seen_remotes = yield self.store.have_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
if e.event_id in seen_remotes.keys():
continue
if e.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids
}
e.internal_metadata.outlier = True
logger.debug(
"do_auth %s missing_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, e, auth_events=auth
)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
except AuthError:
pass
# FIXME: Assumes we have and stored all the state for all the
# prev_events
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.debug("Different auth: %s", different_auth)
# 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events(event, context)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
# 2. Get remote difference.
result = yield self.replication_layer.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
seen_remotes = yield self.store.have_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes.keys():
continue
if ev.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in ev.auth_events]
auth = {
(e.type, e.state_key): e for e in result["auth_chain"]
if e.event_id in auth_ids
}
ev.internal_metadata.outlier = True
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, ev, auth_events=auth
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
# 4. Look at rejects and their proofs.
# TODO.
context.current_state.update(auth_events)
context.state_group = None
try:
self.auth.check(event, auth_events=auth_events)
except AuthError:
raise
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
local_auth (list)
remote_auth (list)
Returns:
dict
"""
logger.debug("construct_auth_difference Start!")
# TODO: Make sure we are OK with local_auth or remote_auth having more
# auth events in them than strictly necessary.
def sort_fun(ev):
return ev.depth, ev.event_id
logger.debug("construct_auth_difference after sort_fun!")
# We find the differences by starting at the "bottom" of each list
# and iterating up on both lists. The lists are ordered by depth and
# then event_id, we iterate up both lists until we find the event ids
# don't match. Then we look at depth/event_id to see which side is
# missing that event, and iterate only up that list. Repeat.
remote_list = list(remote_auth)
remote_list.sort(key=sort_fun)
local_list = list(local_auth)
local_list.sort(key=sort_fun)
local_iter = iter(local_list)
remote_iter = iter(remote_list)
logger.debug("construct_auth_difference before get_next!")
def get_next(it, opt=None):
try:
return it.next()
except:
return opt
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
logger.debug("construct_auth_difference before while")
missing_remotes = []
missing_locals = []
while current_local or current_remote:
if current_remote is None:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local is None:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
if current_local.event_id == current_remote.event_id:
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
continue
if current_local.depth < current_remote.depth:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local.depth > current_remote.depth:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
# They have the same depth, so we fall back to the event_id order
if current_local.event_id < current_remote.event_id:
missing_locals.append(current_local)
current_local = get_next(local_iter)
if current_local.event_id > current_remote.event_id:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
logger.debug("construct_auth_difference after while")
# missing locals should be sent to the server
# We should find why we are missing remotes, as they will have been
# rejected.
# Remove events from missing_remotes if they are referencing a missing
# remote. We only care about the "root" rejected ones.
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id, _ in e.auth_events:
if e_id in missing_remote_ids:
base_remote_rejected.remove(e)
reason_map = {}
for e in base_remote_rejected:
reason = yield self.store.get_rejection_reason(e.event_id)
if reason is None:
# FIXME: ERRR?!
logger.warn("Could not find reason for %s", e.event_id)
raise RuntimeError("")
reason_map[e.event_id] = reason
if reason == RejectedReason.AUTH_ERROR:
pass
elif reason == RejectedReason.REPLACED:
# TODO: Get proof
pass
elif reason == RejectedReason.NOT_ANCESTOR:
# TODO: Get proof.
pass
logger.debug("construct_auth_difference returning")
defer.returnValue({
"auth_chain": local_auth,
"rejects": {
e.event_id: {
"reason": reason_map[e.event_id],
"proof": None,
}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
})

View File

@ -114,7 +114,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True): def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
@ -148,6 +149,15 @@ class MessageHandler(BaseHandler):
builder.content builder.content
) )
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
) )

View File

@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler):
"changed_presencelike_data", self.changed_presencelike_data "changed_presencelike_data", self.changed_presencelike_data
) )
# outbound signal from the presence module to advertise when a user's
# presence has changed
distributor.declare("user_presence_changed")
self.distributor = distributor self.distributor = distributor
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids, room_ids=room_ids,
statuscache=statuscache, statuscache=statuscache,
) )
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None): def _push_presence_remote(self, user, destination, state=None):

View File

@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler):
# each request # each request
httpCli = SimpleHttpClient(self.hs) httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
trustedIdServers = ['matrix.org:8090'] trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers: if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer']) 'credentials', creds['idServer'])

434
synapse/handlers/sync.py Normal file
View File

@ -0,0 +1,434 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from twisted.internet import defer
import collections
import logging
logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"client_info",
"limit",
"gap",
"sort",
"backfill",
"filter",
])
class RoomSyncResult(collections.namedtuple("RoomSyncResult", [
"room_id",
"limited",
"published",
"events",
"state",
"prev_batch",
"ephemeral",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.events or self.state or self.ephemeral)
class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync
"private_user_data", # List of private events for the user.
"public_user_data", # List of public events for all users.
"rooms", # RoomSyncResult for each room.
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if the notifier needs to wait for more events when polling for
events.
"""
return bool(
self.private_user_data or self.public_user_data or self.rooms
)
class SyncHandler(BaseHandler):
def __init__(self, hs):
super(SyncHandler, self).__init__(hs)
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
Returns:
A Deferred SyncResult.
"""
if timeout == 0 or since_token is None:
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback():
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
)
defer.returnValue(result)
def current_sync_for_user(self, sync_config, since_token=None):
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
"""
if since_token is None:
return self.initial_sync(sync_config)
else:
if sync_config.gap:
return self.incremental_sync_with_gap(sync_config, since_token)
else:
#TODO(mjark): Handle gapless sync
raise NotImplementedError()
@defer.inlineCallbacks
def initial_sync(self, sync_config):
"""Get a sync for a client which is starting without any state
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
presence_stream = self.event_sources.sources["presence"]
# TODO (mjark): This looks wrong, shouldn't we be getting the presence
# UP to the present rather than after the present?
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user=sync_config.user,
pagination_config=pagination_config.get_source_config("presence"),
key=None
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(),
membership_list=[Membership.INVITE, Membership.JOIN]
)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
rooms = []
for event in room_list:
room_sync = yield self.initial_sync_for_room(
event.room_id, sync_config, now_token, published_room_ids
)
rooms.append(room_sync)
defer.returnValue(SyncResult(
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def initial_sync_for_room(self, room_id, sync_config, now_token,
published_room_ids):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred RoomSyncResult.
"""
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token,
)
current_state_events = yield self.state_handler.get_current_state(
room_id
)
defer.returnValue(RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=current_state_events,
limited=limited,
ephemeral=[],
))
@defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to
date with the server.
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.presence_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.typing_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
room_id=None,
limit=sync_config.limit + 1,
)
rooms = []
if len(room_events) <= sync_config.limit:
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
events_by_room_id = {}
for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event)
for room_id in room_ids:
recents = events_by_room_id.get(room_id, [])
state = [event for event in recents if event.is_state()]
if recents:
prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
state = yield self.check_joined_room(
sync_config, room_id, state
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch,
state=state,
limited=False,
ephemeral=typing_by_room.get(room_id, [])
)
if room_sync:
rooms.append(room_sync)
else:
for room_id in room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
published_room_ids, typing_by_room
)
if room_sync:
rooms.append(room_sync)
defer.returnValue(SyncResult(
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
limited = True
recents = []
filtering_factor = 2
load_limit = max(sync_config.limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise
room_key = now_token.room_key
while limited and len(recents) < sync_config.limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room(
room_id,
limit=load_limit + 1,
from_token=since_token.room_key if since_token else None,
end_token=room_key,
)
(room_key, _) = keys
loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
limited = False
max_repeat -= 1
if len(recents) > sync_config.limit:
recents = recents[-sync_config.limit:]
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(
"room_key", room_key
)
defer.returnValue((recents, prev_batch_token, limited))
@defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
published_room_ids, typing_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred RoomSyncResult
"""
# TODO(mjark): Check for redactions we might have missed.
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logging.debug("Recents %r", recents)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
current_state_events = yield self.state_handler.get_current_state(
room_id
)
state_at_previous_sync = yield self.get_state_at_previous_sync(
room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state_events,
)
state_events_delta = yield self.check_joined_room(
sync_config, room_id, state_events_delta
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=state_events_delta,
limited=limited,
ephemeral=typing_by_room.get(room_id, [])
)
logging.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made.
Returns:
A Deferred list of Events.
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1,
)
if last_events:
last_event = last_events[0]
last_context = yield self.state_handler.compute_event_context(
last_event
)
if last_event.is_state():
state = [last_event] + last_context.current_state.values()
else:
state = last_context.current_state.values()
else:
state = ()
defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the
state the client got when it last performed a sync.
Returns:
A list of events.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
previous_dict = {event.event_id: event for event in previous_state}
state_delta = []
for event in current_state:
if event.event_id not in previous_dict:
state_delta.append(event)
return state_delta
@defer.inlineCallbacks
def check_joined_room(self, sync_config, room_id, state_delta):
joined = False
for event in state_delta:
if (
event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()
):
if event.content["membership"] == Membership.JOIN:
joined = True
if joined:
state_delta = yield self.state_handler.get_current_state(room_id)
defer.returnValue(state_delta)

View File

@ -62,6 +62,25 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json):
json_str = json.dumps(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri)
response = yield self.agent.request(
"POST",
uri.encode("ascii"),
headers=Headers({
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}): def get_json(self, uri, args={}):
""" Get's some json from the given host and path """ Get's some json from the given host and path

View File

@ -244,6 +244,43 @@ class MatrixFederationHttpClient(object):
defer.returnValue((response.code, body)) defer.returnValue((response.code, body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}):
""" Sends the specifed json data using POST
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._create_request(
destination.encode("ascii"),
"POST",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue((response.code, body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True): def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" GETs some json from the given host homeserver and path """ GETs some json from the given host homeserver and path

View File

@ -16,7 +16,7 @@
from synapse.http.agent_name import AGENT_NAME from synapse.http.agent_name import AGENT_NAME
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource):
return return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
self._send_response( raise UnrecognizedRequestError()
request,
400,
{"error": "Unrecognized request"}
)
except CodeMessageException as e: except CodeMessageException as e:
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
import logging import logging
@ -205,6 +206,53 @@ class Notifier(object):
[notify(l).addErrback(eb) for l in listeners] [notify(l).addErrback(eb) for l in listeners]
) )
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
deferred = defer.Deferred()
from_token = StreamToken("s0", "0", "0")
listener = [_NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)]
if timeout:
self._register_with_keys(listener[0])
result = yield callback()
if timeout:
timed_out = [False]
def _timeout_listener():
timed_out[0] = True
listener[0].notify(self, [], from_token, from_token)
self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)
self._register_with_keys(listener[0])
result = yield callback()
defer.returnValue(result)
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any

410
synapse/push/__init__.py Normal file
View File

@ -0,0 +1,410 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
import synapse.util.async
import baserules
import logging
import fnmatch
import json
import re
logger = logging.getLogger(__name__)
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, _hs, instance_handle, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.instance_handle = instance_handle
self.user_name = user_name
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
self.pushkey = pushkey
self.pushkey_ts = pushkey_ts
self.data = data
self.last_token = last_token
self.last_success = last_success # not actually used
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since
self.alive = True
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@defer.inlineCallbacks
def _actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify'])
if ev['type'] == 'm.room.member':
if ev['state_key'] != self.user_name:
defer.returnValue(['dont_notify'])
rules = yield self.store.get_push_rules_for_user_name(self.user_name)
for r in rules:
r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions'])
user_name_localpart = UserID.from_string(self.user_name).localpart
rules.extend(baserules.make_base_rules(user_name_localpart))
# get *our* member event for display name matching
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=None
)
my_display_name = None
room_member_count = 0
for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules:
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count
)
# ignore rules with no actions (we have an explict 'dont_notify'
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s" %
(r['rule_id'], r['user_name'])
)
continue
if matches:
defer.returnValue(actions)
defer.returnValue(Pusher.DEFAULT_ACTIONS)
def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
pat = condition['pattern']
if pat.strip("*?[]") == pat:
# no special glob characters so we assume the user means
# 'contains this string' rather than 'is this string'
pat = "*%s*" % (pat,)
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return fnmatch.fnmatch(val.upper(), pat.upper())
elif condition['kind'] == 'device':
if 'instance_handle' not in condition:
return True
return condition['instance_handle'] == self.instance_handle
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return fnmatch.fnmatch(
ev['content']['body'].upper(), "*%s*" % (display_name.upper(),)
)
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = Pusher.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else:
return True
@defer.inlineCallbacks
def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases(
ev['room_id']
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=ev['user_id']
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)
@defer.inlineCallbacks
def start(self):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
# because we need the result to be reproduceable in case
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.user_name, self.pushkey, self.last_token)
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
while self.alive:
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk
)
if not self.alive:
continue
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success(
self.user_name,
self.pushkey,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.user_name,
self.pushkey,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.user_name,
self.pushkey,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
def dispatch_push(self, p, tweaks):
"""
Overridden by implementing classes to actually deliver the notification
Args:
p: The event to notify for as a single event from the event stream
Returns: If the notification was delivered, an array containing any
pushkeys that were rejected by the push gateway.
False if the notification could not be delivered (ie.
should be retried).
"""
pass
def reset_badge_count(self):
pass
def presence_changed(self, state):
"""
We clear badge counts whenever a user's last_active time is bumped
This is by no means perfect but I think it's the best we can do
without read receipts.
"""
if 'last_active' in state.state:
last_active = state.state['last_active']
if last_active > self.last_last_active_time:
self.last_last_active_time = last_active
if self.has_unread:
logger.info("Resetting badge count for %s", self.user_name)
self.reset_badge_count()
self.has_unread = False
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val
def _tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_sound' in a:
tweaks['sound'] = a['set_sound']
return tweaks
class PusherConfigException(Exception):
def __init__(self, msg):
super(PusherConfigException, self).__init__(msg)

48
synapse/push/baserules.py Normal file
View File

@ -0,0 +1,48 @@
def make_base_rules(user_name):
rules = [
{
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern': '*%s*' % (user_name,), # Matrix ID match
}
],
'actions': [
'notify',
{
'set_sound': 'default'
}
]
},
{
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_sound': 'default'
}
]
},
{
'conditions': [
{
'kind': 'room_member_count',
'is': '2'
}
],
'actions': [
'notify',
{
'set_sound': 'default'
}
]
}
]
for r in rules:
r['priority_class'] = 0
return rules

146
synapse/push/httppusher.py Normal file
View File

@ -0,0 +1,146 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push import Pusher, PusherConfigException
from synapse.http.client import SimpleHttpClient
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
def __init__(self, _hs, instance_handle, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
instance_handle,
user_name,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
data,
last_token,
last_success,
failing_since
)
if 'url' not in data:
raise PusherConfigException(
"'url' required in data for HTTP pusher"
)
self.url = data['url']
self.httpCli = SimpleHttpClient(self.hs)
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks):
# we probably do not want to push for every presence update
# (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones)
# Actually, presence events will not get this far now because we
# need to filter them out in the main Pusher code.
if 'event_id' not in event:
defer.returnValue(None)
ctx = yield self.get_context_for_event(event)
d = {
'notification': {
'id': event['event_id'],
'type': event['type'],
'sender': event['user_id'],
'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing
# Just set the badge to 1 until we have read receipts
'unread': 1,
# 'missed_calls': 2
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
'tweaks': tweaks
}
]
}
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
if 'content' in event:
d['notification']['content'] = event['content']
if len(ctx['aliases']):
d['notification']['room_alias'] = ctx['aliases'][0]
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0:
d['notification']['room_name'] = ctx['name']
defer.returnValue(d)
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks):
notification_dict = yield self._build_notification_dict(event, tweaks)
if not notification_dict:
defer.returnValue([])
try:
resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)
@defer.inlineCallbacks
def reset_badge_count(self):
d = {
'notification': {
'id': '',
'type': None,
'sender': '',
'counts': {
'unread': 0,
'missed_calls': 0
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
}
]
}
}
try:
resp = yield self.httpCli.post_json_get_json(self.url, d)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)

152
synapse/push/pusherpool.py Normal file
View File

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
import logging
import json
logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
self.store = self.hs.get_datastore()
self.pushers = {}
self.last_pusher_started = -1
distributor = self.hs.get_distributor()
distributor.observe(
"user_presence_changed", self.user_presence_changed
)
@defer.inlineCallbacks
def user_presence_changed(self, user, state):
user_name = user.to_string()
# until we have read receipts, pushers use this to reset a user's
# badge counters to zero
for p in self.pushers.values():
if p.user_name == user_name:
yield p.presence_changed(state)
@defer.inlineCallbacks
def start(self):
pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers)
@defer.inlineCallbacks
def add_pusher(self, user_name, instance_handle, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
"user_name": user_name,
"kind": kind,
"instance_handle": instance_handle,
"app_id": app_id,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(),
"lang": lang,
"data": data,
"last_token": None,
"last_success": None,
"failing_since": None
})
yield self._add_pusher_to_store(
user_name, instance_handle, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_name=user_name,
instance_handle=instance_handle,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=json.dumps(data)
)
self._refresh_pusher((app_id, pushkey))
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
instance_handle=pusherdict['instance_handle'],
user_name=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
failing_since=pusherdict['failing_since']
)
else:
raise PusherConfigException(
"Unknown pusher type '%s' for user %s" %
(pusherdict['kind'], pusherdict['user_name'])
)
@defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey):
p = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey
)
p['data'] = json.loads(p['data'])
self._start_pushers([p])
def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
p = self._create_pusher(pusherdict)
if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
p.start()
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey):
fullid = "%s:%s" % (app_id, pushkey)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)

View File

@ -6,7 +6,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil==0.0.2": ["syutil"], "syutil==0.0.2": ["syutil"],
"matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"], "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"],
"Twisted>=14.0.0": ["twisted>=14.0.0"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],

View File

@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import ( from . import (
room, events, register, login, profile, presence, initial_sync, directory, room, events, register, login, profile, presence, initial_sync, directory,
voip, admin, voip, admin, pusher, push_rule
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -41,3 +40,5 @@ class ClientV1RestResource(JsonResource):
directory.register_servlets(hs, client_resource) directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)

View File

@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:

View File

@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
if not "room_id" in content: if not "room_id" in content:
@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, room_alias): def on_DELETE(self, request, room_alias):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:

View File

@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)

View File

@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)

View File

@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
state = {} state = {}
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):

View File

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
try: try:

View File

@ -0,0 +1,406 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \
StoreError
from .base import ClientV1RestServlet, client_path_pattern
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
import synapse.push.baserules as baserules
import json
class PushRuleRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$")
PRIORITY_CLASS_MAP = {
'default': 0,
'underride': 1,
'sender': 2,
'room': 3,
'content': 4,
'override': 5,
}
PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
def rule_spec_from_path(self, path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
if scope not in ['global', 'device']:
raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
template = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
rule_id = path[0]
spec = {
'scope': scope,
'template': template,
'rule_id': rule_id
}
if device:
spec['device'] = device
return spec
def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None):
if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
conditions = req_obj['conditions']
for c in conditions:
if 'kind' not in c:
raise InvalidRuleException("Condition without 'kind'")
elif rule_template == 'room':
conditions = [{
'kind': 'event_match',
'key': 'room_id',
'pattern': rule_id
}]
elif rule_template == 'sender':
conditions = [{
'kind': 'event_match',
'key': 'user_id',
'pattern': rule_id
}]
elif rule_template == 'content':
if 'pattern' not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
pat = req_obj['pattern']
conditions = [{
'kind': 'event_match',
'key': 'content.body',
'pattern': pat
}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'instance_handle': device
})
if 'actions' not in req_obj:
raise InvalidRuleException("No actions found")
actions = req_obj['actions']
for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']:
pass
elif isinstance(a, dict) and 'set_sound' in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
return conditions, actions
@defer.inlineCallbacks
def on_PUT(self, request):
spec = self.rule_spec_from_path(request.postpath)
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request)
if spec['template'] == 'default':
raise SynapseError(403, "The default rules are immutable.")
content = _parse_json(request)
try:
(conditions, actions) = self.rule_tuple_from_request_object(
spec['template'],
spec['rule_id'],
content,
device=spec['device'] if 'device' in spec else None
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
if before and len(before):
before = before[0]
after = request.args.get("after", None)
if after and len(after):
after = after[0]
try:
yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(),
rule_id=spec['rule_id'],
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after
)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
except RuleNotFoundException as e:
raise SynapseError(400, e.message)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request):
spec = self.rule_spec_from_path(request.postpath)
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request)
if 'device' in spec:
rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
user.to_string()
)
for r in rules:
conditions = json.loads(r['conditions'])
ih = _instance_handle_from_conditions(conditions)
if ih == spec['device'] and r['priority_class'] == priority_class:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), spec['rule_id']
)
defer.returnValue((200, {}))
raise NotFoundError()
else:
try:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), spec['rule_id'],
priority_class=priority_class
)
defer.returnValue((200, {}))
except StoreError as e:
if e.code == 404:
raise NotFoundError()
else:
raise
@defer.inlineCallbacks
def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string())
for r in rawrules:
r["conditions"] = json.loads(r["conditions"])
r["actions"] = json.loads(r["actions"])
rawrules.extend(baserules.make_base_rules(user.to_string()))
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
for r in rawrules:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
# per-device rule
instance_handle = _instance_handle_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not instance_handle:
continue
if instance_handle not in rules['device']:
rules['device'][instance_handle] = {}
rules['device'][instance_handle] = (
_add_empty_priority_class_arrays(
rules['device'][instance_handle]
)
)
rulearray = rules['device'][instance_handle][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
rulearray.append(template_rule)
path = request.postpath[1:]
if path == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules))
elif path[0] == 'global':
path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
instance_handle = path[0]
path = path[1:]
if instance_handle not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][instance_handle]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else:
raise UnrecognizedRequestError()
def on_OPTIONS(self, _):
return 200, {}
def _add_empty_priority_class_arrays(d):
for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _instance_handle_from_conditions(conditions):
"""
Given a list of conditions, return the instance handle of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['instance_handle']
return None
def _filter_ruleset_with_path(ruleset, path):
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset
template_kind = path[0]
if template_kind not in ruleset:
raise UnrecognizedRequestError()
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset[template_kind]
rule_id = path[0]
for r in ruleset[template_kind]:
if r['rule_id'] == rule_id:
return r
raise NotFoundError
def _priority_class_from_spec(spec):
if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
return pc
def _priority_class_to_template_name(pc):
if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
# per-device
prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['default']:
return {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ['override', 'underride']:
return {k: rule[k] for k in ["rule_id", "conditions", "actions"]}
elif template_name in ["sender", "room"]:
return {k: rule[k] for k in ["rule_id", "actions"]}
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
ret = {k: rule[k] for k in ["rule_id", "actions"]}
ret["pattern"] = thecond["pattern"]
return ret
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
class InvalidRuleException(Exception):
pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server)

View File

@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_pattern
import json
class PusherRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushers/set$")
@defer.inlineCallbacks
def on_POST(self, request):
user, _ = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
pusher_pool = self.hs.get_pusherpool()
if ('pushkey' in content and 'app_id' in content
and 'kind' in content and
content['kind'] is None):
yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey']
)
defer.returnValue((200, {}))
reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name',
'device_display_name', 'pushkey', 'lang', 'data']
missing = []
for i in reqd:
if i not in content:
missing.append(i)
if len(missing):
raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM)
try:
yield pusher_pool.add_pusher(
user_name=user.to_string(),
instance_handle=content['instance_handle'],
kind=content['kind'],
app_id=content['app_id'],
app_display_name=content['app_display_name'],
device_display_name=content['device_display_name'],
pushkey=content['pushkey'],
lang=content['lang'],
data=content['data']
)
except PusherConfigException as pce:
raise SynapseError(400, "Config Error: "+pce.message,
errcode=Codes.MISSING_PARAM)
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.",
errcode=Codes.NOT_JSON)
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server)

View File

@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request) room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None) info = yield self.make_room(room_config, auth_user, None)
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -142,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, data.get_dict()["content"])) defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key): def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -158,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(event_dict) yield msg_handler.create_and_send_event(
event_dict, client=client, txn_id=txn_id,
)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -172,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server, with_get=True) register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type): def on_POST(self, request, room_id, event_type, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -183,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"content": content, "content": content,
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -200,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_id, event_type) response = yield self.on_POST(request, room_id, event_type, txn_id)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -215,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier): def on_POST(self, request, room_identifier, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the # the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid # other if it fails to parse, without swallowing other valid
@ -245,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
"room_id": identifier.to_string(), "room_id": identifier.to_string(),
"sender": user.to_string(), "sender": user.to_string(),
"state_key": user.to_string(), "state_key": user.to_string(),
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {"room_id": identifier.to_string()})) defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -259,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_identifier) response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -283,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens) # TODO support Pagination stream API (limit/tokens)
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk( members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id, room_id=room_id,
@ -311,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request( pagination_config = PaginationConfig.from_request(
request, default_limit=10, request, default_limit=10,
) )
@ -335,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
@ -351,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync( content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id, room_id=room_id,
@ -395,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action): def on_POST(self, request, room_id, membership_action, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -418,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"state_key": state_key, "state_key": state_key,
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -432,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_id, membership_action) response = yield self.on_POST(
request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -444,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id): def on_POST(self, request, room_id, event_id, txn_id=None):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
@ -456,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"room_id": room_id, "room_id": room_id,
"sender": user.to_string(), "sender": user.to_string(),
"redacts": event_id, "redacts": event_id,
} },
client=client,
txn_id=txn_id,
) )
defer.returnValue((200, {"event_id": event.event_id})) defer.returnValue((200, {"event_id": event.event_id}))
@ -470,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
except KeyError: except KeyError:
pass pass
response = yield self.on_POST(request, room_id, event_id) response = yield self.on_POST(request, room_id, event_id, txn_id)
self.txns.store_client_transaction(request, txn_id, response) self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response) defer.returnValue(response)
@ -483,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id): def on_PUT(self, request, room_id, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id) room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id)) target_user = UserID.from_string(urllib.unquote(user_id))

View File

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret turnSecret = self.hs.config.turn_shared_secret

View File

@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import (
sync,
filter
)
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -26,4 +30,5 @@ class ClientV2AlphaRestResource(JsonResource):
@staticmethod @staticmethod
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
pass sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)

View File

@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from ._base import client_v2_pattern
import json
import logging
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot get filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only get filters for local users")
try:
filter_id = int(filter_id)
except:
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart,
filter_id=filter_id,
)
defer.returnValue((200, filter.filter_json))
except KeyError:
raise SynapseError(400, "No such filter")
class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, client = yield self.auth.get_user_by_req(request)
if target_user != auth_user:
raise AuthError(403, "Cannot create filters for other users")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only create filters for local users")
try:
content = json.loads(request.content.read())
# TODO(paul): check for required keys and invalid keys
except:
raise SynapseError(400, "Invalid filter definition")
filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart,
user_filter=content,
)
defer.returnValue((200, {"filter_id": str(filter_id)}))
def register_servlets(hs, http_server):
GetFilterRestServlet(hs).register(http_server)
CreateFilterRestServlet(hs).register(http_server)

View File

@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken
from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id,
)
from synapse.api.filtering import Filter
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
class SyncRestServlet(RestServlet):
"""
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
limit(int): Maxiumum number of events per room to return.
gap(bool): Create gaps the message history if limit is exceeded to
ensure that the client has the most recent messages. Defaults to
"true".
sort(str,str): tuple of sort key (e.g. "timeline") and direction
(e.g. "asc", "desc"). Defaults to "timeline,asc".
since(batch_token): Batch token when asking for incremental deltas.
set_presence(str): What state the device presence should be set to.
default is "online".
backfill(bool): Should the HS request message history from other
servers. This may take a long time making it unsuitable for clients
expecting a prompt response. Defaults to "true".
filter(filter_id): A filter to apply to the events returned.
filter_*: Filter override parameters.
Response JSON::
{
"next_batch": // batch token for the next /sync
"private_user_data": // private events for this user.
"public_user_data": // public events for all users including the
// public events for this user.
"rooms": [{ // List of rooms with updates.
"room_id": // Id of the room being updated
"limited": // Was the per-room event limit exceeded?
"published": // Is the room published by our HS?
"event_map": // Map of EventID -> event JSON.
"events": { // The recent events in the room if gap is "true"
// otherwise the next events in the room.
"batch": [] // list of EventIDs in the "event_map".
"prev_batch": // back token for getting previous events.
}
"state": [] // list of EventIDs updating the current state to
// be what it should be at the end of the batch.
"ephemeral": []
}]
}
"""
PATTERN = client_v2_pattern("/sync$")
ALLOWED_SORT = set(["timeline,asc", "timeline,desc"])
ALLOWED_PRESENCE = set(["online", "offline", "idle"])
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
self.auth = hs.get_auth()
self.sync_handler = hs.get_handlers().sync_handler
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def on_GET(self, request):
user, client = yield self.auth.get_user_by_req(request)
timeout = self.parse_integer(request, "timeout", default=0)
limit = self.parse_integer(request, "limit", required=True)
gap = self.parse_boolean(request, "gap", default=True)
sort = self.parse_string(
request, "sort", default="timeline,asc",
allowed_values=self.ALLOWED_SORT
)
since = self.parse_string(request, "since")
set_presence = self.parse_string(
request, "set_presence", default="online",
allowed_values=self.ALLOWED_PRESENCE
)
backfill = self.parse_boolean(request, "backfill", default=False)
filter_id = self.parse_string(request, "filter", default=None)
logger.info(
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
" set_presence=%r, backfill=%r, filter_id=%r" % (
user, timeout, limit, gap, sort, since, set_presence,
backfill, filter_id
)
)
# TODO(mjark): Load filter and apply overrides.
try:
filter = yield self.filtering.get_user_filter(
user.localpart, filter_id
)
except:
filter = Filter({})
# filter = filter.apply_overrides(http_request)
#if filter.matches(event):
# # stuff
sync_config = SyncConfig(
user=user,
client_info=client,
gap=gap,
limit=limit,
sort=sort,
backfill=backfill,
filter=filter,
)
if since is not None:
since_token = StreamToken.from_string(since)
else:
since_token = None
sync_result = yield self.sync_handler.wait_for_sync_for_user(
sync_config, since_token=since_token, timeout=timeout
)
time_now = self.clock.time_msec()
response_content = {
"public_user_data": self.encode_user_data(
sync_result.public_user_data, filter, time_now
),
"private_user_data": self.encode_user_data(
sync_result.private_user_data, filter, time_now
),
"rooms": self.encode_rooms(
sync_result.rooms, filter, time_now, client.token_id
),
"next_batch": sync_result.next_batch.to_string(),
}
defer.returnValue((200, response_content))
def encode_user_data(self, events, filter, time_now):
return events
def encode_rooms(self, rooms, filter, time_now, token_id):
return [
self.encode_room(room, filter, time_now, token_id)
for room in rooms
]
@staticmethod
def encode_room(room, filter, time_now, token_id):
event_map = {}
state_events = filter.filter_room_state(room.state)
recent_events = filter.filter_room_events(room.events)
state_event_ids = []
recent_event_ids = []
for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
state_event_ids.append(event.event_id)
for event in recent_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
recent_event_ids.append(event.event_id)
result = {
"room_id": room.room_id,
"event_map": event_map,
"events": {
"batch": recent_event_ids,
"prev_batch": room.prev_batch.to_string(),
},
"state": state_event_ids,
"limited": room.limited,
"published": room.published,
"ephemeral": room.ephemeral,
}
return result
def register_servlets(hs, http_server):
SyncRestServlet(hs).register(http_server)

View File

@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def map_request_to_name(self, request): def map_request_to_name(self, request):
# auth the user # auth the user
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user # namespace all file uploads on the user
prefix = base64.urlsafe_b64encode( prefix = base64.urlsafe_b64encode(

View File

@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_POST(self, request): def _async_render_POST(self, request):
try: try:
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have # TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point # already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length") content_length = request.getHeader("Content-Length")

View File

@ -31,7 +31,9 @@ from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring from synapse.crypto.keyring import Keyring
from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory from synapse.events.builder import EventBuilderFactory
from synapse.api.filtering import Filtering
class BaseHomeServer(object): class BaseHomeServer(object):
@ -79,7 +81,9 @@ class BaseHomeServer(object):
'event_sources', 'event_sources',
'ratelimiter', 'ratelimiter',
'keyring', 'keyring',
'pusherpool',
'event_builder_factory', 'event_builder_factory',
'filtering',
] ]
def __init__(self, hostname, **kwargs): def __init__(self, hostname, **kwargs):
@ -198,3 +202,9 @@ class HomeServer(BaseHomeServer):
clock=self.get_clock(), clock=self.get_clock(),
hostname=self.hostname, hostname=self.hostname,
) )
def build_filtering(self):
return Filtering(self)
def build_pusherpool(self):
return PusherPool(self)

View File

@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
@ -36,12 +37,16 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
class StateHandler(object): class StateHandler(object):
""" Responsible for doing state conflict resolution. """ Responsible for doing state conflict resolution.
""" """
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs
@defer.inlineCallbacks @defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""): def get_current_state(self, room_id, event_type=None, state_key=""):
@ -163,10 +168,17 @@ class StateHandler(object):
first is the name of a state group if one and only one is involved, first is the name of a state group if one and only one is involved,
otherwise `None`. otherwise `None`.
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
event_ids event_ids
) )
logger.debug(
"resolve_state_groups state_groups %s",
state_groups.keys()
)
group_names = set(state_groups.keys()) group_names = set(state_groups.keys())
if len(group_names) == 1: if len(group_names) == 1:
name, state_list = state_groups.items().pop() name, state_list = state_groups.items().pop()
@ -210,64 +222,93 @@ class StateHandler(object):
else: else:
prev_states = [] prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
}
try: try:
new_state = {} resolved_state = self._resolve_state_events(
new_state.update(unconflicted_state) conflicted_state, auth_events
for key, events in conflicted_state.items(): )
new_state[key] = self._resolve_state_events(events)
except: except:
logger.exception("Failed to resolve state") logger.exception("Failed to resolve state")
raise raise
new_state = unconflicted_state
new_state.update(resolved_state)
defer.returnValue((None, new_state, prev_states)) defer.returnValue((None, new_state, prev_states))
def _get_power_level_from_event_state(self, event, user_id):
if hasattr(event, "old_state_events") and event.old_state_events:
key = (EventTypes.PowerLevels, "", )
power_level_event = event.old_state_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(
user_id
)
if not level:
level = power_level_event.content.get("users_default", 0)
return level
else:
return 0
@log_function @log_function
def _resolve_state_events(self, events): def _resolve_state_events(self, conflicted_state, auth_events):
curr_events = events """ This is where we actually decide which of the conflicted state to
use.
new_powers = [ We resolve conflicts in the following order:
self._get_power_level_from_event_state(e, e.user_id) 1. power levels
for e in curr_events 2. memberships
] 3. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state.items():
power_levels = conflicted_state[power_key]
resolved_state[power_key] = self._resolve_auth_events(power_levels)
new_powers = [ auth_events.update(resolved_state)
int(p) if p else 0 for p in new_powers
]
max_power = max(new_powers) for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
curr_events = [ resolved_state[key] = self._resolve_auth_events(
z[0] for z in zip(curr_events, new_powers) events,
if z[1] == max_power auth_events
]
if not curr_events:
raise RuntimeError("Max didn't get a max?")
elif len(curr_events) == 1:
return curr_events[0]
# TODO: For now, just choose the one with the largest event_id.
return (
sorted(
curr_events,
key=lambda e: hashlib.sha1(
e.event_id + e.user_id + e.room_id + e.type
).hexdigest()
)[0]
) )
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
resolved_state[key] = self._resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
auth_events = dict(auth_events)
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
self.hs.get_auth().check(event, auth_events)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
return sorted(events, key=key_func)

View File

@ -30,10 +30,14 @@ from .stream import StreamStore
from .transactions import TransactionStore from .transactions import TransactionStore
from .keys import KeyStore from .keys import KeyStore
from .event_federation import EventFederationStore from .event_federation import EventFederationStore
from .pusher import PusherStore
from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .rejections import RejectionsStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore
from syutil.base64util import decode_base64 from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
@ -61,14 +65,20 @@ SCHEMAS = [
"state", "state",
"event_edges", "event_edges",
"event_signatures", "event_signatures",
"pusher",
"media_repository", "media_repository",
<<<<<<< HEAD
"application_services" "application_services"
=======
"filtering",
"rejections",
>>>>>>> develop
] ]
# Remember to update this number every time an incompatible change is made to # Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts. # database schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 11 SCHEMA_VERSION = 12
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
@ -82,8 +92,17 @@ class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, TransactionStore, PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore, DirectoryStore, KeyStore, StateStore, SignatureStore,
<<<<<<< HEAD
EventFederationStore, MediaRepositoryStore, EventFederationStore, MediaRepositoryStore,
ApplicationServiceStore ApplicationServiceStore
=======
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringStore,
PusherStore,
PushRuleStore
>>>>>>> develop
): ):
def __init__(self, hs): def __init__(self, hs):
@ -226,6 +245,9 @@ class DataStore(RoomMemberStore, RoomStore,
if not outlier: if not outlier:
self._store_state_groups_txn(txn, event, context) self._store_state_groups_txn(txn, event, context)
if context.rejected:
self._store_rejections_txn(txn, event.event_id, context.rejected)
if current_state: if current_state:
txn.execute( txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?", "DELETE FROM current_state_events WHERE room_id = ?",
@ -264,7 +286,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True, or_replace=True,
) )
if is_new_state: if is_new_state and not context.rejected:
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
"current_state_events", "current_state_events",
@ -290,7 +312,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_ignore=True, or_ignore=True,
) )
if not backfilled: if not backfilled and not context.rejected:
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
table="state_forward_extremities", table="state_forward_extremities",
@ -372,9 +394,12 @@ class DataStore(RoomMemberStore, RoomStore,
"redacted": del_sql, "redacted": del_sql,
} }
if event_type: if event_type and state_key is not None:
sql += " AND s.type = ? AND s.state_key = ? " sql += " AND s.type = ? AND s.state_key = ? "
args = (room_id, event_type, state_key) args = (room_id, event_type, state_key)
elif event_type:
sql += " AND s.type = ?"
args = (room_id, event_type)
else: else:
args = (room_id, ) args = (room_id, )
@ -383,6 +408,41 @@ class DataStore(RoomMemberStore, RoomStore,
events = yield self._parse_events(results) events = yield self._parse_events(results)
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
del_sql = (
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
"LIMIT 1"
)
sql = (
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
"INNER JOIN state_events as s ON e.event_id = s.event_id "
"WHERE c.room_id = ? "
) % {
"redacted": del_sql,
}
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
sql += " OR s.type = 'm.room.aliases')"
args = (room_id,)
results = yield self._execute_and_decode(sql, *args)
events = yield self._parse_events(results)
name = None
aliases = []
for e in events:
if e.type == 'm.room.name':
name = e.content['name']
elif e.type == 'm.room.aliases':
aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_min_token(self): def _get_min_token(self):
row = yield self._execute( row = yield self._execute(
@ -419,6 +479,35 @@ class DataStore(RoomMemberStore, RoomStore,
], ],
) )
def have_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Returns:
dict: Has an entry for each event id we already have seen. Maps to
the rejected reason string if we rejected the event, else maps to
None.
"""
def f(txn):
sql = (
"SELECT e.event_id, reason FROM events as e "
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
"WHERE e.event_id = ?"
)
res = {}
for event_id in event_ids:
txn.execute(sql, (event_id,))
row = txn.fetchone()
if row:
_, rejected = row
res[event_id] = rejected
return res
return self.runInteraction(
"have_events", f,
)
def schema_path(schema): def schema_path(schema):
""" Get a filesystem path for the named database schema """ Get a filesystem path for the named database schema

View File

@ -193,6 +193,50 @@ class SQLBaseStore(object):
txn.execute(sql, values.values()) txn.execute(sql, values.values())
return txn.lastrowid return txn.lastrowid
def _simple_upsert(self, table, keyvalues, values):
"""
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
values (dict): The nonunique columns and their new values
Returns: A deferred
"""
return self.runInteraction(
"_simple_upsert",
self._simple_upsert_txn, table, keyvalues, values
)
def _simple_upsert_txn(self, txn, table, keyvalues, values):
# Try to update
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
sqlargs = values.values() + keyvalues.values()
logger.debug(
"[SQL] %s Args=%s",
sql, sqlargs,
)
txn.execute(sql, sqlargs)
if txn.rowcount == 0:
# We didn't update and rows so insert a new one
allvalues = {}
allvalues.update(keyvalues)
allvalues.update(values)
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
logger.debug(
"[SQL] %s Args=%s",
sql, keyvalues.values(),
)
txn.execute(sql, allvalues.values())
def _simple_select_one(self, table, keyvalues, retcols, def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False): allow_none=False):
"""Executes a SELECT query on the named table, which is expected to """Executes a SELECT query on the named table, which is expected to
@ -344,8 +388,8 @@ class SQLBaseStore(object):
if updatevalues: if updatevalues:
update_sql = "UPDATE %s SET %s WHERE %s" % ( update_sql = "UPDATE %s SET %s WHERE %s" % (
table, table,
", ".join("%s = ?" % (k) for k in updatevalues), ", ".join("%s = ?" % (k,) for k in updatevalues),
" AND ".join("%s = ?" % (k) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
def func(txn): def func(txn):
@ -458,10 +502,12 @@ class SQLBaseStore(object):
return [e for e in events if e] return [e for e in events if e]
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False): get_prev_content=False, allow_rejected=False):
sql = ( sql = (
"SELECT internal_metadata, json, r.event_id FROM event_json as e " "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
"FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts " "LEFT JOIN redactions as r ON e.event_id = r.redacts "
"LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? " "WHERE e.event_id = ? "
"LIMIT 1 " "LIMIT 1 "
) )
@ -473,13 +519,16 @@ class SQLBaseStore(object):
if not res: if not res:
return None return None
internal_metadata, js, redacted = res internal_metadata, js, redacted, rejected_reason = res
if allow_rejected or not rejected_reason:
return self._get_event_from_row_txn( return self._get_event_from_row_txn(
txn, internal_metadata, js, redacted, txn, internal_metadata, js, redacted,
check_redacted=check_redacted, check_redacted=check_redacted,
get_prev_content=get_prev_content, get_prev_content=get_prev_content,
) )
else:
return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False): check_redacted=True, get_prev_content=False):

View File

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import SQLBaseStore
import json
class FilteringStore(SQLBaseStore):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
retcol="filter_json",
allow_none=False,
)
defer.returnValue(json.loads(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = json.dumps(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
# INSERT a new one
def _do_txn(txn):
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
filter_id = 0
else:
filter_id = max_id + 1
sql = (
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
return filter_id
return self.runInteraction("add_user_filter", _do_txn)

View File

@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
import logging
import copy
import json
logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks
def get_push_rules_for_user_name(self, user_name):
sql = (
"SELECT "+",".join(PushRuleTable.fields)+" "
"FROM "+PushRuleTable.table_name+" "
"WHERE user_name = ? "
"ORDER BY priority_class DESC, priority DESC"
)
rows = yield self._execute(None, sql, user_name)
dicts = []
for r in rows:
d = {}
for i, f in enumerate(PushRuleTable.fields):
d[f] = r[i]
dicts.append(d)
defer.returnValue(dicts)
@defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs)
if 'conditions' in vals:
vals['conditions'] = json.dumps(vals['conditions'])
if 'actions' in vals:
vals['actions'] = json.dumps(vals['actions'])
# we could check the rest of the keys are valid column names
# but sqlite will do that anyway so I think it's just pointless.
if 'id' in vals:
del vals['id']
if before or after:
ret = yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
before=before,
after=after,
**vals
)
defer.returnValue(ret)
else:
ret = yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
**vals
)
defer.returnValue(ret)
def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
after = None
relative_to_rule = None
if 'after' in kwargs and kwargs['after']:
after = kwargs['after']
relative_to_rule = after
if 'before' in kwargs and kwargs['before']:
relative_to_rule = kwargs['before']
# get the priority of the rule we're inserting after/before
sql = (
"SELECT priority_class, priority FROM ? "
"WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
)
txn.execute(sql, (user_name, relative_to_rule))
res = txn.fetchall()
if not res:
raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule))
priority_class, base_rule_priority = res[0]
if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
raise InconsistentRuleException(
"Given priority class does not match class of relative rule"
)
new_rule = copy.copy(kwargs)
if 'before' in new_rule:
del new_rule['before']
if 'after' in new_rule:
del new_rule['after']
new_rule['priority_class'] = priority_class
new_rule['user_name'] = user_name
# check if the priority before/after is free
new_rule_priority = base_rule_priority
if after:
new_rule_priority -= 1
else:
new_rule_priority += 1
new_rule['priority'] = new_rule_priority
sql = (
"SELECT COUNT(*) FROM " + PushRuleTable.table_name +
" WHERE user_name = ? AND priority_class = ? AND priority = ?"
)
txn.execute(sql, (user_name, priority_class, new_rule_priority))
res = txn.fetchall()
num_conflicting = res[0][0]
# if there are conflicting rules, bump everything
if num_conflicting:
sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
if after:
sql += "-1"
else:
sql += "+1"
sql += " WHERE user_name = ? AND priority_class = ? AND priority "
if after:
sql += "<= ?"
else:
sql += ">= ?"
txn.execute(sql, (user_name, priority_class, new_rule_priority))
# now insert the new rule
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
def _add_push_rule_highest_priority_txn(self, txn, user_name,
priority_class, **kwargs):
# find the highest priority rule in that class
sql = (
"SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name +
" WHERE user_name = ? and priority_class = ?"
)
txn.execute(sql, (user_name, priority_class))
res = txn.fetchall()
(how_many, highest_prio) = res[0]
new_prio = 0
if how_many > 0:
new_prio = highest_prio + 1
# and insert the new rule
new_rule = copy.copy(kwargs)
if 'id' in new_rule:
del new_rule['id']
new_rule['user_name'] = user_name
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
sql += ",".join(new_rule.keys())+") VALUES ("
sql += ", ".join(["?" for _ in new_rule.keys()])+")"
txn.execute(sql, new_rule.values())
@defer.inlineCallbacks
def delete_push_rule(self, user_name, rule_id, **kwargs):
"""
Delete a push rule. Args specify the row to be deleted and can be
any of the columns in the push_rule table, but below are the
standard ones
Args:
user_name (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
yield self._simple_delete_one(PushRuleTable.table_name, kwargs)
class RuleNotFoundException(Exception):
pass
class InconsistentRuleException(Exception):
pass
class PushRuleTable(Table):
table_name = "push_rules"
fields = [
"id",
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
]
EntryType = collections.namedtuple("PushRuleEntry", fields)

173
synapse/storage/pusher.py Normal file
View File

@ -0,0 +1,173 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from twisted.internet import defer
from synapse.api.errors import StoreError
import logging
logger = logging.getLogger(__name__)
class PusherStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
sql = (
"SELECT id, user_name, kind, instance_handle, app_id,"
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers "
"WHERE app_id = ? AND pushkey = ?"
)
rows = yield self._execute(
None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1]
)
ret = [
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"instance_handle": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": r[9],
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret[0])
@defer.inlineCallbacks
def get_all_pushers(self):
sql = (
"SELECT id, user_name, kind, instance_handle, app_id,"
"app_display_name, device_display_name, pushkey, ts, data, "
"last_token, last_success, failing_since "
"FROM pushers"
)
rows = yield self._execute(None, sql)
ret = [
{
"id": r[0],
"user_name": r[1],
"kind": r[2],
"instance_handle": r[3],
"app_id": r[4],
"app_display_name": r[5],
"device_display_name": r[6],
"pushkey": r[7],
"pushkey_ts": r[8],
"data": r[9],
"last_token": r[10],
"last_success": r[11],
"failing_since": r[12]
}
for r in rows
]
defer.returnValue(ret)
@defer.inlineCallbacks
def add_pusher(self, user_name, instance_handle, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data):
try:
yield self._simple_upsert(
PushersTable.table_name,
dict(
app_id=app_id,
pushkey=pushkey,
),
dict(
user_name=user_name,
kind=kind,
instance_handle=instance_handle,
app_display_name=app_display_name,
device_display_name=device_display_name,
ts=pushkey_ts,
lang=lang,
data=data
))
except Exception as e:
logger.error("create_pusher with failed: %s", e)
raise StoreError(500, "Problem creating pusher.")
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
yield self._simple_delete_one(
PushersTable.table_name,
dict(app_id=app_id, pushkey=pushkey)
)
@defer.inlineCallbacks
def update_pusher_last_token(self, user_name, pushkey, last_token):
yield self._simple_update_one(
PushersTable.table_name,
{'user_name': user_name, 'pushkey': pushkey},
{'last_token': last_token}
)
@defer.inlineCallbacks
def update_pusher_last_token_and_success(self, user_name, pushkey,
last_token, last_success):
yield self._simple_update_one(
PushersTable.table_name,
{'user_name': user_name, 'pushkey': pushkey},
{'last_token': last_token, 'last_success': last_success}
)
@defer.inlineCallbacks
def update_pusher_failing_since(self, user_name, pushkey, failing_since):
yield self._simple_update_one(
PushersTable.table_name,
{'user_name': user_name, 'pushkey': pushkey},
{'failing_since': failing_since}
)
class PushersTable(Table):
table_name = "pushers"
fields = [
"id",
"user_name",
"kind",
"instance_handle",
"app_id",
"app_display_name",
"device_display_name",
"pushkey",
"pushkey_ts",
"data",
"last_token",
"last_success",
"failing_since"
]
EntryType = collections.namedtuple("PusherEntry", fields)

View File

@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token):
sql = ( sql = (
"SELECT users.name, users.admin, access_tokens.device_id" "SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id"
" FROM users" " FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id" " INNER JOIN access_tokens on users.id = access_tokens.user_id"
" WHERE token = ?" " WHERE token = ?"

View File

@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
import logging
logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
self._simple_insert_txn(
txn,
table="rejections",
values={
"event_id": event_id,
"reason": reason,
"last_check": self._clock.time_msec(),
}
)
def get_rejection_reason(self, event_id):
return self._simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={
"event_id": event_id,
},
allow_none=True,
)

View File

@ -0,0 +1,54 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL,
reason TEXT NOT NULL,
last_check TEXT NOT NULL,
CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
);
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
instance_handle varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class TINYINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);

View File

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View File

@ -0,0 +1,24 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS user_filters(
user_id TEXT,
filter_id INTEGER,
filter_json TEXT,
FOREIGN KEY(user_id) REFERENCES users(id)
);
CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
user_id, filter_id
);

View File

@ -0,0 +1,46 @@
/* Copyright 2014 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Push notification endpoints that users have configured
CREATE TABLE IF NOT EXISTS pushers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
instance_handle varchar(32) NOT NULL,
kind varchar(8) NOT NULL,
app_id varchar(64) NOT NULL,
app_display_name varchar(64) NOT NULL,
device_display_name varchar(128) NOT NULL,
pushkey blob NOT NULL,
ts BIGINT NOT NULL,
lang varchar(8),
data blob,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name),
UNIQUE (app_id, pushkey)
);
CREATE TABLE IF NOT EXISTS push_rules (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_name TEXT NOT NULL,
rule_id TEXT NOT NULL,
priority_class TINYINT NOT NULL,
priority INTEGER NOT NULL DEFAULT 0,
conditions TEXT NOT NULL,
actions TEXT NOT NULL,
UNIQUE(user_name, rule_id)
);
CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);

View File

@ -0,0 +1,21 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS rejections(
event_id TEXT NOT NULL,
reason TEXT NOT NULL,
last_check TEXT NOT NULL,
CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
);

View File

@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse(cls, string): def parse(cls, string):
try: try:
if string[0] == 's': if string[0] == 's':
return cls(None, int(string[1:])) return cls(topological=None, stream=int(string[1:]))
if string[0] == 't': if string[0] == 't':
parts = string[1:].split('-', 1) parts = string[1:].split('-', 1)
return cls(int(parts[1]), int(parts[0])) return cls(topological=int(parts[0]), stream=int(parts[1]))
except: except:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))
@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse_stream_token(cls, string): def parse_stream_token(cls, string):
try: try:
if string[0] == 's': if string[0] == 's':
return cls(None, int(string[1:])) return cls(topological=None, stream=int(string[1:]))
except: except:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))
@ -181,8 +181,11 @@ class StreamStore(SQLBaseStore):
get_prev_content=True get_prev_content=True
) )
self._set_before_and_after(ret, rows)
if rows: if rows:
key = "s%d" % max([r["stream_ordering"] for r in rows]) key = "s%d" % max([r["stream_ordering"] for r in rows])
else: else:
# Assume we didn't get anything because there was nothing to # Assume we didn't get anything because there was nothing to
# get. # get.
@ -260,22 +263,44 @@ class StreamStore(SQLBaseStore):
get_prev_content=True get_prev_content=True
) )
self._set_before_and_after(events, rows)
return events, next_token, return events, next_token,
return self.runInteraction("paginate_room_events", f) return self.runInteraction("paginate_room_events", f)
def get_recent_events_for_room(self, room_id, limit, end_token, def get_recent_events_for_room(self, room_id, limit, end_token,
with_feedback=False): with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback # TODO (erikj): Handle compressed feedback
end_token = _StreamToken.parse_stream_token(end_token)
if from_token is None:
sql = ( sql = (
"SELECT stream_ordering, topological_ordering, event_id FROM events " "SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
" WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0" " WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0"
"ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ? " " ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
)
else:
from_token = _StreamToken.parse_stream_token(from_token)
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
" WHERE room_id = ? AND stream_ordering > ?"
" AND stream_ordering <= ? AND outlier = 0"
" ORDER BY topological_ordering DESC, stream_ordering DESC"
" LIMIT ?"
) )
def f(txn): def get_recent_events_for_room_txn(txn):
txn.execute(sql, (room_id, end_token, limit,)) if from_token is None:
txn.execute(sql, (room_id, end_token.stream, limit,))
else:
txn.execute(sql, (
room_id, from_token.stream, end_token.stream, limit
))
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
@ -291,9 +316,9 @@ class StreamStore(SQLBaseStore):
toke = rows[0]["stream_ordering"] - 1 toke = rows[0]["stream_ordering"] - 1
start_token = str(_StreamToken(topo, toke)) start_token = str(_StreamToken(topo, toke))
token = (start_token, end_token) token = (start_token, str(end_token))
else: else:
token = (end_token, end_token) token = (str(end_token), str(end_token))
events = self._get_events_txn( events = self._get_events_txn(
txn, txn,
@ -301,9 +326,13 @@ class StreamStore(SQLBaseStore):
get_prev_content=True get_prev_content=True
) )
self._set_before_and_after(events, rows)
return events, token return events, token
return self.runInteraction("get_recent_events_for_room", f) return self.runInteraction(
"get_recent_events_for_room", get_recent_events_for_room_txn
)
def get_room_events_max_id(self): def get_room_events_max_id(self):
return self.runInteraction( return self.runInteraction(
@ -325,3 +354,12 @@ class StreamStore(SQLBaseStore):
key = res[0]["m"] key = res[0]["m"]
return "s%d" % (key,) return "s%d" % (key,)
@staticmethod
def _set_before_and_after(events, rows):
for event, row in zip(events, rows):
stream = row["stream_ordering"]
topo = event.depth
internal = event.internal_metadata
internal.before = str(_StreamToken(topo, stream - 1))
internal.after = str(_StreamToken(topo, stream))

View File

@ -119,3 +119,6 @@ class StreamToken(
d = self._asdict() d = self._asdict()
d[key] = new_value d[key] = new_value
return StreamToken(**d) return StreamToken(**d)
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

512
tests/api/test_filtering.py Normal file
View File

@ -0,0 +1,512 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from tests import unittest
from twisted.internet import defer
from mock import Mock, NonCallableMock
from tests.utils import (
MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool,
MockKey
)
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.api.filtering import Filter
user_localpart = "test_user"
MockEvent = namedtuple("MockEvent", "sender type room_id")
class FilteringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
db_pool = SQLiteMemoryDbPool()
yield db_pool.prepare()
self.mock_config = NonCallableMock()
self.mock_config.signing_key = [MockKey()]
self.mock_federation_resource = MockHttpResource()
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
hs = HomeServer("test",
db_pool=db_pool,
handlers=None,
http_client=self.mock_http_client,
config=self.mock_config,
keyring=Mock(),
)
self.filtering = hs.get_filtering()
self.filter = Filter({})
self.datastore = hs.get_datastore()
def test_definition_types_works_with_literals(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_types_works_with_wildcards(self):
definition = {
"types": ["m.*", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_types_works_with_unknowns(self):
definition = {
"types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="now.for.something.completely.different",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_literals(self):
definition = {
"not_types": ["m.room.message", "org.matrix.foo.bar"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_wildcards(self):
definition = {
"not_types": ["m.room.message", "org.matrix.*"]
}
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_works_with_unknowns(self):
definition = {
"not_types": ["m.*", "org.*"]
}
event = MockEvent(
sender="@foo:bar",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_not_types_takes_priority_over_types(self):
definition = {
"not_types": ["m.*", "org.*"],
"types": ["m.room.message", "m.room.topic"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_senders_works_with_literals(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_senders_works_with_unknowns(self):
definition = {
"senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_literals(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@flibble:wibble",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_works_with_unknowns(self):
definition = {
"not_senders": ["@flibble:wibble"]
}
event = MockEvent(
sender="@challenger:appears",
type="com.nom.nom.nom",
room_id="!foo:bar"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_not_senders_takes_priority_over_senders(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets", "@misspiggy:muppets"]
}
event = MockEvent(
sender="@misspiggy:muppets",
type="m.room.topic",
room_id="!foo:bar"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_rooms_works_with_literals(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_rooms_works_with_unknowns(self):
definition = {
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_literals(self):
definition = {
"not_rooms": ["!anothersecretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_works_with_unknowns(self):
definition = {
"not_rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!anothersecretbase:unknown"
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_not_rooms_takes_priority_over_rooms(self):
definition = {
"not_rooms": ["!secretbase:unknown"],
"rooms": ["!secretbase:unknown"]
}
event = MockEvent(
sender="@foo:bar",
type="m.room.message",
room_id="!secretbase:unknown"
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertTrue(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_sender(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@misspiggy:muppets", # nope
type="m.room.message", # yup
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_room(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="m.room.message", # yup
room_id="!piggyshouse:muppets" # nope
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
def test_definition_combined_event_bad_type(self):
definition = {
"not_senders": ["@misspiggy:muppets"],
"senders": ["@kermit:muppets"],
"rooms": ["!stage:unknown"],
"not_rooms": ["!piggyshouse:muppets"],
"types": ["m.room.message", "muppets.kermit.*"],
"not_types": ["muppets.misspiggy.*"]
}
event = MockEvent(
sender="@kermit:muppets", # yup
type="muppets.misspiggy.kisses", # nope
room_id="!stage:unknown" # yup
)
self.assertFalse(
self.filter._passes_definition(definition, event)
)
@defer.inlineCallbacks
def test_filter_public_user_data_match(self):
user_filter_json = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.profile",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_public_user_data(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_public_user_data_no_match(self):
user_filter_json = {
"public_user_data": {
"types": ["m.*"]
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="custom.avatar.3d.crazy",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_public_user_data(events=events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_filter_room_state_match(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="m.room.topic",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_room_state(events=events)
self.assertEquals(events, results)
@defer.inlineCallbacks
def test_filter_room_state_no_match(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
user = UserID.from_string("@" + user_localpart + ":test")
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
event = MockEvent(
sender="@foo:bar",
type="org.matrix.custom.event",
room_id="!foo:bar"
)
events = [event]
user_filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
results = user_filter.filter_room_state(events)
self.assertEquals([], results)
@defer.inlineCallbacks
def test_add_filter(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
self.assertEquals(filter_id, 0)
self.assertEquals(user_filter_json,
(yield self.datastore.get_user_filter(
user_localpart=user_localpart,
filter_id=0,
))
)
@defer.inlineCallbacks
def test_get_filter(self):
user_filter_json = {
"room": {
"state": {
"types": ["m.*"]
}
}
}
filter_id = yield self.datastore.add_user_filter(
user_localpart=user_localpart,
user_filter=user_filter_json,
)
filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart,
filter_id=filter_id,
)
self.assertEquals(filter.filter_json, user_filter_json)

View File

@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
"get_room", "get_room",
"get_destination_retry_timings", "get_destination_retry_timings",
"set_destination_retry_timings", "set_destination_retry_timings",
"have_events",
]), ]),
resource_for_federation=NonCallableMock(), resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]), http_client=NonCallableMock(spec_set=[]),
@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.persist_event.return_value = defer.succeed(None)
self.datastore.get_room.return_value = defer.succeed(True) self.datastore.get_room.return_value = defer.succeed(True)
self.auth.check_host_in_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True)
self.datastore.have_events.return_value = defer.succeed({})
def annotate(ev, old_state=None): def annotate(ev, old_state=None):
context = Mock() context = Mock()

View File

@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
"user": UserID.from_string(myid), "user": UserID.from_string(myid),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.handlers.room_member_handler = Mock( hs.handlers.room_member_handler = Mock(
@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
hs.get_clock().time_msec.return_value = 1000000 hs.get_clock().time_msec.return_value = 1000000
def _get_user_by_req(req=None): def _get_user_by_req(req=None):
return UserID.from_string(myid) return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View File

@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase):
) )
def _get_user_by_req(request=None): def _get_user_by_req(request=None):
return UserID.from_string(myid) return (UserID.from_string(myid), "")
hs.get_auth().get_user_by_req = _get_user_by_req hs.get_auth().get_user_by_req = _get_user_by_req

View File

@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token

View File

@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id), "user": UserID.from_string(self.auth_user_id),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token

View File

@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase):
hs = HomeServer("test", hs = HomeServer("test",
db_pool=None, db_pool=None,
datastore=Mock(spec=[ datastore=self.make_datastore_mock(),
"insert_client_ip",
]),
http_client=None, http_client=None,
resource_for_client=self.mock_resource, resource_for_client=self.mock_resource,
resource_for_federation=self.mock_resource, resource_for_federation=self.mock_resource,
@ -53,8 +51,14 @@ class V2AlphaRestTestCase(unittest.TestCase):
"user": UserID.from_string(self.USER_ID), "user": UserID.from_string(self.USER_ID),
"admin": False, "admin": False,
"device_id": None, "device_id": None,
"token_id": 1,
} }
hs.get_auth().get_user_by_token = _get_user_by_token hs.get_auth().get_user_by_token = _get_user_by_token
for r in self.TO_REGISTER: for r in self.TO_REGISTER:
r.register_servlets(hs, self.mock_resource) r.register_servlets(hs, self.mock_resource)
def make_datastore_mock(self):
return Mock(spec=[
"insert_client_ip",
])

View File

@ -0,0 +1,95 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from mock import Mock
from . import V2AlphaRestTestCase
from synapse.rest.client.v2_alpha import filter
from synapse.api.errors import StoreError
class FilterTestCase(V2AlphaRestTestCase):
USER_ID = "@apple:test"
TO_REGISTER = [filter]
def make_datastore_mock(self):
datastore = super(FilterTestCase, self).make_datastore_mock()
self._user_filters = {}
def add_user_filter(user_localpart, definition):
filters = self._user_filters.setdefault(user_localpart, [])
filter_id = len(filters)
filters.append(definition)
return defer.succeed(filter_id)
datastore.add_user_filter = add_user_filter
def get_user_filter(user_localpart, filter_id):
if user_localpart not in self._user_filters:
raise StoreError(404, "No user")
filters = self._user_filters[user_localpart]
if filter_id >= len(filters):
raise StoreError(404, "No filter")
return defer.succeed(filters[filter_id])
datastore.get_user_filter = get_user_filter
return datastore
@defer.inlineCallbacks
def test_add_filter(self):
(code, response) = yield self.mock_resource.trigger("POST",
"/user/%s/filter" % (self.USER_ID),
'{"type": ["m.*"]}'
)
self.assertEquals(200, code)
self.assertEquals({"filter_id": "0"}, response)
self.assertIn("apple", self._user_filters)
self.assertEquals(len(self._user_filters["apple"]), 1)
self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0])
@defer.inlineCallbacks
def test_get_filter(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(200, code)
self.assertEquals({"type": ["m.*"]}, response)
@defer.inlineCallbacks
def test_get_filter_no_id(self):
self._user_filters["apple"] = [
{"type": ["m.*"]}
]
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/2" % (self.USER_ID), None
)
self.assertEquals(404, code)
@defer.inlineCallbacks
def test_get_filter_no_user(self):
(code, response) = yield self.mock_resource.trigger("GET",
"/user/%s/filter/0" % (self.USER_ID), None
)
self.assertEquals(404, code)

View File

@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
) )
self.assertEquals( self.assertEquals(
{"admin": 0, "device_id": None, "name": self.user_id}, {"admin": 0,
"device_id": None,
"name": self.user_id,
"token_id": 1},
(yield self.store.get_user_by_token(self.tokens[0])) (yield self.store.get_user_by_token(self.tokens[0]))
) )
@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals( self.assertEquals(
{"admin": 0, "device_id": None, "name": self.user_id}, {"admin": 0,
"device_id": None,
"name": self.user_id,
"token_id": 2},
(yield self.store.get_user_by_token(self.tokens[1])) (yield self.store.get_user_by_token(self.tokens[1]))
) )

View File

@ -16,11 +16,120 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.state import StateHandler from synapse.state import StateHandler
from mock import Mock from mock import Mock
_next_event_id = 1000
def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
prev_events=[], **kwargs):
global _next_event_id
if not event_id:
_next_event_id += 1
event_id = str(_next_event_id)
if not name:
if state_key is not None:
name = "<%s-%s, %s>" % (type, state_key, event_id,)
else:
name = "<%s, %s>" % (type, event_id,)
d = {
"event_id": event_id,
"type": type,
"sender": "@user_id:example.com",
"room_id": "!room_id:example.com",
"depth": depth,
"prev_events": prev_events,
}
if state_key is not None:
d["state_key"] = state_key
d.update(kwargs)
event = FrozenEvent(d)
return event
class StateGroupStore(object):
def __init__(self):
self._event_to_state_group = {}
self._group_to_state = {}
self._next_group = 1
def get_state_groups(self, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
if group:
groups[group] = self._group_to_state[group]
return defer.succeed(groups)
def store_state_groups(self, event, context):
if context.current_state is None:
return
state_events = context.current_state
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
self._group_to_state[state_group] = state_events.values()
self._event_to_state_group[event.event_id] = state_group
class DictObj(dict):
def __init__(self, **kwargs):
super(DictObj, self).__init__(kwargs)
self.__dict__ = self
class Graph(object):
def __init__(self, nodes, edges):
events = {}
clobbered = set(events.keys())
for event_id, fields in nodes.items():
refs = edges.get(event_id)
if refs:
clobbered.difference_update(refs)
prev_events = [(r, {}) for r in refs]
else:
prev_events = []
events[event_id] = create_event(
event_id=event_id,
prev_events=prev_events,
**fields
)
self._leaves = clobbered
self._events = sorted(events.values(), key=lambda e: e.depth)
def walk(self):
return iter(self._events)
def get_leaves(self):
return (self._events[i] for i in self._leaves)
class StateTestCase(unittest.TestCase): class StateTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.store = Mock( self.store = Mock(
@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase):
"add_event_hashes", "add_event_hashes",
] ]
) )
hs = Mock(spec=["get_datastore"]) hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"])
hs.get_datastore.return_value = self.store hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_auth.return_value = Auth(hs)
self.state = StateHandler(hs) self.state = StateHandler(hs)
self.event_id = 0 self.event_id = 0
@defer.inlineCallbacks
def test_branch_no_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="",
depth=1,
),
"A": DictObj(
type=EventTypes.Message,
depth=2,
),
"B": DictObj(
type=EventTypes.Message,
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"D": DictObj(
type=EventTypes.Message,
depth=4,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].current_state))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="creator",
content={"membership": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
),
"D": DictObj(
type=EventTypes.Message,
depth=5,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["A"],
"D": ["B", "C"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"START", "A", "C"},
{e.event_id for e in context_store["D"].current_state.values()}
)
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
graph = Graph(
nodes={
"START": DictObj(
type=EventTypes.Create,
state_key="creator",
content={"membership": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={"membership": Membership.JOIN},
membership=Membership.JOIN,
depth=2,
),
"B": DictObj(
type=EventTypes.Name,
state_key="",
depth=3,
),
"C": DictObj(
type=EventTypes.Member,
state_key="@user_id_2:example.com",
content={"membership": Membership.BAN},
membership=Membership.BAN,
depth=4,
),
"D": DictObj(
type=EventTypes.Name,
state_key="",
depth=4,
sender="@user_id_2:example.com",
),
"E": DictObj(
type=EventTypes.Message,
depth=5,
),
},
edges={
"A": ["START"],
"B": ["A"],
"C": ["B"],
"D": ["B"],
"E": ["C", "D"]
}
)
store = StateGroupStore()
self.store.get_state_groups.side_effect = store.get_state_groups
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
store.store_state_groups(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
{"START", "A", "B", "C"},
{e.event_id for e in context_store["E"].current_state.values()}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_message(self): def test_annotate_with_old_message(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(
@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_annotate_with_old_state(self): def test_annotate_with_old_state(self):
event = self.create_event(type="state", state_key="", name="event") event = create_event(type="state", state_key="", name="event")
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context( context = yield self.state.compute_event_context(
@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_message(self): def test_trivial_annotate_message(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
event.prev_events = []
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = "group_name_1"
@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_trivial_annotate_state(self): def test_trivial_annotate_state(self):
event = self.create_event(type="state", state_key="", name="event") event = create_event(type="state", state_key="", name="event")
event.prev_events = []
old_state = [ old_state = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = "group_name_1" group_name = "group_name_1"
@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_message_conflict(self): def test_resolve_message_conflict(self):
event = self.create_event(type="test_message", name="event") event = create_event(type="test_message", name="event")
event.prev_events = []
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
group_name_1 = "group_name_1" context = yield self._get_context(event, old_state_1, old_state_2)
group_name_2 = "group_name_2"
self.store.get_state_groups.return_value = {
group_name_1: old_state_1,
group_name_2: old_state_2,
}
context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5) self.assertEqual(len(context.current_state), 5)
@ -181,21 +447,70 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_resolve_state_conflict(self): def test_resolve_state_conflict(self):
event = self.create_event(type="test4", state_key="", name="event") event = create_event(type="test4", state_key="", name="event")
event.prev_events = []
old_state_1 = [ old_state_1 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test1", state_key="2"), create_event(type="test1", state_key="2"),
self.create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
old_state_2 = [ old_state_2 = [
self.create_event(type="test1", state_key="1"), create_event(type="test1", state_key="1"),
self.create_event(type="test3", state_key="2"), create_event(type="test3", state_key="2"),
self.create_event(type="test4", state_key=""), create_event(type="test4", state_key=""),
] ]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
event = create_event(type="test4", name="event")
member_event = create_event(
type=EventTypes.Member,
state_key="@user_id:example.com",
content={
"membership": Membership.JOIN,
}
)
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
old_state_1 = [
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
member_event,
create_event(type="test1", state_key="1", depth=1),
]
context = yield self._get_context(event, old_state_1, old_state_2)
self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1" group_name_1 = "group_name_1"
group_name_2 = "group_name_2" group_name_2 = "group_name_2"
@ -204,33 +519,4 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2, group_name_2: old_state_2,
} }
context = yield self.state.compute_event_context(event) return self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5)
self.assertIsNone(context.state_group)
def create_event(self, name=None, type=None, state_key=None):
self.event_id += 1
event_id = str(self.event_id)
if not name:
if state_key is not None:
name = "<%s-%s>" % (type, state_key)
else:
name = "<%s>" % (type, )
event = Mock(name=name, spec=[])
event.type = type
if state_key is not None:
event.state_key = state_key
event.event_id = event_id
event.is_state = lambda: (state_key is not None)
event.unsigned = {}
event.user_id = "@user_id:example.com"
event.room_id = "!room_id:example.com"
return event