Merge branch 'develop' into markjh/twisted-15

Conflicts:
	synapse/http/matrixfederationclient.py
This commit is contained in:
Mark Haines 2015-08-12 17:07:22 +01:00
commit 998a72d4d9
96 changed files with 4528 additions and 1677 deletions

View File

@ -38,3 +38,10 @@ Brabo <brabo at riseup.net>
Ivan Shapovalov <intelfx100 at gmail.com> Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration * contrib/systemd: a sample systemd unit file and a logger configuration
Eric Myhre <hash at exultant.us>
* Fix bug where ``media_store_path`` config option was ignored by v0 content
repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and logins.

View File

@ -1,3 +1,54 @@
Changes in synapse v0.9.3 (2015-07-01)
======================================
No changes from v0.9.3 Release Candidate 1.
Changes in synapse v0.9.3-rc1 (2015-06-23)
==========================================
General:
* Fix a memory leak in the notifier. (SYN-412)
* Improve performance of room initial sync. (SYN-418)
* General improvements to logging.
* Remove ``access_token`` query params from ``INFO`` level logging.
Configuration:
* Add support for specifying and configuring multiple listeners. (SYN-389)
Application services:
* Fix bug where synapse failed to send user queries to application services.
Changes in synapse v0.9.2-r2 (2015-06-15)
=========================================
Fix packaging so that schema delta python files get included in the package.
Changes in synapse v0.9.2 (2015-06-12)
======================================
General:
* Use ultrajson for json (de)serialisation when a canonical encoding is not
required. Ultrajson is significantly faster than simplejson in certain
circumstances.
* Use connection pools for outgoing HTTP connections.
* Process thumbnails on separate threads.
Configuration:
* Add option, ``gzip_responses``, to disable HTTP response compression.
Federation:
* Improve resilience of backfill by ensuring we fetch any missing auth events.
* Improve performance of backfill and joining remote rooms by removing
unnecessary computations. This included handling events we'd previously
handled as well as attempting to compute the current state for outliers.
Changes in synapse v0.9.1 (2015-05-26) Changes in synapse v0.9.1 (2015-05-26)
====================================== ======================================

View File

@ -5,6 +5,7 @@ include *.rst
include demo/README include demo/README
recursive-include synapse/storage/schema *.sql recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.py
recursive-include demo *.dh recursive-include demo *.dh
recursive-include demo *.py recursive-include demo *.py

View File

@ -101,36 +101,40 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian:: Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux:: Installing prerequisites on ArchLinux::
$ sudo pacman -S base-devel python2 python-pip \ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3 python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
$ xcode-select --install xcode-select --install
$ sudo pip install virtualenv sudo easy_install pip
sudo pip install virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
$ virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
$ source ~/.synapse/bin/activate source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. environment under ``~/.synapse``. Feel free to pick a different directory
if you prefer.
In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/. above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse cd ~/.synapse
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
@ -189,9 +193,9 @@ Running Synapse
To actually run your new homeserver, pick a working directory for Synapse to run To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and:: (e.g. ``~/.synapse``), and::
$ cd ~/.synapse cd ~/.synapse
$ source ./bin/activate source ./bin/activate
$ synctl start synctl start
Platform Specific Instructions Platform Specific Instructions
============================== ==============================
@ -209,12 +213,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):: pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
$ sudo pip2.7 install --upgrade pip sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install You also may need to explicitly specify python 2.7 again during the install
request:: request::
$ pip2.7 install --process-dependency-links \ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class: If you encounter an error with lib bcrypt causing an Wrong ELF Class:
@ -222,13 +226,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if compile it under the right architecture. (This should not be needed if
installing under virtualenv):: installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt sudo pip2.7 install py-bcrypt
During setup of Synapse you need to call python2.7 directly again:: During setup of Synapse you need to call python2.7 directly again::
$ cd ~/.synapse cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
@ -276,22 +280,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it:: may need to manually upgrade it::
$ sudo pip install --upgrade pip sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it refuse to run until you remove the temporary installation directory it
created. To reset the installation:: created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.:: failing, e.g.::
$ pip install twisted pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments. will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running Troubleshooting Running
@ -307,10 +311,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl):: (https://github.com/pyca/pynacl)::
$ # Install from PyPI # Install from PyPI
$ pip install --user --upgrade --force pynacl pip install --user --upgrade --force pynacl
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master # Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux ArchLinux
~~~~~~~~~ ~~~~~~~~~
@ -318,7 +323,7 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1', If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as:: you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
...or by editing synctl with the correct python executable. ...or by editing synctl with the correct python executable.
@ -328,16 +333,16 @@ Synapse Development
To check out a synapse for development, clone the git repo into a working To check out a synapse for development, clone the git repo into a working
directory of your choice:: directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git git clone https://github.com/matrix-org/synapse.git
$ cd synapse cd synapse
Synapse has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
$ virtualenv env virtualenv env
$ source env/bin/activate source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock pip install setuptools_trial mock
This will run a process of downloading and installing all the needed This will run a process of downloading and installing all the needed
dependencies into a virtual env. dependencies into a virtual env.
@ -345,7 +350,7 @@ dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
$ python setup.py test python setup.py test
This should end with a 'PASSED' result:: This should end with a 'PASSED' result::
@ -386,11 +391,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter:: --server-name parameter::
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process. Alternatively, you can run ``synctl start`` to guide you through the process.
@ -407,11 +412,11 @@ record would then look something like::
At this point, you should then run the homeserver with the hostname of this At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have:: SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \ --server-name YOURDOMAIN \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to
@ -425,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at ``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run:: http://localhost:8080. Simply run::
$ demo/start.sh demo/start.sh
This is mainly useful just for development purposes. This is mainly useful just for development purposes.
@ -499,10 +504,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and Before building internal API documentation install sphinx and
sphinxcontrib-napoleon:: sphinxcontrib-napoleon::
$ pip install sphinx pip install sphinx
$ pip install sphinxcontrib-napoleon pip install sphinxcontrib-napoleon
Building internal API documentation:: Building internal API documentation::
$ python setup.py build_sphinx python setup.py build_sphinx

View File

@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
exit 1 exit 1
fi fi
find "$DIR" -name "*.log" -delete for port in 8080 8081 8082; do
find "$DIR" -name "*.db" -delete rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
rm -rf $DIR/etc rm -rf $DIR/etc

View File

@ -8,14 +8,6 @@ cd "$DIR/.."
mkdir -p demo/etc mkdir -p demo/etc
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
export PYTHONPATH=$(readlink -f $(pwd)) export PYTHONPATH=$(readlink -f $(pwd))
@ -35,6 +27,15 @@ for port in 8080 8081 8082; do
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
-D \ -D \

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.9.1" __version__ = "0.9.3"

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
AuthEventTypes = ( AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
) )
@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
@ -187,6 +192,9 @@ class Auth(object):
join_rule = JoinRules.INVITE join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events) user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default? # FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50) ban_level = self._get_named_level(auth_events, "ban", 50)
@ -258,12 +266,12 @@ class Auth(object):
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50) kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level: if user_level < kick_level or user_level <= target_level:
raise AuthError( raise AuthError(
403, "You cannot kick user %s." % target_user_id 403, "You cannot kick user %s." % target_user_id
) )
elif Membership.BAN == membership: elif Membership.BAN == membership:
if user_level < ban_level: if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban") raise AuthError(403, "You don't have permission to ban")
else: else:
raise AuthError(500, "Unknown membership %s" % membership) raise AuthError(500, "Unknown membership %s" % membership)
@ -316,7 +324,7 @@ class Auth(object):
Returns: Returns:
tuple : of UserID and device string: tuple : of UserID and device string:
User ID object of the user making the request User ID object of the user making the request
Client ID object of the client instance the user is using ClientInfo object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -349,7 +357,7 @@ class Auth(object):
) )
return return
except KeyError: except KeyError:
pass # normal users won't have this query parameter set pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
@ -370,6 +378,8 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
request.authenticated_entity = user.to_string()
defer.returnValue((user, ClientInfo(device_id, token_id))) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
@ -516,7 +526,6 @@ class Auth(object):
# Check state_key # Check state_key
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
if not event.state_key.startswith("_"):
if event.state_key.startswith("@"): if event.state_key.startswith("@"):
if event.state_key != event.user_id: if event.state_key != event.user_id:
raise AuthError( raise AuthError(
@ -571,25 +580,26 @@ class Auth(object):
# Check other levels: # Check other levels:
levels_to_check = [ levels_to_check = [
("users_default", []), ("users_default", None),
("events_default", []), ("events_default", None),
("ban", []), ("state_default", None),
("redact", []), ("ban", None),
("kick", []), ("redact", None),
("invite", []), ("kick", None),
("invite", None),
] ]
old_list = current_state.content.get("users") old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()): for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append( levels_to_check.append(
(user, ["users"]) (user, "users")
) )
old_list = current_state.content.get("events") old_list = current_state.content.get("events")
new_list = event.content.get("events") new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()): for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append( levels_to_check.append(
(ev_id, ["events"]) (ev_id, "events")
) )
old_state = current_state.content old_state = current_state.content
@ -597,12 +607,10 @@ class Auth(object):
for level_to_check, dir in levels_to_check: for level_to_check, dir in levels_to_check:
old_loc = old_state old_loc = old_state
for d in dir:
old_loc = old_loc.get(d, {})
new_loc = new_state new_loc = new_state
for d in dir: if dir:
new_loc = new_loc.get(d, {}) old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc: if level_to_check in old_loc:
old_level = int(old_loc[level_to_check]) old_level = int(old_loc[level_to_check])
@ -618,6 +626,14 @@ class Auth(object):
if new_level == old_level: if new_level == old_level:
continue continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level: if old_level > user_level or new_level > user_level:
raise AuthError( raise AuthError(
403, 403,

View File

@ -75,6 +75,8 @@ class EventTypes(object):
Redaction = "m.room.redaction" Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback" Feedback = "m.room.message.feedback"
RoomHistoryVisibility = "m.room.history_visibility"
# These are used for validation # These are used for validation
Message = "m.room.message" Message = "m.room.message"
Topic = "m.room.topic" Topic = "m.room.topic"
@ -85,3 +87,8 @@ class RejectedReason(object):
AUTH_ERROR = "auth_error" AUTH_ERROR = "auth_error"
REPLACED = "replaced" REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor" NOT_ANCESTOR = "not_ancestor"
class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat"

View File

@ -34,8 +34,7 @@ from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory from twisted.web.server import Site, GzipEncoderFactory, Request
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@ -61,11 +60,13 @@ import twisted.manhole.telnet
import synapse import synapse
import contextlib
import logging import logging
import os import os
import re import re
import resource import resource
import subprocess import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
@ -87,10 +88,10 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
return gz_wrap(ClientV1RestResource(self)) return ClientV1RestResource(self)
def build_resource_for_client_v2_alpha(self): def build_resource_for_client_v2_alpha(self):
return gz_wrap(ClientV2AlphaRestResource(self)) return ClientV2AlphaRestResource(self)
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
@ -113,7 +114,7 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
self, self.upload_dir, self.auth, self.content_addr self, self.config.uploads_path, self.auth, self.content_addr
) )
def build_resource_for_media_repository(self): def build_resource_for_media_repository(self):
@ -139,152 +140,105 @@ class SynapseHomeServer(HomeServer):
**self.db_config.get("args", {}) **self.db_config.get("args", {})
) )
def create_resource_tree(self, redirect_root_to_web_client): def _listener_http(self, config, listener_config):
"""Create the resource tree for this Home Server. port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
This in unduly complicated because Twisted does not support putting if tls and config.no_tls:
child resources more than 1 level deep at a time. return
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
config = self.get_config()
web_client = config.web_client
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
]
if web_client:
logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX,
self.get_resource_for_web_client()))
if web_client and redirect_root_to_web_client:
self.root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
self.root_resource = Resource()
metrics_resource = self.get_resource_for_metrics() metrics_resource = self.get_resource_for_metrics()
if config.metrics_port is None and metrics_resource is not None:
desired_tree.append((METRICS_PREFIX, metrics_resource))
# ideally we'd just use getChild and putChild but getChild doesn't work resources = {}
# unless you give it a Request object IN ADDITION to the name :/ So for res in listener_config["resources"]:
# instead, we'll store a copy of this mapping so we can actually add for name in res["names"]:
# extra resources to existing nodes. See self._resource_id for the key. if name == "client":
resource_mappings = {} if res["compress"]:
for full_path, res in desired_tree: client_v1 = gz_wrap(self.get_resource_for_client())
logger.info("Attaching %s to path %s", res, full_path) client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = self._resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else: else:
# we have an existing Resource, use that instead. client_v1 = self.get_resource_for_client()
res_id = self._resource_id(last_resource, path_seg) client_v2 = self.get_resource_for_client_v2_alpha()
last_resource = resource_mappings[res_id]
# =========================== resources.update({
# now attach the actual desired resource CLIENT_PREFIX: client_v1,
last_path_seg = full_path.split('/')[-1] CLIENT_V2_ALPHA_PREFIX: client_v2,
})
# if there is already a resource here, thieve its children and if name == "federation":
# replace it resources.update({
res_id = self._resource_id(last_resource, last_path_seg) FEDERATION_PREFIX: self.get_resource_for_federation(),
if res_id in resource_mappings: })
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = self._resource_id(existing_dummy_resource,
child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place if name in ["static", "client"]:
last_resource.putChild(last_path_seg, res) resources.update({
res_id = self._resource_id(last_resource, last_path_seg) STATIC_PREFIX: self.get_resource_for_static_content(),
resource_mappings[res_id] = res })
return self.root_resource if name in ["media", "federation", "client"]:
resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
})
def _resource_id(self, resource, path_seg): if name in ["keys", "federation"]:
"""Construct an arbitrary resource ID so you can retrieve the mapping resources.update({
later. SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
})
If you want to represent resource A putChild resource B with path C, if name == "webclient":
the mapping should looks like _resource_id(A,C) = B. resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
Args: if name == "metrics" and metrics_resource:
resource (Resource): The *parent* Resource resources[METRICS_PREFIX] = metrics_resource
path_seg (str): The name of the child Resource to be attached.
Returns: root_resource = create_resource_tree(resources)
str: A unique string which can be a key to the child Resource. if tls:
""" reactor.listenSSL(
return "%s-%s" % (resource, path_seg) port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_context_factory,
interface=bind_address
)
else:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse now listening on port %d", port)
def start_listening(self): def start_listening(self):
config = self.get_config() config = self.get_config()
if not config.no_tls and config.bind_port is not None: for listener in config.listeners:
reactor.listenSSL( if listener["type"] == "http":
config.bind_port, self._listener_http(config, listener)
SynapseSite( elif listener["type"] == "manhole":
"synapse.access.https", f = twisted.manhole.telnet.ShellFactory()
config, f.username = "matrix"
self.root_resource, f.password = "rabbithole"
), f.namespace['hs'] = self
self.tls_context_factory,
interface=config.bind_host
)
logger.info("Synapse now listening on port %d", config.bind_port)
if config.unsecure_port is not None:
reactor.listenTCP( reactor.listenTCP(
config.unsecure_port, listener["port"],
SynapseSite( f,
"synapse.access.http", interface=listener.get("bind_address", '127.0.0.1')
config,
self.root_resource,
),
interface=config.bind_host
)
logger.info("Synapse now listening on port %d", config.unsecure_port)
metrics_resource = self.get_resource_for_metrics()
if metrics_resource and config.metrics_port is not None:
reactor.listenTCP(
config.metrics_port,
SynapseSite(
"synapse.access.metrics",
config,
metrics_resource,
),
interface=config.metrics_bind_host,
)
logger.info(
"Metrics now running on %s port %d",
config.metrics_bind_host, config.metrics_port,
) )
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine): def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain( all_users_native = are_all_users_on_domain(
@ -419,11 +373,6 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:
domain_with_port = "%s:%s" % (config.server_name, config.bind_port)
tls_context_factory = context_factory.ServerContextFactory(config) tls_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"]) database_engine = create_engine(config.database_config["name"])
@ -431,8 +380,6 @@ def setup(config_options):
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"),
db_config=config.database_config, db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
@ -441,10 +388,6 @@ def setup(config_options):
database_engine=database_engine, database_engine=database_engine,
) )
hs.create_resource_tree(
redirect_root_to_web_client=True,
)
logger.info("Preparing database: %r...", config.database_config) logger.info("Preparing database: %r...", config.database_config)
try: try:
@ -469,13 +412,6 @@ def setup(config_options):
logger.info("Database prepared in %r.", config.database_config) logger.info("Database prepared in %r.", config.database_config)
if config.manhole:
f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix"
f.password = "rabbithole"
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
hs.start_listening() hs.start_listening()
hs.get_pusherpool().start() hs.get_pusherpool().start()
@ -501,22 +437,194 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return re.sub(
r'(\?.*access_token=)[^&]*(.*)$',
r'\1<redacted>\2',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site): class SynapseSite(Site):
""" """
Subclass of a twisted http Site that does access logging with python's Subclass of a twisted http Site that does access logging with python's
standard logging standard logging
""" """
def __init__(self, logger_name, config, resource, *args, **kwargs): def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs) Site.__init__(self, resource, *args, **kwargs)
if config.captcha_ip_origin_is_x_forwarded:
self._log_formatter = proxiedLogFormatter self.site_tag = site_tag
else:
self._log_formatter = combinedLogFormatter proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name) self.access_logger = logging.getLogger(logger_name)
def log(self, request): def log(self, request):
line = self._log_formatter(self._logDateTime, request) pass
self.access_logger.info(line)
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resource
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def run(hs): def run(hs):
@ -549,6 +657,7 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
if hs.config.print_pidfile:
print hs.config.pid_file print hs.config.pid_file
daemon = Daemonize( daemon = Daemonize(

View File

@ -138,47 +138,37 @@ class Config(object):
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument( config_parser.add_argument(
"-H", "--server-name", "-H", "--server-name",
help="The server name to generate a config file for" help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
if config_args.generate_config: if config_args.generate_config:
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
(config_path,) = config_args.config_path
config_dir_path = os.path.dirname(config_args.config_path[0]) if not os.path.exists(config_path):
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name server_name = config_args.server_name
if not server_name: if not server_name:
print "Must specify a server_name to a generate config for." print "Must specify a server_name to a generate config for."
sys.exit(1) sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path): if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
if os.path.exists(config_path):
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path)
yaml_name = yaml_config["server_name"]
if server_name != yaml_name:
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
sys.exit(1)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
config.update(yaml_config)
print "Generating any missing keys for %r" % (server_name,)
obj.invoke_all("generate_files", config)
sys.exit(0)
with open(config_path, "wb") as config_file: with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config( config_bytes, config = obj.generate_config(
config_dir_path, server_name config_dir_path, server_name
@ -186,16 +176,22 @@ class Config(object):
obj.invoke_all("generate_files", config) obj.invoke_all("generate_files", config)
config_file.write(config_bytes) config_file.write(config_bytes)
print ( print (
"A config file has been generated in %s for server name" "A config file has been generated in %r for server name"
" '%s' with corresponding SSL keys and self-signed" " %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to" " certificates. Please review this file and customise it"
" your needs." " to your needs."
) % (config_path, server_name) ) % (config_path, server_name)
print ( print (
"If this server name is incorrect, you will need to regenerate" "If this server name is incorrect, you will need to"
" the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) sys.exit(0)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
parents=[config_parser], parents=[config_parser],
@ -209,11 +205,11 @@ class Config(object):
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
config_dir_path = os.path.dirname(config_args.config_path[0]) config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
specified_config = {} specified_config = {}
@ -226,6 +222,10 @@ class Config(object):
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args) obj.invoke_all("read_arguments", args)

View File

@ -21,10 +21,6 @@ class CaptchaConfig(Config):
self.recaptcha_private_key = config["recaptcha_private_key"] self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"] self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"] self.enable_registration_captcha = config["enable_registration_captcha"]
# XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = (
config["captcha_ip_origin_is_x_forwarded"]
)
self.captcha_bypass_secret = config.get("captcha_bypass_secret") self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"] self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
@ -33,20 +29,16 @@ class CaptchaConfig(Config):
## Captcha ## ## Captcha ##
# This Home Server's ReCAPTCHA public key. # This Home Server's ReCAPTCHA public key.
recaptcha_private_key: "YOUR_PUBLIC_KEY" recaptcha_private_key: "YOUR_PRIVATE_KEY"
# This Home Server's ReCAPTCHA private key. # This Home Server's ReCAPTCHA private key.
recaptcha_public_key: "YOUR_PRIVATE_KEY" recaptcha_public_key: "YOUR_PUBLIC_KEY"
# Enables ReCaptcha checks when registering, preventing signup # Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha # unless a captcha is answered. Requires a valid ReCaptcha
# public/private key. # public/private key.
enable_registration_captcha: False enable_registration_captcha: False
# When checking captchas, use the X-Forwarded-For (XFF) header
# as the client IP and not the actual client IP.
captcha_ip_origin_is_x_forwarded: False
# A secret key used to bypass the captcha test entirely. # A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE" #captcha_bypass_secret: "YOUR_SECRET_HERE"

View File

@ -25,12 +25,13 @@ from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, VoipConfig, RegistrationConfig, MetricsConfig,
MetricsConfig, AppServiceConfig, KeyConfig,): AppServiceConfig, KeyConfig, SAML2Config, ):
pass pass

View File

@ -28,10 +28,4 @@ class MetricsConfig(Config):
# Enable collection and rendering of performance metrics # Enable collection and rendering of performance metrics
enable_metrics: False enable_metrics: False
# Separate port to accept metrics requests on
# metrics_port: 8081
# Which host to bind the metric listener to
# metrics_bind_host: 127.0.0.1
""" """

View File

@ -21,13 +21,18 @@ class ContentRepositoryConfig(Config):
self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"])
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
return """ return """
# Directory where uploaded images and attachments are stored. # Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s" media_store_path: "%(media_store)s"
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
# The largest allowed upload size in bytes # The largest allowed upload size in bytes
max_upload_size: "10M" max_upload_size: "10M"

54
synapse/config/saml2.py Normal file
View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Ericsson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class SAML2Config(Config):
"""SAML2 Configuration
Synapse uses pysaml2 libraries for providing SAML2 support
config_path: Path to the sp_conf.py configuration file
idp_redirect_url: Identity provider URL which will redirect
the user back to /login/saml2 with proper info.
sp_conf.py file is something like:
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
More information: https://pythonhosted.org/pysaml2/howto/config.html
"""
def read_config(self, config):
saml2_config = config.get("saml2_config", None)
if saml2_config:
self.saml2_enabled = True
self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else:
self.saml2_enabled = False
self.saml2_config_path = None
self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name):
return """
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file
# idp_redirect_url: Identity provider URL which will redirect
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name)

View File

@ -20,25 +20,98 @@ class ServerConfig(Config):
def read_config(self, config): def read_config(self, config):
self.server_name = config["server_name"] self.server_name = config["server_name"]
self.bind_port = config["bind_port"]
self.bind_host = config["bind_host"]
self.unsecure_port = config["unsecure_port"]
self.manhole = config.get("manhole")
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", [])
bind_port = config.get("bind_port")
if bind_port:
self.listeners = []
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
names = ["client", "webclient"] if self.web_client else ["client"]
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"tls": True,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"tls": False,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
manhole = config.get("manhole")
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"type": "manhole",
})
metrics_port = config.get("metrics_port")
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"tls": False,
"type": "http",
"resources": [
{
"names": ["metrics"],
"compress": False,
},
]
})
# Attempt to guess the content_addr for the v0 content repostitory # Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr") content_addr = config.get("content_addr")
if not content_addr: if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name host = self.server_name
if ':' not in host: if ':' not in host:
host = "%s:%d" % (host, self.unsecure_port) host = "%s:%d" % (host, unsecure_port)
else: else:
host = host.split(':')[0] host = host.split(':')[0]
host = "%s:%d" % (host, self.unsecure_port) host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,) content_addr = "http://%s" % (host,)
self.content_addr = content_addr self.content_addr = content_addr
@ -60,18 +133,6 @@ class ServerConfig(Config):
# e.g. matrix.org, localhost:8080, etc. # e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s" server_name: "%(server_name)s"
# The port to listen for HTTPS requests on.
# For when matrix traffic is sent directly to synapse.
bind_port: %(bind_port)s
# The port to listen for HTTP requests on.
# For when matrix traffic passes through loadbalancer that unwraps TLS.
unsecure_port: %(unsecure_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_host: ""
# When running as a daemon, the file to store the pid in # When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s pid_file: %(pid_file)s
@ -83,9 +144,64 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
# Main HTTPS listener
# For when matrix traffic is sent directly to synapse.
-
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_address: ''
# This is a 'http' listener, allows us to specify 'resources'.
type: http
tls: true
# Use the X-Forwarded-For (XFF) header as the client IP and not the
# actual client IP.
x_forwarded: false
# List of HTTP resources to serve on this listener.
resources:
-
# List of resources to host on this listener.
names:
- client # The client-server APIs, both v1 and v2
- webclient # The bundled webclient.
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
compress: true
- names: [federation] # Federation APIs
compress: false
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_address: ''
type: http
x_forwarded: false
resources:
- names: [client, webclient]
compress: true
- names: [federation]
compress: false
# Turn on the twisted telnet manhole service on localhost on the given # Turn on the twisted telnet manhole service on localhost on the given
# port. # port.
#manhole: 9000 # - port: 9000
# bind_address: 127.0.0.1
# type: manhole
""" % locals() """ % locals()
def read_arguments(self, args): def read_arguments(self, args):
@ -93,12 +209,18 @@ class ServerConfig(Config):
self.manhole = args.manhole self.manhole = args.manhole
if args.daemonize is not None: if args.daemonize is not None:
self.daemonize = args.daemonize self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser): def add_arguments(self, parser):
server_group = parser.add_argument_group("server") server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true', server_group.add_argument("-D", "--daemonize", action='store_true',
default=None, default=None,
help="Daemonize the home server") help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole", server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"

View File

@ -27,6 +27,7 @@ class TlsConfig(Config):
self.tls_certificate = self.read_tls_certificate( self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path") config.get("tls_certificate_path")
) )
self.tls_certificate_file = config.get("tls_certificate_path")
self.no_tls = config.get("no_tls", False) self.no_tls = config.get("no_tls", False)
@ -49,7 +50,11 @@ class TlsConfig(Config):
tls_dh_params_path = base_key_name + ".tls.dh" tls_dh_params_path = base_key_name + ".tls.dh"
return """\ return """\
# PEM encoded X509 certificate for TLS # PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s" tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS # PEM encoded private key for TLS
@ -91,7 +96,7 @@ class TlsConfig(Config):
) )
if not os.path.exists(tls_certificate_path): if not os.path.exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certifcate_file: with open(tls_certificate_path, "w") as certificate_file:
cert = crypto.X509() cert = crypto.X509()
subject = cert.get_subject() subject = cert.get_subject()
subject.CN = config["server_name"] subject.CN = config["server_name"]
@ -106,7 +111,7 @@ class TlsConfig(Config):
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certifcate_file.write(cert_pem) certificate_file.write(cert_pem)
if not os.path.exists(tls_dh_params_path): if not os.path.exists(tls_dh_params_path):
if GENERATE_DH_PARAMS: if GENERATE_DH_PARAMS:

View File

@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
_ecCurve = _OpenSSLECCurve(_defaultCurveName) _ecCurve = _OpenSSLECCurve(_defaultCurveName)
_ecCurve.addECKeyToContext(context) _ecCurve.addECKeyToContext(context)
except: except:
logger.exception("Failed to enable eliptic curve for TLS") logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(config.tls_certificate) context.use_certificate_chain_file(config.tls_certificate_file)
if not config.no_tls: if not config.no_tls:
context.use_privatekey(config.tls_private_key) context.use_privatekey(config.tls_private_key)

View File

@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from OpenSSL import crypto from OpenSSL import crypto
from collections import namedtuple
import urllib import urllib
import hashlib import hashlib
import logging import logging
@ -38,6 +40,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -49,18 +54,55 @@ class Keyring(object):
self.key_downloads = {} self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verfies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
raise SynapseError( deferreds[group_id] = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) ))
else:
deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try: try:
verify_key = yield self.get_server_verify_key(server_name, key_ids) _, _, key_id, verify_key = yield deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@ -72,7 +114,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
except Exception as e: except Exception as e:
logger.warn( logger.exception(
"Got Exception when downloading keys for %s: %s %s", "Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message), server_name, type(e).__name__, str(e.message),
) )
@ -82,6 +124,8 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = group_id_to_json[group.group_id]
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except: except:
@ -93,79 +137,208 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
@defer.inlineCallbacks server_to_deferred = {
def get_server_verify_key(self, server_name, key_ids): server_name: defer.Deferred()
"""Finds a verification key for the server with one of the key ids. for server_name, _ in server_and_json
Trys to fetch the key from a trusted perspective server first. }
Args:
server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached: # We want to wait for any previous lookups to complete before
defer.returnValue(cached[0]) # proceeding.
return wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
download = self.key_downloads.get(server_name) server_to_deferred,
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
download = ObservableDeferred(
download,
consumeErrors=True
) )
self.key_downloads[server_name] = download
@download.addBoth # Actually start fetching keys.
def callback(ret): wait_on_deferred.addBoth(
del self.key_downloads[server_name] lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
return ret )
r = yield download.observe() # When we've finished fetching all the keys for a given server_name,
defer.returnValue(r) # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
server_to_deferred.pop(server_name).callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
]
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids): def wait_for_previous_lookups(self, server_names, server_to_deferred):
keys = None """Waits for any previous key lookups for the given servers to finish.
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server
"""
while True:
wait_on = [
self.key_downloads[server_name]
for server_name in server_names
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred:
self.key_downloads[server_name] = ObservableDeferred(deferred)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
missing_keys = {
group.server_name: key_id
for group in group_id_to_group.values()
for key_id in group.key_ids
}
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(zip(
[server_name for server_name, _ in server_name_and_key_ids],
res
)))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys server_name_and_key_ids, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except Exception as e: except Exception as e:
logging.info( logger.exception(
"Unable to getting key %r for %r from %r: %s %s", "Unable to get key from %r: %s %s",
key_ids, server_name, perspective_name, perspective_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
defer.returnValue({})
perspective_results = yield defer.gatherResults([ results = yield defer.gatherResults(
[
get_key(p_name, p_keys) get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
]) ],
consumeErrors=True,
).addErrback(unwrapFirstError)
for results in perspective_results: union_of_keys = {}
if results is not None: for result in results:
keys = results for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
server_name, server_name,
self.clock, self.clock,
self.store, self.store,
) )
with limiter: with limiter:
if not keys: keys = None
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except Exception as e: except Exception as e:
logging.info( logger.info(
"Unable to getting key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
@ -176,14 +349,30 @@ class Keyring(object):
server_name, key_ids server_name, key_ids
) )
for key_id in key_ids: keys = {server_name: keys}
if key_id in keys:
defer.returnValue(keys[key_id]) defer.returnValue(keys)
return
raise ValueError("No verification key found for given key ids") results = yield defer.gatherResults(
[
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids, def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name, perspective_name,
perspective_keys): perspective_keys):
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
@ -204,6 +393,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0 u"minimum_valid_until_ts": 0
} for key_id in key_ids } for key_id in key_ids
} }
for server_name, key_ids in server_names_and_key_ids
} }
}, },
) )
@ -243,23 +433,29 @@ class Keyring(object):
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
response_keys = yield self.process_v2_response( processed_response = yield self.process_v2_response(
server_name, perspective_name, response perspective_name, response
) )
keys.update(response_keys) for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield self.store_keys( yield defer.gatherResults(
[
self.store_keys(
server_name=server_name, server_name=server_name,
from_server=perspective_name, from_server=perspective_name,
verify_keys=keys, verify_keys=response_keys,
) )
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} keys = {}
for requested_key_id in key_ids: for requested_key_id in key_ids:
@ -295,25 +491,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name, from_server=server_name,
requested_id=requested_key_id, requested_ids=[requested_key_id],
response_json=response, response_json=response,
) )
keys.update(response_keys) keys.update(response_keys)
yield self.store_keys( yield defer.gatherResults(
server_name=server_name, [
self.store_keys(
server_name=key_server_name,
from_server=server_name, from_server=server_name,
verify_keys=keys, verify_keys=verify_keys,
) )
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_id=None): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -335,6 +536,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(
@ -357,17 +560,16 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"] ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set() updated_key_ids = set(requested_ids)
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys) updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys) updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
for key_id in updated_key_ids: yield defer.gatherResults(
yield self.store.store_server_keys_json( [
self.store.store_server_keys_json(
server_name=server_name, server_name=server_name,
key_id=key_id, key_id=key_id,
from_server=server_name, from_server=server_name,
@ -375,10 +577,14 @@ class Keyring(object):
ts_expires_ms=ts_valid_until_ms, ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes, key_json_bytes=signed_key_json_bytes,
) )
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(response_keys) results[server_name] = response_keys
raise ValueError("No verification key found for given key ids") defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids): def get_server_verify_key_v1_direct(self, server_name, key_ids):
@ -462,8 +668,13 @@ class Keyring(object):
Returns: Returns:
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
for key_id, key in verify_keys.items():
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key( yield defer.gatherResults(
[
self.store.store_server_verify_key(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
) )
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)

View File

@ -74,6 +74,8 @@ def prune_event(event):
) )
elif event_type == EventTypes.Aliases: elif event_type == EventTypes.Aliases:
add_fields("aliases") add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
allowed_fields = { allowed_fields = {
k: v k: v

View File

@ -18,8 +18,6 @@ from twisted.internet import defer
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -34,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False): def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each """Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of the database and if not then request if from the originating server of
@ -52,85 +51,108 @@ class FederationBase(object):
Returns: Returns:
Deferred : A list of PDUs that have valid signatures and hashes. Deferred : A list of PDUs that have valid signatures and hashes.
""" """
deferreds = self._check_sigs_and_hashes(pdus)
signed_pdus = [] def callback(pdu):
return pdu
@defer.inlineCallbacks def errback(failure, pdu):
def do(pdu): failure.trap(SynapseError)
try: return None
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
def try_local_db(res, pdu):
if not res:
# Check local db. # Check local db.
new_pdu = yield self.store.get_event( return self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
if new_pdu: return res
signed_pdus.append(new_pdu)
return
# Check pdu.origin def try_remote(res, pdu):
if pdu.origin != origin: if not res and pdu.origin != origin:
try: return self.get_pdu(
new_pdu = yield self.get_pdu(
destinations=[pdu.origin], destinations=[pdu.origin],
event_id=pdu.event_id, event_id=pdu.event_id,
outlier=outlier, outlier=outlier,
timeout=10000, timeout=10000,
) ).addErrback(lambda e: None)
return res
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
def warn(res, pdu):
if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
yield defer.gatherResults( for pdu, deferred in zip(pdus, deferreds):
[do(pdu) for pdu in pdus], deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
valid_pdus = yield defer.gatherResults(
deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
signatures. signatures.
Returns: Returns:
FrozenEvent: Either the given event or it redacted if it failed the FrozenEvent: Either the given event or it redacted if it failed the
content hash check. content hash check.
""" """
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try: redacted_pdus = [
yield self.keyring.verify_json_for_server( prune_event(pdu)
pdu.origin, redacted_pdu_json for pdu in pdus
) ]
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu): if not check_event_content_hash(pdu):
logger.warn( logger.warn(
"Event content has been tampered, redacting %s, %s", "Event content has been tampered, redacting %s: %s",
pdu.event_id, encode_canonical_json(pdu.get_dict()) pdu.event_id, pdu.get_pdu_json()
) )
defer.returnValue(redacted_event) return redacted
return pdu
defer.returnValue(pdu) def errback(failure, pdu):
failure.trap(SynapseError)
logger.warn(
"Signature check failed for %s",
pdu.event_id,
)
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
)
return deferreds

View File

@ -30,6 +30,7 @@ import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import copy
import itertools import itertools
import logging import logging
import random import random
@ -167,7 +168,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus], self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -230,7 +231,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -327,6 +328,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id): def make_join(self, destinations, room_id, user_id):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_join(
destination, room_id, user_id destination, room_id, user_id
@ -353,6 +357,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
@ -374,17 +381,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
signed_state, signed_auth = yield defer.gatherResults( pdus = {
[ p.event_id: p
self._check_sigs_and_hash_and_fetch( for p in itertools.chain(state, auth_chain)
destination, state, outlier=True }
),
self._check_sigs_and_hash_and_fetch( valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination, pdus.values(),
outlier=True,
) )
],
consumeErrors=True valid_pdus_map = {
).addErrback(unwrapFirstError) p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -396,7 +425,7 @@ class FederationClient(FederationBase):
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
destination, e.message destination, e.message
) )

View File

@ -93,6 +93,9 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request) yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue((origin, content))
@log_function @log_function

View File

@ -32,6 +32,7 @@ from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler
class Handlers(object): class Handlers(object):
@ -57,6 +58,7 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs) asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler( self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler( hs, asapi, AppServiceScheduler(

View File

@ -78,7 +78,9 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder) context = yield state_handler.compute_event_context(builder)
if builder.is_state(): if builder.is_state():
builder.prev_state = context.prev_state_events builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
yield self.auth.add_auth_events(builder, context) yield self.auth.add_auth_events(builder, context)

View File

@ -177,7 +177,7 @@ class ApplicationServicesHandler(object):
return return
user_info = yield self.store.get_user_by_id(user_id) user_info = yield self.store.get_user_by_id(user_id)
if not user_info: if user_info:
defer.returnValue(False) defer.returnValue(False)
return return

View File

@ -85,8 +85,10 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse # email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects # because it lets unauthenticated clients store arbitrary objects
# on a home server. # on a home server.
# sess['clientdict'] = clientdict # Revisit: Assumimg the REST APIs do sensible validation, the data
# self._save_session(sess) # isn't arbintrary.
sess['clientdict'] = clientdict
self._save_session(sess)
pass pass
elif 'clientdict' in sess: elif 'clientdict' in sess:
clientdict = sess['clientdict'] clientdict = sess['clientdict']

View File

@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer from twisted.internet import defer
@ -138,25 +140,28 @@ class FederationHandler(BaseHandler):
if state and auth_chain is not None: if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication # If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.) # layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state): for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids: if e.event_id in seen_ids:
continue continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try:
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { auth = {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( event_infos.append({
origin, e, auth_events=auth "event": e,
) "auth_events": auth,
})
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
except:
logger.exception( yield self._handle_new_events(
"Failed to handle state event %s", origin,
e.event_id, event_infos,
outliers=True
) )
try: try:
@ -222,6 +227,56 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
states = yield self.store.get_state_for_events(
room_id, [e.event_id for e in events],
)
events_and_states = zip(events, states)
def redact_disallowed(event_and_state):
event, state = event_and_state
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.values():
if ev.type != EventTypes.Member:
continue
try:
domain = UserID.from_string(ev.state_key).domain
except:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
return prune_event(event)
return event
res = map(redact_disallowed, events_and_states)
logger.info("_filter_events_for_server %r", res)
defer.returnValue(res)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities=[]):
@ -247,9 +302,15 @@ class FederationHandler(BaseHandler):
if set(e_id for e_id, _ in ev.prev_events) - event_ids if set(e_id for e_id, _ in ev.prev_events) - event_ids
] ]
logger.info(
"backfill: Got %d events with %d edges",
len(events), len(edges),
)
# For each edge get the current state. # For each edge get the current state.
auth_events = {} auth_events = {}
state_events = {}
events_to_state = {} events_to_state = {}
for e_id in edges: for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room( state, auth = yield self.replication_layer.get_state_for_room(
@ -258,27 +319,57 @@ class FederationHandler(BaseHandler):
event_id=e_id event_id=e_id
) )
auth_events.update({a.event_id: a for a in auth}) auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
yield defer.gatherResults( seen_events = yield self.store.have_events(
[ set(auth_events.keys()) | set(state_events.keys())
self._handle_new_event(dest, a)
for a in auth_events.values()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield defer.gatherResults(
[
self._handle_new_event(
dest, event_map[e_id],
state=events_to_state[e_id],
backfilled=True,
) )
for e_id in events_to_state
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
], ],
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = []
for a in auth_events.values():
if a.event_id in seen_events:
continue
ev_infos.append({
"event": a,
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
}
})
for e_id in events_to_state:
ev_infos.append({
"event": event_map[e_id],
"state": events_to_state[e_id],
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
}
})
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -286,8 +377,12 @@ class FederationHandler(BaseHandler):
if event in events_to_state: if event in events_to_state:
continue continue
yield self._handle_new_event( ev_infos.append({
dest, event, "event": event,
})
yield self._handle_new_events(
dest, ev_infos,
backfilled=True, backfilled=True,
) )
@ -555,32 +650,22 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
yield self._handle_auth_events( ev_infos = []
origin, [e for e in auth_chain if e.event_id != event.event_id] for e in itertools.chain(state, auth_chain):
)
@defer.inlineCallbacks
def handle_state(e):
if e.event_id == event.event_id: if e.event_id == event.event_id:
return continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try:
auth_ids = [e_id for e_id, _ in e.auth_events] auth_ids = [e_id for e_id, _ in e.auth_events]
auth = { ev_infos.append({
"event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( })
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
yield defer.DeferredList([handle_state(e) for e in state]) yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events] auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = { auth_events = {
@ -837,6 +922,8 @@ class FederationHandler(BaseHandler):
limit limit
) )
events = yield self._filter_events_for_server(origin, room_id, events)
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -895,25 +982,63 @@ class FederationHandler(BaseHandler):
def _handle_new_event(self, origin, event, state=None, backfilled=False, def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None): current_state=None, auth_events=None):
logger.debug( outlier = event.internal_metadata.is_outlier()
"_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures, context = yield self._prep_event(
origin, event,
state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events,
) )
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(not outlier and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
contexts = yield defer.gatherResults(
[
self._prep_event(
origin,
ev_info["event"],
state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
]
)
yield self.store.persist_events(
[
(ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
event, old_state=state event, old_state=state, outlier=outlier,
) )
if not auth_events: if not auth_events:
auth_events = context.current_state auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_events:
@ -937,26 +1062,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't defer.returnValue(context)
# seen all the auth events.
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@ -1019,14 +1125,24 @@ class FederationHandler(BaseHandler):
@log_function @log_function
def do_auth(self, origin, event, context, auth_events): def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events. # Check if we have all the auth events.
have_events = yield self.store.have_events( current_state = set(e.event_id for e in auth_events.values())
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
)
else:
have_events = {}
have_events.update({
e.event_id: ""
for e in auth_events.values()
})
seen_events = set(have_events.keys()) seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events missing_auth = event_auth_events - seen_events - current_state
if missing_auth: if missing_auth:
logger.info("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)

View File

@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090'] # trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org'] trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds: if 'id_server' in creds:
id_server = creds['id_server'] id_server = creds['id_server']

View File

@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
"room_key", next_key "room_key", next_key
) )
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield self._filter_events_for_client(user_id, room_id, events)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": [ "chunk": [
serialize_event(e, time_now, as_client_event) for e in events serialize_event(e, time_now, as_client_event)
for e in events
], ],
"start": pagin_config.from_token.to_string(), "start": pagin_config.from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
@ -125,6 +135,52 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
states = yield self.store.get_state_for_events(
room_id, [e.event_id for e in events],
)
events_and_states = zip(events, states)
def allowed(event_and_state):
event, state = event_and_state
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
events_and_states = filter(allowed, events_and_states)
defer.returnValue([
ev
for ev, _ in events_and_states
])
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None): client=None, txn_id=None):
@ -278,6 +334,11 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None user, pagination_config.get_source_config("presence"), None
) )
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -316,6 +377,10 @@ class MessageHandler(BaseHandler):
] ]
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, event.room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -344,7 +409,8 @@ class MessageHandler(BaseHandler):
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,
"end": now_token.to_string() "receipts": receipt,
"end": now_token.to_string(),
} }
defer.returnValue(ret) defer.returnValue(ret)
@ -380,15 +446,6 @@ class MessageHandler(BaseHandler):
if limit is None: if limit is None:
limit = 10 limit = 10
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = [ room_members = [
m for m in current_state.values() m for m in current_state.values()
if m.type == EventTypes.Member if m.type == EventTypes.Member
@ -396,20 +453,46 @@ class MessageHandler(BaseHandler):
] ]
presence_handler = self.hs.get_handlers().presence_handler presence_handler = self.hs.get_handlers().presence_handler
presence = []
for m in room_members: @defer.inlineCallbacks
try: def get_presence():
member_presence = yield presence_handler.get_state( presence_defs = yield defer.DeferredList(
[
presence_handler.get_state(
target_user=UserID.from_string(m.user_id), target_user=UserID.from_string(m.user_id),
auth_user=auth_user, auth_user=auth_user,
as_event=True, as_event=True,
check_auth=False,
) )
presence.append(member_presence) for m in room_members
except SynapseError: ],
logger.exception( consumeErrors=True,
"Failed to get member presence of %r", m.user_id
) )
defer.returnValue([p for success, p in presence_defs if success])
receipts_handler = self.hs.get_handlers().receipts_handler
presence, receipts, (messages, token) = yield defer.gatherResults(
[
get_presence(),
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
defer.returnValue({ defer.returnValue({
@ -421,5 +504,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
}, },
"state": state, "state": state,
"presence": presence "presence": presence,
"receipts": receipts,
}) })

View File

@ -191,8 +191,9 @@ class PresenceHandler(BaseHandler):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False): def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
if check_auth:
visible = yield self.is_presence_visible( visible = yield self.is_presence_visible(
observer_user=auth_user, observer_user=auth_user,
observed_user=target_user observed_user=target_user
@ -200,15 +201,14 @@ class PresenceHandler(BaseHandler):
if not visible: if not visible:
raise SynapseError(404, "Presence information not visible") raise SynapseError(404, "Presence information not visible")
if target_user in self._user_cachemap:
state = self._user_cachemap[target_user].get_state()
else:
state = yield self.store.get_presence_state(target_user.localpart) state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state: if "mtime" in state:
del state["mtime"] del state["mtime"]
state["presence"] = state.pop("state") state["presence"] = state.pop("state")
if target_user in self._user_cachemap:
cached_state = self._user_cachemap[target_user].get_state()
if "last_active" in cached_state:
state["last_active"] = cached_state["last_active"]
else: else:
# TODO(paul): Have remote server send us permissions set # TODO(paul): Have remote server send us permissions set
state = self._get_or_offline_usercache(target_user).get_state() state = self._get_or_offline_usercache(target_user).get_state()
@ -992,7 +992,7 @@ class PresenceHandler(BaseHandler):
room_ids([str]): List of room_ids to notify. room_ids([str]): List of room_ids to notify.
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_event(
"presence_key", "presence_key",
self._user_cachemap_latest_serial, self._user_cachemap_latest_serial,
users_to_push, users_to_push,

View File

@ -0,0 +1,212 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
is_new = yield self._handle_new_receipts([receipt])
if is_new:
self._push_remotes([receipt])
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = [
{
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": user_values["event_ids"],
"data": user_values.get("data", {}),
}
for room_id, room_values in content.items()
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
yield self._handle_new_receipts(receipts)
@defer.inlineCallbacks
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
res = yield self.store.insert_receipt(
room_id, receipt_type, user_id, event_ids, data
)
if not res:
# res will be None if this read receipt is 'old'
defer.returnValue(False)
stream_id, max_persisted_id = res
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_persisted_id, rooms=[room_id]
)
defer.returnValue(True)
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
# TODO: Some of this stuff should be coallesced.
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
}
},
},
)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
room_id,
to_key=to_key,
)
if not result:
defer.returnValue([])
event = {
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
defer.returnValue(([], from_key))
from_key = int(from_key)
to_key = yield self.get_current_key()
if from_key == to_key:
defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))
def get_current_key(self, direction='f'):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
defer.returnValue(([], to_key))
if config.to_key:
from_key = int(config.to_key)
else:
from_key = None
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))

View File

@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be randomly generated. one will be randomly generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -192,6 +193,35 @@ class RegistrationHandler(BaseHandler):
else: else:
logger.info("Valid captcha entered from %s", ip) logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
def register_saml2(self, localpart):
"""
Registers email_id as SAML2 Based Auth.
"""
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id)
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield self.distributor.fire("registered_user", user)
except Exception, e:
yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors
logger.exception(e)
defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def register_email(self, threepidCreds): def register_email(self, threepidCreds):
""" """

View File

@ -19,12 +19,15 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from collections import OrderedDict
import logging import logging
import string import string
@ -33,6 +36,19 @@ logger = logging.getLogger(__name__)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "invited",
"original_invitees_have_ops": False,
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
}
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, user_id, room_id, config): def create_room(self, user_id, room_id, config):
""" Creates a new room. """ Creates a new room.
@ -121,9 +137,25 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname], servers=[self.hs.hostname],
) )
preset_config = config.get(
"preset",
RoomCreationPreset.PUBLIC_CHAT
if is_public
else RoomCreationPreset.PRIVATE_CHAT
)
raw_initial_state = config.get("initial_state", [])
initial_state = OrderedDict()
for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room( creation_events = self._create_events_for_new_room(
user, room_id, is_public=is_public user, room_id,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
@ -170,7 +202,10 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result) defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, is_public=False): def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state):
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.to_string() creator_id = creator.to_string()
event_keys = { event_keys = {
@ -203,9 +238,10 @@ class RoomCreationHandler(BaseHandler):
}, },
) )
power_levels_event = create( returned_events = [creation_event, join_event]
etype=EventTypes.PowerLevels,
content={ if (EventTypes.PowerLevels, '') not in initial_state:
power_level_content = {
"users": { "users": {
creator.to_string(): 100, creator.to_string(): 100,
}, },
@ -213,6 +249,7 @@ class RoomCreationHandler(BaseHandler):
"events": { "events": {
EventTypes.Name: 100, EventTypes.Name: 100,
EventTypes.PowerLevels: 100, EventTypes.PowerLevels: 100,
EventTypes.RoomHistoryVisibility: 100,
}, },
"events_default": 0, "events_default": 0,
"state_default": 50, "state_default": 50,
@ -220,21 +257,43 @@ class RoomCreationHandler(BaseHandler):
"kick": 50, "kick": 50,
"redact": 50, "redact": 50,
"invite": 0, "invite": 0,
}, }
if config["original_invitees_have_ops"]:
for invitee in invite_list:
power_level_content["users"][invitee] = 100
power_levels_event = create(
etype=EventTypes.PowerLevels,
content=power_level_content,
) )
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE returned_events.append(power_levels_event)
if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create( join_rules_event = create(
etype=EventTypes.JoinRules, etype=EventTypes.JoinRules,
content={"join_rule": join_rule}, content={"join_rule": config["join_rules"]},
) )
return [ returned_events.append(join_rules_event)
creation_event,
join_event, if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
power_levels_event, history_event = create(
join_rules_event, etype=EventTypes.RoomHistoryVisibility,
] content={"history_visibility": config["history_visibility"]}
)
returned_events.append(history_event)
for (etype, state_key), content in initial_state.items():
returned_events.append(create(
etype=etype,
state_key=state_key,
content=content,
))
return returned_events
class RoomMemberHandler(BaseHandler): class RoomMemberHandler(BaseHandler):

View File

@ -292,6 +292,51 @@ class SyncHandler(BaseHandler):
next_batch=now_token, next_batch=now_token,
)) ))
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
states = yield self.store.get_state_for_events(
room_id, [e.event_id for e in events],
)
events_and_states = zip(events, states)
def allowed(event_and_state):
event, state = event_and_state
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
events_and_states = filter(allowed, events_and_states)
defer.returnValue([
ev
for ev, _ in events_and_states
])
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None):
@ -313,6 +358,9 @@ class SyncHandler(BaseHandler):
(room_key, _) = keys (room_key, _) = keys
end_key = "s" + room_key.split('-')[-1] end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events) loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), room_id, loaded_recents,
)
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
if len(events) <= load_limit: if len(events) <= load_limit:

View File

@ -218,7 +218,7 @@ class TypingNotificationHandler(BaseHandler):
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id] "typing_key", self._latest_room_serial, rooms=[room_id]
) )

View File

@ -61,21 +61,31 @@ class SimpleHttpClient(object):
self.agent = Agent(reactor, pool=pool) self.agent = Agent(reactor, pool=pool)
self.version_string = hs.version_string self.version_string = hs.version_string
def request(self, method, *args, **kwargs): def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach # A small wrapper around self.agent.request() so we can easily attach
# counters to it # counters to it
outgoing_requests_counter.inc(method) outgoing_requests_counter.inc(method)
d = preserve_context_over_fn( d = preserve_context_over_fn(
self.agent.request, self.agent.request,
method, *args, **kwargs method, uri, *args, **kwargs
) )
logger.info("Sending request %s %s", method, uri)
def _cb(response): def _cb(response):
incoming_responses_counter.inc(method, response.code) incoming_responses_counter.inc(method, response.code)
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
return response return response
def _eb(failure): def _eb(failure):
incoming_responses_counter.inc(method, "ERR") incoming_responses_counter.inc(method, "ERR")
logger.info(
"Error sending request to %s %s: %s %s",
method, uri, failure.type, failure.getErrorMessage()
)
return failure return failure
d.addCallbacks(_cb, _eb) d.addCallbacks(_cb, _eb)
@ -84,7 +94,9 @@ class SimpleHttpClient(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args) logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
response = yield self.request( response = yield self.request(
@ -97,7 +109,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -105,7 +117,7 @@ class SimpleHttpClient(object):
def post_json_get_json(self, uri, post_json): def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json) json_str = encode_canonical_json(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri) logger.debug("HTTP POST %s -> %s", json_str, uri)
response = yield self.request( response = yield self.request(
"POST", "POST",
@ -116,7 +128,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -149,7 +161,7 @@ class SimpleHttpClient(object):
}) })
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -192,7 +204,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str)) bodyProducer=FileBodyProducer(StringIO(json_str))
) )
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
if 200 <= response.code < 300: if 200 <= response.code < 300:
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -226,7 +238,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
) )
try: try:
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
defer.returnValue(body) defer.returnValue(body)
except PartialDownloadError as e: except PartialDownloadError as e:
# twisted dislikes google's response, no content length. # twisted dislikes google's response, no content length.

View File

@ -35,11 +35,13 @@ from syutil.crypto.jsonsign import sign_json
import simplejson as json import simplejson as json
import logging import logging
import sys
import urllib import urllib
import urlparse import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
@ -86,6 +88,7 @@ class MatrixFederationHttpClient(object):
) )
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes): def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse( return urlparse.urlunparse(
@ -106,16 +109,12 @@ class MatrixFederationHttpClient(object):
destination, path_bytes, param_bytes, query_bytes destination, path_bytes, param_bytes, query_bytes
) )
logger.info("Sending request to %s: %s %s", txn_id = "%s-O-%s" % (method, self._next_id)
destination, method, url_bytes) self._next_id = (self._next_id + 1) % (sys.maxint - 1)
logger.debug( outbound_logger.info(
"Types: %s", "{%s} [%s] Sending request: %s %s",
[ txn_id, destination, method, url_bytes
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
) )
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
@ -126,12 +125,15 @@ class MatrixFederationHttpClient(object):
("", "", path_bytes, param_bytes, query_bytes, "") ("", "", path_bytes, param_bytes, query_bytes, "")
) )
log_result = None
try:
while True: while True:
producer = None producer = None
if body_callback: if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict) producer = body_callback(method, http_url_bytes, headers_dict)
try: try:
def send_request():
request_deferred = preserve_context_over_fn( request_deferred = preserve_context_over_fn(
self.agent.request, self.agent.request,
method, method,
@ -140,12 +142,17 @@ class MatrixFederationHttpClient(object):
producer producer
) )
response = yield self.clock.time_bound_deferred(
return self.clock.time_bound_deferred(
request_deferred, request_deferred,
time_out=timeout/1000. if timeout else 60, time_out=timeout/1000. if timeout else 60,
) )
logger.debug("Got response to %s", method) response = yield preserve_context_over_fn(
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,)
break break
except Exception as e: except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError): if not retry_on_dns_fail and isinstance(e, DNSLookupError):
@ -154,10 +161,14 @@ class MatrixFederationHttpClient(object):
destination, destination,
e e
) )
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise raise
logger.warn( logger.warn(
"Sending request failed to %s: %s %s: %s - %s", "{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
destination, destination,
method, method,
url_bytes, url_bytes,
@ -165,19 +176,21 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e), _flatten_response_never_received(e),
) )
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
)
if retries_left and not timeout: if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left)) yield sleep(2 ** (5 - retries_left))
retries_left -= 1 retries_left -= 1
else: else:
raise raise
finally:
logger.info( outbound_logger.info(
"Received response %d %s for %s: %s %s", "{%s} [%s] Result: %s",
response.code, txn_id,
response.phrase,
destination, destination,
method, log_result,
url_bytes
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
@ -185,7 +198,7 @@ class MatrixFederationHttpClient(object):
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
body = yield readBody(response) body = yield preserve_context_over_fn(readBody, response)
raise HttpResponseException( raise HttpResponseException(
response.code, response.phrase, body response.code, response.phrase, body
) )
@ -265,10 +278,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json" "Content-Type not application/json"
) )
logger.debug("Getting resp body") body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -311,9 +321,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json" "Content-Type not application/json"
) )
logger.debug("Getting resp body") body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -371,9 +379,7 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json" "Content-Type not application/json"
) )
logger.debug("Getting resp body") body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -416,7 +422,10 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
try: try:
length = yield _readBodyToFile(response, output_stream, max_size) length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
except: except:
logger.exception("Failed to download body") logger.exception("Failed to download body")
raise raise

View File

@ -79,17 +79,11 @@ def request_handler(request_handler):
_next_request_id += 1 _next_request_id += 1
with LoggingContext(request_id) as request_context: with LoggingContext(request_id) as request_context:
request_context.request = request_id request_context.request = request_id
code = None with request.processing():
start = self.clock.time_msec()
try: try:
logger.info(
"Received request: %s %s",
request.method, request.path
)
d = request_handler(self, request) d = request_handler(self, request)
with PreserveLoggingContext(): with PreserveLoggingContext():
yield d yield d
code = request.code
except CodeMessageException as e: except CodeMessageException as e:
code = e.code code = e.code
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
@ -105,7 +99,6 @@ def request_handler(request_handler):
version_string=self.version_string, version_string=self.version_string,
) )
except: except:
code = 500
logger.exception( logger.exception(
"Failed handle request %s.%s on %r: %r", "Failed handle request %s.%s on %r: %r",
request_handler.__module__, request_handler.__module__,
@ -119,13 +112,6 @@ def request_handler(request_handler):
{"error": "Internal server error"}, {"error": "Internal server error"},
send_cors=True send_cors=True
) )
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
return wrapped_request_handler return wrapped_request_handler
@ -221,7 +207,7 @@ class JsonResource(HttpServer, resource.Resource):
incoming_requests_counter.inc(request.method, servlet_classname) incoming_requests_counter.inc(request.method, servlet_classname)
args = [ args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups() urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
] ]
callback_return = yield callback(request, *args) callback_return = yield callback(request, *args)

View File

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -45,21 +45,11 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred. notify the handler it is sufficient to resolve the deferred.
""" """
__slots__ = ["deferred"]
def __init__(self, deferred): def __init__(self, deferred):
self.deferred = deferred self.deferred = deferred
def notified(self):
return self.deferred.called
def notify(self, token):
""" Inform whoever is listening about the new events.
"""
try:
self.deferred.callback(token)
except defer.AlreadyCalledError:
pass
class _NotifierUserStream(object): class _NotifierUserStream(object):
"""This represents a user connected to the event stream. """This represents a user connected to the event stream.
@ -75,11 +65,12 @@ class _NotifierUserStream(object):
appservice=None): appservice=None):
self.user = str(user) self.user = str(user)
self.appservice = appservice self.appservice = appservice
self.listeners = set()
self.rooms = set(rooms) self.rooms = set(rooms)
self.current_token = current_token self.current_token = current_token
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key, stream_id, time_now_ms): def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an """Notify any listeners for this user of a new event from an
event source. event source.
@ -91,12 +82,10 @@ class _NotifierUserStream(object):
self.current_token = self.current_token.copy_and_advance( self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id stream_key, stream_id
) )
if self.listeners:
self.last_notified_ms = time_now_ms self.last_notified_ms = time_now_ms
listeners = self.listeners noify_deferred = self.notify_deferred
self.listeners = set() self.notify_deferred = ObservableDeferred(defer.Deferred())
for listener in listeners: noify_deferred.callback(self.current_token)
listener.notify(self.current_token)
def remove(self, notifier): def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier """ Remove this listener from all the indexes in the Notifier
@ -114,6 +103,18 @@ class _NotifierUserStream(object):
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
def count_listeners(self):
return len(self.notify_deferred.observers())
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
"""
if self.current_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -158,7 +159,7 @@ class Notifier(object):
for x in self.appservice_to_user_streams.values(): for x in self.appservice_to_user_streams.values():
all_user_streams |= x all_user_streams |= x
return sum(len(stream.listeners) for stream in all_user_streams) return sum(stream.count_listeners() for stream in all_user_streams)
metrics.register_callback("listeners", count_listeners) metrics.register_callback("listeners", count_listeners)
metrics.register_callback( metrics.register_callback(
@ -220,16 +221,7 @@ class Notifier(object):
event event
) )
room_id = event.room_id app_streams = set()
room_user_streams = self.room_to_user_streams.get(room_id, set())
user_streams = room_user_streams.copy()
for user in extra_users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for appservice in self.appservice_to_user_streams: for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
@ -241,24 +233,20 @@ class Notifier(object):
app_user_streams = self.appservice_to_user_streams.get( app_user_streams = self.appservice_to_user_streams.get(
appservice, set() appservice, set()
) )
user_streams |= app_user_streams app_streams |= app_user_streams
logger.debug("on_new_room_event listeners %s", user_streams) self.on_new_event(
"room_key", room_stream_id,
time_now_ms = self.clock.time_msec() users=extra_users,
for user_stream in user_streams: rooms=[event.room_id],
try: extra_streams=app_streams,
user_stream.notify(
"room_key", "s%d" % (room_stream_id,), time_now_ms
) )
except:
logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[],
""" Used to inform listeners that something has happend extra_streams=set()):
presence/user event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
@ -282,14 +270,10 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
deferred = defer.Deferred()
time_now_ms = self.clock.time_msec()
user = str(user) user = str(user)
user_stream = self.user_to_user_stream.get(user) user_stream = self.user_to_user_stream.get(user)
if user_stream is None: if user_stream is None:
@ -302,55 +286,44 @@ class Notifier(object):
rooms=rooms, rooms=rooms,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=time_now_ms, time_now_ms=self.clock.time_msec(),
) )
self._register_with_keys(user_stream) self._register_with_keys(user_stream)
else:
result = None
if timeout:
# Will be set to a _NotificationListener that we'll be waiting on.
# Allows us to cancel it.
listener = None
def timed_out():
if listener:
listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out)
prev_token = from_token
while not result:
try:
current_token = user_stream.current_token current_token = user_stream.current_token
listener = [_NotificationListener(deferred)] result = yield callback(prev_token, current_token)
if timeout and not current_token.is_after(from_token):
user_stream.listeners.add(listener[0])
if current_token.is_after(from_token):
result = yield callback(from_token, current_token)
else:
result = None
timer = [None]
if result: if result:
user_stream.listeners.discard(listener[0]) break
defer.returnValue(result)
return
if timeout: # Now we wait for the _NotifierUserStream to be told there
timed_out = [False] # is a new token.
# We need to supply the token we supplied to callback so
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
yield listener.deferred
except defer.CancelledError:
break
def _timeout_listener(): self.clock.cancel_call_later(timer, ignore_errs=True)
timed_out[0] = True else:
timer[0] = None current_token = user_stream.current_token
user_stream.listeners.discard(listener[0]) result = yield callback(from_token, current_token)
listener[0].notify(current_token)
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
new_token = yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(deferred)
user_stream.listeners.add(listener[0])
result = yield callback(current_token, new_token)
current_token = new_token
if timer[0] is not None:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
defer.returnValue(result) defer.returnValue(result)
@ -368,6 +341,9 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
events = [] events = []
end_token = from_token end_token = from_token
for name, source in self.event_sources.sources.items(): for name, source in self.event_sources.sources.items():
@ -376,10 +352,10 @@ class Notifier(object):
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)
if before_id == after_id: if before_id == after_id:
continue continue
stuff, new_key = yield source.get_new_events_for_user( new_events, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit, user, getattr(from_token, keyname), limit,
) )
events.extend(stuff) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
if events: if events:
@ -402,7 +378,7 @@ class Notifier(object):
expired_streams = [] expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values(): for stream in self.user_to_user_stream.values():
if stream.listeners: if stream.count_listeners():
continue continue
if stream.last_notified_ms < expire_before_ts: if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream) expired_streams.append(stream)

View File

@ -24,6 +24,7 @@ import baserules
import logging import logging
import simplejson as json import simplejson as json
import re import re
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -256,12 +257,31 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s", logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token) self.pushkey, self.user_name, self.last_token)
wait = 0
while self.alive: while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
)
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token) from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1') config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False timeout=timeout, affect_presence=False
) )
# limiting to 1 may get 1 event plus 1 presence event, so # limiting to 1 may get 1 event plus 1 presence event, so
@ -273,10 +293,11 @@ class Pusher(object):
break break
if not single_event: if not single_event:
self.last_token = chunk['end'] self.last_token = chunk['end']
continue logger.debug("Event stream timeout for pushkey %s", self.pushkey)
return
if not self.alive: if not self.alive:
continue return
processed = False processed = False
actions = yield self._actions_for_event(single_event) actions = yield self._actions_for_event(single_event)
@ -319,7 +340,7 @@ class Pusher(object):
) )
if not self.alive: if not self.alive:
continue return
if processed: if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF self.backoff_delay = Pusher.INITIAL_BACKOFF

View File

@ -164,7 +164,7 @@ def make_base_append_underride_rules(user):
] ]
}, },
{ {
'rule_id': 'global/override/.m.rule.contains_display_name', 'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [ 'conditions': [
{ {
'kind': 'contains_display_name' 'kind': 'contains_display_name'

View File

@ -31,6 +31,8 @@ REQUIREMENTS = {
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"],
"pysaml2": ["saml2"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {

View File

@ -20,14 +20,32 @@ from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
import simplejson as json import simplejson as json
import urllib
import logging
from saml2 import BINDING_HTTP_POST
from saml2 import config
from saml2.client import Saml2Client
logger = logging.getLogger(__name__)
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled
def on_GET(self, request): def on_GET(self, request):
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]}) flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if login_submission["type"] == LoginRestServlet.PASS_TYPE:
result = yield self.do_password_login(login_submission) result = yield self.do_password_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState="+urllib.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -94,6 +122,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
) )
class SAML2RestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/saml2")
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
@defer.inlineCallbacks
def on_POST(self, request):
saml2_auth = None
try:
conf = config.SPConfig()
conf.load_file(self.sp_config)
SP = Saml2Client(conf)
saml2_auth = SP.parse_authn_request_response(
request.args['SAMLResponse'][0], BINDING_HTTP_POST)
except Exception, e: # Not authenticated
logger.exception(e)
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
username = saml2_auth.name_id.text
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava
if 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.args['RelayState'][0]) +
'?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' +
urllib.quote(json.dumps(saml2_auth.ava)))
request.finish()
defer.returnValue(None)
defer.returnValue((200, {"status": "authenticated",
"user_id": user_id, "token": token,
"ava": saml2_auth.ava}))
elif 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.args['RelayState'][0]) +
'?status=not_authenticated')
request.finish()
defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"}))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -106,4 +177,6 @@ def _parse_json(request):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)

View File

@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if "user_id" not in content: if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.") raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"] state_key = content["user_id"]
# make sure it looks like a user ID; it'll throw if it's invalid.
UserID.from_string(state_key)
if membership_action == "kick": if membership_action == "kick":
membership_action = "leave" membership_action = "leave"

View File

@ -39,10 +39,10 @@ class HttpTransactionStore(object):
A tuple of (HTTP response code, response content) or None. A tuple of (HTTP response code, response content) or None.
""" """
try: try:
logger.debug("get_response Key: %s TxnId: %s", key, txn_id) logger.debug("get_response TxnId: %s", txn_id)
(last_txn_id, response) = self.transactions[key] (last_txn_id, response) = self.transactions[key]
if txn_id == last_txn_id: if txn_id == last_txn_id:
logger.info("get_response: Returning a response for %s", key) logger.info("get_response: Returning a response for %s", txn_id)
return response return response
except KeyError: except KeyError:
pass pass
@ -58,7 +58,7 @@ class HttpTransactionStore(object):
txn_id (str): The transaction ID for this request. txn_id (str): The transaction ID for this request.
response (tuple): A tuple of (HTTP response code, response content) response (tuple): A tuple of (HTTP response code, response content)
""" """
logger.debug("store_response Key: %s TxnId: %s", key, txn_id) logger.debug("store_response TxnId: %s", txn_id)
self.transactions[key] = (txn_id, response) self.transactions[key] = (txn_id, response)
def store_client_transaction(self, request, txn_id, response): def store_client_transaction(self, request, txn_id, response):

View File

@ -18,7 +18,9 @@ from . import (
filter, filter,
account, account,
register, register,
auth auth,
receipts,
keys,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
account.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)

View File

@ -0,0 +1,276 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from syutil.jsonutil import encode_canonical_json
from ._base import client_v2_pattern
import simplejson as json
import logging
logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet):
"""
POST /keys/upload/<device_id> HTTP/1.1
Content-Type: application/json
{
"device_keys": {
"user_id": "<user_id>",
"device_id": "<device_id>",
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [
"m.olm.curve25519-aes-sha256",
]
"keys": {
"<algorithm>:<device_id>": "<key_base64>",
},
"signatures:" {
"<user_id>" {
"<algorithm>:<device_id>": "<signature_base64>"
} } },
"one_time_keys": {
"<algorithm>:<key_id>": "<key_base64>"
},
}
"""
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
def __init__(self, hs):
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = body.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %r at %d",
device_id, auth_user, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = body.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@defer.inlineCallbacks
def on_GET(self, request, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
class KeyQueryServlet(RestServlet):
"""
GET /keys/query/<user_id> HTTP/1.1
GET /keys/query/<user_id>/<device_id> HTTP/1.1
POST /keys/query HTTP/1.1
Content-Type: application/json
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
HTTP/1.1 200 OK
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"user_id": "<user_id>", // Duplicated to be signed
"device_id": "<device_id>", // Duplicated to be signed
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [ // List of supported algorithms
"m.olm.curve25519-aes-sha256",
],
"keys": { // Must include a ed25519 signing key
"<algorithm>:<key_id>": "<key_base64>",
},
"signatures:" {
// Must be signed with device's ed25519 key
"<user_id>/<device_id>": {
"<algorithm>:<key_id>": "<signature_base64>"
}
// Must be signed by this server.
"<server_name>": {
"<algorithm>:<key_id>": "<signature_base64>"
} } } } } }
"""
PATTERN = client_v2_pattern(
"/keys/query(?:"
"/(?P<user_id>[^/]*)(?:"
"/(?P<device_id>[^/]*)"
")?"
")?"
)
def __init__(self, hs):
super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
logger.debug("onPOST")
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
query = []
for user_id, device_ids in body.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
defer.returnValue(self.json_result(request, results))
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
if not user_id:
user_id = auth_user_id
if not device_id:
device_id = None
# Returns a map of user_id->device_id->json_bytes.
results = yield self.store.get_e2e_device_keys([(user_id, device_id)])
defer.returnValue(self.json_result(request, results))
def json_result(self, request, results):
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
return (200, {"device_keys": json_result})
class OneTimeKeyServlet(RestServlet):
"""
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
POST /keys/claim HTTP/1.1
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
HTTP/1.1 200 OK
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
"""
PATTERN = client_v2_pattern(
"/keys/claim(?:/?|(?:/"
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
")?)"
)
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
results = yield self.store.claim_e2e_one_time_keys(
[(user_id, device_id, algorithm)]
)
defer.returnValue(self.json_result(request, results))
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
query = []
for user_id, device_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
defer.returnValue(self.json_result(request, results))
def json_result(self, request, results):
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
return (200, {"one_time_keys": json_result})
def register_servlets(hs, http_server):
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)

View File

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet):
PATTERN = client_v2_pattern(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(ReceiptRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
user, client = yield self.auth.get_user_by_req(request)
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,
user_id=user.to_string(),
event_id=event_id
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReceiptRestServlet(hs).register(http_server)

View File

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty from ._base import client_v2_pattern, parse_json_dict_from_request
import logging import logging
import hmac import hmac
@ -55,21 +55,55 @@ class RegisterRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_json_dict_from_request(request)
body = parse_request_allow_empty(request) # we do basic sanity checks here because the auth layer will store these
if 'password' not in body: # in sessions. Pull out the username/password provided to us.
raise SynapseError(400, "", Codes.MISSING_PARAM) desired_password = None
if 'password' in body:
if (not isinstance(body['password'], basestring) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body: if 'username' in body:
if (not isinstance(body['username'], basestring) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username'] desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False appservice = None
is_application_server = False
service = None
if 'access_token' in request.args: if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
# have completely different registration flows to normal users
# == Application Service Registration ==
if appservice:
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
if self.hs.config.disable_registration:
raise SynapseError(403, "Registration has been disabled")
if desired_username is not None:
yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ flows = [
@ -82,39 +116,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
result = None
if service:
is_application_server = True
params = body
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
params = body
else:
authed, result, params = yield self.auth_handler.check_auth( authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
) )
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, result))
return
can_register = ( # NB: This may be from the auth handler and NOT from the POST
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if 'password' not in params: if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password'] desired_username = params.get("username", None)
new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
@ -147,18 +162,21 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
result = { result = self._create_registration_details(user_id, token)
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result)) defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def _check_shared_secret_auth(self, username, mac): @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue(self._create_registration_details(user_id, token))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -174,13 +192,23 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1, digestmod=sha1,
).hexdigest() ).hexdigest()
if compare_digest(want_mac, got_mac): if not compare_digest(want_mac, got_mac):
return True
else:
raise SynapseError( raise SynapseError(
403, "HMAC incorrect", 403, "HMAC incorrect",
) )
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
)
defer.returnValue(self._create_registration_details(user_id, token))
def _create_registration_details(self, user_id, token):
return {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -15,20 +15,23 @@
from .thumbnailer import Thumbnailer from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from synapse.api.errors import ( from synapse.api.errors import (
cs_error, Codes, SynapseError cs_error, Codes, SynapseError
) )
from twisted.internet import defer from twisted.internet import defer, threads
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
import os import os
import cgi
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,8 +39,13 @@ logger = logging.getLogger(__name__)
def parse_media_id(request): def parse_media_id(request):
try: try:
server_name, media_id = request.postpath # This allows users to append e.g. /test.png to the URL. Useful for
return (server_name, media_id) # clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
if len(request.postpath) > 2 and is_ascii(request.postpath[-1]):
return server_name, media_id, request.postpath[-1]
else:
return server_name, media_id, None
except: except:
raise SynapseError( raise SynapseError(
404, 404,
@ -52,7 +60,7 @@ class BaseMediaResource(Resource):
def __init__(self, hs, filepaths): def __init__(self, hs, filepaths):
Resource.__init__(self) Resource.__init__(self)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.client = hs.get_http_client() self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.server_name = hs.hostname self.server_name = hs.hostname
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -127,12 +135,21 @@ class BaseMediaResource(Resource):
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = params.get("filename", None)
if upload_name and not is_ascii(upload_name):
upload_name = None
else:
upload_name = None
yield self.store.store_cached_remote_media( yield self.store.store_cached_remote_media(
origin=server_name, origin=server_name,
media_id=media_id, media_id=media_id,
media_type=media_type, media_type=media_type,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
upload_name=None, upload_name=upload_name,
media_length=length, media_length=length,
filesystem_id=file_id, filesystem_id=file_id,
) )
@ -143,7 +160,7 @@ class BaseMediaResource(Resource):
media_info = { media_info = {
"media_type": media_type, "media_type": media_type,
"media_length": length, "media_length": length,
"upload_name": None, "upload_name": upload_name,
"created_ts": time_now_ms, "created_ts": time_now_ms,
"filesystem_id": file_id, "filesystem_id": file_id,
} }
@ -156,11 +173,16 @@ class BaseMediaResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path, def _respond_with_file(self, request, media_type, file_path,
file_size=None): file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8")) request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (upload_name.encode("utf-8"),),
)
# cache for at least a day. # cache for at least a day.
# XXX: we might want to turn this off for data we don't want to # XXX: we might want to turn this off for data we don't want to
@ -222,6 +244,9 @@ class BaseMediaResource(Resource):
) )
return return
local_thumbnails = []
def generate_thumbnails():
scales = set() scales = set()
crops = set() crops = set()
for r_width, r_height, r_method, r_type in requirements: for r_width, r_height, r_method, r_type in requirements:
@ -240,9 +265,10 @@ class BaseMediaResource(Resource):
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len media_id, t_width, t_height, t_type, t_method, t_len
) ))
for t_width, t_height, t_type in crops: for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales: if (t_width, t_height, t_type) in scales:
@ -256,9 +282,14 @@ class BaseMediaResource(Resource):
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail( local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len media_id, t_width, t_height, t_type, t_method, t_len
) ))
yield threads.deferToThread(generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({ defer.returnValue({
"width": m_width, "width": m_width,
@ -273,11 +304,14 @@ class BaseMediaResource(Resource):
if not requirements: if not requirements:
return return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id) input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path) thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width m_width = thumbnailer.width
m_height = thumbnailer.height m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels: if m_width * m_height >= self.max_image_pixels:
logger.info( logger.info(
"Image too large to thumbnail %r x %r > %r", "Image too large to thumbnail %r x %r > %r",
@ -303,10 +337,10 @@ class BaseMediaResource(Resource):
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail( remote_thumbnails.append([
server_name, media_id, file_id, server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
) ])
for t_width, t_height, t_type in crops: for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales: if (t_width, t_height, t_type) in scales:
@ -320,10 +354,15 @@ class BaseMediaResource(Resource):
) )
self._makedirs(t_path) self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_remote_media_thumbnail( remote_thumbnails.append([
server_name, media_id, file_id, server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len t_width, t_height, t_type, t_method, t_len
) ])
yield threads.deferToThread(generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({ defer.returnValue({
"width": m_width, "width": m_width,

View File

@ -32,14 +32,16 @@ class DownloadResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
server_name, media_id = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id) yield self._respond_local_file(request, media_id, name)
else: else:
yield self._respond_remote_file(request, server_name, media_id) yield self._respond_remote_file(
request, server_name, media_id, name
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_local_file(self, request, media_id): def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info:
self._respond_404(request) self._respond_404(request)
@ -47,24 +49,28 @@ class DownloadResource(BaseMediaResource):
media_type = media_info["media_type"] media_type = media_info["media_type"]
media_length = media_info["media_length"] media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.local_media_filepath(media_id) file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file( yield self._respond_with_file(
request, media_type, file_path, media_length request, media_type, file_path, media_length,
upload_name=upload_name,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id): def _respond_remote_file(self, request, server_name, media_id, name):
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self._get_remote_media(server_name, media_id)
media_type = media_info["media_type"] media_type = media_info["media_type"]
media_length = media_info["media_length"] media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"] filesystem_id = media_info["filesystem_id"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath( file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id server_name, filesystem_id
) )
yield self._respond_with_file( yield self._respond_with_file(
request, media_type, file_path, media_length request, media_type, file_path, media_length,
upload_name=upload_name,
) )

View File

@ -36,7 +36,7 @@ class ThumbnailResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
server_name, media_id = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width") width = parse_integer(request, "width")
height = parse_integer(request, "height") height = parse_integer(request, "height")
method = parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")
@ -162,11 +162,12 @@ class ThumbnailResource(BaseMediaResource):
t_method = info["thumbnail_method"] t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop": if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w) aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h)) size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"] type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"] length_quality = info["thumbnail_length"]
info_list.append(( info_list.append((
aspect_quality, size_quality, type_quality, aspect_quality, min_quality, size_quality, type_quality,
length_quality, info length_quality, info
)) ))
if info_list: if info_list:

View File

@ -82,7 +82,7 @@ class Thumbnailer(object):
def save_image(self, output_image, output_type, output_path): def save_image(self, output_image, output_type, output_path):
output_bytes_io = BytesIO() output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70) output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
output_bytes = output_bytes_io.getvalue() output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file: with open(output_path, "wb") as output_file:
output_file.write(output_bytes) output_file.write(output_bytes)

View File

@ -15,7 +15,7 @@
from synapse.http.server import respond_with_json, request_handler from synapse.http.server import respond_with_json, request_handler
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string, is_ascii
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -84,6 +84,12 @@ class UploadResource(BaseMediaResource):
code=413, code=413,
) )
upload_name = request.args.get("filename", None)
if upload_name:
upload_name = upload_name[0]
if upload_name and not is_ascii(upload_name):
raise SynapseError(400, "filename must be ascii")
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader("Content-Type"): if headers.hasHeader("Content-Type"):
@ -99,7 +105,7 @@ class UploadResource(BaseMediaResource):
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
content_uri = yield self.create_content( content_uri = yield self.create_content(
media_type, None, request.content.read(), media_type, upload_name, request.content.read(),
content_length, auth_user content_length, auth_user
) )

View File

@ -132,16 +132,8 @@ class BaseHomeServer(object):
setattr(BaseHomeServer, "get_%s" % (depname), _get) setattr(BaseHomeServer, "get_%s" % (depname), _get)
def get_ip_from_request(self, request): def get_ip_from_request(self, request):
# May be an X-Forwarding-For header depending on config # X-Forwarded-For is handled by our custom request type.
ip_addr = request.getClientIP() return request.getClientIP()
if self.config.captcha_ip_origin_is_x_forwarded:
# use the header
if request.requestHeaders.hasHeader("X-Forwarded-For"):
ip_addr = request.requestHeaders.getRawHeaders(
"X-Forwarded-For"
)[0]
return ip_addr
def is_mine(self, domain_specific_string): def is_mine(self, domain_specific_string):
return domain_specific_string.domain == self.hostname return domain_specific_string.domain == self.hostname

View File

@ -106,7 +106,7 @@ class StateHandler(object):
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks @defer.inlineCallbacks
def compute_event_context(self, event, old_state=None): def compute_event_context(self, event, old_state=None, outlier=False):
""" Fills out the context with the `current state` of the graph. The """ Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph `current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event` just before the event - i.e. it never includes `event`
@ -119,9 +119,23 @@ class StateHandler(object):
Returns: Returns:
an EventContext an EventContext
""" """
yield run_on_reactor()
context = EventContext() context = EventContext()
yield run_on_reactor() if outlier:
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
context.current_state = {
(s.type, s.state_key): s for s in old_state
}
else:
context.current_state = {}
context.prev_state_events = []
context.state_group = None
defer.returnValue(context)
if old_state: if old_state:
context.current_state = { context.current_state = {
@ -155,10 +169,6 @@ class StateHandler(object):
context.current_state = curr_state context.current_state = curr_state
context.state_group = group if not event.is_state() else None context.state_group = group if not event.is_state() else None
prev_state = yield self.store.add_event_hashes(
prev_state
)
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.current_state:

View File

@ -37,6 +37,9 @@ from .rejections import RejectionsStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore
import fnmatch import fnmatch
@ -51,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 19 SCHEMA_VERSION = 21
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventsStore, EventsStore,
ReceiptsStore,
EndToEndKeyStore,
): ):
def __init__(self, hs): def __init__(self, hs):
@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip) key = (user.to_string(), access_token, device_id, ip)
try: try:
last_seen = self.client_ip_last_seen.get(*key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, # It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@ -348,7 +353,12 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file module_name, absolute_path, python_file
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur) module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)

View File

@ -15,6 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.async import ObservableDeferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -27,6 +28,7 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import inspect
import sys import sys
import time import time
import threading import threading
@ -55,9 +57,12 @@ cache_counter = metrics.register_cache(
) )
_CacheSentinel = object()
class Cache(object): class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False): def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru: if lru:
self.cache = LruCache(max_size=max_entries) self.cache = LruCache(max_size=max_entries)
self.max_entries = None self.max_entries = None
@ -81,45 +86,44 @@ class Cache(object):
"Cache objects can only be accessed from the main thread" "Cache objects can only be accessed from the main thread"
) )
def get(self, *keyargs): def get(self, key, default=_CacheSentinel):
if len(keyargs) != self.keylen: val = self.cache.get(key, _CacheSentinel)
raise ValueError("Expected a key to have %d items", self.keylen) if val is not _CacheSentinel:
if keyargs in self.cache:
cache_counter.inc_hits(self.name) cache_counter.inc_hits(self.name)
return self.cache[keyargs] return val
cache_counter.inc_misses(self.name) cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args): if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread() self.check_thread()
if self.sequence == sequence: if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the # Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369) # number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args) self.prefill(key, value)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
def prefill(self, key, value):
if self.max_entries is not None: if self.max_entries is not None:
while len(self.cache) >= self.max_entries: while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False) self.cache.popitem(last=False)
self.cache[keyargs] = value self.cache[key] = value
def invalidate(self, *keyargs): def invalidate(self, key):
self.check_thread() self.check_thread()
if len(keyargs) != self.keylen: if not isinstance(key, tuple):
raise ValueError("Expected a key to have %d items", self.keylen) raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that # Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369) # raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1 self.sequence += 1
self.cache.pop(keyargs, None) self.cache.pop(key, None)
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
@ -127,9 +131,12 @@ class Cache(object):
self.cache.clear() self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False): class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function. """ A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache; a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value. misses use the function body to generate the value.
@ -141,47 +148,108 @@ def cached(max_entries=1000, num_args=1, lru=False):
which can be used to insert values into the cache specifically, without which can be used to insert values into the cache specifically, without
calling the calculation function. calling the calculation function.
""" """
def wrap(orig): def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
cache = Cache( inlineCallbacks=False):
name=orig.__name__, self.orig = orig
max_entries=max_entries,
keylen=num_args, if inlineCallbacks:
lru=lru, self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
) )
@functools.wraps(orig) self.cache = Cache(
@defer.inlineCallbacks name=self.orig.__name__,
def wrapped(self, *keyargs): max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try: try:
cached_result = cache.get(*keyargs) cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES: if DEBUG_CACHES:
actual_result = yield orig(self, *keyargs) @defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result: if actual_result != cached_result:
logger.error( logger.error(
"Stale cache entry %s%r: cached: %r, actual %r", "Stale cache entry %s%r: cached: %r, actual %r",
orig.__name__, keyargs, self.orig.__name__, cache_key,
cached_result, actual_result, cached_result, actual_result,
) )
raise ValueError("Stale cache entry") raise ValueError("Stale cache entry")
defer.returnValue(cached_result) defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError: except KeyError:
# Get the sequence number of the cache before reading from the # Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated # database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369) # while the SELECT is executing (SYN-369)
sequence = cache.sequence sequence = self.cache.sequence
ret = yield orig(self, *keyargs) ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
cache.update(sequence, *keyargs + (ret,)) def onErr(f):
self.cache.invalidate(cache_key)
return f
defer.returnValue(ret) ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
return wrapped return wrapped
return wrap
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
class LoggingTransaction(object): class LoggingTransaction(object):
@ -312,13 +380,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator() self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()

View File

@ -104,7 +104,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
@ -114,7 +114,7 @@ class DirectoryStore(SQLBaseStore):
room_alias, room_alias,
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
return self._simple_upsert(
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
)
def get_e2e_device_keys(self, query_list):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key json byte strings.
"""
def _get_e2e_device_keys(txn):
result = {}
for user_id, device_id in query_list:
user_result = result.setdefault(user_id, {})
keyvalues = {"user_id": user_id}
if device_id:
keyvalues["device_id"] = device_id
rows = self._simple_select_list_txn(
txn, table="e2e_device_keys_json",
keyvalues=keyvalues,
retcols=["device_id", "key_json"]
)
for row in rows:
user_result[row["device_id"]] = row["key_json"]
return result
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn):
for (algorithm, key_id, json_bytes) in key_list:
self._simple_upsert_txn(
txn, table="e2e_one_time_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
)
return self.runInteraction(
"add_e2e_one_time_keys", _add_e2e_one_time_keys
)
def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
" GROUP BY algorithm"
)
txn.execute(sql, (user_id, device_id))
result = {}
for algorithm, key_count in txn.fetchall():
result[algorithm] = key_count
return result
return self.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):
sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
result = {}
delete = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall():
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" AND key_id = ?"
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)

View File

@ -49,14 +49,22 @@ class EventFederationStore(SQLBaseStore):
results = set() results = set()
base_sql = ( base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id = ?" "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
) )
front = set(event_ids) front = set(event_ids)
while front: while front:
new_front = set() new_front = set()
for f in front: front_list = list(front)
txn.execute(base_sql, (f,)) chunks = [
front_list[x:x+100]
for x in xrange(0, len(front), 100)
]
for chunk in chunks:
txn.execute(
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn.fetchall()])
new_front -= results new_front -= results
@ -274,8 +282,7 @@ class EventFederationStore(SQLBaseStore):
}, },
) )
def _handle_prev_events(self, txn, outlier, event_id, prev_events, def _handle_mult_prev_events(self, txn, events):
room_id):
""" """
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
@ -285,45 +292,47 @@ class EventFederationStore(SQLBaseStore):
table="event_edges", table="event_edges",
values=[ values=[
{ {
"event_id": event_id, "event_id": ev.event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": ev.room_id,
"is_state": False, "is_state": False,
} }
for e_id, _ in prev_events for ev in events
for e_id, _ in ev.prev_events
], ],
) )
# Update the extremities table if this is not an outlier. events_by_room = {}
if not outlier: for ev in events:
for e_id, _ in prev_events: events_by_room.setdefault(ev.room_id, []).append(ev)
# TODO (erikj): This could be done as a bulk insert
self._simple_delete_txn( for room_id, room_events in events_by_room.items():
txn, prevs = [
table="event_forward_extremities", e_id for ev in room_events for e_id, _ in ev.prev_events
keyvalues={ if not ev.internal_metadata.is_outlier()
"event_id": e_id, ]
"room_id": room_id, if prevs:
} txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
) )
# We only insert as a forward extremity the new event if there are
# no other events that reference it as a prev event
query = ( query = (
"SELECT 1 FROM event_edges WHERE prev_event_id = ?" "INSERT INTO event_forward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
" )"
) )
txn.execute(query, (event_id,)) txn.executemany(
query,
if not txn.fetchone(): [(ev.event_id, ev.room_id, ev.event_id) for ev in events]
query = (
"INSERT INTO event_forward_extremities"
" (event_id, room_id)"
" VALUES (?, ?)"
) )
txn.execute(query, (event_id, room_id))
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS (" " SELECT ?, ? WHERE NOT EXISTS ("
@ -337,18 +346,23 @@ class EventFederationStore(SQLBaseStore):
) )
txn.executemany(query, [ txn.executemany(query, [
(e_id, room_id, e_id, room_id, e_id, room_id, False) (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for e_id, _ in prev_events for ev in events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
]) ])
query = ( query = (
"DELETE FROM event_backward_extremities" "DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?" " WHERE event_id = ? AND room_id = ?"
) )
txn.execute(query, (event_id, room_id)) txn.executemany(
query,
[(ev.event_id, ev.room_id) for ev in events]
)
for room_id in events_by_room:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
@ -400,9 +414,11 @@ class EventFederationStore(SQLBaseStore):
keyvalues={ keyvalues={
"event_id": event_id, "event_id": event_id,
}, },
retcol="depth" retcol="depth",
allow_none=True,
) )
if depth:
queue.put((-depth, event_id)) queue.put((-depth, event_id))
while not queue.empty() and len(event_results) < limit: while not queue.empty() and len(event_results) < limit:
@ -489,4 +505,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View File

@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_json from syutil.jsonutil import encode_json
from contextlib import contextmanager from contextlib import contextmanager
@ -46,6 +44,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True):
if not events_and_contexts:
return
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
start = self.min_token - 1
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else:
stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
self, len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
chunks = [
events_and_contexts[x:x+100]
for x in xrange(0, len(events_and_contexts), 100)
]
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
is_new_state=is_new_state,
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False, def persist_event(self, event, context, backfilled=False,
@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
try: try:
with stream_ordering_manager as stream_ordering: with stream_ordering_manager as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore):
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True, is_new_state=True, current_state=None):
current_state=None):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn( self._simple_delete_txn(
@ -149,21 +184,72 @@ class EventsStore(SQLBaseStore):
} }
) )
outlier = event.internal_metadata.is_outlier() return self._persist_events_txn(
if not outlier:
self._update_min_depth_for_room_txn(
txn, txn,
event.room_id, [(event, context)],
event.depth backfilled=backfilled,
is_new_state=is_new_state,
) )
have_persisted = self._simple_select_one_txn( @log_function
txn, def _persist_events_txn(self, txn, events_and_contexts, backfilled,
table="events", is_new_state=True):
keyvalues={"event_id": event.event_id},
retcols=["event_id", "outlier"], # Remove the any existing cache entries for the event_ids
allow_none=True, for event, _ in events_and_contexts:
txn.call_after(self._invalidate_get_event_cache, event.event_id)
depth_updates = {}
for event, _ in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)),
),
[event.event_id for event, _ in events_and_contexts]
)
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
event_map = {}
to_remove = set()
for event, context in events_and_contexts:
# Handle the case of the list including the same event multiple
# times. The tricky thing here is when they differ by whether
# they are an outlier.
if event.event_id in event_map:
other = event_map[event.event_id]
if not other.internal_metadata.is_outlier():
to_remove.add(event)
continue
elif not event.internal_metadata.is_outlier():
to_remove.add(event)
continue
else:
to_remove.add(other)
event_map[event.event_id] = event
if event.event_id not in have_persisted:
continue
to_remove.add(event)
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
self._store_state_groups_txn(
txn, event, context,
) )
metadata_json = encode_json( metadata_json = encode_json(
@ -171,16 +257,6 @@ class EventsStore(SQLBaseStore):
using_frozen_dicts=USE_FROZEN_DICTS using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8") ).decode("UTF-8")
# If we have already persisted this event, we don't need to do any
# more processing.
# The processing above must be done on every call to persist event,
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier,
# but is no longer one.
if have_persisted:
if not outlier and have_persisted["outlier"]:
self._store_state_groups_txn(txn, event, context)
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
" WHERE event_id = ?" " WHERE event_id = ?"
@ -198,29 +274,45 @@ class EventsStore(SQLBaseStore):
sql, sql,
(False, event.event_id,) (False, event.event_id,)
) )
return
if not outlier: events_and_contexts = filter(
self._store_state_groups_txn(txn, event, context) lambda ec: ec[0] not in to_remove,
events_and_contexts
self._handle_prev_events(
txn,
outlier=outlier,
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
) )
if event.type == EventTypes.Member: if not events_and_contexts:
self._store_room_member_txn(txn, event) return
elif event.type == EventTypes.Name:
self._store_mult_state_groups_txn(txn, [
(event, context)
for event, context in events_and_contexts
if not event.internal_metadata.is_outlier()
])
self._handle_mult_prev_events(
txn,
events=[event for event, _ in events_and_contexts],
)
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
self._store_room_name_txn(txn, event) self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic: elif event.type == EventTypes.Topic:
self._store_room_topic_txn(txn, event) self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
event_dict = { self._store_room_members_txn(
txn,
[
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
]
)
def event_dict(event):
return {
k: v k: v
for k, v in event.get_dict().items() for k, v in event.get_dict().items()
if k not in [ if k not in [
@ -229,63 +321,44 @@ class EventsStore(SQLBaseStore):
] ]
} }
self._simple_insert_txn( self._simple_insert_many_txn(
txn, txn,
table="event_json", table="event_json",
values={ values=[
{
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
"internal_metadata": metadata_json, "internal_metadata": encode_json(
"json": encode_json( event.internal_metadata.get_dict(),
event_dict, using_frozen_dicts=USE_FROZEN_DICTS using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"), ).decode("UTF-8"),
}, "json": encode_json(
event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
}
for event, _ in events_and_contexts
],
) )
content = encode_json( self._simple_insert_many_txn(
event.content, using_frozen_dicts=USE_FROZEN_DICTS txn,
).decode("UTF-8") table="events",
values=[
vals = { {
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth, "topological_ordering": event.depth,
"event_id": event.event_id,
"type": event.type,
"room_id": event.room_id,
"content": content,
"processed": True,
"outlier": outlier,
"depth": event.depth, "depth": event.depth,
"event_id": event.event_id,
"room_id": event.room_id,
"type": event.type,
"processed": True,
"outlier": event.internal_metadata.is_outlier(),
"content": encode_json(
event.content, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
} }
for event, _ in events_and_contexts
unrec = { ],
k: v
for k, v in event.get_dict().items()
if k not in vals.keys() and k not in [
"redacted",
"redacted_because",
"signatures",
"hashes",
"prev_events",
]
}
vals["unrecognized_keys"] = encode_json(
unrec, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = (
"INSERT INTO events"
" (stream_ordering, topological_ordering, event_id, type,"
" room_id, content, processed, outlier, depth)"
" VALUES (?,?,?,?,?,?,?,?,?)"
)
txn.execute(
sql,
(
stream_ordering, event.depth, event.event_id, event.type,
event.room_id, content, True, outlier, event.depth
)
) )
if context.rejected: if context.rejected:
@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, context.rejected txn, event.event_id, context.rejected
) )
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
txn, event.event_id, hash_alg, hash_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg,
hash_bytes
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_auth", table="event_auth",
@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore):
"room_id": event.room_id, "room_id": event.room_id,
"auth_id": auth_id, "auth_id": auth_id,
} }
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events for auth_id, _ in event.auth_events
], ],
) )
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) self._store_event_reference_hashes_txn(
self._store_event_reference_hash_txn( txn, [event for event, _ in events_and_contexts]
txn, event.event_id, ref_alg, ref_hash_bytes
) )
if event.is_state(): state_events_and_contexts = filter(
lambda i: i[0].is_state(),
events_and_contexts,
)
state_values = []
for event, context in state_events_and_contexts:
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -337,10 +402,12 @@ class EventsStore(SQLBaseStore):
if hasattr(event, "replaces_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state vals["prev_state"] = event.replaces_state
self._simple_insert_txn( state_values.append(vals)
self._simple_insert_many_txn(
txn, txn,
"state_events", table="state_events",
vals, values=state_values,
) )
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -349,25 +416,27 @@ class EventsStore(SQLBaseStore):
values=[ values=[
{ {
"event_id": event.event_id, "event_id": event.event_id,
"prev_event_id": e_id, "prev_event_id": prev_id,
"room_id": event.room_id, "room_id": event.room_id,
"is_state": True, "is_state": True,
} }
for e_id, h in event.prev_state for event, _ in state_events_and_contexts
for prev_id, _ in event.prev_state
], ],
) )
if is_new_state and not context.rejected: if is_new_state:
for event, _ in state_events_and_contexts:
if not context.rejected:
txn.call_after( txn.call_after(
self.get_current_state_for_key.invalidate, self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key (event.room_id, event.type, event.state_key,)
) )
if (event.type == EventTypes.Name if event.type in [EventTypes.Name, EventTypes.Aliases]:
or event.type == EventTypes.Aliases):
txn.call_after( txn.call_after(
self.get_room_name_and_aliases.invalidate, self.get_room_name_and_aliases.invalidate,
event.room_id (event.room_id,)
) )
self._simple_upsert_txn( self._simple_upsert_txn(
@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
for get_prev_content in (False, True): for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted, self._get_event_cache.invalidate(
get_prev_content) (event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events: for event_id in events:
try: try:
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content (event_id, check_redacted, get_prev_content,)
) )
if allow_rejected or not ret.rejected_reason: if allow_rejected or not ret.rejected_reason:
@ -736,7 +806,8 @@ class EventsStore(SQLBaseStore):
because = yield self.get_event( because = yield self.get_event(
redaction_id, redaction_id,
check_redacted=False check_redacted=False,
allow_none=True,
) )
if because: if because:
@ -746,12 +817,13 @@ class EventsStore(SQLBaseStore):
prev = yield self.get_event( prev = yield self.get_event(
ev.unsigned["replaces_state"], ev.unsigned["replaces_state"],
get_prev_content=False, get_prev_content=False,
allow_none=True,
) )
if prev: if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
defer.returnValue(ev) defer.returnValue(ev)
@ -808,7 +880,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
return ev return ev

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore from _base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -71,6 +71,24 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate", desc="store_server_certificate",
) )
@cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
},
retcols=["key_id", "verify_key"],
desc="get_all_server_verify_keys",
)
defer.returnValue({
row["key_id"]: decode_verify_key_bytes(
row["key_id"], str(row["verify_key"])
)
for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids): def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given """Retrieve the NACL verification key for a given server for the given
@ -81,24 +99,14 @@ class KeyStore(SQLBaseStore):
Returns: Returns:
(list of VerifyKey): The verification keys. (list of VerifyKey): The verification keys.
""" """
sql = ( keys = yield self.get_all_server_verify_keys(server_name)
"SELECT key_id, verify_key FROM server_signature_keys" defer.returnValue({
" WHERE server_name = ?" k: keys[k]
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" for k in key_ids
) if k in keys and keys[k]
})
rows = yield self._execute_and_decode(
"get_server_verify_keys", sql, server_name, *key_ids
)
keys = []
for row in rows:
key_id = row["key_id"]
key_bytes = row["verify_key"]
key = decode_verify_key_bytes(key_id, str(key_bytes))
keys.append(key)
defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key): verify_key):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
@ -109,7 +117,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
return self._simple_upsert( yield self._simple_upsert(
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
@ -123,6 +131,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server
@ -152,6 +162,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms, "ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes), "key_json": buffer(key_json_bytes),
}, },
desc="store_server_keys_json",
) )
def get_server_keys_json(self, server_keys): def get_server_keys_json(self, server_keys):

View File

@ -98,7 +98,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result) defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +133,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,8 +23,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -41,8 +40,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name, table=PushRuleEnableTable.table_name,
@ -153,11 +151,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -189,10 +187,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -218,8 +216,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate(user_name) self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -240,10 +238,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id}, {'id': new_id},
) )
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )

347
synapse/storage/receipts.py Normal file
View File

@ -0,0 +1,347 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer
from synapse.util import unwrapFirstError
from blist import sorteddict
import logging
import ujson as json
logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache()
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
Args:
room_ids (list): List of room_ids.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
"""
room_ids = set(room_ids)
if from_key:
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
self, room_ids, from_key
)
results = yield defer.gatherResults(
[
self.get_linearized_receipts_for_room(
room_id, to_key, from_key=from_key
)
for room_id in room_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue([ev for res in results for ev in res])
@defer.inlineCallbacks
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
Args:
room_ids (str): The room id.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
"""
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, from_key, to_key)
)
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, to_key)
)
rows = self.cursor_to_dict(txn)
return rows
rows = yield self.runInteraction(
"get_linearized_receipts_for_room", f
)
if not rows:
defer.returnValue([])
content = {}
for row in rows:
content.setdefault(
row["event_id"], {}
).setdefault(
row["receipt_type"], {}
)[row["user_id"]] = json.loads(row["data"])
defer.returnValue([{
"type": "m.receipt",
"room_id": room_id,
"content": content,
}])
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
@cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""
rows = yield self._simple_select_list(
table="receipts_graph",
keyvalues={"room_id": room_id},
retcols=["receipt_type", "user_id", "event_id"],
desc="get_linearized_receipts_for_room",
)
result = {}
for row in rows:
result.setdefault(
row["user_id"], {}
).setdefault(
row["receipt_type"], []
).append(row["event_id"])
defer.returnValue(result)
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
)
txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results:
res = self._simple_select_one_txn(
txn,
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
)
topological_ordering = int(res["topological_ordering"])
stream_ordering = int(res["stream_ordering"])
for to, so, _ in results:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:
return False
self._simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
)
self._simple_insert_txn(
txn,
table="receipts_linearized",
values={
"stream_id": stream_id,
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
}
)
return True
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
return
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
query = (
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
")"
) % (",".join(["?"] * len(event_ids)))
txn.execute(query, [room_id] + event_ids)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
yield self._receipts_stream_cache.room_has_changed(
self, room_id, stream_id
)
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
data,
stream_id=stream_id,
)
if not have_persisted:
defer.returnValue(None)
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
data):
return self.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id, receipt_type, user_id, event_ids, data
)
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data):
self._simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
)
self._simple_insert_txn(
txn,
table="receipts_graph",
values={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": json.dumps(event_ids),
"data": json.dumps(data),
}
)
class _RoomStreamChangeCache(object):
"""Keeps track of the stream_id of the latest change in rooms.
Given a list of rooms and stream key, it will give a subset of rooms that
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = None
@defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key):
"""Returns subset of room ids that have had new receipts since the
given key. If the key is too old it will just return the given list.
"""
if key > (yield self._get_earliest_key(store)):
keys = self._cache.keys()
i = keys.bisect_right(key)
result = set(
self._cache[k] for k in keys[i:]
).intersection(room_ids)
else:
result = room_ids
defer.returnValue(result)
@defer.inlineCallbacks
def room_has_changed(self, store, room_id, key):
"""Informs the cache that the room has been changed at the given key.
"""
if key > (yield self._get_earliest_key(store)):
old_key = self._room_to_key.get(room_id, None)
if old_key:
key = max(key, old_key)
self._cache.pop(old_key, None)
self._cache[key] = room_id
while len(self._cache) > self._size_of_cache:
k, r = self._cache.popitem()
self._earliest_key = max(k, self._earliest_key)
self._room_to_key.pop(r, None)
@defer.inlineCallbacks
def _get_earliest_key(self, store):
if self._earliest_key is None:
self._earliest_key = yield store.get_max_receipt_stream_id()
self._earliest_key = int(self._earliest_key)
defer.returnValue(self._earliest_key)

View File

@ -131,7 +131,7 @@ class RegistrationStore(SQLBaseStore):
user_id user_id
) )
for r in rows: for r in rows:
self.get_user_by_token.invalidate(r) self.get_user_by_token.invalidate((r,))
@cached() @cached()
def get_user_by_token(self, token): def get_user_by_token(self, token):

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
import collections import collections
import logging import logging
@ -186,8 +186,7 @@ class RoomStore(SQLBaseStore):
} }
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):
sql = ( sql = (

View File

@ -35,38 +35,28 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def _store_room_member_txn(self, txn, event): def _store_room_members_txn(self, txn, events):
"""Store a room member in the database. """Store a room member in the database.
""" """
try: self._simple_insert_many_txn(
target_user_id = event.state_key
except:
logger.exception(
"Failed to parse target_user_id=%s", target_user_id
)
raise
logger.debug(
"_store_room_member_txn: target_user_id=%s, membership=%s",
target_user_id,
event.membership,
)
self._simple_insert_txn(
txn, txn,
"room_memberships", table="room_memberships",
values=[
{ {
"event_id": event.event_id, "event_id": event.event_id,
"user_id": target_user_id, "user_id": event.state_key,
"sender": event.user_id, "sender": event.user_id,
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
} }
for event in events
]
) )
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) for event in events:
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -88,7 +78,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None lambda events: events[0] if events else None
) )
@cached() @cached(max_entries=5000)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -164,7 +154,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
@cached() @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_joined_hosts_for_room", "get_joined_hosts_for_room",

View File

@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def run_upgrade(cur): def run_upgrade(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex") cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall(): for row in cur.fetchall():
try: try:

View File

@ -0,0 +1 @@
SELECT 1;

View File

@ -0,0 +1,76 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main purpose of this upgrade is to change the unique key on the
pushers table again (it was missed when the v16 full schema was
made) but this also changes the pushkey and data columns to text.
When selecting a bytea column into a text column, postgres inserts
the hex encoded data, and there's no portable way of getting the
UTF-8 bytes, so we have to do it in Python.
"""
import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
access_token BIGINT DEFAULT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey TEXT NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data TEXT,
last_token TEXT,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
""")
cur.execute("""SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
""")
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
cur.execute(database_engine.convert_param_style("""
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
row
)
count += 1
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)

View File

@ -0,0 +1,34 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
user_id TEXT NOT NULL, -- The user these keys are for.
device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
);
CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
user_id TEXT NOT NULL, -- The user this one-time key is for.
device_id TEXT NOT NULL, -- The device this one-time key is for.
algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
key_json TEXT NOT NULL, -- The key as a JSON blob.
CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
);

View File

@ -0,0 +1,38 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS receipts_graph(
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_ids TEXT NOT NULL,
data TEXT NOT NULL,
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
);
CREATE TABLE IF NOT EXISTS receipts_linearized (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
data TEXT NOT NULL,
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
);
CREATE INDEX receipts_linearized_id ON receipts_linearized(
stream_id
);

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from _base import SQLBaseStore from _base import SQLBaseStore
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hashes_txn(self, txn, events):
hash_bytes):
"""Store a hash for a PDU """Store a hash for a PDU
Args: Args:
txn (cursor): txn (cursor):
event_id (str): Id for the Event. events (list): list of Events.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
""" """
self._simple_insert_txn(
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
vals.append({
"event_id": event.event_id,
"algorithm": ref_alg,
"hash": buffer(ref_hash_bytes),
})
self._simple_insert_many_txn(
txn, txn,
"event_reference_hashes", table="event_reference_hashes",
{ values=vals,
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
) )
def _get_event_signatures_txn(self, txn, event_id): def _get_event_signatures_txn(self, txn, event_id):

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore, cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -81,31 +81,41 @@ class StateStore(SQLBaseStore):
f, f,
) )
@defer.inlineCallbacks state_list = yield defer.gatherResults(
def c(vals):
vals[:] = yield self._get_events(vals, get_prev_content=False)
yield defer.gatherResults(
[ [
c(vals) self._fetch_events_for_group(group, vals)
for vals in states.values() for group, vals in states.items()
], ],
consumeErrors=True, consumeErrors=True,
) )
defer.returnValue(states) defer.returnValue(dict(state_list))
def _fetch_events_for_group(self, key, events):
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (key, evs)
)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
return self._store_mult_state_groups_txn(txn, [(event, context)])
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if context.current_state is None: if context.current_state is None:
return continue
if context.state_group is not None:
state_groups[event.event_id] = context.state_group
continue
state_events = dict(context.current_state) state_events = dict(context.current_state)
if event.is_state(): if event.is_state():
state_events[(event.type, event.state_key)] = event state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._state_groups_id_gen.get_next_txn(txn) state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -131,14 +141,19 @@ class StateStore(SQLBaseStore):
for state in state_events.values() for state in state_events.values()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
values={ values=[
"state_group": state_group, {
"state_group": state_groups[event.event_id],
"event_id": event.event_id, "event_id": event.event_id,
}, }
for event, context in events_and_contexts
if context.current_state is not None
],
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -173,8 +188,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3) @cachedInlineCallbacks(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key): def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn): def f(txn):
sql = ( sql = (
@ -190,6 +204,65 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks
def get_state_for_events(self, room_id, event_ids):
def f(txn):
groups = set()
event_to_group = {}
for event_id in event_ids:
# TODO: Remove this loop.
group = self._simple_select_one_onecol_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
)
if group:
event_to_group[event_id] = group
groups.add(group)
group_to_state_ids = {}
for group in groups:
state_ids = self._simple_select_onecol_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": group},
retcol="event_id",
)
group_to_state_ids[group] = state_ids
return event_to_group, group_to_state_ids
res = yield self.runInteraction(
"annotate_events_with_state_groups",
f,
)
event_to_group, group_to_state_ids = res
state_list = yield defer.gatherResults(
[
self._fetch_events_for_group(group, vals)
for group, vals in group_to_state_ids.items()
],
consumeErrors=True,
)
state_dict = {
group: {
(ev.type, ev.state_key): ev
for ev in state
}
for group, state in state_list
}
defer.returnValue([
state_dict.get(event_to_group.get(event, None), None)
for event in event_ids
])
def _make_group_id(clock): def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5) return str(int(clock.time_msec())) + random_string(5)

View File

@ -72,7 +72,10 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self): def __init__(self, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = None self._current_max = None
@ -107,6 +110,37 @@ class StreamIdGenerator(object):
defer.returnValue(manager()) defer.returnValue(manager())
@defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_ids
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
@defer.inlineCallbacks @defer.inlineCallbacks
def get_max_token(self, store): def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
@ -126,7 +160,7 @@ class StreamIdGenerator(object):
def _get_or_compute_current_max(self, txn): def _get_or_compute_current_max(self, txn):
with self._lock: with self._lock:
txn.execute("SELECT MAX(stream_ordering) FROM events") txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall() rows = txn.fetchall()
val, = rows[0] val, = rows[0]

View File

@ -20,6 +20,7 @@ from synapse.types import StreamToken
from synapse.handlers.presence import PresenceEventSource from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object): class NullSource(object):
@ -43,6 +44,7 @@ class EventSources(object):
"room": RoomEventSource, "room": RoomEventSource,
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -62,7 +64,10 @@ class EventSources(object):
), ),
typing_key=( typing_key=(
yield self.sources["typing"].get_current_key() yield self.sources["typing"].get_current_key()
) ),
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
class StreamToken( class StreamToken(
namedtuple( namedtuple(
"Token", "Token",
("room_key", "presence_key", "typing_key") ("room_key", "presence_key", "typing_key", "receipt_key")
) )
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -109,6 +109,9 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
if len(keys) == len(cls._fields) - 1:
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys) return cls(*keys)
except: except:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
@ -131,6 +134,7 @@ class StreamToken(
(other_token.room_stream_id < self.room_stream_id) (other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key)) or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key)) or (int(other_token.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):
@ -174,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-", "topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = []
@ -207,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View File

@ -91,8 +91,12 @@ class Clock(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
return reactor.callLater(delay, wrapped_callback, *args, **kwargs) return reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer): def cancel_call_later(self, timer, ignore_errs=False):
try:
timer.cancel() timer.cancel()
except:
if not ignore_errs:
raise
def time_bound_deferred(self, given_deferred, time_out): def time_bound_deferred(self, given_deferred, time_out):
if given_deferred.called: if given_deferred.called:

View File

@ -38,6 +38,9 @@ class ObservableDeferred(object):
deferred. deferred.
If consumeErrors is true errors will be captured from the origin deferred. If consumeErrors is true errors will be captured from the origin deferred.
Cancelling or otherwise resolving an observer will not affect the original
ObservableDeferred.
""" """
__slots__ = ["_deferred", "_observers", "_result"] __slots__ = ["_deferred", "_observers", "_result"]
@ -45,10 +48,10 @@ class ObservableDeferred(object):
def __init__(self, deferred, consumeErrors=False): def __init__(self, deferred, consumeErrors=False):
object.__setattr__(self, "_deferred", deferred) object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", []) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
self._observers.pop().callback(r) self._observers.pop().callback(r)
@ -57,7 +60,7 @@ class ObservableDeferred(object):
return r return r
def errback(f): def errback(f):
self._result = (False, f) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
self._observers.pop().errback(f) self._observers.pop().errback(f)
@ -74,14 +77,28 @@ class ObservableDeferred(object):
def observe(self): def observe(self):
if not self._result: if not self._result:
d = defer.Deferred() d = defer.Deferred()
self._observers.append(d)
def remove(r):
self._observers.discard(d)
return r
d.addBoth(remove)
self._observers.add(d)
return d return d
else: else:
success, res = self._result success, res = self._result
return defer.succeed(res) if success else defer.fail(res) return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self._deferred, name) return getattr(self._deferred, name)
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View File

@ -140,6 +140,37 @@ class PreserveLoggingContext(object):
) )
class _PreservingContextDeferred(defer.Deferred):
"""A deferred that ensures that all callbacks and errbacks are called with
the given logging context.
"""
def __init__(self, context):
self._log_context = context
defer.Deferred.__init__(self)
def addCallbacks(self, callback, errback=None,
callbackArgs=None, callbackKeywords=None,
errbackArgs=None, errbackKeywords=None):
callback = self._wrap_callback(callback)
errback = self._wrap_callback(errback)
return defer.Deferred.addCallbacks(
self, callback,
errback=errback,
callbackArgs=callbackArgs,
callbackKeywords=callbackKeywords,
errbackArgs=errbackArgs,
errbackKeywords=errbackKeywords,
)
def _wrap_callback(self, f):
def g(res, *args, **kwargs):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = self._log_context
res = f(res, *args, **kwargs)
return res
return g
def preserve_context_over_fn(fn, *args, **kwargs): def preserve_context_over_fn(fn, *args, **kwargs):
"""Takes a function and invokes it with the given arguments, but removes """Takes a function and invokes it with the given arguments, but removes
and restores the current logging context while doing so. and restores the current logging context while doing so.
@ -160,24 +191,7 @@ def preserve_context_over_deferred(deferred):
"""Given a deferred wrap it such that any callbacks added later to it will """Given a deferred wrap it such that any callbacks added later to it will
be invoked with the current context. be invoked with the current context.
""" """
d = defer.Deferred()
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
d = _PreservingContextDeferred(current_context)
def cb(res): deferred.chainDeferred(d)
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.callback(res)
return res
def eb(failure):
with PreserveLoggingContext():
LoggingContext.thread_local.current_context = current_context
res = d.errback(failure)
return res
if deferred.called:
return deferred
deferred.addCallbacks(cb, eb)
return d return d

View File

@ -33,3 +33,12 @@ def random_string_with_symbols(length):
return ''.join( return ''.join(
random.choice(_string_with_symbols) for _ in xrange(length) random.choice(_string_with_symbols) for _ in xrange(length)
) )
def is_ascii(s):
try:
s.encode("ascii")
except UnicodeDecodeError:
return False
else:
return True

View File

@ -57,6 +57,49 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, event interested_service, event
) )
@defer.inlineCallbacks
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value=None)
event = Mock(
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.mock_as_api.query_user.assert_called_once_with(
services[0], user_id
)
@defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services[0].is_interested_in_user = Mock(return_value=True)
self.mock_store.get_app_services = Mock(return_value=services)
self.mock_store.get_user_by_id = Mock(return_value={
"name": user_id
})
event = Mock(
sender=user_id,
type="m.room.message",
room_id="!foo:bar"
)
self.mock_as_api.push = Mock()
self.mock_as_api.query_user = Mock()
yield self.handler.notify_interested_services(event)
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been."
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_query_room_alias_exists(self): def test_query_room_alias_exists(self):
room_alias_str = "#foo:bar" room_alias_str = "#foo:bar"

View File

@ -100,7 +100,7 @@ class FederationTestCase(unittest.TestCase):
return defer.succeed({}) return defer.succeed({})
self.datastore.have_events.side_effect = have_events self.datastore.have_events.side_effect = have_events
def annotate(ev, old_state=None): def annotate(ev, old_state=None, outlier=False):
context = Mock() context = Mock()
context.current_state = {} context.current_state = {}
context.auth_events = {} context.auth_events = {}
@ -120,7 +120,7 @@ class FederationTestCase(unittest.TestCase):
) )
self.state_handler.compute_event_context.assert_called_once_with( self.state_handler.compute_event_context.assert_called_once_with(
ANY, old_state=None, ANY, old_state=None, outlier=False
) )
self.auth.check.assert_called_once_with(ANY, auth_events={}) self.auth.check.assert_called_once_with(ANY, auth_events={})

View File

@ -42,6 +42,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
"get_room", "get_room",
"store_room", "store_room",
"get_latest_events_in_room", "get_latest_events_in_room",
"add_event_hashes",
]), ]),
resource_for_federation=NonCallableMock(), resource_for_federation=NonCallableMock(),
http_client=NonCallableMock(spec_set=[]), http_client=NonCallableMock(spec_set=[]),
@ -88,6 +89,7 @@ class RoomMemberHandlerTestCase(unittest.TestCase):
self.ratelimiter.send_message.return_value = (True, 0) self.ratelimiter.send_message.return_value = (True, 0)
self.datastore.persist_event.return_value = (1,1) self.datastore.persist_event.return_value = (1,1)
self.datastore.add_event_hashes.return_value = []
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite(self): def test_invite(self):

View File

@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.mock_federation_resource = MockHttpResource() self.mock_federation_resource = MockHttpResource()
mock_notifier = Mock(spec=["on_new_user_event"]) mock_notifier = Mock(spec=["on_new_event"])
self.on_new_user_event = mock_notifier.on_new_user_event self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[]) self.auth = Mock(spec=[])
@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=20000, timeout=20000,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
) )
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
room_id=self.room_id, room_id=self.room_id,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000, timeout=10000,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 1, rooms=[self.room_id]), call('typing_key', 1, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.clock.advance_time(11) self.clock.advance_time(11)
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]), call('typing_key', 2, rooms=[self.room_id]),
]) ])
@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
timeout=10000, timeout=10000,
) )
self.on_new_user_event.assert_has_calls([ self.on_new_event.assert_has_calls([
call('typing_key', 3, rooms=[self.room_id]), call('typing_key', 3, rooms=[self.room_id]),
]) ])
self.on_new_user_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)

View File

@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase):
) )
self.assertEquals(200, code, msg=str(response)) self.assertEquals(200, code, msg=str(response))
self.assertEquals(0, len(response["chunk"])) # We may get a presence event for ourselves down
self.assertEquals(
0,
len([
c for c in response["chunk"]
if not (
c.get("type") == "m.presence"
and c["content"].get("user_id") == self.user_id
)
])
)
# joined room (expect all content for room) # joined room (expect all content for room)
yield self.join(room=room_id, user=self.user_id, tok=self.token) yield self.join(room=room_id, user=self.user_id, tok=self.token)

View File

@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# all be ours # all be ours
# I'll already get my own presence state change # I'll already get my own presence state change
self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []}, self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []},
response response
) )
@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
"/events?from=s0_1_0&timeout=0", None) "/events?from=s0_1_0&timeout=0", None)
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [ self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [
{"type": "m.presence", {"type": "m.presence",
"content": { "content": {
"user_id": "@banana:test", "user_id": "@banana:test",

View File

@ -0,0 +1,134 @@
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.api.errors import SynapseError
from twisted.internet import defer
from mock import Mock, MagicMock
from tests import unittest
import json
class RegisterRestServletTestCase(unittest.TestCase):
def setUp(self):
# do the dance to hook up request data to self.request_data
self.request_data = ""
self.request = Mock(
content=Mock(read=Mock(side_effect=lambda: self.request_data)),
)
self.request.args = {}
self.appservice = None
self.auth = Mock(get_appservice_by_req=Mock(
side_effect=lambda x: defer.succeed(self.appservice))
)
self.auth_result = (False, None, None)
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x,y,z: self.auth_result)
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
# do the dance to hook it up to the hs global
self.handlers = Mock(
auth_handler=self.auth_handler,
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler
)
self.hs = Mock()
self.hs.hostname = "superbig~testing~thing.com"
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.config.disable_registration = False
# init the thing we're testing
self.servlet = RegisterRestServlet(self.hs)
@defer.inlineCallbacks
def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = {
"id": "1234"
}
self.registration_handler.appservice_register = Mock(
return_value=(user_id, token)
)
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (200, {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}))
@defer.inlineCallbacks
def test_POST_appservice_registration_invalid(self):
self.request.args = {
"access_token": "i_am_an_app_service"
}
self.request_data = json.dumps({
"username": "kermit"
})
self.appservice = None # no application service exists
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (401, None))
def test_POST_bad_password(self):
self.request_data = json.dumps({
"username": "kermit",
"password": 666
})
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)
def test_POST_bad_username(self):
self.request_data = json.dumps({
"username": 777,
"password": "monkey"
})
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)
@defer.inlineCallbacks
def test_POST_user_valid(self):
user_id = "@kermit:muppet"
token = "kermits_access_token"
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
self.registration_handler.register = Mock(return_value=(user_id, token))
result = yield self.servlet.on_POST(self.request)
self.assertEquals(result, (200, {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname
}))
def test_POST_disabled_registration(self):
self.hs.config.disable_registration = True
self.request_data = json.dumps({
"username": "kermit",
"password": "monkey"
})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (True, None, {
"username": "kermit",
"password": "monkey"
})
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
d = self.servlet.on_POST(self.request)
return self.assertFailure(d, SynapseError)

View File

@ -17,6 +17,8 @@
from tests import unittest from tests import unittest
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import ObservableDeferred
from synapse.storage._base import Cache, cached from synapse.storage._base import Cache, cached
@ -40,12 +42,12 @@ class CacheTestCase(unittest.TestCase):
self.assertEquals(self.cache.get("foo"), 123) self.assertEquals(self.cache.get("foo"), 123)
def test_invalidate(self): def test_invalidate(self):
self.cache.prefill("foo", 123) self.cache.prefill(("foo",), 123)
self.cache.invalidate("foo") self.cache.invalidate(("foo",))
failed = False failed = False
try: try:
self.cache.get("foo") self.cache.get(("foo",))
except KeyError: except KeyError:
failed = True failed = True
@ -96,87 +98,102 @@ class CacheDecoratorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_passthrough(self): def test_passthrough(self):
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
return key return key
self.assertEquals((yield func(self, "foo")), "foo") a = A()
self.assertEquals((yield func(self, "bar")), "bar")
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_hit(self): def test_hit(self):
callcount = [0] callcount = [0]
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
yield func(self, "foo") a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
self.assertEquals((yield func(self, "foo")), "foo") self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invalidate(self): def test_invalidate(self):
callcount = [0] callcount = [0]
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
yield func(self, "foo") a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1) self.assertEquals(callcount[0], 1)
func.invalidate("foo") a.func.invalidate(("foo",))
yield func(self, "foo") yield a.func("foo")
self.assertEquals(callcount[0], 2) self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self): def test_invalidate_missing(self):
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
return key return key
func.invalidate("what") A().func.invalidate(("what",))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_max_entries(self): def test_max_entries(self):
callcount = [0] callcount = [0]
class A(object):
@cached(max_entries=10) @cached(max_entries=10)
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return key
for k in range(0,12): a = A()
yield func(self, k)
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12) self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate # There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times # all 12 values again, we must get called at least 2 more times
for k in range(0,12): for k in range(0,12):
yield func(self, k) yield a.func(k)
self.assertTrue(callcount[0] >= 14, self.assertTrue(callcount[0] >= 14,
msg="Expected callcount >= 14, got %d" % (callcount[0])) msg="Expected callcount >= 14, got %d" % (callcount[0]))
@defer.inlineCallbacks
def test_prefill(self): def test_prefill(self):
callcount = [0] callcount = [0]
d = defer.succeed(123)
class A(object):
@cached() @cached()
def func(self, key): def func(self, key):
callcount[0] += 1 callcount[0] += 1
return key return d
func.prefill("foo", 123) a = A()
self.assertEquals((yield func(self, "foo")), 123) a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0) self.assertEquals(callcount[0], 0)

View File

@ -46,7 +46,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
(yield self.store.get_user_by_id(self.user_id)) (yield self.store.get_user_by_id(self.user_id))
) )
result = yield self.store.get_user_by_token(self.tokens[1]) result = yield self.store.get_user_by_token(self.tokens[0])
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {

View File

@ -73,8 +73,8 @@ class DistributorTestCase(unittest.TestCase):
yield d yield d
self.assertTrue(d.called) self.assertTrue(d.called)
observers[0].assert_called_once("Go") observers[0].assert_called_once_with("Go")
observers[1].assert_called_once("Go") observers[1].assert_called_once_with("Go")
self.assertEquals(mock_logger.warning.call_count, 1) self.assertEquals(mock_logger.warning.call_count, 1)
self.assertIsInstance(mock_logger.warning.call_args[0][0], self.assertIsInstance(mock_logger.warning.call_args[0][0],

View File

@ -114,6 +114,8 @@ class MockHttpResource(HttpServer):
mock_request.method = http_method mock_request.method = http_method
mock_request.uri = path mock_request.uri = path
mock_request.getClientIP.return_value = "-"
mock_request.requestHeaders.getRawHeaders.return_value=[ mock_request.requestHeaders.getRawHeaders.return_value=[
"X-Matrix origin=test,key=,sig=" "X-Matrix origin=test,key=,sig="
] ]