Merge remote-tracking branch 'origin/develop' into dbkr/notifications_api

This commit is contained in:
David Baker 2016-08-11 14:09:13 +01:00
commit b4ecf0b886
208 changed files with 9809 additions and 3566 deletions

View File

@ -1,3 +1,297 @@
Changes in synapse v0.17.0 (2016-08-08)
=======================================
This release contains significant security bug fixes regarding authenticating
events received over federation. PLEASE UPGRADE.
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Changes:
* Add federation /version API (PR #990)
* Make psutil dependency optional (PR #992)
Bug fixes:
* Fix URL preview API to exclude HTML comments in description (PR #988)
* Fix error handling of remote joins (PR #991)
Changes in synapse v0.17.0-rc4 (2016-08-05)
===========================================
Changes:
* Change the way we summarize URLs when previewing (PR #973)
* Add new ``/state_ids/`` federation API (PR #979)
* Speed up processing of ``/state/`` response (PR #986)
Bug fixes:
* Fix event persistence when event has already been partially persisted
(PR #975, #983, #985)
* Fix port script to also copy across backfilled events (PR #982)
Changes in synapse v0.17.0-rc3 (2016-08-02)
===========================================
Changes:
* Forbid non-ASes from registering users whose names begin with '_' (PR #958)
* Add some basic admin API docs (PR #963)
Bug fixes:
* Send the correct host header when fetching keys (PR #941)
* Fix joining a room that has missing auth events (PR #964)
* Fix various push bugs (PR #966, #970)
* Fix adding emails on registration (PR #968)
Changes in synapse v0.17.0-rc2 (2016-08-02)
===========================================
(This release did not include the changes advertised and was identical to RC1)
Changes in synapse v0.17.0-rc1 (2016-07-28)
===========================================
This release changes the LDAP configuration format in a backwards incompatible
way, see PR #843 for details.
Features:
* Add purge_media_cache admin API (PR #902)
* Add deactivate account admin API (PR #903)
* Add optional pepper to password hashing (PR #907, #910 by KentShikama)
* Add an admin option to shared secret registration (breaks backwards compat)
(PR #909)
* Add purge local room history API (PR #911, #923, #924)
* Add requestToken endpoints (PR #915)
* Add an /account/deactivate endpoint (PR #921)
* Add filter param to /messages. Add 'contains_url' to filter. (PR #922)
* Add device_id support to /login (PR #929)
* Add device_id support to /v2/register flow. (PR #937, #942)
* Add GET /devices endpoint (PR #939, #944)
* Add GET /device/{deviceId} (PR #943)
* Add update and delete APIs for devices (PR #949)
Changes:
* Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt)
* Linearize some federation endpoints based on (origin, room_id) (PR #879)
* Remove the legacy v0 content upload API. (PR #888)
* Use similar naming we use in email notifs for push (PR #894)
* Optionally include password hash in createUser endpoint (PR #905 by
KentShikama)
* Use a query that postgresql optimises better for get_events_around (PR #906)
* Fall back to 'username' if 'user' is not given for appservice registration.
(PR #927 by Half-Shot)
* Add metrics for psutil derived memory usage (PR #936)
* Record device_id in client_ips (PR #938)
* Send the correct host header when fetching keys (PR #941)
* Log the hostname the reCAPTCHA was completed on (PR #946)
* Make the device id on e2e key upload optional (PR #956)
* Add r0.2.0 to the "supported versions" list (PR #960)
* Don't include name of room for invites in push (PR #961)
Bug fixes:
* Fix substitution failure in mail template (PR #887)
* Put most recent 20 messages in email notif (PR #892)
* Ensure that the guest user is in the database when upgrading accounts
(PR #914)
* Fix various edge cases in auth handling (PR #919)
* Fix 500 ISE when sending alias event without a state_key (PR #925)
* Fix bug where we stored rejections in the state_group, persist all
rejections (PR #948)
* Fix lack of check of if the user is banned when handling 3pid invites
(PR #952)
* Fix a couple of bugs in the transaction and keyring code (PR #954, #955)
Changes in synapse v0.16.1-r1 (2016-07-08)
==========================================
THIS IS A CRITICAL SECURITY UPDATE.
This fixes a bug which allowed users' accounts to be accessed by unauthorised
users.
Changes in synapse v0.16.1 (2016-06-20)
=======================================
Bug fixes:
* Fix assorted bugs in ``/preview_url`` (PR #872)
* Fix TypeError when setting unicode passwords (PR #873)
Performance improvements:
* Turn ``use_frozen_events`` off by default (PR #877)
* Disable responding with canonical json for federation (PR #878)
Changes in synapse v0.16.1-rc1 (2016-06-15)
===========================================
Features: None
Changes:
* Log requester for ``/publicRoom`` endpoints when possible (PR #856)
* 502 on ``/thumbnail`` when can't connect to remote server (PR #862)
* Linearize fetching of gaps on incoming events (PR #871)
Bugs fixes:
* Fix bug where rooms where marked as published by default (PR #857)
* Fix bug where joining room with an event with invalid sender (PR #868)
* Fix bug where backfilled events were sent down sync streams (PR #869)
* Fix bug where outgoing connections could wedge indefinitely, causing push
notifications to be unreliable (PR #870)
Performance improvements:
* Improve ``/publicRooms`` performance(PR #859)
Changes in synapse v0.16.0 (2016-06-09)
=======================================
NB: As of v0.14 all AS config files must have an ID field.
Bug fixes:
* Don't make rooms published by default (PR #857)
Changes in synapse v0.16.0-rc2 (2016-06-08)
===========================================
Features:
* Add configuration option for tuning GC via ``gc.set_threshold`` (PR #849)
Changes:
* Record metrics about GC (PR #771, #847, #852)
* Add metric counter for number of persisted events (PR #841)
Bug fixes:
* Fix 'From' header in email notifications (PR #843)
* Fix presence where timeouts were not being fired for the first 8h after
restarts (PR #842)
* Fix bug where synapse sent malformed transactions to AS's when retrying
transactions (Commits 310197b, 8437906)
Performance improvements:
* Remove event fetching from DB threads (PR #835)
* Change the way we cache events (PR #836)
* Add events to cache when we persist them (PR #840)
Changes in synapse v0.16.0-rc1 (2016-06-03)
===========================================
Version 0.15 was not released. See v0.15.0-rc1 below for additional changes.
Features:
* Add email notifications for missed messages (PR #759, #786, #799, #810, #815,
#821)
* Add a ``url_preview_ip_range_whitelist`` config param (PR #760)
* Add /report endpoint (PR #762)
* Add basic ignore user API (PR #763)
* Add an openidish mechanism for proving that you own a given user_id (PR #765)
* Allow clients to specify a server_name to avoid 'No known servers' (PR #794)
* Add secondary_directory_servers option to fetch room list from other servers
(PR #808, #813)
Changes:
* Report per request metrics for all of the things using request_handler (PR
#756)
* Correctly handle ``NULL`` password hashes from the database (PR #775)
* Allow receipts for events we haven't seen in the db (PR #784)
* Make synctl read a cache factor from config file (PR #785)
* Increment badge count per missed convo, not per msg (PR #793)
* Special case m.room.third_party_invite event auth to match invites (PR #814)
Bug fixes:
* Fix typo in event_auth servlet path (PR #757)
* Fix password reset (PR #758)
Performance improvements:
* Reduce database inserts when sending transactions (PR #767)
* Queue events by room for persistence (PR #768)
* Add cache to ``get_user_by_id`` (PR #772)
* Add and use ``get_domain_from_id`` (PR #773)
* Use tree cache for ``get_linearized_receipts_for_room`` (PR #779)
* Remove unused indices (PR #782)
* Add caches to ``bulk_get_push_rules*`` (PR #804)
* Cache ``get_event_reference_hashes`` (PR #806)
* Add ``get_users_with_read_receipts_in_room`` cache (PR #809)
* Use state to calculate ``get_users_in_room`` (PR #811)
* Load push rules in storage layer so that they get cached (PR #825)
* Make ``get_joined_hosts_for_room`` use get_users_in_room (PR #828)
* Poke notifier on next reactor tick (PR #829)
* Change CacheMetrics to be quicker (PR #830)
Changes in synapse v0.15.0-rc1 (2016-04-26)
===========================================
Features:
* Add login support for Javascript Web Tokens, thanks to Niklas Riekenbrauck
(PR #671,#687)
* Add URL previewing support (PR #688)
* Add login support for LDAP, thanks to Christoph Witzany (PR #701)
* Add GET endpoint for pushers (PR #716)
Changes:
* Never notify for member events (PR #667)
* Deduplicate identical ``/sync`` requests (PR #668)
* Require user to have left room to forget room (PR #673)
* Use DNS cache if within TTL (PR #677)
* Let users see their own leave events (PR #699)
* Deduplicate membership changes (PR #700)
* Increase performance of pusher code (PR #705)
* Respond with error status 504 if failed to talk to remote server (PR #731)
* Increase search performance on postgres (PR #745)
Bug fixes:
* Fix bug where disabling all notifications still resulted in push (PR #678)
* Fix bug where users couldn't reject remote invites if remote refused (PR #691)
* Fix bug where synapse attempted to backfill from itself (PR #693)
* Fix bug where profile information was not correctly added when joining remote
rooms (PR #703)
* Fix bug where register API required incorrect key name for AS registration
(PR #727)
Changes in synapse v0.14.0 (2016-03-30) Changes in synapse v0.14.0 (2016-03-30)
======================================= =======================================
@ -511,7 +805,7 @@ Configuration:
* Add support for changing the bind host of the metrics listener via the * Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option. ``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21) Changes in synapse v0.9.0-r5 (2015-05-21)
========================================= =========================================
@ -853,7 +1147,7 @@ See UPGRADE for information about changes to the client server API, including
breaking backwards compatibility with VoIP calls and registration API. breaking backwards compatibility with VoIP calls and registration API.
Homeserver: Homeserver:
* When a user changes their displayname or avatar the server will now update * When a user changes their displayname or avatar the server will now update
all their join states to reflect this. all their join states to reflect this.
* The server now adds "age" key to events to indicate how old they are. This * The server now adds "age" key to events to indicate how old they are. This
is clock independent, so at no point does any server or webclient have to is clock independent, so at no point does any server or webclient have to
@ -911,7 +1205,7 @@ Changes in synapse 0.2.2 (2014-09-06)
===================================== =====================================
Homeserver: Homeserver:
* When the server returns state events it now also includes the previous * When the server returns state events it now also includes the previous
content. content.
* Add support for inviting people when creating a new room. * Add support for inviting people when creating a new room.
* Make the homeserver inform the room via `m.room.aliases` when a new alias * Make the homeserver inform the room via `m.room.aliases` when a new alias
@ -923,7 +1217,7 @@ Webclient:
* Handle `m.room.aliases` events. * Handle `m.room.aliases` events.
* Asynchronously send messages and show a local echo. * Asynchronously send messages and show a local echo.
* Inform the UI when a message failed to send. * Inform the UI when a message failed to send.
* Only autoscroll on receiving a new message if the user was already at the * Only autoscroll on receiving a new message if the user was already at the
bottom of the screen. bottom of the screen.
* Add support for ban/kick reasons. * Add support for ban/kick reasons.

View File

@ -14,6 +14,7 @@ recursive-include docs *
recursive-include res * recursive-include res *
recursive-include scripts * recursive-include scripts *
recursive-include scripts-dev * recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py recursive-include tests *.py
recursive-include synapse/static *.css recursive-include synapse/static *.css
@ -23,5 +24,7 @@ recursive-include synapse/static *.js
exclude jenkins.sh exclude jenkins.sh
exclude jenkins*.sh exclude jenkins*.sh
exclude jenkins*
recursive-exclude jenkins *.sh
prune demo/etc prune demo/etc

View File

@ -11,8 +11,8 @@ VoIP. The basics you need to know to get up and running are:
like ``#matrix:matrix.org`` or ``#test:localhost:8448``. like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future - Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
you will normally refer to yourself and others using a 3PID: email you will normally refer to yourself and others using a third party identifier
address, phone number, etc rather than manipulating Matrix user IDs) (3PID): email address, phone number, etc rather than manipulating Matrix user IDs)
The overall architecture is:: The overall architecture is::
@ -58,12 +58,13 @@ the spec in the context of a codebase and let you run your own homeserver and
generally help bootstrap the ecosystem. generally help bootstrap the ecosystem.
In Matrix, every user runs one or more Matrix clients, which connect through to In Matrix, every user runs one or more Matrix clients, which connect through to
a Matrix homeserver which stores all their personal chat history and user a Matrix homeserver. The homeserver stores all their personal chat history and
account information - much as a mail client connects through to an IMAP/SMTP user account information - much as a mail client connects through to an
server. Just like email, you can either run your own Matrix homeserver and IMAP/SMTP server. Just like email, you can either run your own Matrix
control and own your own communications and history or use one hosted by homeserver and control and own your own communications and history or use one
someone else (e.g. matrix.org) - there is no single point of control or hosted by someone else (e.g. matrix.org) - there is no single point of control
mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts, etc. or mandatory service provider in Matrix, unlike WhatsApp, Facebook, Hangouts,
etc.
Synapse ships with two basic demo Matrix clients: webclient (a basic group chat Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
web client demo implemented in AngularJS) and cmdclient (a basic Python web client demo implemented in AngularJS) and cmdclient (a basic Python
@ -444,7 +445,7 @@ You have two choices here, which will influence the form of your Matrix user
IDs: IDs:
1) Use the machine's own hostname as available on public DNS in the form of 1) Use the machine's own hostname as available on public DNS in the form of
its A or AAAA records. This is easier to set up initially, perhaps for its A records. This is easier to set up initially, perhaps for
testing, but lacks the flexibility of SRV. testing, but lacks the flexibility of SRV.
2) Set up a SRV record for your domain name. This requires you create a SRV 2) Set up a SRV record for your domain name. This requires you create a SRV
@ -617,7 +618,7 @@ Building internal API documentation::
Halp!! Synapse eats all my RAM! Help!! Synapse eats all my RAM!
=============================== ===============================
Synapse's architecture is quite RAM hungry currently - we deliberately Synapse's architecture is quite RAM hungry currently - we deliberately

View File

@ -27,7 +27,7 @@ running:
# Pull the latest version of the master branch. # Pull the latest version of the master branch.
git pull git pull
# Update the versions of synapse's python dependencies. # Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install --upgrade
Upgrading to v0.15.0 Upgrading to v0.15.0

View File

@ -9,6 +9,7 @@ Description=Synapse Matrix homeserver
Type=simple Type=simple
User=synapse User=synapse
Group=synapse Group=synapse
EnvironmentFile=-/etc/sysconfig/synapse
WorkingDirectory=/var/lib/synapse WorkingDirectory=/var/lib/synapse
ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml

12
docs/admin_api/README.rst Normal file
View File

@ -0,0 +1,12 @@
Admin APIs
==========
This directory includes documentation for the various synapse specific admin
APIs available.
Only users that are server admins can use these APIs. A user can be marked as a
server admin by updating the database directly, e.g.:
``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'``
Restarting may be required for the changes to register.

View File

@ -0,0 +1,15 @@
Purge History API
=================
The purge history API allows server admins to purge historic events from their
database, reclaiming disk space.
Depending on the amount of history being purged a call to the API may take
several minutes or longer. During this period users will not be able to
paginate further back in the room from the point being purged from.
The API is simply:
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
including an ``access_token`` of a server admin.

View File

@ -0,0 +1,19 @@
Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
{
"before_ts": <unix_timestamp_in_ms>
}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.
If the user re-requests purged remote media, synapse will re-request the media
from the originating server.

View File

@ -32,5 +32,4 @@ The format of the AS configuration file is as follows:
See the spec_ for further details on how application services work. See the spec_ for further details on how application services work.
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api .. _spec: https://matrix.org/docs/spec/application_service/unstable.html

View File

@ -43,7 +43,10 @@ Basically, PEP8
together, or want to deliberately extend or preserve vertical/horizontal together, or want to deliberately extend or preserve vertical/horizontal
space) space)
Comments should follow the google code style. This is so that we can generate Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_.
documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/) This is so that we can generate documentation with
`sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the
`examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
in the sphinx documentation.
Code should pass pep8 --max-line-length=100 without any warnings. Code should pass pep8 --max-line-length=100 without any warnings.

View File

@ -9,31 +9,35 @@ the Home Server to generate credentials that are valid for use on the TURN
server through the use of a secret shared between the Home Server and the server through the use of a secret shared between the Home Server and the
TURN server. TURN server.
This document described how to install coturn This document describes how to install coturn
(https://code.google.com/p/coturn/) which also supports the TURN REST API, (https://github.com/coturn/coturn) which also supports the TURN REST API,
and integrate it with synapse. and integrate it with synapse.
coturn Setup coturn Setup
============ ============
You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process.
1. Check out coturn:: 1. Check out coturn::
svn checkout http://coturn.googlecode.com/svn/trunk/ coturn
git clone https://github.com/coturn/coturn.git coturn
cd coturn cd coturn
2. Configure it:: 2. Configure it::
./configure ./configure
You may need to install libevent2: if so, you should do so You may need to install ``libevent2``: if so, you should do so
in the way recommended by your operating system. in the way recommended by your operating system.
You can ignore warnings about lack of database support: a You can ignore warnings about lack of database support: a
database is unnecessary for this purpose. database is unnecessary for this purpose.
3. Build and install it:: 3. Build and install it::
make make
make install make install
4. Make a config file in /etc/turnserver.conf. You can customise 4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant
a config file from turnserver.conf.default. The relevant
lines, with example values, are:: lines, with example values, are::
lt-cred-mech lt-cred-mech
@ -41,7 +45,7 @@ coturn Setup
static-auth-secret=[your secret key here] static-auth-secret=[your secret key here]
realm=turn.myserver.org realm=turn.myserver.org
See turnserver.conf.default for explanations of the options. See turnserver.conf for explanations of the options.
One way to generate the static-auth-secret is with pwgen:: One way to generate the static-auth-secret is with pwgen::
pwgen -s 64 1 pwgen -s 64 1
@ -54,6 +58,7 @@ coturn Setup
import your private key and certificate. import your private key and certificate.
7. Start the turn server:: 7. Start the turn server::
bin/turnserver -o bin/turnserver -o

22
jenkins-dendron-postgres.sh Executable file
View File

@ -0,0 +1,22 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
./dendron/jenkins/build_dendron.sh
./sytest/jenkins/prep_sytest_for_postgres.sh
./sytest/jenkins/install_and_run.sh \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--federation-reader \

View File

@ -4,60 +4,14 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml ./jenkins/prepare_synapse.sh
export TRIAL_FLAGS="--reporter=subunit" ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log ./sytest/jenkins/prep_sytest_for_postgres.sh
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove" ./sytest/jenkins/install_and_run.sh \
--synapse-directory $WORKSPACE \
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8000}
./jenkins/prep_sytest_for_postgres.sh
echo >&2 "Running sytest with PostgreSQL";
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@ -4,54 +4,12 @@ set -eux
: ${WORKSPACE:="$(pwd)"} : ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1 export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml ./jenkins/prepare_synapse.sh
export TRIAL_FLAGS="--reporter=subunit" ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log ./sytest/jenkins/install_and_run.sh \
# Don't exit with non-0 status code on Jenkins, --synapse-directory $WORKSPACE \
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PORT_BASE:=8500}
./jenkins/install_and_run.sh --coverage \
--python $TOX_BIN/python \
--synapse-directory $WORKSPACE \
--port-base $PORT_BASE
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@ -22,4 +22,8 @@ export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished w
rm .coverage* || echo "No coverage files to remove" rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
tox -e py27 tox -e py27

44
jenkins/clone.sh Executable file
View File

@ -0,0 +1,44 @@
#! /bin/bash
# This clones a project from github into a named subdirectory
# If the project has a branch with the same name as this branch
# then it will checkout that branch after cloning.
# Otherwise it will checkout "origin/develop."
# The first argument is the name of the directory to checkout
# the branch into.
# The second argument is the URL of the remote repository to checkout.
# Usually something like https://github.com/matrix-org/sytest.git
set -eux
NAME=$1
PROJECT=$2
BASE=".$NAME-base"
# Update our mirror.
if [ ! -d ".$NAME-base" ]; then
# Create a local mirror of the source repository.
# This saves us from having to download the entire repository
# when this script is next run.
git clone "$PROJECT" "$BASE" --mirror
else
# Fetch any updates from the source repository.
(cd "$BASE"; git fetch -p)
fi
# Remove the existing repository so that we have a clean copy
rm -rf "$NAME"
# Cloning with --shared means that we will share portions of the
# .git directory with our local mirror.
git clone "$BASE" "$NAME" --shared
# Jenkins may have supplied us with the name of the branch in the
# environment. Otherwise we will have to guess based on the current
# commit.
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
cd "$NAME"
# check out the relevant branch
git checkout "${GIT_BRANCH}" || (
echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop"
git checkout "origin/develop"
)

19
jenkins/prepare_synapse.sh Executable file
View File

@ -0,0 +1,19 @@
#! /bin/bash
cd "`dirname $0`/.."
TOX_DIR=$WORKSPACE/.tox
mkdir -p $TOX_DIR
if ! [ $TOX_DIR -ef .tox ]; then
ln -s "$TOX_DIR" .tox
fi
# set up the virtualenv
tox -e py27 --notest -v
TOX_BIN=$TOX_DIR/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
$TOX_BIN/pip install psycopg2

View File

@ -145,6 +145,11 @@ pre, code {
text-decoration: none; text-decoration: none;
} }
.debug {
font-size: 10px;
color: #888;
}
.footer { .footer {
margin-top: 20px; margin-top: 20px;
text-align: center; text-align: center;

View File

@ -17,11 +17,15 @@
</td> </td>
<td class="message_contents"> <td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} {% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{{ message.sender_name }}</div> <div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
{% endif %} {% endif %}
<div class="message_body"> <div class="message_body">
{% if message.msgtype == "m.text" %} {% if message.msgtype == "m.text" %}
{{ message.body_text_html }} {{ message.body_text_html }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_html }}
{% elif message.msgtype == "m.image" %} {% elif message.msgtype == "m.image" %}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{% elif message.msgtype == "m.file" %} {% elif message.msgtype == "m.file" %}

View File

@ -1,7 +1,11 @@
{% for message in notif.messages %} {% for message in notif.messages %}
{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) {% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{% if message.msgtype == "m.text" %} {% if message.msgtype == "m.text" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %} {% elif message.msgtype == "m.image" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %} {% elif message.msgtype == "m.file" %}

View File

@ -30,18 +30,20 @@
{% include 'room.html' with context %} {% include 'room.html' with context %}
{% endfor %} {% endfor %}
<div class="footer"> <div class="footer">
<small> <a href="{{ unsubscribe_link }}">Unsubscribe</a>
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room '{{ reason.room_name }}' because:<br/> <br/>
1. An event was received at {{ reason.received_at|format_ts("%c") }} <br/>
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago.<br/> <div class="debug">
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %} {% if reason.last_sent_ts %}
2. The last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %} {% else %}
2. We can't remember the last time we sent a mail for this room. and we don't have a last time we sent a mail for this room.
{% endif %} {% endif %}
</small> </div>
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
</div> </div>
</td> </td>
<td> </td> <td> </td>

View File

@ -116,17 +116,19 @@ def get_json(origin_name, origin_key, destination, path):
authorization_headers = [] authorization_headers = []
for key, sig in signed_json["signatures"][origin_name].items(): for key, sig in signed_json["signatures"][origin_name].items():
authorization_headers.append(bytes( header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( origin_name, key, sig,
origin_name, key, sig, )
) authorization_headers.append(bytes(header))
)) sys.stderr.write(header)
sys.stderr.write("\n")
result = requests.get( result = requests.get(
lookup(destination, path), lookup(destination, path),
headers={"Authorization": authorization_headers[0]}, headers={"Authorization": authorization_headers[0]},
verify=False, verify=False,
) )
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json() return result.json()
@ -141,6 +143,7 @@ def main():
) )
json.dump(result, sys.stdout) json.dump(result, sys.stdout)
print ""
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,10 +1,16 @@
#!/usr/bin/env python #!/usr/bin/env python
import argparse import argparse
import sys
import bcrypt import bcrypt
import getpass import getpass
import yaml
bcrypt_rounds=12 bcrypt_rounds=12
password_pepper = ""
def prompt_for_pass(): def prompt_for_pass():
password = getpass.getpass("Password: ") password = getpass.getpass("Password: ")
@ -28,12 +34,22 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.",
)
args = parser.parse_args() args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds)
password_config = config.get("password_config", {})
password_pepper = password_config.get("pepper", password_pepper)
password = args.password password = args.password
if not password: if not password:
password = prompt_for_pass() password = prompt_for_pass()
print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds)) print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds))

View File

@ -25,18 +25,26 @@ import urllib2
import yaml import yaml
def request_registration(user, password, server_location, shared_secret): def request_registration(user, password, server_location, shared_secret, admin=False):
mac = hmac.new( mac = hmac.new(
key=shared_secret, key=shared_secret,
msg=user,
digestmod=hashlib.sha1, digestmod=hashlib.sha1,
).hexdigest() )
mac.update(user)
mac.update("\x00")
mac.update(password)
mac.update("\x00")
mac.update("admin" if admin else "notadmin")
mac = mac.hexdigest()
data = { data = {
"user": user, "user": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret", "type": "org.matrix.login.shared_secret",
"admin": admin,
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -68,7 +76,7 @@ def request_registration(user, password, server_location, shared_secret):
sys.exit(1) sys.exit(1)
def register_new_user(user, password, server_location, shared_secret): def register_new_user(user, password, server_location, shared_secret, admin):
if not user: if not user:
try: try:
default_user = getpass.getuser() default_user = getpass.getuser()
@ -99,7 +107,14 @@ def register_new_user(user, password, server_location, shared_secret):
print "Passwords do not match" print "Passwords do not match"
sys.exit(1) sys.exit(1)
request_registration(user, password, server_location, shared_secret) if not admin:
admin = raw_input("Make admin [no]: ")
if admin in ("y", "yes", "true"):
admin = True
else:
admin = False
request_registration(user, password, server_location, shared_secret, bool(admin))
if __name__ == "__main__": if __name__ == "__main__":
@ -119,6 +134,11 @@ if __name__ == "__main__":
default=None, default=None,
help="New password for user. Will prompt if omitted.", help="New password for user. Will prompt if omitted.",
) )
parser.add_argument(
"-a", "--admin",
action="store_true",
help="Register new user as an admin. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True) group = parser.add_mutually_exclusive_group(required=True)
group.add_argument( group.add_argument(
@ -151,4 +171,4 @@ if __name__ == "__main__":
else: else:
secret = args.shared_secret secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret) register_new_user(args.user, args.password, args.server_url, secret, args.admin)

View File

@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = { BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"], "events": ["processed", "outlier", "contains_url"],
"rooms": ["is_public"], "rooms": ["is_public"],
"event_edges": ["is_state"], "event_edges": ["is_state"],
"presence_list": ["accepted"], "presence_list": ["accepted"],
@ -92,8 +92,12 @@ class Store(object):
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
_simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"] _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
"_simple_select_one_onecol_txn"
]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@ -158,31 +162,40 @@ class Porter(object):
def setup_table(self, table): def setup_table(self, table):
if table in APPEND_ONLY_TABLES: if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting. # It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol( row = yield self.postgres_store._simple_select_one(
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
retcol="rowid", retcols=("forward_rowid", "backward_rowid"),
allow_none=True, allow_none=True,
) )
total_to_port = None total_to_port = None
if next_chunk is None: if row is None:
if table == "sent_transactions": if table == "sent_transactions":
next_chunk, already_ported, total_to_port = ( forward_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions() yield self._setup_sent_transactions()
) )
backward_chunk = 0
else: else:
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "rowid": 1} values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
next_chunk = 1 forward_chunk = 1
backward_chunk = 0
already_ported = 0 already_ported = 0
else:
forward_chunk = row["forward_rowid"]
backward_chunk = row["backward_rowid"]
if total_to_port is None: if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk table, forward_chunk, backward_chunk
) )
else: else:
def delete_all(txn): def delete_all(txn):
@ -196,46 +209,85 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": table, "rowid": 0} values={
"table_name": table,
"forward_rowid": 1,
"backward_rowid": 0,
}
) )
next_chunk = 1 forward_chunk = 1
backward_chunk = 0
already_ported, total_to_port = yield self._get_total_count_to_port( already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk table, forward_chunk, backward_chunk
) )
defer.returnValue((table, already_ported, total_to_port, next_chunk)) defer.returnValue(
(table, already_ported, total_to_port, forward_chunk, backward_chunk)
)
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk): def handle_table(self, table, postgres_size, table_size, forward_chunk,
backward_chunk):
if not table_size: if not table_size:
return return
self.progress.add_table(table, postgres_size, table_size) self.progress.add_table(table, postgres_size, table_size)
if table == "event_search": if table == "event_search":
yield self.handle_search_table(postgres_size, table_size, next_chunk) yield self.handle_search_table(
postgres_size, table_size, forward_chunk, backward_chunk
)
return return
select = ( forward_select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,) % (table,)
) )
backward_select = (
"SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
% (table,)
)
do_forward = [True]
do_backward = [True]
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (next_chunk, self.batch_size,)) forward_rows = []
rows = txn.fetchall() backward_rows = []
headers = [column[0] for column in txn.description] if do_forward[0]:
txn.execute(forward_select, (forward_chunk, self.batch_size,))
forward_rows = txn.fetchall()
if not forward_rows:
do_forward[0] = False
return headers, rows if do_backward[0]:
txn.execute(backward_select, (backward_chunk, self.batch_size,))
backward_rows = txn.fetchall()
if not backward_rows:
do_backward[0] = False
headers, rows = yield self.sqlite_store.runInteraction("select", r) if forward_rows or backward_rows:
headers = [column[0] for column in txn.description]
else:
headers = None
if rows: return headers, forward_rows, backward_rows
next_chunk = rows[-1][0] + 1
headers, frows, brows = yield self.sqlite_store.runInteraction(
"select", r
)
if frows or brows:
if frows:
forward_chunk = max(row[0] for row in frows) + 1
if brows:
backward_chunk = min(row[0] for row in brows) - 1
rows = frows + brows
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):
@ -247,7 +299,10 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": table}, keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk}, updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -259,7 +314,8 @@ class Porter(object):
return return
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_search_table(self, postgres_size, table_size, next_chunk): def handle_search_table(self, postgres_size, table_size, forward_chunk,
backward_chunk):
select = ( select = (
"SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
" FROM event_search as es" " FROM event_search as es"
@ -270,7 +326,7 @@ class Porter(object):
while True: while True:
def r(txn): def r(txn):
txn.execute(select, (next_chunk, self.batch_size,)) txn.execute(select, (forward_chunk, self.batch_size,))
rows = txn.fetchall() rows = txn.fetchall()
headers = [column[0] for column in txn.description] headers = [column[0] for column in txn.description]
@ -279,7 +335,7 @@ class Porter(object):
headers, rows = yield self.sqlite_store.runInteraction("select", r) headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows: if rows:
next_chunk = rows[-1][0] + 1 forward_chunk = rows[-1][0] + 1
# We have to treat event_search differently since it has a # We have to treat event_search differently since it has a
# different structure in the two different databases. # different structure in the two different databases.
@ -312,7 +368,10 @@ class Porter(object):
txn, txn,
table="port_from_sqlite3", table="port_from_sqlite3",
keyvalues={"table_name": "event_search"}, keyvalues={"table_name": "event_search"},
updatevalues={"rowid": next_chunk}, updatevalues={
"forward_rowid": forward_chunk,
"backward_rowid": backward_chunk,
},
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
@ -324,7 +383,6 @@ class Porter(object):
else: else:
return return
def setup_db(self, db_config, database_engine): def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect( db_conn = database_engine.module.connect(
**{ **{
@ -395,10 +453,32 @@ class Porter(object):
txn.execute( txn.execute(
"CREATE TABLE port_from_sqlite3 (" "CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE," " table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL" " forward_rowid bigint NOT NULL,"
" backward_rowid bigint NOT NULL"
")" ")"
) )
# The old port script created a table with just a "rowid" column.
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table(txn):
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" RENAME rowid TO forward_rowid"
)
txn.execute(
"ALTER TABLE IF EXISTS port_from_sqlite3"
" ADD backward_rowid bigint NOT NULL DEFAULT 0"
)
try:
yield self.postgres_store.runInteraction(
"alter_table", alter_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
try: try:
yield self.postgres_store.runInteraction( yield self.postgres_store.runInteraction(
"create_port_table", create_port_table "create_port_table", create_port_table
@ -458,7 +538,7 @@ class Porter(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _setup_sent_transactions(self): def _setup_sent_transactions(self):
# Only save things from the last day # Only save things from the last day
yesterday = int(time.time()*1000) - 86400000 yesterday = int(time.time() * 1000) - 86400000
# And save the max transaction id from each destination # And save the max transaction id from each destination
select = ( select = (
@ -514,7 +594,11 @@ class Porter(object):
yield self.postgres_store._simple_insert( yield self.postgres_store._simple_insert(
table="port_from_sqlite3", table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk} values={
"table_name": "sent_transactions",
"forward_rowid": next_chunk,
"backward_rowid": 0,
}
) )
def get_sent_table_size(txn): def get_sent_table_size(txn):
@ -535,13 +619,18 @@ class Porter(object):
defer.returnValue((next_chunk, inserted_rows, total_count)) defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk): def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
rows = yield self.sqlite_store.execute_sql( frows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk, forward_chunk,
) )
defer.returnValue(rows[0][0]) brows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
backward_chunk,
)
defer.returnValue(frows[0][0] + brows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_already_ported_count(self, table): def _get_already_ported_count(self, table):
@ -552,10 +641,10 @@ class Porter(object):
defer.returnValue(rows[0][0]) defer.returnValue(rows[0][0])
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk): def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
remaining, done = yield defer.gatherResults( remaining, done = yield defer.gatherResults(
[ [
self._get_remaining_count_to_port(table, next_chunk), self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
self._get_already_ported_count(table), self._get_already_ported_count(table),
], ],
consumeErrors=True, consumeErrors=True,
@ -686,7 +775,7 @@ class CursesProgress(Progress):
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr( self.stdscr.addstr(
i+2, left_margin + max_len - len(table), i + 2, left_margin + max_len - len(table),
table, table,
curses.A_BOLD | color, curses.A_BOLD | color,
) )
@ -694,18 +783,18 @@ class CursesProgress(Progress):
size = 20 size = 20
progress = "[%s%s]" % ( progress = "[%s%s]" % (
"#" * int(perc*size/100), "#" * int(perc * size / 100),
" " * (size - int(perc*size/100)), " " * (size - int(perc * size / 100)),
) )
self.stdscr.addstr( self.stdscr.addstr(
i+2, left_margin + max_len + middle_space, i + 2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
) )
if self.finished: if self.finished:
self.stdscr.addstr( self.stdscr.addstr(
rows-1, 0, rows - 1, 0,
"Press any key to exit...", "Press any key to exit...",
) )

View File

@ -16,7 +16,5 @@ ignore =
[flake8] [flake8]
max-line-length = 90 max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. # W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
ignore = W503
[pep8]
max-line-length = 90

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.14.0" __version__ = "0.17.0"

View File

@ -13,23 +13,22 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""This module contains classes for authenticating the user.""" import logging
import pymacaroons
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, UserID, get_domain_from_id
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64
import logging import synapse.types
import pymacaroons from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,13 +41,20 @@ AuthEventTypes = (
class Auth(object): class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401 self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
# Docs for these currently lives at
# https://github.com/matrix-org/matrix-doc/blob/master/drafts/macaroons_caveats.rst
# In addition, we have type == delete_pusher which grants access only to
# delete pushers.
self._KNOWN_CAVEAT_PREFIXES = set([ self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ", "gen = ",
"guest = ", "guest = ",
@ -57,7 +63,7 @@ class Auth(object):
"user_id = ", "user_id = ",
]) ])
def check(self, event, auth_events): def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args: Args:
@ -73,6 +79,13 @@ class Auth(object):
if not hasattr(event, "room_id"): if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event) raise AuthError(500, "Event has no room_id: %s" % event)
sender_domain = get_domain_from_id(event.sender)
# Check the sender's domain has signed the event
if do_sig_check and not event.signatures.get(sender_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None: if auth_events is None:
# Oh, we don't know what the state of the room was, so we # Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now) # are trusting that this is allowed (at least for now)
@ -80,6 +93,12 @@ class Auth(object):
return True return True
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME # FIXME
return True return True
@ -102,6 +121,22 @@ class Auth(object):
# FIXME: Temp hack # FIXME: Temp hack
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True return True
logger.debug( logger.debug(
@ -120,6 +155,24 @@ class Auth(object):
return allowed return allowed
self.check_event_sender_in_room(event, auth_events) self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events) self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels: if event.type == EventTypes.PowerLevels:
@ -323,6 +376,10 @@ class Auth(object):
if Membership.INVITE == membership and "third_party_invite" in event.content: if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events): if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.") raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True return True
if Membership.JOIN != membership: if Membership.JOIN != membership:
@ -507,15 +564,13 @@ class Auth(object):
return default return default
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False): def get_user_by_req(self, request, allow_guest=False, rights="access"):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
tuple of: defer.Deferred: resolves to a ``synapse.types.Requester`` object
UserID (str)
Access token ID (str)
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -524,16 +579,18 @@ class Auth(object):
user_id = yield self._get_appservice_user_id(request.args) user_id = yield self._get_appservice_user_id(request.args)
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(synapse.types.create_requester(user_id))
Requester(UserID.from_string(user_id), "", False)
)
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_access_token(access_token) user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"] user = user_info["user"]
token_id = user_info["token_id"] token_id = user_info["token_id"]
is_guest = user_info["is_guest"] is_guest = user_info["is_guest"]
# device_id may not be present if get_user_by_access_token has been
# stubbed out.
device_id = user_info.get("device_id")
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
"User-Agent", "User-Agent",
@ -545,7 +602,8 @@ class Auth(object):
user=user, user=user,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent user_agent=user_agent,
device_id=device_id,
) )
if is_guest and not allow_guest: if is_guest and not allow_guest:
@ -555,7 +613,8 @@ class Auth(object):
request.authenticated_entity = user.to_string() request.authenticated_entity = user.to_string()
defer.returnValue(Requester(user, token_id, is_guest)) defer.returnValue(synapse.types.create_requester(
user, token_id, is_guest, device_id))
except KeyError: except KeyError:
raise AuthError( raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@ -590,7 +649,7 @@ class Auth(object):
defer.returnValue(user_id) defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_access_token(self, token): def get_user_by_access_token(self, token, rights="access"):
""" Get a registered user's ID. """ Get a registered user's ID.
Args: Args:
@ -601,47 +660,61 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: try:
ret = yield self.get_user_from_macaroon(token) ret = yield self.get_user_from_macaroon(token, rights)
except AuthError: except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens # TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons. # have been re-issued as macaroons.
if self.hs.config.expire_access_token:
raise
ret = yield self._look_up_user_by_access_token(token) ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_from_macaroon(self, macaroon_str): def get_user_from_macaroon(self, macaroon_str, rights="access"):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self.validate_macaroon(macaroon, "access", self.hs.config.expire_access_token) user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id)
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
user_prefix = "user_id = "
user = None
guest = False guest = False
for caveat in macaroon.caveats: for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix): if caveat.caveat_id == "guest = true":
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
elif caveat.caveat_id == "guest = true":
guest = True guest = True
if user is None:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
if guest: if guest:
ret = { ret = {
"user": user, "user": user,
"is_guest": True, "is_guest": True,
"token_id": None, "token_id": None,
"device_id": None,
}
elif rights == "delete_pusher":
# We don't store these tokens in the database
ret = {
"user": user,
"is_guest": False,
"token_id": None,
"device_id": None,
} }
else: else:
# This codepath exists so that we can actually return a # This codepath exists for several reasons:
# token ID, because we use token IDs in place of device # * so that we can actually return a token ID, which is used
# identifiers throughout the codebase. # in some parts of the schema (where we probably ought to
# TODO(daniel): Remove this fallback when device IDs are # use device IDs instead)
# properly implemented. # * the only way we currently have to invalidate an
# access_token is by removing it from the database, so we
# have to check here that it is still in the db
# * some attributes (notably device_id) aren't stored in the
# macaroon. They probably should be.
# TODO: build the dictionary from the macaroon once the
# above are fixed
ret = yield self._look_up_user_by_access_token(macaroon_str) ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user: if ret["user"] != user:
logger.error( logger.error(
@ -661,21 +734,46 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def validate_macaroon(self, macaroon, type_string, verify_expiry): def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
AuthError if there is no user_id caveat in the macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
""" """
validate that a Macaroon is understood by and was signed by this server. validate that a Macaroon is understood by and was signed by this server.
Args: Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh") type_string(str): The kind of token required (e.g. "access", "refresh",
"delete_pusher")
verify_expiry(bool): Whether to verify whether the macaroon has expired. verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet. token refresh, so we can't enforce expiry yet.
user_id (str): The user_id required
""" """
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = " + type_string) v.satisfy_exact("type = " + type_string)
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
if verify_expiry: if verify_expiry:
v.satisfy_general(self._verify_expiry) v.satisfy_general(self._verify_expiry)
@ -714,10 +812,14 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = { user_info = {
"user": UserID.from_string(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None), "token_id": ret.get("token_id", None),
"is_guest": False, "is_guest": False,
"device_id": ret.get("device_id"),
} }
defer.returnValue(user_info) defer.returnValue(user_info)

View File

@ -42,8 +42,10 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE" THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
INVALID_USERNAME = "M_INVALID_USERNAME" INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -191,6 +191,17 @@ class Filter(object):
def __init__(self, filter_json): def __init__(self, filter_json):
self.filter_json = filter_json self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
self.not_types = self.filter_json.get("not_types", [])
self.rooms = self.filter_json.get("rooms", None)
self.not_rooms = self.filter_json.get("not_rooms", [])
self.senders = self.filter_json.get("senders", None)
self.not_senders = self.filter_json.get("not_senders", [])
self.contains_url = self.filter_json.get("contains_url", None)
def check(self, event): def check(self, event):
"""Checks whether the filter matches the given event. """Checks whether the filter matches the given event.
@ -209,9 +220,10 @@ class Filter(object):
event.get("room_id", None), event.get("room_id", None),
sender, sender,
event.get("type", None), event.get("type", None),
"url" in event.get("content", {})
) )
def check_fields(self, room_id, sender, event_type): def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields. """Checks whether the filter matches the given event fields.
Returns: Returns:
@ -225,15 +237,20 @@ class Filter(object):
for name, match_func in literal_keys.items(): for name, match_func in literal_keys.items():
not_name = "not_%s" % (name,) not_name = "not_%s" % (name,)
disallowed_values = self.filter_json.get(not_name, []) disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)): if any(map(match_func, disallowed_values)):
return False return False
allowed_values = self.filter_json.get(name, None) allowed_values = getattr(self, name)
if allowed_values is not None: if allowed_values is not None:
if not any(map(match_func, allowed_values)): if not any(map(match_func, allowed_values)):
return False return False
contains_url_filter = self.filter_json.get("contains_url")
if contains_url_filter is not None:
if contains_url_filter != contains_url:
return False
return True return True
def filter_rooms(self, room_ids): def filter_rooms(self, room_ids):

View File

@ -16,13 +16,11 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import ( from synapse import python_dependencies # noqa: E402
check_requirements, MissingRequirementError
) # NOQA
try: try:
check_requirements() python_dependencies.check_requirements()
except MissingRequirementError as e: except python_dependencies.MissingRequirementError as e:
message = "\n".join([ message = "\n".join([
"Missing Requirement: %s" % (e.message,), "Missing Requirement: %s" % (e.message,),
"To install run:", "To install run:",

View File

@ -0,0 +1,206 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse.api.urls import FEDERATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.crypto import context_factory
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import gc
logger = logging.getLogger("synapse.app.federation_reader")
class FederationReaderSlavedStore(
SlavedEventStore,
SlavedKeyStore,
RoomStore,
DirectoryStore,
TransactionStore,
BaseSlavedStore,
):
pass
class FederationReaderServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse federation reader", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
tls_server_context_factory = context_factory.ServerContextFactory(config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
ss.setup()
ss.get_handlers()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-federation-reader",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -16,6 +16,7 @@
import synapse import synapse
import gc
import logging import logging
import os import os
import sys import sys
@ -50,6 +51,7 @@ from synapse.api.urls import (
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.metrics import register_memory_metrics
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
@ -146,7 +148,7 @@ class SynapseHomeServer(HomeServer):
MEDIA_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
CONTENT_REPO_PREFIX: ContentRepoResource( CONTENT_REPO_PREFIX: ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr self, self.config.uploads_path
), ),
}) })
@ -265,10 +267,9 @@ def setup(config_options):
HomeServer HomeServer
""" """
try: try:
config = HomeServerConfig.load_config( config = HomeServerConfig.load_or_generate_config(
"Synapse Homeserver", "Synapse Homeserver",
config_options, config_options,
generate_section="Homeserver"
) )
except ConfigError as e: except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n") sys.stderr.write("\n" + e.message + "\n")
@ -284,7 +285,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config # check any extra requirements we have now we have a config
check_requirements(config) check_requirements(config)
version_string = get_version_string("Synapse", synapse) version_string = "Synapse/" + get_version_string(synapse)
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string) logger.info("Server version: %s", version_string)
@ -301,7 +302,6 @@ def setup(config_options):
db_config=config.database_config, db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory, tls_server_context_factory=tls_server_context_factory,
config=config, config=config,
content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine, database_engine=database_engine,
) )
@ -336,6 +336,8 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates() hs.get_datastore().start_doing_background_updates()
hs.get_replication_layer().start_get_pdu_cache() hs.get_replication_layer().start_get_pdu_cache()
register_memory_metrics(hs)
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
return hs return hs
@ -351,6 +353,8 @@ class SynapseService(service.Service):
def startService(self): def startService(self):
hs = setup(self.config) hs = setup(self.config)
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
def stopService(self): def stopService(self):
return self._port.stopListening() return self._port.stopListening()
@ -422,6 +426,8 @@ def run(hs):
# sys.settrace(logcontext_tracer) # sys.settrace(logcontext_tracer)
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit) change_resource_limit(hs.config.soft_file_limit)
if hs.config.gc_thresholds:
gc.set_threshold(*hs.config.gc_thresholds)
reactor.run() reactor.run()
if hs.config.daemonize: if hs.config.daemonize:

View File

@ -18,9 +18,8 @@ import synapse
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.config._base import ConfigError from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig from synapse.config.logger import setup_logging
from synapse.config.logger import LoggingConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.emailconfig import EmailConfig
from synapse.http.site import SynapseSite from synapse.http.site import SynapseSite
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
@ -44,61 +43,11 @@ from daemonize import Daemonize
import sys import sys
import logging import logging
import gc
logger = logging.getLogger("synapse.app.pusher") logger = logging.getLogger("synapse.app.pusher")
class SlaveConfig(DatabaseConfig):
def read_config(self, config):
self.replication_url = config["replication_url"]
self.server_name = config["server_name"]
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use", False
)
self.user_agent_suffix = None
self.start_pushers = True
self.listeners = config["listeners"]
self.soft_file_limit = config.get("soft_file_limit")
self.daemonize = config.get("daemonize")
self.pid_file = self.abspath(config.get("pid_file"))
self.public_baseurl = config["public_baseurl"]
def default_config(self, server_name, **kwargs):
pid_file = self.abspath("pusher.pid")
return """\
# Slave configuration
# The replication listener on the synapse to talk to.
#replication_url: https://localhost:{replication_port}/_synapse/replication
server_name: "%(server_name)s"
listeners: []
# Enable a ssh manhole listener on the pusher.
# - type: manhole
# port: {manhole_port}
# bind_address: 127.0.0.1
# Enable a metric listener on the pusher.
# - type: http
# port: {metrics_port}
# bind_address: 127.0.0.1
# resources:
# - names: ["metrics"]
# compress: False
report_stats: False
daemonize: False
pid_file: %(pid_file)s
""" % locals()
class PusherSlaveConfig(SlaveConfig, LoggingConfig, EmailConfig):
pass
class PusherSlaveStore( class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
SlavedAccountDataStore SlavedAccountDataStore
@ -163,7 +112,7 @@ class PusherServer(HomeServer):
def remove_pusher(self, app_id, push_key, user_id): def remove_pusher(self, app_id, push_key, user_id):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
replication_url = self.config.replication_url replication_url = self.config.worker_replication_url
url = replication_url + "/remove_pushers" url = replication_url + "/remove_pushers"
return http_client.post_json_get_json(url, { return http_client.post_json_get_json(url, {
"remove": [{ "remove": [{
@ -196,8 +145,8 @@ class PusherServer(HomeServer):
) )
logger.info("Synapse pusher now listening on port %d", port) logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self): def start_listening(self, listeners):
for listener in self.config.listeners: for listener in listeners:
if listener["type"] == "http": if listener["type"] == "http":
self._listen_http(listener) self._listen_http(listener)
elif listener["type"] == "manhole": elif listener["type"] == "manhole":
@ -217,7 +166,7 @@ class PusherServer(HomeServer):
def replicate(self): def replicate(self):
http_client = self.get_simple_http_client() http_client = self.get_simple_http_client()
store = self.get_datastore() store = self.get_datastore()
replication_url = self.config.replication_url replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool() pusher_pool = self.get_pusherpool()
clock = self.get_clock() clock = self.get_clock()
@ -290,22 +239,33 @@ class PusherServer(HomeServer):
poke_pushers(result) poke_pushers(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
sleep(30) yield sleep(30)
def setup(config_options): def start(config_options):
try: try:
config = PusherSlaveConfig.load_config( config = HomeServerConfig.load_config(
"Synapse pusher", config_options "Synapse pusher", config_options
) )
except ConfigError as e: except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n") sys.stderr.write("\n" + e.message + "\n")
sys.exit(1) sys.exit(1)
if not config: assert config.worker_app == "synapse.app.pusher"
sys.exit(0)
config.setup_logging() setup_logging(config.worker_log_config, config.worker_log_file)
if config.start_pushers:
sys.stderr.write(
"\nThe pushers must be disabled in the main synapse process"
"\nbefore they can be run in a separate worker."
"\nPlease add ``start_pushers: false`` to the main config"
"\n"
)
sys.exit(1)
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
database_engine = create_engine(config.database_config) database_engine = create_engine(config.database_config)
@ -313,14 +273,20 @@ def setup(config_options):
config.server_name, config.server_name,
db_config=config.database_config, db_config=config.database_config,
config=config, config=config,
version_string=get_version_string("Synapse", synapse), version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine, database_engine=database_engine,
) )
ps.setup() ps.setup()
ps.start_listening() ps.start_listening(config.worker_listeners)
change_resource_limit(ps.config.soft_file_limit) def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start(): def start():
ps.replicate() ps.replicate()
@ -329,28 +295,20 @@ def setup(config_options):
reactor.callWhenRunning(start) reactor.callWhenRunning(start)
return ps if config.worker_daemonize:
daemon = Daemonize(
app="synapse-pusher",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__': if __name__ == '__main__':
with LoggingContext("main"): with LoggingContext("main"):
ps = setup(sys.argv[1:]) ps = start(sys.argv[1:])
if ps.config.daemonize:
def run():
with LoggingContext("run"):
change_resource_limit(ps.config.soft_file_limit)
reactor.run()
daemon = Daemonize(
app="synapse-pusher",
pid=ps.config.pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
reactor.run()

465
synapse/app/synchrotron.py Normal file
View File

@ -0,0 +1,465 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.api.constants import EventTypes, PresenceState
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.events import FrozenEvent
from synapse.handlers.presence import PresenceHandler
from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import PresenceStore, UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.async import sleep
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.manhole import manhole
from synapse.util.rlimit import change_resource_limit
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor, defer
from twisted.web.resource import Resource
from daemonize import Daemonize
import sys
import logging
import contextlib
import gc
import ujson as json
logger = logging.getLogger("synapse.app.synchrotron")
class SynchrotronSlavedStore(
SlavedPushRuleStore,
SlavedEventStore,
SlavedReceiptsStore,
SlavedAccountDataStore,
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedFilteringStore,
SlavedPresenceStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
# XXX: This is a bit broken because we don't persist forgotten rooms
# in a way that they can be streamed. This means that we don't have a
# way to invalidate the forgotten rooms cache correctly.
# For now we expire the cache every 10 minutes.
BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
# XXX: This is a bit broken because we don't persist the accepted list in a
# way that can be replicated. This means that we don't have a way to
# invalidate the cache correctly.
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
state.user_id: state
for state in active_presence
}
self.process_id = random_string(16)
logger.info("Presence process_id is %r", self.process_id)
self._sending_sync = False
self._need_to_send_sync = False
self.clock.looping_call(
self._send_syncing_users_regularly,
UPDATE_SYNCING_USERS_MS,
)
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
def set_state(self, user, state):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
def user_syncing(self, user_id, affect_presence):
if affect_presence:
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
prev_states = yield self.current_state_for_users([user_id])
if prev_states[user_id].state == PresenceState.OFFLINE:
# TODO: Don't block the sync request on this HTTP hit.
yield self._send_syncing_users_now()
def _end():
# We check that the user_id is in user_to_num_current_syncs because
# user_to_num_current_syncs may have been cleared if we are
# shutting down.
if affect_presence and user_id in self.user_to_num_current_syncs:
self.user_to_num_current_syncs[user_id] -= 1
@contextlib.contextmanager
def _user_syncing():
try:
yield
finally:
_end()
defer.returnValue(_user_syncing())
@defer.inlineCallbacks
def _on_shutdown(self):
# When the synchrotron is shutdown tell the master to clear the in
# progress syncs for this process
self.user_to_num_current_syncs.clear()
yield self._send_syncing_users_now()
def _send_syncing_users_regularly(self):
# Only send an update if we aren't in the middle of sending one.
if not self._sending_sync:
preserve_fn(self._send_syncing_users_now)()
@defer.inlineCallbacks
def _send_syncing_users_now(self):
if self._sending_sync:
# We don't want to race with sending another update.
# Instead we wait for that update to finish and send another
# update afterwards.
self._need_to_send_sync = True
return
# Flag that we are sending an update.
self._sending_sync = True
yield self.http_client.post_json_get_json(self.syncing_users_url, {
"process_id": self.process_id,
"syncing_users": [
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count > 0
],
})
# Unset the flag as we are no longer sending an update.
self._sending_sync = False
if self._need_to_send_sync:
# If something happened while we were sending the update then
# we might need to send another update.
# TODO: Check if the update that was sent matches the current state
# as we only need to send an update if they are different.
self._need_to_send_sync = False
yield self._send_syncing_users_now()
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
self.user_to_current_state[user_id] = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
class SynchrotronTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._room_serials = {}
self._room_typing = {}
def stream_positions(self):
return {"typing": self._latest_room_serial}
def process_replication(self, result):
stream = result.get("typing")
if stream:
self._latest_room_serial = int(stream["position"])
for row in stream["rows"]:
position, room_id, typing_json = row
typing = json.loads(typing_json)
self._room_serials[room_id] = position
self._room_typing[room_id] = typing
class SynchrotronApplicationService(object):
def notify_interested_services(self, event):
pass
class SynchrotronServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self, listeners):
for listener in listeners:
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
def expire_broken_caches():
store.who_forgot_in_room.invalidate_all()
store.get_presence_list_accepted.invalidate_all()
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
stream = result.get(stream_name)
if stream:
position_index = stream["field_names"].index("position")
if room:
room_index = stream["field_names"].index(room)
if user:
user_index = stream["field_names"].index(user)
users = ()
rooms = ()
for row in stream["rows"]:
position = row[position_index]
if user:
users = (row[user_index],)
if room:
rooms = (row[room_index],)
notifier.on_new_event(
stream_key, position, users=users, rooms=rooms
)
def notify(result):
stream = result.get("events")
if stream:
max_position = stream["position"]
for row in stream["rows"]:
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
extra_users = ()
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
notifier.on_new_room_event(
event, position, max_position, extra_users
)
notify_from_stream(
result, "push_rules", "push_rules_key", user="user_id"
)
notify_from_stream(
result, "user_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "room_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "tag_account_data", "account_data_key", user="user_id"
)
notify_from_stream(
result, "receipts", "receipt_key", room="room_id"
)
notify_from_stream(
result, "typing", "typing_key", room="room_id"
)
next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
now_ms = clock.time_msec()
if now_ms > next_expire_broken_caches_ms:
expire_broken_caches()
next_expire_broken_caches_ms = (
now_ms + store.BROKEN_CACHE_EXPIRY_MS
)
yield store.process_replication(result)
typing_handler.process_replication(result)
presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)
def build_presence_handler(self):
return SynchrotronPresence(self)
def build_typing_handler(self):
return SynchrotronTyping(self)
def start(config_options):
try:
config = HomeServerConfig.load_config(
"Synapse synchrotron", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
assert config.worker_app == "synapse.app.synchrotron"
setup_logging(config.worker_log_config, config.worker_log_file)
database_engine = create_engine(config.database_config)
ss = SynchrotronServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)
ss.setup()
ss.start_listening(config.worker_listeners)
def run():
with LoggingContext("run"):
logger.info("Running")
change_resource_limit(config.soft_file_limit)
if config.gc_thresholds:
gc.set_threshold(*config.gc_thresholds)
reactor.run()
def start():
ss.get_datastore().start_profiling()
ss.replicate()
reactor.callWhenRunning(start)
if config.worker_daemonize:
daemon = Daemonize(
app="synapse-synchrotron",
pid=config.worker_pid_file,
action=run,
auto_close_fds=False,
verbose=True,
logger=logger,
)
daemon.start()
else:
run()
if __name__ == '__main__':
with LoggingContext("main"):
start(sys.argv[1:])

View File

@ -14,11 +14,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import argparse
import collections
import glob
import os import os
import os.path import os.path
import subprocess
import signal import signal
import subprocess
import sys
import yaml import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
@ -28,60 +31,181 @@ RED = "\x1b[1;31m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
def write(message, colour=NORMAL, stream=sys.stdout):
if colour == NORMAL:
stream.write(message + "\n")
else:
stream.write(colour + message + NORMAL + "\n")
def start(configfile): def start(configfile):
print ("Starting ...") write("Starting ...")
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", configfile]) args.extend(["--daemonize", "-c", configfile])
try: try:
subprocess.check_call(args) subprocess.check_call(args)
print (GREEN + "started" + NORMAL) write("started synapse.app.homeserver(%r)" % (configfile,), colour=GREEN)
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print ( write(
RED + "error starting (exit code: %d); see above for logs" % e.returncode,
"error starting (exit code: %d); see above for logs" % e.returncode + colour=RED,
NORMAL
) )
def stop(pidfile): def start_worker(app, configfile, worker_configfile):
args = [
"python", "-B",
"-m", app,
"-c", configfile,
"-c", worker_configfile
]
try:
subprocess.check_call(args)
write("started %s(%r)" % (app, worker_configfile), colour=GREEN)
except subprocess.CalledProcessError as e:
write(
"error starting %s(%r) (exit code: %d); see above for logs" % (
app, worker_configfile, e.returncode,
),
colour=RED,
)
def stop(pidfile, app):
if os.path.exists(pidfile): if os.path.exists(pidfile):
pid = int(open(pidfile).read()) pid = int(open(pidfile).read())
os.kill(pid, signal.SIGTERM) os.kill(pid, signal.SIGTERM)
print (GREEN + "stopped" + NORMAL) write("stopped %s" % (app,), colour=GREEN)
Worker = collections.namedtuple("Worker", [
"app", "configfile", "pidfile", "cache_factor"
])
def main(): def main():
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
if not os.path.exists(configfile): parser = argparse.ArgumentParser()
sys.stderr.write(
"No config file found\n" parser.add_argument(
"To generate a config file, run '%s -c %s --generate-config" "action",
" --server-name=<server name>'\n" % ( choices=["start", "stop", "restart"],
" ".join(SYNAPSE), configfile help="whether to start, stop or restart the synapse",
) )
parser.add_argument(
"configfile",
nargs="?",
default="homeserver.yaml",
help="the homeserver config file, defaults to homserver.yaml",
)
parser.add_argument(
"-w", "--worker",
metavar="WORKERCONFIG",
help="start or stop a single worker",
)
parser.add_argument(
"-a", "--all-processes",
metavar="WORKERCONFIGDIR",
help="start or stop all the workers in the given directory"
" and the main synapse process",
)
options = parser.parse_args()
if options.worker and options.all_processes:
write(
'Cannot use "--worker" with "--all-processes"',
stream=sys.stderr
) )
sys.exit(1) sys.exit(1)
config = yaml.load(open(configfile)) configfile = options.configfile
if not os.path.exists(configfile):
write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), options.configfile
),
stream=sys.stderr,
)
sys.exit(1)
with open(configfile) as stream:
config = yaml.load(stream)
pidfile = config["pid_file"] pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor", None) cache_factor = config.get("synctl_cache_factor")
start_stop_synapse = True
if cache_factor: if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor) os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
action = sys.argv[1] if sys.argv[1:] else "usage" worker_configfiles = []
if action == "start": if options.worker:
start(configfile) start_stop_synapse = False
elif action == "stop": worker_configfile = options.worker
stop(pidfile) if not os.path.exists(worker_configfile):
elif action == "restart": write(
stop(pidfile) "No worker config found at %r" % (worker_configfile,),
start(configfile) stream=sys.stderr,
else: )
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],)) sys.exit(1)
sys.exit(1) worker_configfiles.append(worker_configfile)
if options.all_processes:
worker_configdir = options.all_processes
if not os.path.isdir(worker_configdir):
write(
"No worker config directory found at %r" % (worker_configdir,),
stream=sys.stderr,
)
sys.exit(1)
worker_configfiles.extend(sorted(glob.glob(
os.path.join(worker_configdir, "*.yaml")
)))
workers = []
for worker_configfile in worker_configfiles:
with open(worker_configfile) as stream:
worker_config = yaml.load(stream)
worker_app = worker_config["worker_app"]
worker_pidfile = worker_config["worker_pid_file"]
worker_daemonize = worker_config["worker_daemonize"]
assert worker_daemonize # TODO print something more user friendly
worker_cache_factor = worker_config.get("synctl_cache_factor")
workers.append(Worker(
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
))
action = options.action
if action == "stop" or action == "restart":
for worker in workers:
stop(worker.pidfile, worker.app)
if start_stop_synapse:
stop(pidfile, "synapse.app.homeserver")
# TODO: Wait for synapse to actually shutdown before starting it again
if action == "start" or action == "restart":
if start_stop_synapse:
start(configfile)
for worker in workers:
if worker.cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(worker.cache_factor)
start_worker(worker.app, configfile, worker.configfile)
if cache_factor:
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
else:
os.environ.pop("SYNAPSE_CACHE_FACTOR", None)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -56,22 +56,22 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AppServiceScheduler(object): class ApplicationServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the """ Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this components together. This also serves as the "event_pool", which in this
case is a simple array. case is a simple array.
""" """
def __init__(self, clock, store, as_api): def __init__(self, hs):
self.clock = clock self.clock = hs.get_clock()
self.store = store self.store = hs.get_datastore()
self.as_api = as_api self.as_api = hs.get_application_service_api()
def create_recoverer(service, callback): def create_recoverer(service, callback):
return _Recoverer(clock, store, as_api, service, callback) return _Recoverer(self.clock, self.store, self.as_api, service, callback)
self.txn_ctrl = _TransactionController( self.txn_ctrl = _TransactionController(
clock, store, as_api, create_recoverer self.clock, self.store, self.as_api, create_recoverer
) )
self.queuer = _ServiceQueuer(self.txn_ctrl) self.queuer = _ServiceQueuer(self.txn_ctrl)

View File

@ -157,9 +157,40 @@ class Config(object):
return default_config, config return default_config, config
@classmethod @classmethod
def load_config(cls, description, argv, generate_section=None): def load_config(cls, description, argv):
obj = cls() config_parser = argparse.ArgumentParser(
description=description,
)
config_parser.add_argument(
"-c", "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
" their location is given explicitly in the config."
" Defaults to the directory containing the last config file",
)
config_args = config_parser.parse_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
obj = cls()
obj.read_config_files(
config_files,
keys_directory=config_args.keys_directory,
generate_keys=False,
)
return obj
@classmethod
def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False) config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c", "--config-path",
@ -176,7 +207,7 @@ class Config(object):
config_parser.add_argument( config_parser.add_argument(
"--report-stats", "--report-stats",
action="store", action="store",
help="Stuff", help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"] choices=["yes", "no"]
) )
config_parser.add_argument( config_parser.add_argument(
@ -197,36 +228,11 @@ class Config(object):
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
generate_keys = config_args.generate_keys generate_keys = config_args.generate_keys
config_files = [] obj = cls()
if config_args.config_path:
for config_path in config_args.config_path:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
if config_args.generate_config: if config_args.generate_config:
if config_args.report_stats is None: if config_args.report_stats is None:
@ -299,28 +305,43 @@ class Config(object):
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
if config_args.keys_directory: obj.read_config_files(
config_dir_path = config_args.keys_directory config_files,
else: keys_directory=config_args.keys_directory,
config_dir_path = os.path.dirname(config_args.config_path[-1]) generate_keys=generate_keys,
config_dir_path = os.path.abspath(config_dir_path) )
if generate_keys:
return None
obj.invoke_all("read_arguments", args)
return obj
def read_config_files(self, config_files, keys_directory=None,
generate_keys=False):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
config_dir_path = os.path.abspath(keys_directory)
specified_config = {} specified_config = {}
for config_file in config_files: for config_file in config_files:
yaml_config = cls.read_config_file(config_file) yaml_config = self.read_config_file(config_file)
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config: if "server_name" not in specified_config:
raise ConfigError(MISSING_SERVER_NAME) raise ConfigError(MISSING_SERVER_NAME)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config( _, config = self.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
server_name=server_name, server_name=server_name,
is_generating_file=False, is_generating_file=False,
) )
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
raise ConfigError( raise ConfigError(
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
@ -328,11 +349,51 @@ class Config(object):
) )
if generate_keys: if generate_keys:
obj.invoke_all("generate_files", config) self.invoke_all("generate_files", config)
return return
obj.invoke_all("read_config", config) self.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
return obj def find_config_files(search_paths):
"""Finds config files using a list of search paths. If a path is a file
then that file path is added to the list. If a search path is a directory
then all the "*.yaml" files in that directory are added to the list in
sorted order.
Args:
search_paths(list(str)): A list of paths to search.
Returns:
list(str): A list of file paths.
"""
config_files = []
if search_paths:
for config_path in search_paths:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
return config_files

View File

@ -27,6 +27,7 @@ class CaptchaConfig(Config):
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
## Captcha ## ## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
# This Home Server's ReCAPTCHA public key. # This Home Server's ReCAPTCHA public key.
recaptcha_public_key: "YOUR_PUBLIC_KEY" recaptcha_public_key: "YOUR_PUBLIC_KEY"

View File

@ -89,7 +89,7 @@ class EmailConfig(Config):
# enable_notifs: false # enable_notifs: false
# smtp_host: "localhost" # smtp_host: "localhost"
# smtp_port: 25 # smtp_port: 25
# notif_from: Your Friendly Matrix Home Server <noreply@example.com> # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
# app_name: Matrix # app_name: Matrix
# template_dir: res/templates # template_dir: res/templates
# notif_template_html: notif_mail.html # notif_template_html: notif_mail.html

View File

@ -32,13 +32,15 @@ from .password import PasswordConfig
from .jwt import JWTConfig from .jwt import JWTConfig
from .ldap import LDAPConfig from .ldap import LDAPConfig
from .emailconfig import EmailConfig from .emailconfig import EmailConfig
from .workers import WorkerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,): JWTConfig, LDAPConfig, PasswordConfig, EmailConfig,
WorkerConfig,):
pass pass

View File

@ -13,40 +13,88 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
MISSING_LDAP3 = (
"Missing ldap3 library. This is required for LDAP Authentication."
)
class LDAPMode(object):
SIMPLE = "simple",
SEARCH = "search",
LIST = (SIMPLE, SEARCH)
class LDAPConfig(Config): class LDAPConfig(Config):
def read_config(self, config): def read_config(self, config):
ldap_config = config.get("ldap_config", None) ldap_config = config.get("ldap_config", {})
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False) self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"] if self.ldap_enabled:
self.ldap_tls = ldap_config.get("tls", False) # verify dependencies are available
self.ldap_search_base = ldap_config["search_base"] try:
self.ldap_search_property = ldap_config["search_property"] import ldap3
self.ldap_email_property = ldap_config["email_property"] ldap3 # to stop unused lint
self.ldap_full_name_property = ldap_config["full_name_property"] except ImportError:
else: raise ConfigError(MISSING_LDAP3)
self.ldap_enabled = False
self.ldap_server = None self.ldap_mode = LDAPMode.SIMPLE
self.ldap_port = None
self.ldap_tls = False # verify config sanity
self.ldap_search_base = None self.require_keys(ldap_config, [
self.ldap_search_property = None "uri",
self.ldap_email_property = None "base",
self.ldap_full_name_property = None "attributes",
])
self.ldap_uri = ldap_config["uri"]
self.ldap_start_tls = ldap_config.get("start_tls", False)
self.ldap_base = ldap_config["base"]
self.ldap_attributes = ldap_config["attributes"]
if "bind_dn" in ldap_config:
self.ldap_mode = LDAPMode.SEARCH
self.require_keys(ldap_config, [
"bind_dn",
"bind_password",
])
self.ldap_bind_dn = ldap_config["bind_dn"]
self.ldap_bind_password = ldap_config["bind_password"]
self.ldap_filter = ldap_config.get("filter", None)
# verify attribute lookup
self.require_keys(ldap_config['attributes'], [
"uid",
"name",
"mail",
])
def require_keys(self, config, required):
missing = [key for key in required if key not in config]
if missing:
raise ConfigError(
"LDAP enabled but missing required config values: {}".format(
", ".join(missing)
)
)
def default_config(self, **kwargs): def default_config(self, **kwargs):
return """\ return """\
# ldap_config: # ldap_config:
# enabled: true # enabled: true
# server: "ldap://localhost" # uri: "ldap://ldap.example.com:389"
# port: 389 # start_tls: true
# tls: false # base: "ou=users,dc=example,dc=com"
# search_base: "ou=Users,dc=example,dc=com" # attributes:
# search_property: "cn" # uid: "cn"
# email_property: "email" # mail: "email"
# full_name_property: "givenName" # name: "givenName"
# #bind_dn:
# #bind_password:
# #filter: "(objectClass=posixAccount)"
""" """

View File

@ -126,54 +126,58 @@ class LoggingConfig(Config):
) )
def setup_logging(self): def setup_logging(self):
log_format = ( setup_logging(self.log_config, self.log_file, self.verbosity)
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
)
if self.log_config is None:
level = logging.INFO
level_for_storage = logging.INFO
if self.verbosity:
level = logging.DEBUG
if self.verbosity > 1:
level_for_storage = logging.DEBUG
# FIXME: we need a logging.WARN for a -q quiet option def setup_logging(log_config=None, log_file=None, verbosity=None):
logger = logging.getLogger('') log_format = (
logger.setLevel(level) "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
" - %(message)s"
)
if log_config is None:
logging.getLogger('synapse.storage').setLevel(level_for_storage) level = logging.INFO
level_for_storage = logging.INFO
if verbosity:
level = logging.DEBUG
if verbosity > 1:
level_for_storage = logging.DEBUG
formatter = logging.Formatter(log_format) # FIXME: we need a logging.WARN for a -q quiet option
if self.log_file: logger = logging.getLogger('')
# TODO: Customisable file size / backup count logger.setLevel(level)
handler = logging.handlers.RotatingFileHandler(
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
)
def sighup(signum, stack): logging.getLogger('synapse.storage').setLevel(level_for_storage)
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for formatter = logging.Formatter(log_format)
# stealing SIGHUP, because it means no other part of synapse if log_file:
# can use it instead. If we want to catch SIGHUP anywhere # TODO: Customisable file size / backup count
# else as well, I'd suggest we find a nicer way to broadcast handler = logging.handlers.RotatingFileHandler(
# it around. log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
if getattr(signal, "SIGHUP"): )
signal.signal(signal.SIGHUP, sighup)
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
handler.addFilter(LoggingContextFilter(request="")) def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
logger.addHandler(handler) # TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else: else:
with open(self.log_config, 'r') as f: handler = logging.StreamHandler()
logging.config.dictConfig(yaml.load(f)) handler.setFormatter(formatter)
observer = PythonLoggingObserver() handler.addFilter(LoggingContextFilter(request=""))
observer.start()
logger.addHandler(handler)
else:
with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
observer = PythonLoggingObserver()
observer.start()

View File

@ -23,10 +23,14 @@ class PasswordConfig(Config):
def read_config(self, config): def read_config(self, config):
password_config = config.get("password_config", {}) password_config = config.get("password_config", {})
self.password_enabled = password_config.get("enabled", True) self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "")
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
return """ return """
# Enable password for login. # Enable password for login.
password_config: password_config:
enabled: true enabled: true
# Uncomment and change to a secret random string for extra security.
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#pepper: ""
""" """

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config, ConfigError
class ServerConfig(Config): class ServerConfig(Config):
@ -27,8 +27,9 @@ class ServerConfig(Config):
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile") self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix") self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl") self.public_baseurl = config.get("public_baseurl")
self.secondary_directory_servers = config.get("secondary_directory_servers", [])
if self.public_baseurl is not None: if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/': if self.public_baseurl[-1] != '/':
@ -37,6 +38,8 @@ class ServerConfig(Config):
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port") bind_port = config.get("bind_port")
if bind_port: if bind_port:
self.listeners = [] self.listeners = []
@ -104,26 +107,6 @@ class ServerConfig(Config):
] ]
}) })
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name
if ':' not in host:
host = "%s:%d" % (host, unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
self.content_addr = content_addr
def default_config(self, server_name, **kwargs): def default_config(self, server_name, **kwargs):
if ":" in server_name: if ":" in server_name:
bind_port = int(server_name.split(":")[1]) bind_port = int(server_name.split(":")[1])
@ -156,6 +139,17 @@ class ServerConfig(Config):
# hard limit. # hard limit.
soft_file_limit: 0 soft_file_limit: 0
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
# gc_thresholds: [700, 10, 10]
# A list of other Home Servers to fetch the public room directory from
# and include in the public room directory of this home server
# This is a temporary stopgap solution to populate new server with a
# list of rooms until there exists a good solution of a decentralized
# room directory.
# secondary_directory_servers:
# - matrix.org
# List of ports that Synapse should listen on, their purpose and their # List of ports that Synapse should listen on, their purpose and their
# configuration. # configuration.
listeners: listeners:
@ -237,3 +231,20 @@ class ServerConfig(Config):
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"
" service on the given port.") " service on the given port.")
def read_gc_thresholds(thresholds):
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.
"""
if thresholds is None:
return None
try:
assert len(thresholds) == 3
return (
int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
)
except:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
)

31
synapse/config/workers.py Normal file
View File

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# Copyright 2016 matrix.org
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class WorkerConfig(Config):
"""The workers are processes run separately to the main synapse process.
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
def read_config(self, config):
self.worker_app = config.get("worker_app")
self.worker_listeners = config.get("worker_listeners")
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url")

View File

@ -77,10 +77,12 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None self.host = None
self._peer = None
def connectionMade(self): def connectionMade(self):
self.host = self.transport.getHost() self._peer = self.transport.getPeer()
logger.debug("Connected to %s", self.host) logger.debug("Connected to %s", self._peer)
self.sendCommand(b"GET", self.path) self.sendCommand(b"GET", self.path)
if self.host: if self.host:
self.sendHeader(b"Host", self.host) self.sendHeader(b"Host", self.host)
@ -124,7 +126,10 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host) logger.debug(
"Timeout waiting for response from %s: %s",
self.host, self._peer,
)
self.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
@ -133,4 +138,5 @@ class SynapseKeyClientFactory(Factory):
def protocol(self): def protocol(self):
protocol = SynapseKeyClientProtocol() protocol = SynapseKeyClientProtocol()
protocol.path = self.path protocol.path = self.path
protocol.host = self.host
return protocol return protocol

View File

@ -44,7 +44,25 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) VerifyKeyRequest = namedtuple("VerifyRequest", (
"server_name", "key_ids", "json_object", "deferred"
))
"""
A request for a verify key to verify a JSON object.
Attributes:
server_name(str): The name of the server to verify against.
key_ids(set(str)): The set of key_ids to that could be used to verify the
JSON object
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched
"""
class KeyLookupError(ValueError):
pass
class Keyring(object): class Keyring(object):
@ -74,39 +92,32 @@ class Keyring(object):
list of deferreds indicating success or failure to verify each list of deferreds indicating success or failure to verify each
json object's signature for the given server_name. json object's signature for the given server_name.
""" """
group_id_to_json = {} verify_requests = []
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json: for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name) key_ids = signature_ids(json_object, server_name)
if not key_ids: if not key_ids:
deferreds[group_id] = defer.fail(SynapseError( deferred = defer.fail(SynapseError(
400, 400,
"Not signed with a supported algorithm", "Not signed with a supported algorithm",
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
else: else:
deferreds[group_id] = defer.Deferred() deferred = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids) verify_request = VerifyKeyRequest(
server_name, key_ids, json_object, deferred
)
group_id_to_group[group_id] = group verify_requests.append(verify_request)
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_key_deferred(group, deferred): def handle_key_deferred(verify_request):
server_name = group.server_name server_name = verify_request.server_name
try: try:
_, _, key_id, verify_key = yield deferred _, key_id, verify_key = yield verify_request.deferred
except IOError as e: except IOError as e:
logger.warn( logger.warn(
"Got IOError when downloading keys for %s: %s %s", "Got IOError when downloading keys for %s: %s %s",
@ -128,7 +139,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
json_object = group_id_to_json[group.group_id] json_object = verify_request.json_object
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
@ -157,36 +168,34 @@ class Keyring(object):
# Actually start fetching keys. # Actually start fetching keys.
wait_on_deferred.addBoth( wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) lambda _: self.get_server_verify_keys(verify_requests)
) )
# When we've finished fetching all the keys for a given server_name, # When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that # resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed. # any lookups waiting will proceed.
server_to_gids = {} server_to_request_ids = {}
def remove_deferreds(res, server_name, group_id): def remove_deferreds(res, server_name, verify_request):
server_to_gids[server_name].discard(group_id) request_id = id(verify_request)
if not server_to_gids[server_name]: server_to_request_ids[server_name].discard(request_id)
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) d = server_to_deferred.pop(server_name, None)
if d: if d:
d.callback(None) d.callback(None)
return res return res
for g_id, deferred in deferreds.items(): for verify_request in verify_requests:
server_name = group_id_to_group[g_id].server_name server_name = verify_request.server_name
server_to_gids.setdefault(server_name, set()).add(g_id) request_id = id(verify_request)
deferred.addBoth(remove_deferreds, server_name, g_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
deferred.addBoth(remove_deferreds, server_name, verify_request)
# Pass those keys to handle_key_deferred so that the json object # Pass those keys to handle_key_deferred so that the json object
# signatures can be verified # signatures can be verified
return [ return [
preserve_context_over_fn( preserve_context_over_fn(handle_key_deferred, verify_request)
handle_key_deferred, for verify_request in verify_requests
group_id_to_group[g_id],
deferreds[g_id],
)
for g_id in group_ids
] ]
@defer.inlineCallbacks @defer.inlineCallbacks
@ -220,7 +229,7 @@ class Keyring(object):
d.addBoth(rm, server_name) d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): def get_server_verify_keys(self, verify_requests):
"""Takes a dict of KeyGroups and tries to find at least one key for """Takes a dict of KeyGroups and tries to find at least one key for
each group. each group.
""" """
@ -237,62 +246,64 @@ class Keyring(object):
merged_results = {} merged_results = {}
missing_keys = {} missing_keys = {}
for group in group_id_to_group.values(): for verify_request in verify_requests:
missing_keys.setdefault(group.server_name, set()).update( missing_keys.setdefault(verify_request.server_name, set()).update(
group.key_ids verify_request.key_ids
) )
for fn in key_fetch_fns: for fn in key_fetch_fns:
results = yield fn(missing_keys.items()) results = yield fn(missing_keys.items())
merged_results.update(results) merged_results.update(results)
# We now need to figure out which groups we have keys for # We now need to figure out which verify requests we have keys
# and which we don't # for and which we don't
missing_groups = {} missing_keys = {}
for group in group_id_to_group.values(): requests_missing_keys = []
for key_id in group.key_ids: for verify_request in verify_requests:
if key_id in merged_results[group.server_name]: server_name = verify_request.server_name
result_keys = merged_results[server_name]
if verify_request.deferred.called:
# We've already called this deferred, which probably
# means that we've already found a key for it.
continue
for key_id in verify_request.key_ids:
if key_id in result_keys:
with PreserveLoggingContext(): with PreserveLoggingContext():
group_id_to_deferred[group.group_id].callback(( verify_request.deferred.callback((
group.group_id, server_name,
group.server_name,
key_id, key_id,
merged_results[group.server_name][key_id], result_keys[key_id],
)) ))
break break
else: else:
missing_groups.setdefault( # The else block is only reached if the loop above
group.server_name, [] # doesn't break.
).append(group) missing_keys.setdefault(server_name, set()).update(
verify_request.key_ids
)
requests_missing_keys.append(verify_request)
if not missing_groups: if not missing_keys:
break break
missing_keys = { for verify_request in requests_missing_keys.values():
server_name: set( verify_request.deferred.errback(SynapseError(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401, 401,
"No key for %s with id %s" % ( "No key for %s with id %s" % (
group.server_name, group.key_ids, verify_request.server_name, verify_request.key_ids,
), ),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
)) ))
def on_err(err): def on_err(err):
for deferred in group_id_to_deferred.values(): for verify_request in verify_requests:
if not deferred.called: if not verify_request.deferred.called:
deferred.errback(err) verify_request.deferred.errback(err)
do_iterations().addErrback(on_err) do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks @defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids): def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults( res = yield defer.gatherResults(
@ -356,7 +367,7 @@ class Keyring(object):
) )
except Exception as e: except Exception as e:
logger.info( logger.info(
"Unable to getting key %r for %r directly: %s %s", "Unable to get key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
@ -418,7 +429,7 @@ class Keyring(object):
for response in responses: for response in responses:
if (u"signatures" not in response if (u"signatures" not in response
or perspective_name not in response[u"signatures"]): or perspective_name not in response[u"signatures"]):
raise ValueError( raise KeyLookupError(
"Key response not signed by perspective server" "Key response not signed by perspective server"
" %r" % (perspective_name,) " %r" % (perspective_name,)
) )
@ -441,13 +452,13 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]), list(response[u"signatures"][perspective_name]),
list(perspective_keys) list(perspective_keys)
) )
raise ValueError( raise KeyLookupError(
"Response not signed with a known key for perspective" "Response not signed with a known key for perspective"
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response perspective_name, response, only_from_server=False
) )
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
@ -484,10 +495,10 @@ class Keyring(object):
if (u"signatures" not in response if (u"signatures" not in response
or server_name not in response[u"signatures"]): or server_name not in response[u"signatures"]):
raise ValueError("Key response not signed by remote server") raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response: if "tls_fingerprints" not in response:
raise ValueError("Key response missing TLS fingerprints") raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate( certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate crypto.FILETYPE_ASN1, tls_certificate
@ -501,7 +512,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"]) response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints: if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints") raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
@ -527,7 +538,7 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_ids=[]): requested_ids=[], only_from_server=True):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -551,9 +562,16 @@ class Keyring(object):
results = {} results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise KeyLookupError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )
@ -621,15 +639,15 @@ class Keyring(object):
if ("signatures" not in response if ("signatures" not in response
or server_name not in response["signatures"]): or server_name not in response["signatures"]):
raise ValueError("Key response not signed by remote server") raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response: if "tls_certificate" not in response:
raise ValueError("Key response missing TLS certificate") raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"] tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match") raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore. # Cache the result in the datastore.
@ -645,7 +663,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]: if key_id not in response["verify_keys"]:
raise ValueError( raise KeyLookupError(
"Key response must include verification keys for all" "Key response must include verification keys for all"
" signatures" " signatures"
) )

View File

@ -88,6 +88,8 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
if "replaces_state" in event.unsigned:
allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)( return type(event)(
allowed_fields, allowed_fields,

View File

@ -31,6 +31,9 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
def __init__(self, hs):
pass
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False): include_none=False):

View File

@ -24,6 +24,7 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
@ -50,7 +51,33 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
self._clock.looping_call(
self._clear_tried_cache, 60 * 1000,
)
def _clear_tried_cache(self):
"""Clear pdu_destination_tried cache"""
now = self._clock.time_msec()
old_dict = self.pdu_destination_tried
self.pdu_destination_tried = {}
for event_id, destination_dict in old_dict.items():
destination_dict = {
dest: time
for dest, time in destination_dict.items()
if time + PDU_RETRY_TIME_MS > now
}
if destination_dict:
self.pdu_destination_tried[event_id] = destination_dict
def start_get_pdu_cache(self): def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache( self._get_pdu_cache = ExpiringCache(
@ -233,12 +260,19 @@ class FederationClient(FederationBase):
# TODO: Rate limit the number of times we try and get the same event. # TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache: if self._get_pdu_cache:
e = self._get_pdu_cache.get(event_id) ev = self._get_pdu_cache.get(event_id)
if e: if ev:
defer.returnValue(e) defer.returnValue(ev)
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
pdu = None pdu = None
for destination in destinations: for destination in destinations:
now = self._clock.time_msec()
last_attempt = pdu_attempts.get(destination, 0)
if last_attempt + PDU_RETRY_TIME_MS > now:
continue
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
@ -266,25 +300,19 @@ class FederationClient(FederationBase):
break break
except SynapseError: pdu_attempts[destination] = now
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
except SynapseError as e:
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
) )
continue
except NotRetryingDestination as e: except NotRetryingDestination as e:
logger.info(e.message) logger.info(e.message)
continue continue
except Exception as e: except Exception as e:
pdu_attempts[destination] = now
logger.info( logger.info(
"Failed to get PDU %s from %s because %s", "Failed to get PDU %s from %s because %s",
event_id, destination, e, event_id, destination, e,
@ -311,6 +339,42 @@ class FederationClient(FederationBase):
Deferred: Results in a list of PDUs. Deferred: Results in a list of PDUs.
""" """
try:
# First we try and ask for just the IDs, as thats far quicker if
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id,
)
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events(
[destination], room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
logger.warn("Failed to get %r", failed_to_fetch)
event_map = {
ev.event_id: ev for ev in fetched_events
}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
event_map[e_id] for e_id in auth_event_ids if e_id in event_map
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue((pdus, auth_chain))
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
else:
raise e
result = yield self.transport_layer.get_room_state( result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id, destination, room_id, event_id=event_id,
) )
@ -324,18 +388,93 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", []) for p in result.get("auth_chain", [])
] ]
seen_events = yield self.store.get_events([
ev.event_id for ev in itertools.chain(pdus, auth_chain)
])
signed_pdus = yield self._check_sigs_and_hash_and_fetch( signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus, outlier=True destination,
[p for p in pdus if p.event_id not in seen_events],
outlier=True
)
signed_pdus.extend(
seen_events[p.event_id] for p in pdus if p.event_id in seen_events
) )
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination,
[p for p in auth_chain if p.event_id not in seen_events],
outlier=True
)
signed_auth.extend(
seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
) )
signed_auth.sort(key=lambda e: e.depth) signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth)) defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
def get_events(self, destinations, room_id, event_ids, return_local=True):
"""Fetch events from some remote destinations, checking if we already
have them.
Args:
destinations (list)
room_id (str)
event_ids (list)
return_local (bool): Whether to include events we already have in
the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
signed_events = []
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
def random_server_list():
srvs = list(destinations)
random.shuffle(srvs)
return srvs
batch_size = 20
missing_events = list(missing_events)
for i in xrange(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for success, result in res:
if success:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
defer.returnValue((signed_events, failed_to_fetch))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
@ -411,14 +550,19 @@ class FederationClient(FederationBase):
(destination, self.event_from_pdu_json(pdu_dict)) (destination, self.event_from_pdu_json(pdu_dict))
) )
break break
except CodeMessageException: except CodeMessageException as e:
raise if not 500 <= e.code < 600:
raise
else:
logger.warn(
"Failed to make_%s via %s: %s",
membership, destination, e.message
)
except Exception as e: except Exception as e:
logger.warn( logger.warn(
"Failed to make_%s via %s: %s", "Failed to make_%s via %s: %s",
membership, destination, e.message membership, destination, e.message
) )
raise
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@ -490,8 +634,14 @@ class FederationClient(FederationBase):
"auth_chain": signed_auth, "auth_chain": signed_auth,
"origin": destination, "origin": destination,
}) })
except CodeMessageException: except CodeMessageException as e:
raise if not 500 <= e.code < 600:
raise
else:
logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
except Exception as e: except Exception as e:
logger.exception( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
@ -550,6 +700,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.") raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_public_rooms(self, destinations):
results_by_server = {}
@defer.inlineCallbacks
def _get_result(s):
if s == self.server_name:
defer.returnValue()
try:
result = yield self.transport_layer.get_public_rooms(s)
results_by_server[s] = result
except:
logger.exception("Error getting room list from server %r", s)
yield concurrently_execute(_get_result, destinations, 3)
defer.returnValue(results_by_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth): def query_auth(self, destination, room_id, event_id, local_auth):
""" """

View File

@ -19,11 +19,13 @@ from twisted.internet import defer
from .federation_base import FederationBase from .federation_base import FederationBase
from .units import Transaction, Edu from .units import Transaction, Edu
from synapse.util.async import Linearizer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.api.errors import FederationError, SynapseError from synapse.api.errors import AuthError, FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
@ -44,6 +46,18 @@ received_queries_counter = metrics.register_counter("received_queries", labels=[
class FederationServer(FederationBase): class FederationServer(FederationBase):
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
def set_handler(self, handler): def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate """Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are receipt of new PDUs from other home servers. The required methods are
@ -83,11 +97,14 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_backfill_request(self, origin, room_id, versions, limit): def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request( with (yield self._server_linearizer.queue((origin, room_id))):
origin, room_id, versions, limit pdus = yield self.handler.on_backfill_request(
) origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) res = self._transaction_from_pdus(pdus).get_dict()
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -178,15 +195,59 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_context_state_request(self, origin, room_id, event_id): def on_context_state_request(self, origin, room_id, event_id):
if event_id: if not event_id:
pdus = yield self.handler.get_state_for_pdu( raise NotImplementedError("Specify an event")
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain: in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
result = self._state_resp_cache.get((room_id, event_id))
if not result:
with (yield self._server_linearizer.queue((origin, room_id))):
resp = yield self._state_resp_cache.set(
(room_id, event_id),
self._on_context_state_request_compute(room_id, event_id)
)
else:
resp = yield result
defer.returnValue((200, resp))
@defer.inlineCallbacks
def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
defer.returnValue((200, {
"pdu_ids": [pdu.event_id for pdu in pdus],
"auth_chain_ids": [pdu.event_id for pdu in auth_chain],
}))
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
pdus = yield self.handler.get_state_for_pdu(
room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update( event.signatures.update(
compute_event_signature( compute_event_signature(
event, event,
@ -194,13 +255,11 @@ class FederationServer(FederationBase):
self.hs.config.signing_key[0] self.hs.config.signing_key[0]
) )
) )
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, { defer.returnValue({
"pdus": [pdu.get_pdu_json() for pdu in pdus], "pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
})) })
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -274,14 +333,16 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id): def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec() with (yield self._server_linearizer.queue((origin, room_id))):
auth_pdus = yield self.handler.on_event_auth(event_id) time_now = self._clock.time_msec()
defer.returnValue((200, { auth_pdus = yield self.handler.on_event_auth(event_id)
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], res = {
})) "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}
defer.returnValue((200, res))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth_request(self, origin, content, event_id): def on_query_auth_request(self, origin, content, room_id, event_id):
""" """
Content is a dict with keys:: Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain. auth_chain (list): A list of events that give the auth chain.
@ -300,58 +361,41 @@ class FederationServer(FederationBase):
Returns: Returns:
Deferred: Results in `dict` with the same format as `content` Deferred: Results in `dict` with the same format as `content`
""" """
auth_chain = [ with (yield self._server_linearizer.queue((origin, room_id))):
self.event_from_pdu_json(e) auth_chain = [
for e in content["auth_chain"] self.event_from_pdu_json(e)
] for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch( signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True origin, auth_chain, outlier=True
) )
ret = yield self.handler.on_query_auth( ret = yield self.handler.on_query_auth(
origin, origin,
event_id, event_id,
signed_auth, signed_auth,
content.get("rejects", []), content.get("rejects", []),
content.get("missing", []), content.get("missing", []),
) )
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
send_content = { send_content = {
"auth_chain": [ "auth_chain": [
e.get_pdu_json(time_now) e.get_pdu_json(time_now)
for e in ret["auth_chain"] for e in ret["auth_chain"]
], ],
"rejects": ret.get("rejects", []), "rejects": ret.get("rejects", []),
"missing": ret.get("missing", []), "missing": ret.get("missing", []),
} }
defer.returnValue( defer.returnValue(
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function @log_function
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
query = [] return self.on_query_request("client_keys", content)
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -377,11 +421,24 @@ class FederationServer(FederationBase):
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth): latest_events, limit, min_depth):
missing_events = yield self.handler.on_get_missing_events( with (yield self._server_linearizer.queue((origin, room_id))):
origin, room_id, earliest_events, latest_events, limit, min_depth logger.info(
) "on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d, min_depth: %d",
earliest_events, latest_events, limit, min_depth
)
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
time_now = self._clock.time_msec() if len(missing_events) < 5:
logger.info(
"Returning %d events: %r", len(missing_events), missing_events
)
else:
logger.info("Returning %d events", len(missing_events))
time_now = self._clock.time_msec()
defer.returnValue({ defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events], "events": [ev.get_pdu_json(time_now) for ev in missing_events],
@ -481,42 +538,59 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth: elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen: if get_missing and prevs - seen:
latest = yield self.store.get_latest_event_ids_in_room( # If we're missing stuff, ensure we only fetch stuff one
pdu.room_id # at a time.
) with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
# We add the prev events that we have seen to the latest if prevs - seen:
# list to ensure the remote server doesn't give them to us latest = yield self.store.get_latest_event_ids_in_room(
latest = set(latest) pdu.room_id
latest |= seen )
missing_events = yield self.get_missing_events( # We add the prev events that we have seen to the latest
origin, # list to ensure the remote server doesn't give them to us
pdu.room_id, latest = set(latest)
earliest_events_ids=list(latest), latest |= seen
latest_events=[pdu],
limit=10,
min_depth=min_depth,
)
# We want to sort these by depth so we process them and logger.info(
# tell clients about them in order. "Missing %d events for room %r: %r...",
missing_events.sort(key=lambda x: x.depth) len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
for e in missing_events: missing_events = yield self.get_missing_events(
yield self._handle_new_pdu( origin,
origin, pdu.room_id,
e, earliest_events_ids=list(latest),
get_missing=False latest_events=[pdu],
) limit=10,
min_depth=min_depth,
)
have_seen = yield self.store.have_events( # We want to sort these by depth so we process them and
[ev for ev, _ in pdu.prev_events] # tell clients about them in order.
) missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self._handle_new_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
prevs = {e_id for e_id, _ in pdu.prev_events} prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys()) seen = set(have_seen.keys())
if prevs - seen: if prevs - seen:
logger.info(
"Still missing %d events for room %r: %r...",
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
fetch_state = True fetch_state = True
if fetch_state: if fetch_state:
@ -531,7 +605,7 @@ class FederationServer(FederationBase):
origin, pdu.room_id, pdu.event_id, origin, pdu.room_id, pdu.event_id,
) )
except: except:
logger.warn("Failed to get state for event: %s", pdu.event_id) logger.exception("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu( yield self.handler.on_receive_pdu(
origin, origin,

View File

@ -72,5 +72,7 @@ class ReplicationLayer(FederationClient, FederationServer):
self.hs = hs self.hs = hs
super(ReplicationLayer, self).__init__(hs)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name

View File

@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination, get_retry_limiter, NotRetryingDestination,
) )
from synapse.util.metrics import measure_func
import synapse.metrics import synapse.metrics
import logging import logging
@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer self.transport_layer = transport_layer
self._clock = hs.get_clock() self.clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track # Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are # of which destinations have transactions in flight and when they are
@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
def can_send_to(self, destination): def can_send_to(self, destination):
"""Can we send messages to the given server? """Can we send messages to the given server?
@ -119,266 +119,215 @@ class TransactionQueue(object):
if not destinations: if not destinations:
return return
deferreds = []
for destination in destinations: for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append( self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order) (pdu, order)
) )
def chain(failure): preserve_context_over_fn(
if not deferred.called: self._attempt_new_transaction, destination
deferred.errback(failure) )
def log_failure(f):
logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
# NO inlineCallbacks
def enqueue_edu(self, edu): def enqueue_edu(self, edu):
destination = edu.destination destination = edu.destination
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
deferred = defer.Deferred() self.pending_edus_by_dest.setdefault(destination, []).append(edu)
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred) preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(f):
logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination): def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost": if destination == self.server_name or destination == "localhost":
return return
deferred = defer.Deferred()
if not self.can_send_to(destination): if not self.can_send_to(destination):
return return
self.pending_failures_by_dest.setdefault( self.pending_failures_by_dest.setdefault(
destination, [] destination, []
).append( ).append(failure)
(failure, deferred)
preserve_context_over_fn(
self._attempt_new_transaction, destination
) )
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor() yield run_on_reactor()
while True:
# list of (pending_pdu, deferred, order)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
"TX [%s] Transaction already in progress",
destination
)
return
# list of (pending_pdu, deferred, order) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
if destination in self.pending_transactions: pending_edus = self.pending_edus_by_dest.pop(destination, [])
# XXX: pending_transactions can get stuck on by a never-ending pending_failures = self.pending_failures_by_dest.pop(destination, [])
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these if pending_pdus:
# requests logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
logger.debug( destination, len(pending_pdus))
"TX [%s] Transaction already in progress",
destination if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures
) )
return
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) @measure_func("_send_new_transaction")
pending_edus = self.pending_edus_by_dest.pop(destination, []) @defer.inlineCallbacks
pending_failures = self.pending_failures_by_dest.pop(destination, []) def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures):
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
return
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] _attempt_new_transaction", destination)
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[2]) pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus] pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus] edus = pending_edus
failures = [x[0].get_dict() for x in pending_failures] failures = [x.get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
txn_id = str(self._next_txn_id) try:
self.pending_transactions[destination] = 1
limiter = yield get_retry_limiter( logger.debug("TX [%s] _attempt_new_transaction", destination)
destination,
self._clock,
self.store,
)
logger.debug( txn_id = str(self._next_txn_id)
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
logger.debug("TX [%s] Persisting transaction...", destination) limiter = yield get_retry_limiter(
destination,
transaction = Transaction.create_new( self.clock,
origin_server_ts=int(self._clock.time_msec()), self.store,
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
) )
logger.debug("TX [%s] Sent transaction", destination) logger.debug(
logger.debug("TX [%s] Marking as delivered...", destination) "TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
yield self.transaction_actions.delivered( logger.debug("TX [%s] Persisting transaction...", destination)
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination) transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
logger.debug("TX [%s] Yielding to callbacks...", destination) self._next_txn_id += 1
for deferred in deferreds: yield self.transaction_actions.prepare_to_send(transaction)
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that logger.debug("TX [%s] Persisted transaction", destination)
# deferred have fired logger.info(
try: "TX [%s] {%s} Sending transaction [%s],"
yield deferred " (PDUs: %d, EDUs: %d, failures: %d)",
except: destination, txn_id,
pass transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
logger.debug("TX [%s] Yielded to callbacks", destination) with limiter:
except NotRetryingDestination: # Actually send the transaction
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for deferred in deferreds: # FIXME (erikj): This is a bit of a hack to make the Pdu age
if not deferred.called: # keys work
deferred.errback(e) def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
finally: try:
# We want to be *very* sure we delete this after we stop processing response = yield self.transport_layer.send_transaction(
self.pending_transactions.pop(destination, None) transaction, json_data_cb
)
code = 200
# Check to see if there is anything else to send. if response:
self._attempt_new_transaction(destination) for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
if code != 200:
for p in pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for p in pdus:
logger.info("Failed to send event %s to %s", p.event_id, destination)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)

View File

@ -54,6 +54,28 @@ class TransportLayerClient(object):
destination, path=path, args={"event_id": event_id}, destination, path=path, args={"event_id": event_id},
) )
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
destination (str): The host name of the remote home server we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
path = PREFIX + "/state_ids/%s/" % room_id
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@log_function @log_function
def get_event(self, destination, event_id, timeout=None): def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server. """ Requests the pdu with give id and origin from the given server.
@ -224,6 +246,18 @@ class TransportLayerClient(object):
defer.returnValue(response) defer.returnValue(response)
@defer.inlineCallbacks
@log_function
def get_public_rooms(self, remote_server):
path = PREFIX + "/publicRooms"
response = yield self.client.get_json(
destination=remote_server,
path=path,
)
defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def exchange_third_party_invite(self, destination, room_id, event_dict): def exchange_third_party_invite(self, destination, room_id, event_dict):

View File

@ -18,13 +18,14 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request, parse_string from synapse.http.servlet import parse_json_object_from_request
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
import functools import functools
import logging import logging
import simplejson as json
import re import re
import synapse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,7 +38,7 @@ class TransportLayerServer(JsonResource):
self.hs = hs self.hs = hs
self.clock = hs.get_clock() self.clock = hs.get_clock()
super(TransportLayerServer, self).__init__(hs) super(TransportLayerServer, self).__init__(hs, canonical_json=False)
self.authenticator = Authenticator(hs) self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter( self.ratelimiter = FederationRateLimiter(
@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
) )
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
pass
class Authenticator(object): class Authenticator(object):
def __init__(self, hs): def __init__(self, hs):
self.keyring = hs.get_keyring() self.keyring = hs.get_keyring()
@ -67,7 +78,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks @defer.inlineCallbacks
def authenticate_request(self, request): def authenticate_request(self, request, content):
json_request = { json_request = {
"method": request.method, "method": request.method,
"uri": request.uri, "uri": request.uri,
@ -75,17 +86,10 @@ class Authenticator(object):
"signatures": {}, "signatures": {},
} }
content = None if content is not None:
origin = None json_request["content"] = content
if request.method in ["PUT", "POST"]: origin = None
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
content = json.loads(content_bytes)
json_request["content"] = content
except:
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
def parse_auth_header(header_str): def parse_auth_header(header_str):
try: try:
@ -103,14 +107,14 @@ class Authenticator(object):
sig = strip_quotes(param_dict["sig"]) sig = strip_quotes(param_dict["sig"])
return (origin, key, sig) return (origin, key, sig)
except: except:
raise SynapseError( raise AuthenticationError(
400, "Malformed Authorization header", Codes.UNAUTHORIZED 400, "Malformed Authorization header", Codes.UNAUTHORIZED
) )
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
if not auth_headers: if not auth_headers:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -121,7 +125,7 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig json_request["signatures"].setdefault(origin, {})[key] = sig
if not json_request["signatures"]: if not json_request["signatures"]:
raise SynapseError( raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED, 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
) )
@ -130,38 +134,59 @@ class Authenticator(object):
logger.info("Request from %s", origin) logger.info("Request from %s", origin)
request.authenticated_entity = origin request.authenticated_entity = origin
defer.returnValue((origin, content)) defer.returnValue(origin)
class BaseFederationServlet(object): class BaseFederationServlet(object):
def __init__(self, handler, authenticator, ratelimiter, server_name): REQUIRE_AUTH = True
def __init__(self, handler, authenticator, ratelimiter, server_name,
room_list_handler):
self.handler = handler self.handler = handler
self.authenticator = authenticator self.authenticator = authenticator
self.ratelimiter = ratelimiter self.ratelimiter = ratelimiter
self.room_list_handler = room_list_handler
def _wrap(self, code): def _wrap(self, func):
authenticator = self.authenticator authenticator = self.authenticator
ratelimiter = self.ratelimiter ratelimiter = self.ratelimiter
@defer.inlineCallbacks @defer.inlineCallbacks
@functools.wraps(code) @functools.wraps(func)
def new_code(request, *args, **kwargs): def new_func(request, *args, **kwargs):
content = None
if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
content = parse_json_object_from_request(request)
try: try:
(origin, content) = yield authenticator.authenticate_request(request) origin = yield authenticator.authenticate_request(request, content)
with ratelimiter.ratelimit(origin) as d: except NoAuthenticationError:
yield d origin = None
response = yield code( if self.REQUIRE_AUTH:
origin, content, request.args, *args, **kwargs logger.exception("authenticate_request failed")
) raise
except: except:
logger.exception("authenticate_request failed") logger.exception("authenticate_request failed")
raise raise
if origin:
with ratelimiter.ratelimit(origin) as d:
yield d
response = yield func(
origin, content, request.args, *args, **kwargs
)
else:
response = yield func(
origin, content, request.args, *args, **kwargs
)
defer.returnValue(response) defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish # Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__ new_func.__self__ = func.__self__
return new_code return new_func
def register(self, server): def register(self, server):
pattern = re.compile("^" + PREFIX + self.PATH + "$") pattern = re.compile("^" + PREFIX + self.PATH + "$")
@ -269,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet):
) )
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/"
def on_GET(self, origin, content, query, room_id):
return self.handler.on_state_ids_request(
origin,
room_id,
query.get("event_id", [None])[0],
)
class FederationBackfillServlet(BaseFederationServlet): class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/" PATH = "/backfill/(?P<context>[^/]*)/"
@ -365,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
class FederationClientKeysQueryServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query" PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query): def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
@ -386,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, origin, content, query, context, event_id): def on_POST(self, origin, content, query, context, event_id):
new_content = yield self.handler.on_query_auth_request( new_content = yield self.handler.on_query_auth_request(
origin, content, event_id origin, content, context, event_id
) )
defer.returnValue((200, new_content)) defer.returnValue((200, new_content))
@ -418,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
class On3pidBindServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet):
PATH = "/3pid/onbind" PATH = "/3pid/onbind"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, origin, content, query):
content = parse_json_object_from_request(request)
if "invites" in content: if "invites" in content:
last_exception = None last_exception = None
for invite in content["invites"]: for invite in content["invites"]:
@ -442,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
raise last_exception raise last_exception
defer.returnValue((200, {})) defer.returnValue((200, {}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
class OpenIdUserInfo(BaseFederationServlet): class OpenIdUserInfo(BaseFederationServlet):
""" """
@ -467,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
PATH = "/openid/userinfo" PATH = "/openid/userinfo"
REQUIRE_AUTH = False
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, origin, content, query):
token = parse_string(request, "access_token") token = query.get("access_token", [None])[0]
if token is None: if token is None:
defer.returnValue((401, { defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required" "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
@ -486,10 +518,58 @@ class OpenIdUserInfo(BaseFederationServlet):
defer.returnValue((200, {"sub": user_id})) defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet. class PublicRoomList(BaseFederationServlet):
def _wrap(self, code): """
return code Fetch the public room list for this server.
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
intended for consumption by other home servers.
GET /publicRooms HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"chunk": [
{
"aliases": [
"#test:localhost"
],
"guest_can_join": false,
"name": "test room",
"num_joined_members": 3,
"room_id": "!whkydVegtvatLfXmPN:localhost",
"world_readable": false
}
],
"end": "END",
"start": "START"
}
"""
PATH = "/publicRooms"
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
data = yield self.room_list_handler.get_local_public_room_list()
defer.returnValue((200, data))
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
return defer.succeed((200, {
"server": {
"name": "Synapse",
"version": get_version_string(synapse)
},
}))
SERVLET_CLASSES = ( SERVLET_CLASSES = (
@ -497,6 +577,7 @@ SERVLET_CLASSES = (
FederationPullServlet, FederationPullServlet,
FederationEventServlet, FederationEventServlet,
FederationStateServlet, FederationStateServlet,
FederationStateIdsServlet,
FederationBackfillServlet, FederationBackfillServlet,
FederationQueryServlet, FederationQueryServlet,
FederationMakeJoinServlet, FederationMakeJoinServlet,
@ -513,6 +594,8 @@ SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo, OpenIdUserInfo,
PublicRoomList,
FederationVersionServlet,
) )
@ -523,4 +606,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
authenticator=authenticator, authenticator=authenticator,
ratelimiter=ratelimiter, ratelimiter=ratelimiter,
server_name=hs.hostname, server_name=hs.hostname,
room_list_handler=hs.get_room_list_handler(),
).register(resource) ).register(resource)

View File

@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomListHandler, RoomContextHandler, RoomCreationHandler, RoomContextHandler,
) )
from .room_member import RoomMemberHandler from .room_member import RoomMemberHandler
from .message import MessageHandler from .message import MessageHandler
@ -26,8 +24,6 @@ from .federation import FederationHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler from .receipts import ReceiptsHandler
from .search import SearchHandler from .search import SearchHandler
@ -35,10 +31,21 @@ from .search import SearchHandler
class Handlers(object): class Handlers(object):
""" A collection of all the event handlers. """ Deprecated. A collection of handlers.
There's no need to lazily create these; we'll just make them all eagerly At some point most of the classes whose name ended "Handler" were
at construction time. accessed through this class.
However this makes it painful to unit test the handlers and to run cut
down versions of synapse that only use specific handlers because using a
single handler required creating all of the handlers. So some of the
handlers have been lifted out of the Handlers object and are now accessed
directly through the homeserver object itself.
Any new handlers should follow the new pattern of being accessed through
the homeserver object and should not be added to the Handlers object.
The remaining handlers should be moved out of the handlers object.
""" """
def __init__(self, hs): def __init__(self, hs):
@ -50,19 +57,9 @@ class Handlers(object):
self.event_handler = EventHandler(hs) self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs) self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs) self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs) self.identity_handler = IdentityHandler(hs)
self.search_handler = SearchHandler(hs) self.search_handler = SearchHandler(hs)
self.room_context_handler = RoomContextHandler(hs) self.room_context_handler = RoomContextHandler(hs)

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError import synapse.types
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, Requester from synapse.api.errors import LimitExceededError
from synapse.types import UserID
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,11 +31,15 @@ class BaseHandler(object):
Common base class for the event handlers. Common base class for the event handlers.
Attributes: Attributes:
store (synapse.storage.events.StateStore): store (synapse.storage.DataStore):
state_handler (synapse.state.StateHandler): state_handler (synapse.state.StateHandler):
""" """
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -120,7 +124,8 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = Requester(target_user, "", True) requester = synapse.types.create_requester(
target_user, is_guest=True)
handler = self.hs.get_handlers().room_member_handler handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership( yield handler.update_membership(
requester, requester,

View File

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID
import logging import logging
@ -35,16 +34,13 @@ def log_failure(failure):
) )
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object): class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api, appservice_scheduler): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.is_mine_id = hs.is_mine_id
self.appservice_api = appservice_api self.appservice_api = hs.get_application_service_api()
self.scheduler = appservice_scheduler self.scheduler = hs.get_application_service_scheduler()
self.started_scheduler = False self.started_scheduler = False
@defer.inlineCallbacks @defer.inlineCallbacks
@ -169,8 +165,7 @@ class ApplicationServicesHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_unknown_user(self, user_id): def _is_unknown_user(self, user_id):
user = UserID.from_string(user_id) if not self.is_mine_id(user_id):
if not self.hs.is_mine(user):
# we don't know if they are unknown or not since it isn't one of our # we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes. # users. We can't poke ASes.
defer.returnValue(False) defer.returnValue(False)

View File

@ -18,8 +18,9 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import AuthError, LoginError, Codes from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.config.ldap import LDAPMode
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -28,6 +29,12 @@ import bcrypt
import pymacaroons import pymacaroons
import simplejson import simplejson
try:
import ldap3
except ImportError:
ldap3 = None
pass
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -38,6 +45,10 @@ class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs) super(AuthHandler, self).__init__(hs)
self.checkers = { self.checkers = {
LoginType.PASSWORD: self._check_password_auth, LoginType.PASSWORD: self._check_password_auth,
@ -50,19 +61,23 @@ class AuthHandler(BaseHandler):
self.INVALID_TOKEN_HTTP_STATUS = 401 self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server if self.ldap_enabled:
self.ldap_port = hs.config.ldap_port if not ldap3:
self.ldap_tls = hs.config.ldap_tls raise RuntimeError(
self.ldap_search_base = hs.config.ldap_search_base 'Missing ldap3 library. This is required for LDAP Authentication.'
self.ldap_search_property = hs.config.ldap_search_property )
self.ldap_email_property = hs.config.ldap_email_property self.ldap_mode = hs.config.ldap_mode
self.ldap_full_name_property = hs.config.ldap_full_name_property self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
if self.ldap_enabled is True: self.ldap_base = hs.config.ldap_base
import ldap self.ldap_filter = hs.config.ldap_filter
logger.info("Import ldap version: %s", ldap.__version__) self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
self.hs = hs # FIXME better possibility to access registrationHandler later? self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -220,7 +235,6 @@ class AuthHandler(BaseHandler):
sess = self._get_session_info(session_id) sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default) return sess.setdefault('serverdict', {}).get(key, default)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _): def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
@ -230,11 +244,7 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'): if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
if not (yield self._check_password(user_id, password)): return self._check_password(user_id, password)
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -270,8 +280,17 @@ class AuthHandler(BaseHandler):
data = pde.response data = pde.response
resp_body = simplejson.loads(data) resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']: if 'success' in resp_body:
defer.returnValue(True) # Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
"Successful" if resp_body['success'] else "Failed",
resp_body.get('hostname')
)
if resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -338,67 +357,84 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks def validate_password_login(self, user_id, password):
def login_with_password(self, user_id, password):
""" """
Authenticates the user with their username and password. Authenticates the user with their username and password.
Used only by the v1 login API. Used only by the v1 login API.
Args: Args:
user_id (str): User ID user_id (str): complete @user:id
password (str): Password password (str): Password
Returns: Returns:
A tuple of: defer.Deferred: (str) canonical user id
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem accessing the database
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
return self._check_password(user_id, password)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_login_tuple_for_user_id(self, user_id): def get_login_tuple_for_user_id(self, user_id, device_id=None,
initial_display_name=None):
""" """
Gets login tuple for the user with the given user ID. Gets login tuple for the user with the given user ID.
Creates a new access/refresh token for the user.
The user is assumed to have been authenticated by some other The user is assumed to have been authenticated by some other
machanism (e.g. CAS) machanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
Args: Args:
user_id (str): User ID user_id (str): canonical User ID
device_id (str|None): the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
initial_display_name (str): display name to associate with the
device if it needs re-registering
Returns: Returns:
A tuple of: A tuple of:
The user's ID.
The access token for the user's session. The access token for the user's session.
The refresh token for the user's session. The refresh token for the user's session.
Raises: Raises:
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) logger.info("Logging in user %s on device %s", user_id, device_id)
access_token = yield self.issue_access_token(user_id, device_id)
refresh_token = yield self.issue_refresh_token(user_id, device_id)
logger.info("Logging in user %s", user_id) # the device *should* have been registered before we got here; however,
access_token = yield self.issue_access_token(user_id) # it's possible we raced against a DELETE operation. The thing we
refresh_token = yield self.issue_refresh_token(user_id) # really don't want is active access_tokens without a record of the
defer.returnValue((user_id, access_token, refresh_token)) # device, so we double-check it here.
if device_id is not None:
yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
defer.returnValue((access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def does_user_exist(self, user_id): def check_user_exists(self, user_id):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
(str) user_id: complete @user:id
Returns:
defer.Deferred: (str) canonical_user_id, or None if zero or
multiple matches
"""
try: try:
yield self._find_user_id_and_pwd_hash(user_id) res = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(True) defer.returnValue(res[0])
except LoginError: except LoginError:
defer.returnValue(False) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id): def _find_user_id_and_pwd_hash(self, user_id):
@ -428,84 +464,232 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
""" """Authenticate a user against the LDAP and local databases.
user_id is checked case insensitively against the local database, but
will throw if there are multiple inexact matches.
Args:
user_id (str): complete @user:id
Returns: Returns:
True if the user_id successfully authenticated (str) the canonical_user_id
Raises:
LoginError if the password was incorrect
""" """
valid_ldap = yield self._check_ldap_password(user_id, password) valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap: if valid_ldap:
defer.returnValue(True) defer.returnValue(user_id)
valid_local_password = yield self._check_local_password(user_id, password) result = yield self._check_local_password(user_id, password)
if valid_local_password: defer.returnValue(result)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_local_password(self, user_id, password): def _check_local_password(self, user_id, password):
try: """Authenticate a user against the local password database.
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash)) user_id is checked case insensitively, but will throw if there are
except LoginError: multiple inexact matches.
defer.returnValue(False)
Args:
user_id (str): complete @user:id
Returns:
(str) the canonical_user_id
Raises:
LoginError if the password was incorrect
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
result = self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_ldap_password(self, user_id, password): def _check_ldap_password(self, user_id, password):
if not self.ldap_enabled: """ Attempt to authenticate a user against an LDAP Server
logger.debug("LDAP not configured") and register an account if none exists.
Returns:
True if authentication against LDAP was successful
"""
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False) defer.returnValue(False)
import ldap if self.ldap_mode not in LDAPMode.LIST:
raise RuntimeError(
'Invalid ldap mode specified: {mode}'.format(
mode=self.ldap_mode
)
)
logger.info("Authenticating %s with LDAP" % user_id)
try: try:
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) server = ldap3.Server(self.ldap_uri)
logger.debug("Connecting LDAP server at %s" % ldap_url) logger.debug(
l = ldap.initialize(ldap_url) "Attempting ldap connection with %s",
if self.ldap_tls: self.ldap_uri
logger.debug("Initiating TLS") )
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
dn = "%s=%s, %s" % ( # bind with the the local users ldap credentials
self.ldap_search_property, bind_dn = "{prop}={value},{base}".format(
local_name, prop=self.ldap_attributes['uid'],
self.ldap_search_base) value=localpart,
logger.debug("DN for LDAP authentication: %s" % dn) base=self.ldap_base
)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
if not (yield self.does_user_exist(user_id)): "Established ldap connection in simple mode: %s",
handler = self.hs.get_handlers().registration_handler conn
user_id, access_token = (
yield handler.register(localpart=local_name)
) )
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in simple mode through StartTLS: %s",
conn
)
conn.bind()
elif self.ldap_mode == LDAPMode.SEARCH:
# connect with preconfigured credentials and search for local user
conn = ldap3.Connection(
server,
self.ldap_bind_dn,
self.ldap_bind_password
)
logger.debug(
"Established ldap connection in search mode: %s",
conn
)
if self.ldap_start_tls:
conn.start_tls()
logger.debug(
"Upgraded ldap connection in search mode through StartTLS: %s",
conn
)
conn.bind()
# find matching dn
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_filter:
query = "(&{query}{filter})".format(
query=query,
filter=self.ldap_filter
)
logger.debug("ldap search filter: %s", query)
result = conn.search(self.ldap_base, query)
if result and len(conn.response) == 1:
# found exactly one result
user_dn = conn.response[0]['dn']
logger.debug('ldap search found dn: %s', user_dn)
# unbind and reconnect, rebind with found dn
conn.unbind()
conn = ldap3.Connection(
server,
user_dn,
password,
auto_bind=True
)
else:
# found 0 or > 1 results, abort!
logger.warn(
"ldap search returned unexpected (%d!=1) amount of results",
len(conn.response)
)
defer.returnValue(False)
logger.info(
"User authenticated against ldap server: %s",
conn
)
# check for existing account, if none exists, create one
if not (yield self.check_user_exists(user_id)):
# query user metadata for account creation
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
)
if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter:
query = "(&{filter}{user_filter})".format(
filter=query,
user_filter=self.ldap_filter
)
logger.debug("ldap registration filter: %s", query)
result = conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
self.ldap_attributes['name'],
self.ldap_attributes['mail']
]
)
if len(conn.response) == 1:
attrs = conn.response[0]['attributes']
mail = attrs[self.ldap_attributes['mail']][0]
name = attrs[self.ldap_attributes['name']][0]
# create account
registration_handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield registration_handler.register(localpart=localpart)
)
# TODO: bind email, set displayname with data from ldap directory
logger.info(
"ldap registration successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
len(result)
)
defer.returnValue(False)
defer.returnValue(True) defer.returnValue(True)
except ldap.LDAPError, e: except ldap3.core.exceptions.LDAPException as e:
logger.warn("LDAP error: %s", e) logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_access_token(self, user_id): def issue_access_token(self, user_id, device_id=None):
access_token = self.generate_access_token(user_id) access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token) yield self.store.add_access_token_to_user(user_id, access_token,
device_id)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks @defer.inlineCallbacks
def issue_refresh_token(self, user_id): def issue_refresh_token(self, user_id, device_id=None):
refresh_token = self.generate_refresh_token(user_id) refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token) yield self.store.add_refresh_token_to_user(user_id, refresh_token,
device_id)
defer.returnValue(refresh_token) defer.returnValue(refresh_token)
def generate_access_token(self, user_id, extra_caveats=None): def generate_access_token(self, user_id, extra_caveats=None,
duration_in_ms=(60 * 60 * 1000)):
extra_caveats = extra_caveats or [] extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id) macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000) expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
for caveat in extra_caveats: for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat) macaroon.add_first_party_caveat(caveat)
@ -529,14 +713,20 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("time < %d" % (expiry,)) macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize() return macaroon.serialize()
def generate_delete_pusher_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token): def validate_short_term_login_token_and_get_user_id(self, login_token):
auth_api = self.hs.get_auth()
try: try:
macaroon = pymacaroons.Macaroon.deserialize(login_token) macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth() user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", True) auth_api.validate_macaroon(macaroon, "login", True, user_id)
return self.get_user_from_macaroon(macaroon) return user_id
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): except Exception:
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN) raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
@ -547,23 +737,18 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None): def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)
except_access_token_ids = [requester.access_token_id] if requester else [] except_access_token_ids = [requester.access_token_id] if requester else []
yield self.store.user_set_password_hash(user_id, password_hash) try:
yield self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
yield self.store.user_delete_access_tokens( yield self.store.user_delete_access_tokens(
user_id, except_access_token_ids user_id, except_access_token_ids
) )
@ -603,7 +788,8 @@ class AuthHandler(BaseHandler):
Returns: Returns:
Hashed password (str). Hashed password (str).
""" """
return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) return bcrypt.hashpw(password + self.hs.config.password_pepper,
bcrypt.gensalt(self.bcrypt_rounds))
def validate_hash(self, password, stored_hash): def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash. """Validates that self.hash(password) == stored_hash.
@ -616,6 +802,7 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool). Whether self.hash(password) == stored_hash (bool).
""" """
if stored_hash: if stored_hash:
return bcrypt.hashpw(password, stored_hash) == stored_hash return bcrypt.hashpw(password + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash
else: else:
return False return False

181
synapse/handlers/device.py Normal file
View File

@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api import errors
from synapse.util import stringutils
from twisted.internet import defer
from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
"""
If the given device has not been registered, register it with the
supplied display name.
If no device_id is supplied, we make one up.
Args:
user_id (str): @user:id
device_id (str | None): device id supplied by client
initial_device_display_name (str | None): device display name from
client
Returns:
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string_with_symbols(16)
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id) for device_id in device_map.keys())
)
devices = device_map.values()
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
devices=((user_id, device_id),)
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
"""
try:
yield self.store.delete_device(user_id, device_id)
except errors.StoreError, e:
if e.code == 404:
# no match
pass
else:
raise
yield self.store.user_delete_access_tokens(
user_id, device_id=device_id,
delete_refresh_tokens=True,
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
"""
try:
yield self.store.update_device(
user_id,
device_id,
new_display_name=content.get("display_name")
)
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
device.update({
"last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"),
})

View File

@ -33,6 +33,7 @@ class DirectoryHandler(BaseHandler):
super(DirectoryHandler, self).__init__(hs) super(DirectoryHandler, self).__init__(hs)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
@ -281,7 +282,7 @@ class DirectoryHandler(BaseHandler):
) )
if not result: if not result:
# Query AS to see if it exists # Query AS to see if it exists
as_handler = self.hs.get_handlers().appservice_handler as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias) result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result) defer.returnValue(result)

View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import json
import logging
from twisted.internet import defer
from synapse.api import errors
import synapse.types
logger = logging.getLogger(__name__)
class E2eKeysHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
self.federation.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
@defer.inlineCallbacks
def query_devices(self, query_body):
""" Handle a device key query from a client
{
"device_keys": {
"<user_id>": ["<device_id>"]
}
}
->
{
"device_keys": {
"<user_id>": {
"<device_id>": {
...
}
}
}
}
"""
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
# make a map from domain to user_id to device_ids
queries_by_domain = collections.defaultdict(dict)
for user_id, device_ids in device_keys_query.items():
user = synapse.types.UserID.from_string(user_id)
queries_by_domain[user.domain][user_id] = device_ids
# do the queries
# TODO: do these in parallel
results = {}
for destination, destination_query in queries_by_domain.items():
if destination == self.server_name:
res = yield self.query_local_devices(destination_query)
else:
res = yield self.federation.query_client_keys(
destination, {"device_keys": destination_query}
)
res = res["device_keys"]
for user_id, keys in res.items():
if user_id in destination_query:
results[user_id] = keys
defer.returnValue((200, {"device_keys": results}))
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
query (dict[string, list[string]|None): map from user_id to a list
of devices to query (None for all devices)
Returns:
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
if not self.is_mine_id(user_id):
logger.warning("Request for keys for non-local user %s",
user_id)
raise errors.SynapseError(400, "Not a user here")
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
defer.returnValue(result_dict)
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
defer.returnValue({"device_keys": res})

View File

@ -66,10 +66,6 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.distributor.observe("user_joined_room", self.user_joined_room)
self.waiting_for_join_list = {}
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.replication_layer = hs.get_replication_layer() self.replication_layer = hs.get_replication_layer()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -128,7 +124,7 @@ class FederationHandler(BaseHandler):
try: try:
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event origin, auth_chain, state, event
) )
except AuthError as e: except AuthError as e:
raise FederationError( raise FederationError(
@ -253,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
domain = UserID.from_string(ev.state_key).domain domain = get_domain_from_id(ev.state_key)
except: except:
continue continue
@ -339,30 +335,59 @@ class FederationHandler(BaseHandler):
state_events.update({s.event_id: s for s in state}) state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state events_to_state[e_id] = state
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for a_id, _ in event.auth_events
)
auth_events.update({
e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
})
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch
)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
required_auth.update(
a_id for event in results for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_events( seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys()) set(auth_events.keys()) | set(state_events.keys())
) )
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = [] ev_infos = []
for a in auth_events.values(): for a in auth_events.values():
if a.event_id in seen_events: if a.event_id in seen_events:
@ -374,6 +399,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in a.auth_events for a_id, _ in a.auth_events
if a_id in auth_events
} }
}) })
@ -385,6 +411,7 @@ class FederationHandler(BaseHandler):
(auth_events[a_id].type, auth_events[a_id].state_key): (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id] auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events for a_id, _ in event_map[e_id].auth_events
if a_id in auth_events
} }
}) })
@ -403,7 +430,7 @@ class FederationHandler(BaseHandler):
# previous to work out the state. # previous to work out the state.
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
yield self._handle_new_event( yield self._handle_new_event(
dest, event dest, event, backfilled=True,
) )
defer.returnValue(events) defer.returnValue(events)
@ -639,7 +666,7 @@ class FederationHandler(BaseHandler):
pass pass
event_stream_id, max_stream_id = yield self._persist_auth_tree( event_stream_id, max_stream_id = yield self._persist_auth_tree(
auth_chain, state, event origin, auth_chain, state, event
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -690,7 +717,9 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create join %r because %s", event, e) logger.warn("Failed to create join %r because %s", event, e)
raise e raise e
self.auth.check(event, auth_events=context.current_state) # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
defer.returnValue(event) defer.returnValue(event)
@ -920,7 +949,9 @@ class FederationHandler(BaseHandler):
) )
try: try:
self.auth.check(event, auth_events=context.current_state) # The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
self.auth.check(event, auth_events=context.current_state, do_sig_check=False)
except AuthError as e: except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e) logger.warn("Failed to create new leave %r because %s", event, e)
raise e raise e
@ -989,14 +1020,9 @@ class FederationHandler(BaseHandler):
defer.returnValue(None) defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): def get_state_for_pdu(self, room_id, event_id):
yield run_on_reactor() yield run_on_reactor()
if do_auth:
in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
room_id, [event_id] room_id, [event_id]
) )
@ -1020,13 +1046,16 @@ class FederationHandler(BaseHandler):
res = results.values() res = results.values()
for event in res: for event in res:
event.signatures.update( # We sign these again because there was a bug where we
compute_event_signature( # incorrectly signed things the first time round
event, if self.hs.is_mine_id(event.event_id):
self.hs.hostname, event.signatures.update(
self.hs.config.signing_key[0] compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
) )
)
defer.returnValue(res) defer.returnValue(res)
else: else:
@ -1064,16 +1093,17 @@ class FederationHandler(BaseHandler):
) )
if event: if event:
# FIXME: This is a temporary work around where we occasionally if self.hs.is_mine_id(event.event_id):
# return events slightly differently than when they were # FIXME: This is a temporary work around where we occasionally
# originally signed # return events slightly differently than when they were
event.signatures.update( # originally signed
compute_event_signature( event.signatures.update(
event, compute_event_signature(
self.hs.hostname, event,
self.hs.config.signing_key[0] self.hs.hostname,
self.hs.config.signing_key[0]
)
) )
)
if do_auth: if do_auth:
in_room = yield self.auth.check_host_in_room( in_room = yield self.auth.check_host_in_room(
@ -1083,6 +1113,12 @@ class FederationHandler(BaseHandler):
if not in_room: if not in_room:
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
events = yield self._filter_events_for_server(
origin, event.room_id, [event]
)
event = events[0]
defer.returnValue(event) defer.returnValue(event)
else: else:
defer.returnValue(None) defer.returnValue(None)
@ -1091,15 +1127,6 @@ class FederationHandler(BaseHandler):
def get_min_depth_for_context(self, context): def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context) return self.store.get_min_depth(context)
@log_function
def user_joined_room(self, user, room_id):
waiters = self.waiting_for_join_list.get(
(user.to_string(), room_id),
[]
)
while waiters:
waiters.pop().callback(None)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None, def _handle_new_event(self, origin, event, state=None, auth_events=None,
@ -1122,11 +1149,12 @@ class FederationHandler(BaseHandler):
backfilled=backfilled, backfilled=backfilled,
) )
# this intentionally does not yield: we don't care about the result if not backfilled:
# and don't need to wait for it. # this intentionally does not yield: we don't care about the result
preserve_fn(self.hs.get_pusherpool().on_new_notifications)( # and don't need to wait for it.
event_stream_id, max_stream_id preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
) event_stream_id, max_stream_id
)
defer.returnValue((context, event_stream_id, max_stream_id)) defer.returnValue((context, event_stream_id, max_stream_id))
@ -1158,11 +1186,19 @@ class FederationHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _persist_auth_tree(self, auth_events, state, event): def _persist_auth_tree(self, origin, auth_events, state, event):
"""Checks the auth chain is valid (and passes auth checks) for the """Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically. state and event. Then persists the auth chain and state atomically.
Persists the event seperately. Persists the event seperately.
Will attempt to fetch missing auth events.
Args:
origin (str): Where the events came from
auth_events (list)
state (list)
event (Event)
Returns: Returns:
2-tuple of (event_stream_id, max_stream_id) from the persist_event 2-tuple of (event_stream_id, max_stream_id) from the persist_event
call for `event` call for `event`
@ -1175,7 +1211,7 @@ class FederationHandler(BaseHandler):
event_map = { event_map = {
e.event_id: e e.event_id: e
for e in auth_events for e in itertools.chain(auth_events, state, [event])
} }
create_event = None create_event = None
@ -1184,10 +1220,29 @@ class FederationHandler(BaseHandler):
create_event = e create_event = e
break break
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id, _ in e.auth_events:
if e_id not in event_map:
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
m_ev = yield self.replication_layer.get_pdu(
[origin],
e_id,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
else:
logger.info("Failed to find auth event %r", e_id)
for e in itertools.chain(auth_events, state, [event]): for e in itertools.chain(auth_events, state, [event]):
auth_for_e = { auth_for_e = {
(event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
for e_id, _ in e.auth_events for e_id, _ in e.auth_events
if e_id in event_map
} }
if create_event: if create_event:
auth_for_e[(EventTypes.Create, "")] = create_event auth_for_e[(EventTypes.Create, "")] = create_event
@ -1421,7 +1476,7 @@ class FederationHandler(BaseHandler):
local_view = dict(auth_events) local_view = dict(auth_events)
remote_view = dict(auth_events) remote_view = dict(auth_events)
remote_view.update({ remote_view.update({
(d.type, d.state_key): d for d in different_events (d.type, d.state_key): d for d in different_events if d
}) })
new_state, prev_state = self.state_handler.resolve_events( new_state, prev_state = self.state_handler.resolve_events(

View File

@ -21,7 +21,7 @@ from synapse.api.errors import (
) )
from ._base import BaseHandler from ._base import BaseHandler
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, Codes
import json import json
import logging import logging
@ -41,6 +41,20 @@ class IdentityHandler(BaseHandler):
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
) )
def _should_trust_id_server(self, id_server):
if id_server not in self.trusted_id_servers:
if self.trust_any_id_server_just_for_testing_do_not_use:
logger.warn(
"Trusting untrustworthy ID server %r even though it isn't"
" in the trusted id list for testing because"
" 'use_insecure_ssl_client_just_for_testing_do_not_use'"
" is set in the config",
id_server,
)
else:
return False
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
@ -59,19 +73,12 @@ class IdentityHandler(BaseHandler):
else: else:
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
if id_server not in self.trusted_id_servers: if not self._should_trust_id_server(id_server):
if self.trust_any_id_server_just_for_testing_do_not_use: logger.warn(
logger.warn( '%s is not a trusted ID server: rejecting 3pid ' +
"Trusting untrustworthy ID server %r even though it isn't" 'credentials', id_server
" in the trusted id list for testing because" )
" 'use_insecure_ssl_client_just_for_testing_do_not_use'" defer.returnValue(None)
" is set in the config",
id_server,
)
else:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None)
data = {} data = {}
try: try:
@ -129,6 +136,12 @@ class IdentityHandler(BaseHandler):
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor() yield run_on_reactor()
if not self._should_trust_id_server(id_server):
raise SynapseError(
400, "Untrusted ID server '%s'" % id_server,
Codes.SERVER_NOT_TRUSTED
)
params = { params = {
'email': email, 'email': email,
'client_secret': client_secret, 'client_secret': client_secret,

View File

@ -26,9 +26,9 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
from synapse.util.caches.snapshot_cache import SnapshotCache from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import preserve_fn
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from ._base import BaseHandler from ._base import BaseHandler
@ -50,9 +50,23 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator() self.validator = EventValidator()
self.snapshot_cache = SnapshotCache() self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock()
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
depth = event.depth
with (yield self.pagination_lock.write(room_id)):
yield self.store.delete_old_state(room_id, depth)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None, def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True): as_client_event=True, event_filter=None):
"""Get messages in a room. """Get messages in a room.
Args: Args:
@ -61,11 +75,11 @@ class MessageHandler(BaseHandler):
pagin_config (synapse.api.streams.PaginationConfig): The pagination pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any. config rules to apply, if any.
as_client_event (bool): True to get events in client-server format. as_client_event (bool): True to get events in client-server format.
event_filter (Filter): Filter to apply to results or None
Returns: Returns:
dict: Pagination API results dict: Pagination API results
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token: if pagin_config.from_token:
room_token = pagin_config.from_token.room_key room_token = pagin_config.from_token.room_key
@ -85,42 +99,48 @@ class MessageHandler(BaseHandler):
source_config = pagin_config.get_source_config("room") source_config = pagin_config.get_source_config("room")
membership, member_event_id = yield self._check_in_room_or_world_readable( with (yield self.pagination_lock.read(room_id)):
room_id, user_id membership, member_event_id = yield self._check_in_room_or_world_readable(
) room_id, user_id
if source_config.direction == 'b':
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token_for_stream_and_room(
room_id, room_token.stream
)
if membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
) )
events, next_key = yield data_source.get_pagination_rows( if source_config.direction == 'b':
requester.user, source_config, room_id # if we're going backwards, we might need to backfill. This
) # requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token(
room_id, room_token.stream
)
next_token = pagin_config.from_token.copy_and_replace( if membership == Membership.LEAVE:
"room_key", next_key # If they have left the room then clamp the token to be before
) # they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
events, next_key = yield self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
direction=source_config.direction,
limit=source_config.limit,
event_filter=event_filter,
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events: if not events:
defer.returnValue({ defer.returnValue({
@ -129,6 +149,9 @@ class MessageHandler(BaseHandler):
"end": next_token.to_string(), "end": next_token.to_string(),
}) })
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client( events = yield filter_events_for_client(
self.store, self.store,
user_id, user_id,
@ -908,13 +931,16 @@ class MessageHandler(BaseHandler):
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
with PreserveLoggingContext(): @defer.inlineCallbacks
# Don't block waiting on waking up all the listeners. def _notify():
yield run_on_reactor()
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, event, event_stream_id, max_stream_id,
extra_users=extra_users extra_users=extra_users
) )
preserve_fn(_notify)()
# If invite, remove room_state from unsigned before sending. # If invite, remove room_state from unsigned before sending.
event.unsigned.pop("invite_room_state", None) event.unsigned.pop("invite_room_state", None)

View File

@ -50,6 +50,8 @@ timers_fired_counter = metrics.register_counter("timers_fired")
federation_presence_counter = metrics.register_counter("federation_presence") federation_presence_counter = metrics.register_counter("federation_presence")
bump_active_time_counter = metrics.register_counter("bump_active_time") bump_active_time_counter = metrics.register_counter("bump_active_time")
get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active" # "currently_active"
@ -68,6 +70,10 @@ FEDERATION_TIMEOUT = 30 * 60 * 1000
# How often to resend presence to remote servers # How often to resend presence to remote servers
FEDERATION_PING_INTERVAL = 25 * 60 * 1000 FEDERATION_PING_INTERVAL = 25 * 60 * 1000
# How long we will wait before assuming that the syncs from an external process
# are dead.
EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
@ -158,15 +164,26 @@ class PresenceHandler(object):
self.serial_to_user = {} self.serial_to_user = {}
self._next_serial = 1 self._next_serial = 1
# Keeps track of the number of *ongoing* syncs. While this is non zero # Keeps track of the number of *ongoing* syncs on this process. While
# a user will never go offline. # this is non zero a user will never go offline.
self.user_to_num_current_syncs = {} self.user_to_num_current_syncs = {}
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
# go offline.
# Each process has a unique identifier and an update frequency. If
# no update is received from that process within the update period then
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
self.external_process_to_current_syncs = {}
self.external_process_last_updated_ms = {}
# Start a LoopingCall in 30s that fires every 5s. # Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
self.clock.call_later( self.clock.call_later(
30 * 1000, 30,
self.clock.looping_call, self.clock.looping_call,
self._handle_timeouts, self._handle_timeouts,
5000, 5000,
@ -266,31 +283,48 @@ class PresenceHandler(object):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
""" """
logger.info("Handling presence timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()
with Measure(self.clock, "presence_handle_timeouts"): try:
# Fetch the list of users that *may* have timed out. Things may have with Measure(self.clock, "presence_handle_timeouts"):
# changed since the timeout was set, so we won't necessarily have to # Fetch the list of users that *may* have timed out. Things may have
# take any action. # changed since the timeout was set, so we won't necessarily have to
users_to_check = self.wheel_timer.fetch(now) # take any action.
users_to_check = set(self.wheel_timer.fetch(now))
states = [ # Check whether the lists of syncing processes from an external
self.user_to_current_state.get( # process have expired.
user_id, UserPresenceState.default(user_id) expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_last_updated_ms.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in users_to_check
]
timers_fired_counter.inc_by(len(states))
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_user_ids=self.get_currently_syncing_users(),
now=now,
) )
for user_id in set(users_to_check)
]
timers_fired_counter.inc_by(len(states)) preserve_fn(self._update_states)(changes)
except:
changes = handle_timeouts( logger.exception("Exception in _handle_timeouts loop")
states,
is_mine_fn=self.is_mine_id,
user_to_num_current_syncs=self.user_to_num_current_syncs,
now=now,
)
preserve_fn(self._update_states)(changes)
@defer.inlineCallbacks @defer.inlineCallbacks
def bump_presence_active_time(self, user): def bump_presence_active_time(self, user):
@ -363,6 +397,74 @@ class PresenceHandler(object):
defer.returnValue(_user_syncing()) defer.returnValue(_user_syncing())
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
Returns:
set(str): A set of user_id strings.
"""
syncing_user_ids = {
user_id for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
syncing_user_ids.update(user_ids)
return syncing_user_ids
@defer.inlineCallbacks
def update_external_syncs(self, process_id, syncing_user_ids):
"""Update the syncing users for an external process
Args:
process_id(str): An identifier for the process the users are
syncing against. This allows synapse to process updates
as user start and stop syncing against a given process.
syncing_user_ids(set(str)): The set of user_ids that are
currently syncing on that server.
"""
# Grab the previous list of user_ids that were syncing on that process
prev_syncing_user_ids = (
self.external_process_to_current_syncs.get(process_id, set())
)
# Grab the current presence state for both the users that are syncing
# now and the users that were syncing before this update.
prev_states = yield self.current_state_for_users(
syncing_user_ids | prev_syncing_user_ids
)
updates = []
time_now_ms = self.clock.time_msec()
# For each new user that is syncing check if we need to mark them as
# being online.
for new_user_id in syncing_user_ids - prev_syncing_user_ids:
prev_state = prev_states[new_user_id]
if prev_state.state == PresenceState.OFFLINE:
updates.append(prev_state.copy_and_replace(
state=PresenceState.ONLINE,
last_active_ts=time_now_ms,
last_user_sync_ts=time_now_ms,
))
else:
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
# For each user that is still syncing or stopped syncing update the
# last sync time so that we will correctly apply the grace period when
# they stop syncing.
for old_user_id in prev_syncing_user_ids:
prev_state = prev_states[old_user_id]
updates.append(prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
))
yield self._update_states(updates)
# Update the last updated time for the process. We expire the entries
# if we don't receive an update in the given timeframe.
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
self.external_process_to_current_syncs[process_id] = syncing_user_ids
@defer.inlineCallbacks @defer.inlineCallbacks
def current_state_for_user(self, user_id): def current_state_for_user(self, user_id):
"""Get the current presence state for a user. """Get the current presence state for a user.
@ -879,13 +981,13 @@ class PresenceEventSource(object):
user_ids_changed = set() user_ids_changed = set()
changed = None changed = None
if from_key and max_token - from_key < 100: if from_key:
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
changed = stream_change_cache.get_all_entities_changed(from_key) changed = stream_change_cache.get_all_entities_changed(from_key)
# get_all_entities_changed can return None if changed is not None and len(changed) < 500:
if changed is not None: # For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
get_updates_counter.inc("stream")
for other_user_id in changed: for other_user_id in changed:
if other_user_id in friends: if other_user_id in friends:
user_ids_changed.add(other_user_id) user_ids_changed.add(other_user_id)
@ -897,6 +999,8 @@ class PresenceEventSource(object):
else: else:
# Too many possible updates. Find all users we can see and check # Too many possible updates. Find all users we can see and check
# if any of them have changed. # if any of them have changed.
get_updates_counter.inc("full")
user_ids_to_check = set() user_ids_to_check = set()
for room_id in room_ids: for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
@ -935,15 +1039,14 @@ class PresenceEventSource(object):
return self.get_new_events(user, from_key=None, include_offline=False) return self.get_new_events(user, from_key=None, include_offline=False)
def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now): def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
"""Checks the presence of users that have timed out and updates as """Checks the presence of users that have timed out and updates as
appropriate. appropriate.
Args: Args:
user_states(list): List of UserPresenceState's to check. user_states(list): List of UserPresenceState's to check.
is_mine_fn (fn): Function that returns if a user_id is ours is_mine_fn (fn): Function that returns if a user_id is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently syncing_user_ids (set): Set of user_ids with active syncs.
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@ -954,21 +1057,20 @@ def handle_timeouts(user_states, is_mine_fn, user_to_num_current_syncs, now):
for state in user_states: for state in user_states:
is_mine = is_mine_fn(state.user_id) is_mine = is_mine_fn(state.user_id)
new_state = handle_timeout(state, is_mine, user_to_num_current_syncs, now) new_state = handle_timeout(state, is_mine, syncing_user_ids, now)
if new_state: if new_state:
changes[state.user_id] = new_state changes[state.user_id] = new_state
return changes.values() return changes.values()
def handle_timeout(state, is_mine, user_to_num_current_syncs, now): def handle_timeout(state, is_mine, syncing_user_ids, now):
"""Checks the presence of the user to see if any of the timers have elapsed """Checks the presence of the user to see if any of the timers have elapsed
Args: Args:
state (UserPresenceState) state (UserPresenceState)
is_mine (bool): Whether the user is ours is_mine (bool): Whether the user is ours
user_to_num_current_syncs (dict): Mapping of user_id to number of currently syncing_user_ids (set): Set of user_ids with active syncs.
active syncs.
now (int): Current time in ms. now (int): Current time in ms.
Returns: Returns:
@ -1002,7 +1104,7 @@ def handle_timeout(state, is_mine, user_to_num_current_syncs, now):
# If there are have been no sync for a while (and none ongoing), # If there are have been no sync for a while (and none ongoing),
# set presence to offline # set presence to offline
if not user_to_num_current_syncs.get(user_id, 0): if user_id not in syncing_user_ids:
if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT: if now - state.last_user_sync_ts > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace( state = state.copy_and_replace(
state=PresenceState.OFFLINE, state=PresenceState.OFFLINE,

View File

@ -13,15 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
from twisted.internet import defer from twisted.internet import defer
import synapse.types
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester from synapse.types import UserID
from ._base import BaseHandler from ._base import BaseHandler
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,13 +36,6 @@ class ProfileHandler(BaseHandler):
"profile", self.on_profile_query "profile", self.on_profile_query
) )
distributor = hs.get_distributor()
distributor.observe("registered_user", self.registered_user)
def registered_user(self, user):
return self.store.create_profile(user.localpart)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_displayname(self, target_user): def get_displayname(self, target_user):
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
@ -172,7 +165,9 @@ class ProfileHandler(BaseHandler):
try: try:
# Assume the user isn't a guest because we don't let guests set # Assume the user isn't a guest because we don't let guests set
# profile or avatar data. # profile or avatar data.
requester = Requester(user, "", False) # XXX why are we recreating `requester` here for each room?
# what was wrong with the `requester` we were passed?
requester = synapse.types.create_requester(user)
yield handler.update_membership( yield handler.update_membership(
requester, requester,
user, user,

View File

@ -14,19 +14,19 @@
# limitations under the License. # limitations under the License.
"""Contains functions for registering clients.""" """Contains functions for registering clients."""
import logging
import urllib
from twisted.internet import defer from twisted.internet import defer
from synapse.types import UserID import synapse.types
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
) )
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.util.distributor import registered_user from synapse.types import UserID
from synapse.util.async import run_on_reactor
import logging from ._base import BaseHandler
import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -37,8 +37,6 @@ class RegistrationHandler(BaseHandler):
super(RegistrationHandler, self).__init__(hs) super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.distributor = hs.get_distributor()
self.distributor.declare("registered_user")
self.captcha_client = CaptchaServerHttpClient(hs) self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None self._next_generated_user_id = None
@ -55,6 +53,13 @@ class RegistrationHandler(BaseHandler):
Codes.INVALID_USERNAME Codes.INVALID_USERNAME
) )
if localpart[0] == '_':
raise SynapseError(
400,
"User ID may not begin with _",
Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
@ -93,7 +98,8 @@ class RegistrationHandler(BaseHandler):
password=None, password=None,
generate_token=True, generate_token=True,
guest_access_token=None, guest_access_token=None,
make_guest=False make_guest=False,
admin=False,
): ):
"""Registers a new client on the server. """Registers a new client on the server.
@ -101,8 +107,13 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be generated. one will be generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user). via a password (e.g. the user is an application service user).
generate_token (bool): Whether a new access token should be
generated. Having this be True should be considered deprecated,
since it offers no means of associating a device_id with the
access_token. Instead you should call auth_handler.issue_access_token
after registration.
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -140,9 +151,12 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash, password_hash=password_hash,
was_guest=was_guest, was_guest=was_guest,
make_guest=make_guest, make_guest=make_guest,
create_profile_with_localpart=(
# If the user was a guest then they already have a profile
None if was_guest else user.localpart
),
admin=admin,
) )
yield registered_user(self.distributor, user)
else: else:
# autogen a sequential user ID # autogen a sequential user ID
attempts = 0 attempts = 0
@ -160,7 +174,8 @@ class RegistrationHandler(BaseHandler):
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
make_guest=make_guest make_guest=make_guest,
create_profile_with_localpart=user.localpart,
) )
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
@ -168,7 +183,6 @@ class RegistrationHandler(BaseHandler):
user_id = None user_id = None
token = None token = None
attempts += 1 attempts += 1
yield registered_user(self.distributor, user)
# We used to generate default identicons here, but nowadays # We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding # we want clients to generate their own as part of their branding
@ -195,15 +209,13 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
token = self.auth_handler().generate_access_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
create_profile_with_localpart=user.localpart,
) )
yield registered_user(self.distributor, user) defer.returnValue(user_id)
defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
@ -248,9 +260,9 @@ class RegistrationHandler(BaseHandler):
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=None password_hash=None,
create_profile_with_localpart=user.localpart,
) )
yield registered_user(self.distributor, user)
except Exception as e: except Exception as e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors # Ignore Registration errors
@ -359,8 +371,10 @@ class RegistrationHandler(BaseHandler):
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_create_user(self, localpart, displayname, duration_seconds): def get_or_create_user(self, localpart, displayname, duration_in_ms,
"""Creates a new user or returns an access token for an existing one password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
Args: Args:
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
@ -387,32 +401,32 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
auth_handler = self.hs.get_handlers().auth_handler token = self.auth_handler().generate_access_token(
token = auth_handler.generate_short_term_login_token(user_id, duration_seconds) user_id, None, duration_in_ms)
if need_register: if need_register:
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=None password_hash=password_hash,
create_profile_with_localpart=user.localpart,
) )
yield registered_user(self.distributor, user)
else: else:
yield self.store.flush_user(user_id=user_id) yield self.store.user_delete_access_tokens(user_id=user_id)
yield self.store.add_access_token_to_user(user_id=user_id, token=token) yield self.store.add_access_token_to_user(user_id=user_id, token=token)
if displayname is not None: if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname) logger.info("setting user display name: %s -> %s", user_id, displayname)
profile_handler = self.hs.get_handlers().profile_handler profile_handler = self.hs.get_handlers().profile_handler
requester = synapse.types.create_requester(user)
yield profile_handler.set_displayname( yield profile_handler.set_displayname(
user, user, displayname user, requester, displayname
) )
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
def auth_handler(self): def auth_handler(self):
return self.hs.get_handlers().auth_handler return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def guest_access_token_for(self, medium, address, inviter_user_id): def guest_access_token_for(self, medium, address, inviter_user_id):

View File

@ -20,7 +20,7 @@ from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, JoinRules, RoomCreationPreset, EventTypes, JoinRules, RoomCreationPreset, Membership,
) )
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
@ -36,6 +36,8 @@ import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
id_server_scheme = "https://" id_server_scheme = "https://"
@ -343,9 +345,15 @@ class RoomCreationHandler(BaseHandler):
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(RoomListHandler, self).__init__(hs) super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache() self.response_cache = ResponseCache(hs)
self.remote_list_request_cache = ResponseCache(hs)
self.remote_list_cache = {}
self.fetch_looping_call = hs.get_clock().looping_call(
self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL
)
self.fetch_all_remote_lists()
def get_public_room_list(self): def get_local_public_room_list(self):
result = self.response_cache.get(()) result = self.response_cache.get(())
if not result: if not result:
result = self.response_cache.set((), self._get_public_room_list()) result = self.response_cache.set((), self._get_public_room_list())
@ -359,14 +367,10 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_room(room_id): def handle_room(room_id):
# We pull each bit of state out indvidually to avoid pulling the current_state = yield self.state_handler.get_current_state(room_id)
# full state into memory. Due to how the caching works this should
# be fairly quick, even if not originally in the cache.
def get_state(etype, state_key):
return self.state_handler.get_current_state(room_id, etype, state_key)
# Double check that this is actually a public room. # Double check that this is actually a public room.
join_rules_event = yield get_state(EventTypes.JoinRules, "") join_rules_event = current_state.get((EventTypes.JoinRules, ""))
if join_rules_event: if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None) join_rule = join_rules_event.content.get("join_rule", None)
if join_rule and join_rule != JoinRules.PUBLIC: if join_rule and join_rule != JoinRules.PUBLIC:
@ -374,47 +378,51 @@ class RoomListHandler(BaseHandler):
result = {"room_id": room_id} result = {"room_id": room_id}
joined_users = yield self.store.get_users_in_room(room_id) num_joined_users = len([
if len(joined_users) == 0: 1 for _, event in current_state.items()
if event.type == EventTypes.Member
and event.membership == Membership.JOIN
])
if num_joined_users == 0:
return return
result["num_joined_members"] = len(joined_users) result["num_joined_members"] = num_joined_users
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
if aliases: if aliases:
result["aliases"] = aliases result["aliases"] = aliases
name_event = yield get_state(EventTypes.Name, "") name_event = yield current_state.get((EventTypes.Name, ""))
if name_event: if name_event:
name = name_event.content.get("name", None) name = name_event.content.get("name", None)
if name: if name:
result["name"] = name result["name"] = name
topic_event = yield get_state(EventTypes.Topic, "") topic_event = current_state.get((EventTypes.Topic, ""))
if topic_event: if topic_event:
topic = topic_event.content.get("topic", None) topic = topic_event.content.get("topic", None)
if topic: if topic:
result["topic"] = topic result["topic"] = topic
canonical_event = yield get_state(EventTypes.CanonicalAlias, "") canonical_event = current_state.get((EventTypes.CanonicalAlias, ""))
if canonical_event: if canonical_event:
canonical_alias = canonical_event.content.get("alias", None) canonical_alias = canonical_event.content.get("alias", None)
if canonical_alias: if canonical_alias:
result["canonical_alias"] = canonical_alias result["canonical_alias"] = canonical_alias
visibility_event = yield get_state(EventTypes.RoomHistoryVisibility, "") visibility_event = current_state.get((EventTypes.RoomHistoryVisibility, ""))
visibility = None visibility = None
if visibility_event: if visibility_event:
visibility = visibility_event.content.get("history_visibility", None) visibility = visibility_event.content.get("history_visibility", None)
result["world_readable"] = visibility == "world_readable" result["world_readable"] = visibility == "world_readable"
guest_event = yield get_state(EventTypes.GuestAccess, "") guest_event = current_state.get((EventTypes.GuestAccess, ""))
guest = None guest = None
if guest_event: if guest_event:
guest = guest_event.content.get("guest_access", None) guest = guest_event.content.get("guest_access", None)
result["guest_can_join"] = guest == "can_join" result["guest_can_join"] = guest == "can_join"
avatar_event = yield get_state("m.room.avatar", "") avatar_event = current_state.get(("m.room.avatar", ""))
if avatar_event: if avatar_event:
avatar_url = avatar_event.content.get("url", None) avatar_url = avatar_event.content.get("url", None)
if avatar_url: if avatar_url:
@ -427,6 +435,55 @@ class RoomListHandler(BaseHandler):
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": results}) defer.returnValue({"start": "START", "end": "END", "chunk": results})
@defer.inlineCallbacks
def fetch_all_remote_lists(self):
deferred = self.hs.get_replication_layer().get_public_rooms(
self.hs.config.secondary_directory_servers
)
self.remote_list_request_cache.set((), deferred)
self.remote_list_cache = yield deferred
@defer.inlineCallbacks
def get_aggregated_public_room_list(self):
"""
Get the public room list from this server and the servers
specified in the secondary_directory_servers config option.
XXX: Pagination...
"""
# We return the results from out cache which is updated by a looping call,
# unless we're missing a cache entry, in which case wait for the result
# of the fetch if there's one in progress. If not, omit that server.
wait = False
for s in self.hs.config.secondary_directory_servers:
if s not in self.remote_list_cache:
logger.warn("No cached room list from %s: waiting for fetch", s)
wait = True
break
if wait and self.remote_list_request_cache.get(()):
yield self.remote_list_request_cache.get(())
public_rooms = yield self.get_local_public_room_list()
# keep track of which room IDs we've seen so we can de-dup
room_ids = set()
# tag all the ones in our list with our server name.
# Also add the them to the de-deping set
for room in public_rooms['chunk']:
room["server_name"] = self.hs.hostname
room_ids.add(room["room_id"])
# Now add the results from federation
for server_name, server_result in self.remote_list_cache.items():
for room in server_result["chunk"]:
if room["room_id"] not in room_ids:
room["server_name"] = server_name
public_rooms["chunk"].append(room)
room_ids.add(room["room_id"])
defer.returnValue(public_rooms)
class RoomContextHandler(BaseHandler): class RoomContextHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -14,24 +14,22 @@
# limitations under the License. # limitations under the License.
import logging
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler import synapse.types
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, Membership, EventTypes, Membership,
) )
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import UserID, RoomID
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room from synapse.util.distributor import user_left_room, user_joined_room
from ._base import BaseHandler
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -315,7 +313,7 @@ class RoomMemberHandler(BaseHandler):
) )
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else: else:
requester = Requester(target_user, None, False) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context) prev_event = message_handler.deduplicate_state_event(event, context)

File diff suppressed because it is too large Load Diff

View File

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
# A tiny object useful for storing a user's membership in a room, as a mapping # A tiny object useful for storing a user's membership in a room, as a mapping
# key # key
RoomMember = namedtuple("RoomMember", ("room_id", "user")) RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
class TypingHandler(object): class TypingHandler(object):
@ -38,7 +38,7 @@ class TypingHandler(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.server_name = hs.config.server_name self.server_name = hs.config.server_name
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@ -67,20 +67,23 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout): def started_typing(self, target_user, auth_user, room_id, timeout):
if not self.is_mine(target_user): target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string()) yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug( logger.debug(
"%s has started typing in %s", target_user.to_string(), room_id "%s has started typing in %s", target_user_id, room_id
) )
until = self.clock.time_msec() + timeout until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user=target_user) member = RoomMember(room_id=room_id, user_id=target_user_id)
was_present = member in self._member_typing_until was_present = member in self._member_typing_until
@ -104,25 +107,28 @@ class TypingHandler(object):
yield self._push_update( yield self._push_update(
room_id=room_id, room_id=room_id,
user=target_user, user_id=target_user_id,
typing=True, typing=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id): def stopped_typing(self, target_user, auth_user, room_id):
if not self.is_mine(target_user): target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
if target_user != auth_user: if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
yield self.auth.check_joined_room(room_id, target_user.to_string()) yield self.auth.check_joined_room(room_id, target_user_id)
logger.debug( logger.debug(
"%s has stopped typing in %s", target_user.to_string(), room_id "%s has stopped typing in %s", target_user_id, room_id
) )
member = RoomMember(room_id=room_id, user=target_user) member = RoomMember(room_id=room_id, user_id=target_user_id)
if member in self._member_typing_timer: if member in self._member_typing_timer:
self.clock.cancel_call_later(self._member_typing_timer[member]) self.clock.cancel_call_later(self._member_typing_timer[member])
@ -132,8 +138,9 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):
if self.is_mine(user): user_id = user.to_string()
member = RoomMember(room_id=room_id, user=user) if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
yield self._stopped_typing(member) yield self._stopped_typing(member)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -144,7 +151,7 @@ class TypingHandler(object):
yield self._push_update( yield self._push_update(
room_id=member.room_id, room_id=member.room_id,
user=member.user, user_id=member.user_id,
typing=False, typing=False,
) )
@ -156,7 +163,7 @@ class TypingHandler(object):
del self._member_typing_timer[member] del self._member_typing_timer[member]
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_update(self, room_id, user, typing): def _push_update(self, room_id, user_id, typing):
domains = yield self.store.get_joined_hosts_for_room(room_id) domains = yield self.store.get_joined_hosts_for_room(room_id)
deferreds = [] deferreds = []
@ -164,7 +171,7 @@ class TypingHandler(object):
if domain == self.server_name: if domain == self.server_name:
self._push_update_local( self._push_update_local(
room_id=room_id, room_id=room_id,
user=user, user_id=user_id,
typing=typing typing=typing
) )
else: else:
@ -173,7 +180,7 @@ class TypingHandler(object):
edu_type="m.typing", edu_type="m.typing",
content={ content={
"room_id": room_id, "room_id": room_id,
"user_id": user.to_string(), "user_id": user_id,
"typing": typing, "typing": typing,
}, },
)) ))
@ -183,23 +190,26 @@ class TypingHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):
room_id = content["room_id"] room_id = content["room_id"]
user = UserID.from_string(content["user_id"]) user_id = content["user_id"]
# Check that the string is a valid user id
UserID.from_string(user_id)
domains = yield self.store.get_joined_hosts_for_room(room_id) domains = yield self.store.get_joined_hosts_for_room(room_id)
if self.server_name in domains: if self.server_name in domains:
self._push_update_local( self._push_update_local(
room_id=room_id, room_id=room_id,
user=user, user_id=user_id,
typing=content["typing"] typing=content["typing"]
) )
def _push_update_local(self, room_id, user, typing): def _push_update_local(self, room_id, user_id, typing):
room_set = self._room_typing.setdefault(room_id, set()) room_set = self._room_typing.setdefault(room_id, set())
if typing: if typing:
room_set.add(user) room_set.add(user_id)
else: else:
room_set.discard(user) room_set.discard(user_id)
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
@ -211,13 +221,14 @@ class TypingHandler(object):
def get_all_typing_updates(self, last_id, current_id): def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state. # TODO: Work out a way to do this without scanning the entire state.
if last_id == current_id:
return []
rows = [] rows = []
for room_id, serial in self._room_serials.items(): for room_id, serial in self._room_serials.items():
if last_id < serial and serial <= current_id: if last_id < serial and serial <= current_id:
typing = self._room_typing[room_id] typing = self._room_typing[room_id]
typing_bytes = json.dumps([ typing_bytes = json.dumps(list(typing), ensure_ascii=False)
u.to_string() for u in typing
], ensure_ascii=False)
rows.append((serial, room_id, typing_bytes)) rows.append((serial, room_id, typing_bytes))
rows.sort() rows.sort()
return rows return rows
@ -239,7 +250,7 @@ class TypingNotificationEventSource(object):
"type": "m.typing", "type": "m.typing",
"room_id": room_id, "room_id": room_id,
"content": { "content": {
"user_ids": [u.to_string() for u in typing], "user_ids": list(typing),
}, },
} }

View File

@ -24,12 +24,13 @@ from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl, protocol from twisted.internet import defer, reactor, ssl, protocol, task
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.web.client import ( from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, FileBodyProducer, PartialDownloadError, readBody, PartialDownloadError,
) )
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
from twisted.web.http import PotentialDataLoss from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
@ -468,3 +469,26 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname, port): def creatorForNetloc(self, hostname, port):
return self return self
class FileBodyProducer(TwistedFileBodyProducer):
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
We override the pauseProducing and resumeProducing methods in twisted's
FileBodyProducer so that they do not raise exceptions if the task has
already completed.
"""
def pauseProducing(self):
try:
super(FileBodyProducer, self).pauseProducing()
except task.TaskDone:
# task has already completed
pass
def resumeProducing(self):
try:
super(FileBodyProducer, self).resumeProducing()
except task.NotPaused:
# task was not paused (probably because it had already completed)
pass

View File

@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60, time_out=timeout / 1000. if timeout else 60,
) )
response = yield preserve_context_over_fn( response = yield preserve_context_over_fn(send_request)
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,) log_result = "%d %s" % (response.code, response.phrase,)
break break

View File

@ -205,6 +205,7 @@ class JsonResource(HttpServer, resource.Resource):
def register_paths(self, method, path_patterns, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )

View File

@ -22,22 +22,20 @@ import functools
import os import os
import stat import stat
import time import time
import gc
from twisted.internet import reactor from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
MemoryUsageMetric,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# We'll keep all the available metrics in a single toplevel dict, one shared all_metrics = []
# for the entire process. We don't currently support per-HomeServer instances
# of metrics, because in practice any one python VM will host only one
# HomeServer anyway. This makes a lot of implementation neater
all_metrics = {}
class Metrics(object): class Metrics(object):
@ -53,7 +51,7 @@ class Metrics(object):
metric = metric_class(full_name, *args, **kwargs) metric = metric_class(full_name, *args, **kwargs)
all_metrics[full_name] = metric all_metrics.append(metric)
return metric return metric
def register_counter(self, *args, **kwargs): def register_counter(self, *args, **kwargs):
@ -69,6 +67,21 @@ class Metrics(object):
return self._register(CacheMetric, *args, **kwargs) return self._register(CacheMetric, *args, **kwargs)
def register_memory_metrics(hs):
try:
import psutil
process = psutil.Process()
process.memory_info().rss
except (ImportError, AttributeError):
logger.warn(
"psutil is not installed or incorrect version."
" Disabling memory metrics."
)
return
metric = MemoryUsageMetric(hs, psutil)
all_metrics.append(metric)
def get_metrics_for(pkg_name): def get_metrics_for(pkg_name):
""" Returns a Metrics instance for conveniently creating metrics """ Returns a Metrics instance for conveniently creating metrics
namespaced with the given name prefix. """ namespaced with the given name prefix. """
@ -84,12 +97,12 @@ def render_all():
# TODO(paul): Internal hack # TODO(paul): Internal hack
update_resource_metrics() update_resource_metrics()
for name in sorted(all_metrics.keys()): for metric in all_metrics:
try: try:
strs += all_metrics[name].render() strs += metric.render()
except Exception: except Exception:
strs += ["# FAILED to render %s" % name] strs += ["# FAILED to render"]
logger.exception("Failed to render %s metric", name) logger.exception("Failed to render metric")
strs.append("") # to generate a final CRLF strs.append("") # to generate a final CRLF
@ -156,6 +169,13 @@ reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time") tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls") pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
gc_time = reactor_metrics.register_distribution("gc_time", labels=["gen"])
gc_unreachable = reactor_metrics.register_counter("gc_unreachable", labels=["gen"])
reactor_metrics.register_callback(
"gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
)
def runUntilCurrentTimer(func): def runUntilCurrentTimer(func):
@ -182,6 +202,22 @@ def runUntilCurrentTimer(func):
end = time.time() * 1000 end = time.time() * 1000
tick_time.inc_by(end - start) tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending) pending_calls_metric.inc_by(num_pending)
# Check if we need to do a manual GC (since its been disabled), and do
# one if necessary.
threshold = gc.get_threshold()
counts = gc.get_count()
for i in (2, 1, 0):
if threshold[i] < counts[i]:
logger.info("Collecting gc %d", i)
start = time.time() * 1000
unreachable = gc.collect(i)
end = time.time() * 1000
gc_time.inc_by(end - start, i)
gc_unreachable.inc_by(unreachable, i)
return ret return ret
return f return f
@ -196,5 +232,9 @@ try:
# runUntilCurrent is called when we have pending calls. It is called once # runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling. # per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
gc.disable()
except AttributeError: except AttributeError:
pass pass

View File

@ -47,9 +47,6 @@ class BaseMetric(object):
for k, v in zip(self.labels, values)]) for k, v in zip(self.labels, values)])
) )
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CounterMetric(BaseMetric): class CounterMetric(BaseMetric):
"""The simplest kind of metric; one that stores a monotonically-increasing """The simplest kind of metric; one that stores a monotonically-increasing
@ -83,6 +80,9 @@ class CounterMetric(BaseMetric):
def render_item(self, k): def render_item(self, k):
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
def render(self):
return map_concat(self.render_item, sorted(self.counts.keys()))
class CallbackMetric(BaseMetric): class CallbackMetric(BaseMetric):
"""A metric that returns the numeric value returned by a callback whenever """A metric that returns the numeric value returned by a callback whenever
@ -126,30 +126,70 @@ class DistributionMetric(object):
class CacheMetric(object): class CacheMetric(object):
"""A combination of two CounterMetrics, one to count cache hits and one to __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
count a total, and a callback metric to yield the current size.
This metric generates standard metric name pairs, so that monitoring rules def __init__(self, name, size_callback, cache_name):
can easily be applied to measure hit ratio."""
def __init__(self, name, size_callback, labels=[]):
self.name = name self.name = name
self.cache_name = cache_name
self.hits = CounterMetric(name + ":hits", labels=labels) self.hits = 0
self.total = CounterMetric(name + ":total", labels=labels) self.misses = 0
self.size = CallbackMetric( self.size_callback = size_callback
name + ":size",
callback=size_callback,
labels=labels,
)
def inc_hits(self, *values): def inc_hits(self):
self.hits.inc(*values) self.hits += 1
self.total.inc(*values)
def inc_misses(self, *values): def inc_misses(self):
self.total.inc(*values) self.misses += 1
def render(self): def render(self):
return self.hits.render() + self.total.render() + self.size.render() size = self.size_callback()
hits = self.hits
total = self.misses + self.hits
return [
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
]
class MemoryUsageMetric(object):
"""Keeps track of the current memory usage, using psutil.
The class will keep the current min/max/sum/counts of rss over the last
WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
"""
UPDATE_HZ = 2 # number of times to get memory per second
WINDOW_SIZE_SEC = 30 # the size of the window in seconds
def __init__(self, hs, psutil):
clock = hs.get_clock()
self.memory_snapshots = []
self.process = psutil.Process()
clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
def _update_curr_values(self):
max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
self.memory_snapshots.append(self.process.memory_info().rss)
self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
def render(self):
if not self.memory_snapshots:
return []
max_rss = max(self.memory_snapshots)
min_rss = min(self.memory_snapshots)
sum_rss = sum(self.memory_snapshots)
len_rss = len(self.memory_snapshots)
return [
"process_psutil_rss:max %d" % max_rss,
"process_psutil_rss:min %d" % min_rss,
"process_psutil_rss:total %d" % sum_rss,
"process_psutil_rss:count %d" % len_rss,
]

View File

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -140,8 +140,6 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000 UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs): def __init__(self, hs):
self.hs = hs
self.user_to_user_stream = {} self.user_to_user_stream = {}
self.room_to_user_streams = {} self.room_to_user_streams = {}
self.appservice_to_user_streams = {} self.appservice_to_user_streams = {}
@ -151,10 +149,8 @@ class Notifier(object):
self.pending_new_room_events = [] self.pending_new_room_events = []
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
hs.get_distributor().observe( self.state_handler = hs.get_state_handler()
"user_joined_room", self._user_joined_room
)
self.clock.looping_call( self.clock.looping_call(
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
@ -232,9 +228,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]): def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event""" """Notify any user streams that are interested in this room event"""
# poke any interested application service. # poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services( self.appservice_handler.notify_interested_services(event)
event
)
app_streams = set() app_streams = set()
@ -250,6 +244,9 @@ class Notifier(object):
) )
app_streams |= app_user_streams app_streams |= app_user_streams
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
self.on_new_event( self.on_new_event(
"room_key", room_stream_id, "room_key", room_stream_id,
users=extra_users, users=extra_users,
@ -449,7 +446,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _is_world_readable(self, room_id): def _is_world_readable(self, room_id):
state = yield self.hs.get_state_handler().get_current_state( state = yield self.state_handler.get_current_state(
room_id, room_id,
EventTypes.RoomHistoryVisibility EventTypes.RoomHistoryVisibility
) )
@ -485,9 +482,8 @@ class Notifier(object):
user_stream.appservice, set() user_stream.appservice, set()
).add(user_stream) ).add(user_stream)
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user_id, room_id):
user = str(user) new_user_stream = self.user_to_user_stream.get(user_id)
new_user_stream = self.user_to_user_stream.get(user)
if new_user_stream is not None: if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set()) room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream) room_streams.add(new_user_stream)

View File

@ -40,7 +40,7 @@ class ActionGenerator:
def handle_push_actions_for_event(self, event, context): def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "handle_push_actions_for_event"): with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event( bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store event, self.hs, self.store, context.current_state
) )
actions_by_user = yield bulk_evaluator.action_for_event_by_user( actions_by_user = yield bulk_evaluator.action_for_event_by_user(

View File

@ -14,84 +14,56 @@
# limitations under the License. # limitations under the License.
import logging import logging
import ujson as json
from twisted.internet import defer from twisted.internet import defer
from .baserules import list_with_base_rules
from .push_rule_evaluator import PushRuleEvaluatorForEvent from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes, Membership
from synapse.visibility import filter_events_for_clients from synapse.visibility import filter_events_for_clients
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def decode_rule_json(rule):
rule['conditions'] = json.loads(rule['conditions'])
rule['actions'] = json.loads(rule['actions'])
return rule
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_rules(room_id, user_ids, store): def _get_rules(room_id, user_ids, store):
rules_by_user = yield store.bulk_get_push_rules(user_ids) rules_by_user = yield store.bulk_get_push_rules(user_ids)
rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
rules_by_user = { rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
uid: list_with_base_rules([
decode_rule_json(rule_list)
for rule_list in rules_by_user.get(uid, [])
])
for uid in user_ids
}
# We apply the rules-enabled map here: bulk_get_push_rules doesn't
# fetch disabled rules, but this won't account for any server default
# rules the user has disabled, so we need to do this too.
for uid in user_ids:
if uid not in rules_enabled_by_user:
continue
user_enabled_map = rules_enabled_by_user[uid]
for i, rule in enumerate(rules_by_user[uid]):
rule_id = rule['rule_id']
if rule_id in user_enabled_map:
if rule.get('enabled', True) != bool(user_enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
rule['enabled'] = bool(user_enabled_map[rule_id])
rules_by_user[uid][i] = rule
defer.returnValue(rules_by_user) defer.returnValue(rules_by_user)
@defer.inlineCallbacks @defer.inlineCallbacks
def evaluator_for_event(event, hs, store): def evaluator_for_event(event, hs, store, current_state):
room_id = event.room_id room_id = event.room_id
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
# We also will want to generate notifs for other people in the room so # We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid # their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've # generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room. # sent a read receipt into the room.
all_in_room = yield store.get_users_in_room(room_id) local_users_in_room = set(
all_in_room = set(all_in_room) e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
and hs.is_mine_id(e.state_key)
)
receipts = yield store.get_receipts_for_room(room_id, "m.read") # users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield store.get_if_users_have_pushers(
local_users_in_room
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield store.get_users_with_read_receipts_in_room(room_id)
# any users with pushers must be ours: they have pushers # any users with pushers must be ours: they have pushers
user_ids = set(users_with_pushers) for uid in users_with_receipts:
for r in receipts: if uid in local_users_in_room:
if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room: user_ids.add(uid)
user_ids.add(r['user_id'])
# if this event is an invite event, we may need to run rules for the user # if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited # who's been invited, otherwise they won't get told they've been invited
@ -102,8 +74,6 @@ def evaluator_for_event(event, hs, store):
if has_pusher: if has_pusher:
user_ids.add(invited_user) user_ids.add(invited_user)
user_ids = list(user_ids)
rules_by_user = yield _get_rules(room_id, user_ids, store) rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator( defer.returnValue(BulkPushRuleEvaluator(
@ -141,7 +111,10 @@ class BulkPushRuleEvaluator:
self.store, user_tuples, [event], {event.event_id: current_state} self.store, user_tuples, [event], {event.event_id: current_state}
) )
room_members = yield self.store.get_users_in_room(self.room_id) room_members = set(
e.state_key for e in current_state.values()
if e.type == EventTypes.Member and e.membership == Membership.JOIN
)
evaluator = PushRuleEvaluatorForEvent(event, len(room_members)) evaluator = PushRuleEvaluatorForEvent(event, len(room_members))

View File

@ -13,29 +13,19 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.push.baserules import list_with_base_rules
from synapse.push.rulekinds import ( from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
) )
import copy import copy
import simplejson as json
def format_push_rules_for_user(user, rawrules, enabled_map): def format_push_rules_for_user(user, ruleslist):
"""Converts a list of rawrules and a enabled map into nested dictionaries """Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules""" to match the Matrix client-server format for push rules"""
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy # We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(list_with_base_rules(ruleslist)) ruleslist = copy.deepcopy(ruleslist)
rules = {'global': {}, 'device': {}} rules = {'global': {}, 'device': {}}
@ -60,9 +50,7 @@ def format_push_rules_for_user(user, rawrules, enabled_map):
template_rule = _rule_to_template(r) template_rule = _rule_to_template(r)
if template_rule: if template_rule:
if r['rule_id'] in enabled_map: if 'enabled' in r:
template_rule['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' in r:
template_rule['enabled'] = r['enabled'] template_rule['enabled'] = r['enabled']
else: else:
template_rule['enabled'] = True template_rule['enabled'] = True

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
@ -32,12 +33,20 @@ DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
# Each room maintains its own throttle counter, but each new mail notification # Each room maintains its own throttle counter, but each new mail notification
# sends the pending notifications for all rooms. # sends the pending notifications for all rooms.
THROTTLE_START_MS = 10 * 60 * 1000 THROTTLE_START_MS = 10 * 60 * 1000
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # (2 * 60 * 1000) * (2 ** 11) # ~3 days THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours # THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
# If no event triggers a notification for this long after the previous, # If no event triggers a notification for this long after the previous,
# the throttle is released. # the throttle is released.
THROTTLE_RESET_AFTER_MS = (2 * 60 * 1000) * (2 ** 11) # ~3 days # 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
# notification when things get going again...
THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
# does each email include all unread notifs, or just the ones which have happened
# since the last mail?
# XXX: this is currently broken as it includes ones from parted rooms(!)
INCLUDE_ALL_UNREAD_NOTIFS = False
class EmailPusher(object): class EmailPusher(object):
@ -65,7 +74,12 @@ class EmailPusher(object):
self.processing = False self.processing = False
if self.hs.config.email_enable_notifs: if self.hs.config.email_enable_notifs:
self.mailer = Mailer(self.hs) if 'data' in pusherdict and 'brand' in pusherdict['data']:
app_name = pusherdict['data']['brand']
else:
app_name = self.hs.config.email_app_name
self.mailer = Mailer(self.hs, app_name)
else: else:
self.mailer = None self.mailer = None
@ -79,7 +93,11 @@ class EmailPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
self.timed_call.cancel() try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering): def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
@ -126,9 +144,9 @@ class EmailPusher(object):
up logging, measures and guards against multiple instances of it up logging, measures and guards against multiple instances of it
being run. being run.
""" """
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
self.user_id, self.last_stream_ordering, self.max_stream_ordering fn = self.store.get_unread_push_actions_for_user_in_range_for_email
) unprocessed = yield fn(self.user_id, start, self.max_stream_ordering)
soonest_due_at = None soonest_due_at = None
@ -150,7 +168,6 @@ class EmailPusher(object):
# we then consider all previously outstanding notifications # we then consider all previously outstanding notifications
# to be delivered. # to be delivered.
# debugging:
reason = { reason = {
'room_id': push_action['room_id'], 'room_id': push_action['room_id'],
'now': self.clock.time_msec(), 'now': self.clock.time_msec(),
@ -165,16 +182,22 @@ class EmailPusher(object):
yield self.save_last_stream_ordering_and_success(max([ yield self.save_last_stream_ordering_and_success(max([
ea['stream_ordering'] for ea in unprocessed ea['stream_ordering'] for ea in unprocessed
])) ]))
yield self.sent_notif_update_throttle(
push_action['room_id'], push_action # we update the throttle on all the possible unprocessed push actions
) for ea in unprocessed:
yield self.sent_notif_update_throttle(
ea['room_id'], ea
)
break break
else: else:
if soonest_due_at is None or should_notify_at < soonest_due_at: if soonest_due_at is None or should_notify_at < soonest_due_at:
soonest_due_at = should_notify_at soonest_due_at = should_notify_at
if self.timed_call is not None: if self.timed_call is not None:
self.timed_call.cancel() try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None self.timed_call = None
if soonest_due_at is not None: if soonest_due_at is not None:
@ -263,5 +286,5 @@ class EmailPusher(object):
logger.info("Sending notif email for user %r", self.user_id) logger.info("Sending notif email for user %r", self.user_id)
yield self.mailer.send_notification_mail( yield self.mailer.send_notification_mail(
self.user_id, self.email, push_actions, reason self.app_id, self.user_id, self.email, push_actions, reason
) )

View File

@ -16,6 +16,7 @@
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
import logging import logging
import push_rule_evaluator import push_rule_evaluator
@ -38,6 +39,7 @@ class HttpPusher(object):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock() self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
self.user_id = pusherdict['user_name'] self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id'] self.app_id = pusherdict['app_id']
self.app_display_name = pusherdict['app_display_name'] self.app_display_name = pusherdict['app_display_name']
@ -108,7 +110,11 @@ class HttpPusher(object):
def on_stop(self): def on_stop(self):
if self.timed_call: if self.timed_call:
self.timed_call.cancel() try:
self.timed_call.cancel()
except (AlreadyCalled, AlreadyCancelled):
pass
self.timed_call = None
@defer.inlineCallbacks @defer.inlineCallbacks
def _process(self): def _process(self):
@ -140,7 +146,8 @@ class HttpPusher(object):
run once per pusher. run once per pusher.
""" """
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering self.user_id, self.last_stream_ordering, self.max_stream_ordering
) )
@ -237,7 +244,9 @@ class HttpPusher(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge): def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event) ctx = yield push_tools.get_context_for_event(
self.state_handler, event, self.user_id
)
d = { d = {
'notification': { 'notification': {
@ -269,8 +278,8 @@ class HttpPusher(object):
if 'content' in event: if 'content' in event:
d['notification']['content'] = event.content d['notification']['content'] = event.content
if len(ctx['aliases']): # We no longer send aliases separately, instead, we send the human
d['notification']['room_alias'] = ctx['aliases'][0] # readable name of the room, which may be an alias.
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name'] d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0: if 'name' in ctx and len(ctx['name']) > 0:

View File

@ -41,11 +41,14 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \ MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
"in the %s room..." "in the %(room)s room..."
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..." MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..." MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "There are some messages on %(app)s for you in the %(room)s room..." MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
MESSAGES_IN_ROOMS = "Here are some messages on %(app)s you may have missed..." MESSAGES_IN_ROOM_AND_OTHERS = \
"You have messages on %(app)s in the %(room)s room and others..."
MESSAGES_FROM_PERSON_AND_OTHERS = \
"You have messages on %(app)s from %(person)s and others..."
INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \ INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
"%(room)s room on %(app)s..." "%(room)s room on %(app)s..."
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..." INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
@ -75,12 +78,14 @@ ALLOWED_ATTRS = {
class Mailer(object): class Mailer(object):
def __init__(self, hs): def __init__(self, hs, app_name):
self.hs = hs self.hs = hs
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.auth_handler = self.hs.get_auth_handler()
self.state_handler = self.hs.get_state_handler() self.state_handler = self.hs.get_state_handler()
loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir) loader = jinja2.FileSystemLoader(self.hs.config.email_template_dir)
self.app_name = self.hs.config.email_app_name self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
env = jinja2.Environment(loader=loader) env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = self.mxc_to_http_filter env.filters["mxc_to_http"] = self.mxc_to_http_filter
@ -92,8 +97,16 @@ class Mailer(object):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_notification_mail(self, user_id, email_address, push_actions, reason): def send_notification_mail(self, app_id, user_id, email_address,
raw_from = email.utils.parseaddr(self.hs.config.email_notif_from)[1] push_actions, reason):
try:
from_string = self.hs.config.email_notif_from % {
"app": self.app_name
}
except TypeError:
from_string = self.hs.config.email_notif_from
raw_from = email.utils.parseaddr(from_string)[1]
raw_to = email.utils.parseaddr(email_address)[1] raw_to = email.utils.parseaddr(email_address)[1]
if raw_to == '': if raw_to == '':
@ -119,6 +132,8 @@ class Mailer(object):
user_display_name = yield self.store.get_profile_displayname( user_display_name = yield self.store.get_profile_displayname(
UserID.from_string(user_id).localpart UserID.from_string(user_id).localpart
) )
if user_display_name is None:
user_display_name = user_id
except StoreError: except StoreError:
user_display_name = user_id user_display_name = user_id
@ -128,9 +143,14 @@ class Mailer(object):
state_by_room[room_id] = room_state state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email # Run at most 3 of these at once: sync does 10 at a time but email
# notifs are much realtime than sync so we can afford to wait a bit. # notifs are much less realtime than sync so we can afford to wait a bit.
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3) yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
rooms_in_order.sort(
key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
)
rooms = [] rooms = []
for r in rooms_in_order: for r in rooms_in_order:
@ -139,17 +159,19 @@ class Mailer(object):
) )
rooms.append(roomvars) rooms.append(roomvars)
summary_text = self.make_summary_text( reason['room_name'] = calculate_room_name(
notifs_by_room, state_by_room, notif_events, user_id state_by_room[reason['room_id']], user_id, fallback_to_members=True
) )
reason['room_name'] = calculate_room_name( summary_text = self.make_summary_text(
state_by_room[reason['room_id']], user_id, fallback_to_members=False notifs_by_room, state_by_room, notif_events, user_id, reason
) )
template_vars = { template_vars = {
"user_display_name": user_display_name, "user_display_name": user_display_name,
"unsubscribe_link": self.make_unsubscribe_link(), "unsubscribe_link": self.make_unsubscribe_link(
user_id, app_id, email_address
),
"summary_text": summary_text, "summary_text": summary_text,
"app_name": self.app_name, "app_name": self.app_name,
"rooms": rooms, "rooms": rooms,
@ -164,7 +186,7 @@ class Mailer(object):
multipart_msg = MIMEMultipart('alternative') multipart_msg = MIMEMultipart('alternative')
multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text) multipart_msg['Subject'] = "[%s] %s" % (self.app_name, summary_text)
multipart_msg['From'] = self.hs.config.email_notif_from multipart_msg['From'] = from_string
multipart_msg['To'] = email_address multipart_msg['To'] = email_address
multipart_msg['Date'] = email.utils.formatdate() multipart_msg['Date'] = email.utils.formatdate()
multipart_msg['Message-ID'] = email.utils.make_msgid() multipart_msg['Message-ID'] = email.utils.make_msgid()
@ -251,14 +273,16 @@ class Mailer(object):
sender_state_event = room_state[("m.room.member", event.sender)] sender_state_event = room_state[("m.room.member", event.sender)]
sender_name = name_from_member_event(sender_state_event) sender_name = name_from_member_event(sender_state_event)
sender_avatar_url = sender_state_event.content["avatar_url"] sender_avatar_url = sender_state_event.content.get("avatar_url")
# 'hash' for deterministically picking default images: use # 'hash' for deterministically picking default images: use
# sender_hash % the number of default images to choose from # sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender) sender_hash = string_ordinal_total(event.sender)
msgtype = event.content.get("msgtype")
ret = { ret = {
"msgtype": event.content["msgtype"], "msgtype": msgtype,
"is_historical": event.event_id != notif['event_id'], "is_historical": event.event_id != notif['event_id'],
"id": event.event_id, "id": event.event_id,
"ts": event.origin_server_ts, "ts": event.origin_server_ts,
@ -267,9 +291,9 @@ class Mailer(object):
"sender_hash": sender_hash, "sender_hash": sender_hash,
} }
if event.content["msgtype"] == "m.text": if msgtype == "m.text":
self.add_text_message_vars(ret, event) self.add_text_message_vars(ret, event)
elif event.content["msgtype"] == "m.image": elif msgtype == "m.image":
self.add_image_message_vars(ret, event) self.add_image_message_vars(ret, event)
if "body" in event.content: if "body" in event.content:
@ -278,16 +302,17 @@ class Mailer(object):
return ret return ret
def add_text_message_vars(self, messagevars, event): def add_text_message_vars(self, messagevars, event):
if "format" in event.content: msgformat = event.content.get("format")
msgformat = event.content["format"]
else:
msgformat = None
messagevars["format"] = msgformat messagevars["format"] = msgformat
if msgformat == "org.matrix.custom.html": formatted_body = event.content.get("formatted_body")
messagevars["body_text_html"] = safe_markup(event.content["formatted_body"]) body = event.content.get("body")
else:
messagevars["body_text_html"] = safe_text(event.content["body"]) if msgformat == "org.matrix.custom.html" and formatted_body:
messagevars["body_text_html"] = safe_markup(formatted_body)
elif body:
messagevars["body_text_html"] = safe_text(body)
return messagevars return messagevars
@ -296,7 +321,8 @@ class Mailer(object):
return messagevars return messagevars
def make_summary_text(self, notifs_by_room, state_by_room, notif_events, user_id): def make_summary_text(self, notifs_by_room, state_by_room,
notif_events, user_id, reason):
if len(notifs_by_room) == 1: if len(notifs_by_room) == 1:
# Only one room has new stuff # Only one room has new stuff
room_id = notifs_by_room.keys()[0] room_id = notifs_by_room.keys()[0]
@ -371,9 +397,28 @@ class Mailer(object):
} }
else: else:
# Stuff's happened in multiple different rooms # Stuff's happened in multiple different rooms
return MESSAGES_IN_ROOMS % {
"app": self.app_name, # ...but we still refer to the 'reason' room which triggered the mail
} if reason['room_name'] is not None:
return MESSAGES_IN_ROOM_AND_OTHERS % {
"room": reason['room_name'],
"app": self.app_name,
}
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
sender_ids = list(set([
notif_events[n['event_id']].sender
for n in notifs_by_room[reason['room_id']]
]))
return MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events([
state_by_room[reason['room_id']][("m.room.member", s)]
for s in sender_ids
]),
"app": self.app_name,
}
def make_room_link(self, room_id): def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS # need /beta for Universal Links to work on iOS
@ -393,9 +438,18 @@ class Mailer(object):
notif['room_id'], notif['event_id'] notif['room_id'], notif['event_id']
) )
def make_unsubscribe_link(self): def make_unsubscribe_link(self, user_id, app_id, email_address):
# XXX: matrix.to params = {
return "https://vector.im/#/settings" "access_token": self.auth_handler.generate_delete_pusher_token(user_id),
"app_id": app_id,
"pushkey": email_address,
}
# XXX: make r0 once API is stable
return "%s_matrix/client/unstable/pushers/remove?%s" % (
self.hs.config.public_baseurl,
urllib.urlencode(params),
)
def mxc_to_http_filter(self, value, width, height, resize_method="crop"): def mxc_to_http_filter(self, value, width, height, resize_method="crop"):
if value[0:6] != "mxc://": if value[0:6] != "mxc://":

View File

@ -14,6 +14,9 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.util.presentable_names import (
calculate_room_name, name_from_member_event
)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -45,24 +48,21 @@ def get_badge_count(store, user_id):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(store, ev): def get_context_for_event(state_handler, ev, user_id):
name_aliases = yield store.get_room_name_and_aliases( ctx = {}
ev.room_id
)
ctx = {'aliases': name_aliases[1]} room_state = yield state_handler.get_current_state(ev.room_id)
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield store.get_current_state( # we no longer bother setting room_alias, and make room_name the
room_id=ev.room_id, # human-readable name instead, be that m.room.name, an alias or
event_type='m.room.member', # a list of people in the room
state_key=ev.user_id name = calculate_room_name(
room_state, user_id, fallback_to_single_member=False
) )
for mev in their_member_events_for_room: if name:
if mev.content['membership'] == 'join' and 'displayname' in mev.content: ctx['name'] = name
dn = mev.content['displayname']
if dn is not None: sender_state_event = room_state[("m.room.member", ev.sender)]
ctx['sender_display_name'] = dn ctx['sender_display_name'] = name_from_member_event(sender_state_event)
defer.returnValue(ctx) defer.returnValue(ctx)

View File

@ -48,6 +48,12 @@ CONDITIONAL_REQUIREMENTS = {
"Jinja2>=2.8": ["Jinja2>=2.8"], "Jinja2>=2.8": ["Jinja2>=2.8"],
"bleach>=1.4.2": ["bleach>=1.4.2"], "bleach>=1.4.2": ["bleach>=1.4.2"],
}, },
"ldap": {
"ldap3>=1.0": ["ldap3>=1.0"],
},
"psutil": {
"psutil>=2.0.0": ["psutil>=2.0.0"],
},
} }

View File

@ -0,0 +1,59 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PresenceResource(Resource):
"""
HTTP endpoint for marking users as syncing.
POST /_synapse/replication/presence HTTP/1.1
Content-Type: application/json
{
"process_id": "<process_id>",
"syncing_users": ["<user_id>"]
}
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.clock = hs.get_clock()
self.presence_handler = hs.get_presence_handler()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler()
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
process_id = content["process_id"]
syncing_user_ids = content["syncing_users"]
yield self.presence_handler.update_external_syncs(
process_id, set(syncing_user_ids)
)
respond_with_json_bytes(request, 200, "{}")

View File

@ -16,6 +16,7 @@
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource from synapse.replication.pusher_resource import PusherResource
from synapse.replication.presence_resource import PresenceResource
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
@ -115,6 +116,7 @@ class ReplicationResource(Resource):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.putChild("remove_pushers", PusherResource(hs)) self.putChild("remove_pushers", PusherResource(hs))
self.putChild("syncing_users", PresenceResource(hs))
def render_GET(self, request): def render_GET(self, request):
self._async_render_GET(request) self._async_render_GET(request)

View File

@ -15,7 +15,10 @@
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.account_data import AccountDataStore from synapse.storage.account_data import AccountDataStore
from synapse.storage.tags import TagsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedAccountDataStore(BaseSlavedStore): class SlavedAccountDataStore(BaseSlavedStore):
@ -25,6 +28,14 @@ class SlavedAccountDataStore(BaseSlavedStore):
self._account_data_id_gen = SlavedIdTracker( self._account_data_id_gen = SlavedIdTracker(
db_conn, "account_data_max_stream_id", "stream_id", db_conn, "account_data_max_stream_id", "stream_id",
) )
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache",
self._account_data_id_gen.get_current_token(),
)
get_account_data_for_user = (
AccountDataStore.__dict__["get_account_data_for_user"]
)
get_global_account_data_by_type_for_users = ( get_global_account_data_by_type_for_users = (
AccountDataStore.__dict__["get_global_account_data_by_type_for_users"] AccountDataStore.__dict__["get_global_account_data_by_type_for_users"]
@ -34,6 +45,16 @@ class SlavedAccountDataStore(BaseSlavedStore):
AccountDataStore.__dict__["get_global_account_data_by_type_for_user"] AccountDataStore.__dict__["get_global_account_data_by_type_for_user"]
) )
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
def stream_positions(self): def stream_positions(self):
result = super(SlavedAccountDataStore, self).stream_positions() result = super(SlavedAccountDataStore, self).stream_positions()
position = self._account_data_id_gen.get_current_token() position = self._account_data_id_gen.get_current_token()
@ -47,15 +68,33 @@ class SlavedAccountDataStore(BaseSlavedStore):
if stream: if stream:
self._account_data_id_gen.advance(int(stream["position"])) self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]: for row in stream["rows"]:
user_id, data_type = row[1:3] position, user_id, data_type = row[:3]
self.get_global_account_data_by_type_for_user.invalidate( self.get_global_account_data_by_type_for_user.invalidate(
(data_type, user_id,) (data_type, user_id,)
) )
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("room_account_data") stream = result.get("room_account_data")
if stream: if stream:
self._account_data_id_gen.advance(int(stream["position"])) self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_account_data_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
stream = result.get("tag_account_data") stream = result.get("tag_account_data")
if stream: if stream:
self._account_data_id_gen.advance(int(stream["position"])) self._account_data_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.get_tags_for_user.invalidate((user_id,))
self._account_data_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedAccountDataStore, self).process_replication(result)

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.config.appservice import load_appservices
class SlavedApplicationServiceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedApplicationServiceStore, self).__init__(db_conn, hs)
self.services_cache = load_appservices(
hs.config.server_name,
hs.config.app_service_config_files
)
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
].orig

View File

@ -18,11 +18,11 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore from synapse.storage.state import StateStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json import ujson as json
@ -57,10 +57,12 @@ class SlavedEventStore(BaseSlavedStore):
"EventsRoomStreamChangeCache", min_event_val, "EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill, prefilled_cache=event_cache_prefill,
) )
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
)
# Cached functions can't be accessed through a class instance so we need # Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them. # to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[ get_latest_event_ids_in_room = EventFederationStore.__dict__[
@ -87,9 +89,15 @@ class SlavedEventStore(BaseSlavedStore):
_get_state_group_from_group = ( _get_state_group_from_group = (
StateStore.__dict__["_get_state_group_from_group"] StateStore.__dict__["_get_state_group_from_group"]
) )
get_recent_event_ids_for_room = (
StreamStore.__dict__["get_recent_event_ids_for_room"]
)
get_unread_push_actions_for_user_in_range = ( get_unread_push_actions_for_user_in_range_for_http = (
DataStore.get_unread_push_actions_for_user_in_range.__func__ DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__
)
get_unread_push_actions_for_user_in_range_for_email = (
DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__
) )
get_push_action_users_in_range = ( get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__ DataStore.get_push_action_users_in_range.__func__
@ -109,24 +117,25 @@ class SlavedEventStore(BaseSlavedStore):
DataStore.get_room_events_stream_for_room.__func__ DataStore.get_room_events_stream_for_room.__func__
) )
get_events_around = DataStore.get_events_around.__func__ get_events_around = DataStore.get_events_around.__func__
get_state_for_event = DataStore.get_state_for_event.__func__
get_state_for_events = DataStore.get_state_for_events.__func__ get_state_for_events = DataStore.get_state_for_events.__func__
get_state_groups = DataStore.get_state_groups.__func__ get_state_groups = DataStore.get_state_groups.__func__
get_recent_events_for_room = DataStore.get_recent_events_for_room.__func__
get_room_events_stream_for_rooms = (
DataStore.get_room_events_stream_for_rooms.__func__
)
get_stream_token_for_event = DataStore.get_stream_token_for_event.__func__
_set_before_and_after = DataStore._set_before_and_after _set_before_and_after = staticmethod(DataStore._set_before_and_after)
_get_events = DataStore._get_events.__func__ _get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__ _get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__ _invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_parse_events_txn = DataStore._parse_events_txn.__func__
_get_events_txn = DataStore._get_events_txn.__func__
_get_event_txn = DataStore._get_event_txn.__func__
_enqueue_events = DataStore._enqueue_events.__func__ _enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__ _do_fetch = DataStore._do_fetch.__func__
_fetch_events_txn = DataStore._fetch_events_txn.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__ _fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__ _get_event_from_row = DataStore._get_event_from_row.__func__
_get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
_get_rooms_for_user_where_membership_is_txn = ( _get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__ DataStore._get_rooms_for_user_where_membership_is_txn.__func__
) )
@ -136,6 +145,15 @@ class SlavedEventStore(BaseSlavedStore):
_get_events_around_txn = DataStore._get_events_around_txn.__func__ _get_events_around_txn = DataStore._get_events_around_txn.__func__
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
get_backfill_events = DataStore.get_backfill_events.__func__
_get_backfill_events = DataStore._get_backfill_events.__func__
get_missing_events = DataStore.get_missing_events.__func__
_get_missing_events = DataStore._get_missing_events.__func__
get_auth_chain = DataStore.get_auth_chain.__func__
get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__
_get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__
def stream_positions(self): def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions() result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token() result["events"] = self._stream_id_gen.get_current_token()
@ -194,7 +212,6 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate_all() self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,)) # self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id) self._invalidate_get_event_cache(event.event_id)
@ -220,9 +237,9 @@ class SlavedEventStore(BaseSlavedStore):
self.get_rooms_for_user.invalidate((event.state_key,)) self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,)) # self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,)) self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed( self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering event.state_key, event.internal_metadata.stream_ordering
# ) )
self.get_invited_rooms_for_user.invalidate((event.state_key,)) self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state(): if not event.is_state():
@ -238,9 +255,3 @@ class SlavedEventStore(BaseSlavedStore):
self._get_current_state_for_key.invalidate(( self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key event.room_id, event.type, event.state_key
)) ))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View File

@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage.filtering import FilteringStore
class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedFilteringStore, self).__init__(db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.keys import KeyStore
class SlavedKeyStore(BaseSlavedStore):
_get_server_verify_key = KeyStore.__dict__[
"_get_server_verify_key"
]
get_server_verify_keys = DataStore.get_server_verify_keys.__func__
store_server_verify_key = DataStore.store_server_verify_key.__func__
get_server_certificate = DataStore.get_server_certificate.__func__
store_server_certificate = DataStore.store_server_certificate.__func__
get_server_keys_json = DataStore.get_server_keys_json.__func__
store_server_keys_json = DataStore.store_server_keys_json.__func__

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
self._presence_id_gen = SlavedIdTracker(
db_conn, "presence_stream", "stream_id",
)
self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
_get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication(self, result):
stream = result.get("presence")
if stream:
self._presence_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
position, user_id = row[:2]
self.presence_stream_cache.entity_has_changed(
user_id, position
)
return super(SlavedPresenceStore, self).process_replication(result)

View File

@ -0,0 +1,67 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .events import SlavedEventStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.push_rule import PushRuleStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedPushRuleStore(SlavedEventStore):
def __init__(self, db_conn, hs):
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache",
self._push_rules_stream_id_gen.get_current_token(),
)
get_push_rules_for_user = PushRuleStore.__dict__["get_push_rules_for_user"]
get_push_rules_enabled_for_user = (
PushRuleStore.__dict__["get_push_rules_enabled_for_user"]
)
have_push_rules_changed_for_user = (
DataStore.have_push_rules_changed_for_user.__func__
)
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def stream_positions(self):
result = super(SlavedPushRuleStore, self).stream_positions()
result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("push_rules")
if stream:
for row in stream["rows"]:
position = row[0]
user_id = row[2]
self.get_push_rules_for_user.invalidate((user_id,))
self.get_push_rules_enabled_for_user.invalidate((user_id,))
self.push_rules_stream_cache.entity_has_changed(
user_id, position
)
self._push_rules_stream_id_gen.advance(int(stream["position"]))
return super(SlavedPushRuleStore, self).process_replication(result)

View File

@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore from synapse.storage.receipts import ReceiptsStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
# So, um, we want to borrow a load of functions intended for reading from # So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the # a DataStore, but we don't want to take functions that either write to the
@ -37,11 +38,28 @@ class SlavedReceiptsStore(BaseSlavedStore):
db_conn, "receipts_linearized", "stream_id" db_conn, "receipts_linearized", "stream_id"
) )
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"] get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_linearized_receipts_for_room = (
ReceiptsStore.__dict__["get_linearized_receipts_for_room"]
)
_get_linearized_receipts_for_rooms = (
ReceiptsStore.__dict__["_get_linearized_receipts_for_rooms"]
)
get_last_receipt_event_id_for_user = (
ReceiptsStore.__dict__["get_last_receipt_event_id_for_user"]
)
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__ get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__ get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
get_linearized_receipts_for_rooms = (
DataStore.get_linearized_receipts_for_rooms.__func__
)
def stream_positions(self): def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions() result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token() result["receipts"] = self._receipts_id_gen.get_current_token()
@ -52,10 +70,15 @@ class SlavedReceiptsStore(BaseSlavedStore):
if stream: if stream:
self._receipts_id_gen.advance(int(stream["position"])) self._receipts_id_gen.advance(int(stream["position"]))
for row in stream["rows"]: for row in stream["rows"]:
room_id, receipt_type, user_id = row[1:4] position, room_id, receipt_type, user_id = row[:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id) self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
self._receipts_stream_cache.entity_has_changed(room_id, position)
return super(SlavedReceiptsStore, self).process_replication(result) return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type)) self.get_receipts_for_user.invalidate((user_id, receipt_type))
self.get_linearized_receipts_for_room.invalidate_many((room_id,))
self.get_last_receipt_event_id_for_user.invalidate(
(user_id, room_id, receipt_type)
)

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
from synapse.storage.registration import RegistrationStore
class SlavedRegistrationStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedRegistrationStore, self).__init__(db_conn, hs)
# TODO: use the cached version and invalidate deleted tokens
get_user_by_access_token = RegistrationStore.__dict__[
"get_user_by_access_token"
].orig
_query_for_auth = DataStore._query_for_auth.__func__

View File

@ -0,0 +1,21 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from synapse.storage import DataStore
class RoomStore(BaseSlavedStore):
get_public_room_ids = DataStore.get_public_room_ids.__func__

Some files were not shown because too many files have changed in this diff Show More