Merge branch 'release-v0.9.0' of github.com:matrix-org/synapse

This commit is contained in:
Erik Johnston 2015-05-07 19:07:00 +01:00
commit 89c0cd4acc
200 changed files with 9548 additions and 3091 deletions

37
AUTHORS.rst Normal file
View File

@ -0,0 +1,37 @@
Erik Johnston <erik at matrix.org>
* HS core
* Federation API impl
Mark Haines <mark at matrix.org>
* HS core
* Crypto
* Content repository
* CS v2 API impl
Kegan Dougal <kegan at matrix.org>
* HS core
* CS v1 API impl
* AS API impl
Paul "LeoNerd" Evans <paul at matrix.org>
* HS core
* Presence
* Typing Notifications
* Performance metrics and caching layer
Dave Baker <dave at matrix.org>
* Push notifications
* Auth CS v2 impl
Matthew Hodgson <matthew at matrix.org>
* General doc & housekeeping
* Vertobot/vertobridge matrix<->verto PoC
Emmanuel Rohee <manu at matrix.org>
* Supporting iOS clients (testability and fallback registration)
Turned to Dust <dwinslow86 at gmail.com>
* ArchLinux installation instructions
Brabo <brabo at riseup.net>
* Installation instruction fixes

View File

@ -1,3 +1,53 @@
Changes in synapse v0.9.0 (2015-05-07)
======================================
General:
* Add support for using a PostgreSQL database instead of SQLite. See
`docs/postgres.rst`_ for details.
* Add password change and reset APIs. See `Registration`_ in the spec.
* Fix memory leak due to not releasing stale notifiers - SYN-339.
* Fix race in caches that occasionally caused some presence updates to be
dropped - SYN-369.
* Check server name has not changed on restart.
Federation:
* Add key distribution mechanisms for fetching public keys of unavailable
remote home servers. See `Retrieving Server Keys`_ in the spec.
Configuration:
* Add support for multiple config files.
* Add support for dictionaries in config files.
* Remove support for specifying config options on the command line, except
for:
* ``--daemonize`` - Daemonize the home server.
* ``--manhole`` - Turn on the twisted telnet manhole service on the given
port.
* ``--database-path`` - The path to a sqlite database to use.
* ``--verbose`` - The verbosity level.
* ``--log-file`` - File to log to.
* ``--log-config`` - Python logging config file.
* ``--enable-registration`` - Enable registration for new users.
Application services:
* Reliably retry sending of events from Synapse to application services, as per
`Application Services`_ spec.
* Application services can no longer register via the ``/register`` API,
instead their configuration should be saved to a file and listed in the
synapse ``app_service_config_files`` config option. The AS configuration file
has the same format as the old ``/register`` request.
See `docs/application_services.rst`_ for more information.
.. _`docs/postgres.rst`: docs/postgres.rst
.. _`docs/application_services.rst`: docs/application_services.rst
.. _`Registration`: https://github.com/matrix-org/matrix-doc/blob/master/specification/10_client_server_api.rst#registration
.. _`Retrieving Server Keys`: https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys
.. _`Application Services`: https://github.com/matrix-org/matrix-doc/blob/0c6bd9/specification/25_application_service_api.rst#home-server---application-service-api
Changes in synapse v0.8.1 (2015-03-18) Changes in synapse v0.8.1 (2015-03-18)
====================================== ======================================

118
CONTRIBUTING.rst Normal file
View File

@ -0,0 +1,118 @@
Contributing code to Matrix
===========================
Everyone is welcome to contribute code to Matrix
(https://github.com/matrix-org), provided that they are willing to license
their contributions under the same license as the project itself. We follow a
simple 'inbound=outbound' model for contributions: the act of submitting an
'inbound' contribution means that the contributor agrees to license the code
under the same terms as the project's overall 'outbound' license - in our
case, this is almost always Apache Software License v2 (see LICENSE).
How to contribute
~~~~~~~~~~~~~~~~~
The preferred and easiest way to contribute changes to Matrix is to fork the
relevant project on github, and then create a pull request to ask us to pull
your changes into our repo
(https://help.github.com/articles/using-pull-requests/)
**The single biggest thing you need to know is: please base your changes on
the develop branch - /not/ master.**
We use the master branch to track the most recent release, so that folks who
blindly clone the repo and automatically check out master get something that
works. Develop is the unstable branch where all the development actually
happens: the workflow is that contributors should fork the develop branch to
make a 'feature' branch for a particular contribution, and then make a pull
request to merge this back into the matrix.org 'official' develop branch. We
use github's pull request workflow to review the contribution, and either ask
you to make any refinements needed or merge it and make them ourselves. The
changes will then land on master when we next do a release.
We use Jenkins for continuous integration (http://matrix.org/jenkins), and
typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open.
Code style
~~~~~~~~~~
All Matrix projects have a well-defined code-style - and sometimes we've even
got as far as documenting it... For instance, synapse's code style doc lives
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.rst.
Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
Attribution
~~~~~~~~~~~
Everyone who contributes anything to Matrix is welcome to be listed in the
AUTHORS.rst file for the project in question. Please feel free to include a
change to AUTHORS.rst in your pull request to list yourself and a short
description of the area(s) you've worked on. Also, we sometimes have swag to
give away to contributors - if you feel that Matrix-branded apparel is missing
from your life, please mail us your shipping address to matrix at matrix.org and we'll try to fix it :)
Sign off
~~~~~~~~
In order to have a concrete record that your contribution is intentional
and you agree to license it under the same terms as the project's license, we've adopted the
same lightweight approach that the Linux Kernel
(https://www.kernel.org/doc/Documentation/SubmittingPatches), Docker
(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other
projects use: the DCO (Developer Certificate of Origin:
http://developercertificate.org/). This is a simple declaration that you wrote
the contribution or otherwise have the right to contribute it to Matrix::
Developer Certificate of Origin
Version 1.1
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
660 York Street, Suite 102,
San Francisco, CA 94110 USA
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
Developer's Certificate of Origin 1.1
By making a contribution to this project, I certify that:
(a) The contribution was created in whole or in part by me and I
have the right to submit it under the open source license
indicated in the file; or
(b) The contribution is based upon previous work that, to the best
of my knowledge, is covered under an appropriate open source
license and I have the right under that license to submit that
work with modifications, whether created in whole or in part
by me, under the same open source license (unless I am
permitted to submit under a different license), as indicated
in the file; or
(c) The contribution was provided directly to me by some other
person who certified (a), (b) or (c) and I have not modified
it.
(d) I understand and agree that this project and the contribution
are public and that a record of the contribution (including all
personal information I submit with it, including my sign-off) is
maintained indefinitely and may be redistributed consistent with
this project or the open source license(s) involved.
If you agree to this for your contribution, then all that's needed is to
include the line in your commit or pull request comment::
Signed-off-by: Your Name <your@email.example.org>
...using your real name; unfortunately pseudonyms and anonymous contributions
can't be accepted. Git makes this trivial - just use the -s flag when you do
``git commit``, having first set ``user.name`` and ``user.email`` git configs
(which you should have done anyway :)
Conclusion
~~~~~~~~~~
That's it! Matrix is a very open and collaborative project as you might expect given our obsession with open communication. If we're going to successfully matrix together all the fragmented communication technologies out there we are reliant on contributions and collaboration from the community to do so. So please get involved - and we hope you have as much fun hacking on Matrix as we do!

View File

@ -20,7 +20,7 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be ``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by the web client at http://matrix.org/alpha or via an IRC bridge at accessed by the web client at http://matrix.org/beta or via an IRC bridge at
irc://irc.freenode.net/matrix. irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it Synapse is currently in rapid development, but as of version 0.5 we believe it
@ -69,24 +69,30 @@ Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
web client demo implemented in AngularJS) and cmdclient (a basic Python web client demo implemented in AngularJS) and cmdclient (a basic Python
command line utility which lets you easily see what the JSON APIs are up to). command line utility which lets you easily see what the JSON APIs are up to).
Meanwhile, iOS and Android SDKs and clients are currently in development and available from: Meanwhile, iOS and Android SDKs and clients are available from:
- https://github.com/matrix-org/matrix-ios-sdk - https://github.com/matrix-org/matrix-ios-sdk
- https://github.com/matrix-org/matrix-ios-kit
- https://github.com/matrix-org/matrix-ios-console
- https://github.com/matrix-org/matrix-android-sdk - https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at We'd like to invite you to join #matrix:matrix.org (via
http://matrix.org/docs/spec, experiment with the APIs and the demo https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at
clients, and report any bugs via http://matrix.org/jira. https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api,
experiment with the APIs and the demo clients, and report any bugs via
https://matrix.org/jira.
Thanks for using Matrix! Thanks for using Matrix!
[1] End-to-end encryption is currently in development [1] End-to-end encryption is currently in development
Homeserver Installation Synapse Installation
======================= ====================
Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OSX) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
@ -118,6 +124,9 @@ To install the synapse homeserver run::
This installs synapse, along with the libraries it uses, into a virtual This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. environment under ``~/.synapse``.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse $ cd ~/.synapse
@ -128,8 +137,18 @@ To set up your homeserver, run (in your virtualenv, as before)::
Substituting your host and domain name as appropriate. Substituting your host and domain name as appropriate.
This will generate you a config file that you can then customise, but it will
also generate a set of keys for you. These keys will allow your Home Server to
identify itself to other Home Servers, so don't lose or delete them. It would be
wise to back them up somewhere safe. If, for whatever reason, you do need to
change your Home Server's keys, you may find that other Home Servers have the
old key cached. If you update the signing key, you should change the name of the
key in the <server name>.signing.key file (the second word, which by default is
, 'auto') to something different.
By default, registration of new users is disabled. You can either enable By default, registration of new users is disabled. You can either enable
registration in the config (it is then recommended to also set up CAPTCHA), or registration in the config by specifying ``enable_registration: true``
(it is then recommended to also set up CAPTCHA), or
you can use the command line to register new users:: you can use the command line to register new users::
$ source ~/.synapse/bin/activate $ source ~/.synapse/bin/activate
@ -142,36 +161,51 @@ you can use the command line to register new users::
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details. a TURN server. See docs/turn-howto.rst for details.
Troubleshooting Installation Using PostgreSQL
---------------------------- ================
Synapse requires pip 1.7 or later, so if your OS provides too old a version and As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an
you get errors about ``error: no such option: --process-dependency-links`` you alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has
may need to manually upgrade it:: traditionally used for convenience and simplicity.
$ sudo pip install --upgrade pip The advantages of Postgres include:
If pip crashes mid-installation for reason (e.g. lost terminal), pip may * significant performance improvements due to the superior threading and
refuse to run until you remove the temporary installation directory it caching model, smarter query optimiser
created. To reset the installation:: * allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
$ rm -rf /tmp/pip_install_matrix For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_.
pip seems to leak *lots* of memory during installation. For instance, a Linux Running Synapse
host with 512MB of RAM may run out of memory whilst installing Twisted. If this ===============
happens, you will have to individually install the dependencies which are
failing, e.g.::
$ pip install twisted To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and::
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you $ cd ~/.synapse
will need to export CFLAGS=-Qunused-arguments. $ source ./bin/activate
$ synctl start
Platform Specific Instructions
==============================
ArchLinux ArchLinux
--------- ---------
Installation on ArchLinux may encounter a few hiccups as Arch defaults to The quickest way to get up and running with ArchLinux is probably with Ivan
python 3, but synapse currently assumes python 2.7 by default. Shapovalov's AUR package from
https://aur.archlinux.org/packages/matrix-synapse/, which should pull in all
the necessary dependencies.
Alternatively, to install using pip a few changes may be needed as ArchLinux
defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):: pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
@ -191,7 +225,7 @@ installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt $ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt $ sudo pip2.7 install py-bcrypt
During setup of homeserver you need to call python2.7 directly again:: During setup of Synapse you need to call python2.7 directly again::
$ cd ~/.synapse $ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \ $ python2.7 -m synapse.app.homeserver \
@ -232,15 +266,33 @@ Troubleshooting:
you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find
it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a`` it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a``
Running Your Homeserver Troubleshooting
======================= ===============
To actually run your new homeserver, pick a working directory for Synapse to run Troubleshooting Installation
(e.g. ``~/.synapse``), and:: ----------------------------
$ cd ~/.synapse Synapse requires pip 1.7 or later, so if your OS provides too old a version and
$ source ./bin/activate you get errors about ``error: no such option: --process-dependency-links`` you
$ synctl start may need to manually upgrade it::
$ sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
$ pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running Troubleshooting Running
----------------------- -----------------------
@ -261,25 +313,25 @@ fix try re-installing from PyPI or directly from
$ pip install --user https://github.com/pyca/pynacl/tarball/master $ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux ArchLinux
--------- ~~~~~~~~~
If running `$ synctl start` fails with 'returned non-zero exit status 1', If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as:: you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid $ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
...or by editing synctl with the correct python executable. ...or by editing synctl with the correct python executable.
Homeserver Development Synapse Development
====================== ===================
To check out a homeserver for development, clone the git repo into a working To check out a synapse for development, clone the git repo into a working
directory of your choice:: directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git $ git clone https://github.com/matrix-org/synapse.git
$ cd synapse $ cd synapse
The homeserver has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
$ virtualenv env $ virtualenv env
@ -290,7 +342,7 @@ to install using pip and a virtualenv::
This will run a process of downloading and installing all the needed This will run a process of downloading and installing all the needed
dependencies into a virtual env. dependencies into a virtual env.
Once this is done, you may wish to run the homeserver's unit tests, to Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
$ python setup.py test $ python setup.py test
@ -302,10 +354,10 @@ This should end with a 'PASSED' result::
PASSED (successes=143) PASSED (successes=143)
Upgrading an existing homeserver Upgrading an existing Synapse
================================ =============================
IMPORTANT: Before upgrading an existing homeserver to a new version, please IMPORTANT: Before upgrading an existing synapse to a new version, please
refer to UPGRADE.rst for any additional instructions. refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g. Otherwise, simply re-install the new codebase over the current one - e.g.
@ -348,7 +400,7 @@ and port where the server is running. (At the current time synapse does not
support clustering multiple servers into a single logical homeserver). The DNS support clustering multiple servers into a single logical homeserver). The DNS
record would then look something like:: record would then look something like::
$ dig -t srv _matrix._tcp.machine.my.domaine.name $ dig -t srv _matrix._tcp.machine.my.domain.name
_matrix._tcp IN SRV 10 0 8448 machine.my.domain.name. _matrix._tcp IN SRV 10 0 8448 machine.my.domain.name.
@ -357,7 +409,6 @@ SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \ $ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \ --server-name YOURDOMAIN \
--bind-port 8448 \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml $ python -m synapse.app.homeserver --config-path homeserver.yaml
@ -366,12 +417,8 @@ SRV record, as that is the name other machines will expect it to have::
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to
increase the verbosity of logging output; at least for initial testing. increase the verbosity of logging output; at least for initial testing.
For the initial alpha release, the homeserver is not speaking TLS for Running a Demo Federation of Synapses
either client-server or server-server traffic for ease of debugging. We have -------------------------------------
also not spent any time yet getting the homeserver to run behind loadbalancers.
Running a Demo Federation of Homeservers
----------------------------------------
If you want to get up and running quickly with a trio of homeservers in a If you want to get up and running quickly with a trio of homeservers in a
private federation (``localhost:8080``, ``localhost:8081`` and private federation (``localhost:8080``, ``localhost:8081`` and
@ -406,7 +453,10 @@ account. Your name will take the form of::
Specify your desired localpart in the topmost box of the "Register for an Specify your desired localpart in the topmost box of the "Register for an
account" form, and click the "Register" button. Hostnames can contain ports if account" form, and click the "Register" button. Hostnames can contain ports if
required due to lack of SRV records (e.g. @matthew:localhost:8448 on an required due to lack of SRV records (e.g. @matthew:localhost:8448 on an
internal synapse sandbox running on localhost) internal synapse sandbox running on localhost).
If registration fails, you may need to enable it in the homeserver (see
`Synapse Installation`_ above)
Logging In To An Existing Account Logging In To An Existing Account
@ -432,7 +482,7 @@ track 3PID logins and publish end-user public keys.
It's currently early days for identity servers as Matrix is not yet using 3PIDs It's currently early days for identity servers as Matrix is not yet using 3PIDs
as the primary means of identity and E2E encryption is not complete. As such, as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (http://matrix.org:8090) at the current we are running a single identity server (https://matrix.org) at the current
time. time.

View File

@ -1,3 +1,37 @@
Upgrading to v0.x.x
===================
Application services have had a breaking API change in this version.
They can no longer register themselves with a home server using the AS HTTP API. This
decision was made because a compromised application service with free reign to register
any regex in effect grants full read/write access to the home server if a regex of ``.*``
is used. An attack where a compromised AS re-registers itself with ``.*`` was deemed too
big of a security risk to ignore, and so the ability to register with the HS remotely has
been removed.
It has been replaced by specifying a list of application service registrations in
``homeserver.yaml``::
app_service_config_files: ["registration-01.yaml", "registration-02.yaml"]
Where ``registration-01.yaml`` looks like::
url: <String> # e.g. "https://my.application.service.com"
as_token: <String>
hs_token: <String>
sender_localpart: <String> # This is a new field which denotes the user_id localpart when using the AS token
namespaces:
users:
- exclusive: <Boolean>
regex: <String> # e.g. "@prefix_.*"
aliases:
- exclusive: <Boolean>
regex: <String>
rooms:
- exclusive: <Boolean>
regex: <String>
Upgrading to v0.8.0 Upgrading to v0.8.0
=================== ===================

93
contrib/scripts/kick_users.py Executable file
View File

@ -0,0 +1,93 @@
#!/usr/bin/env python
from argparse import ArgumentParser
import json
import requests
import sys
import urllib
def _mkurl(template, kws):
for key in kws:
template = template.replace(key, kws[key])
return template
def main(hs, room_id, access_token, user_id_prefix, why):
if not why:
why = "Automated kick."
print "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix)
room_state_url = _mkurl(
"$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN",
{
"$HS": hs,
"$ROOM": room_id,
"$TOKEN": access_token
}
)
print "Getting room state => %s" % room_state_url
res = requests.get(room_state_url)
print "HTTP %s" % res.status_code
state_events = res.json()
if "error" in state_events:
print "FATAL"
print state_events
return
kick_list = []
room_name = room_id
for event in state_events:
if not event["type"] == "m.room.member":
if event["type"] == "m.room.name":
room_name = event["content"].get("name")
continue
if not event["content"].get("membership") == "join":
continue
if event["state_key"].startswith(user_id_prefix):
kick_list.append(event["state_key"])
if len(kick_list) == 0:
print "No user IDs match the prefix '%s'" % user_id_prefix
return
print "The following user IDs will be kicked from %s" % room_name
for uid in kick_list:
print uid
doit = raw_input("Continue? [Y]es\n")
if len(doit) > 0 and doit.lower() == 'y':
print "Kicking members..."
# encode them all
kick_list = [urllib.quote(uid) for uid in kick_list]
for uid in kick_list:
kick_url = _mkurl(
"$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN",
{
"$HS": hs,
"$UID": uid,
"$ROOM": room_id,
"$TOKEN": access_token
}
)
kick_body = {
"membership": "leave",
"reason": why
}
print "Kicking %s" % uid
res = requests.put(kick_url, data=json.dumps(kick_body))
if res.status_code != 200:
print "ERROR: HTTP %s" % res.status_code
if res.json().get("error"):
print "ERROR: JSON %s" % res.json()
if __name__ == "__main__":
parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.")
parser.add_argument("-u","--user-id",help="The user ID prefix e.g. '@irc_'")
parser.add_argument("-t","--token",help="Your access_token")
parser.add_argument("-r","--room",help="The room ID to kick members in")
parser.add_argument("-s","--homeserver",help="The base HS url e.g. http://matrix.org")
parser.add_argument("-w","--why",help="Reason for the kick. Optional.")
args = parser.parse_args()
if not args.room or not args.token or not args.user_id or not args.homeserver:
parser.print_help()
sys.exit(1)
else:
main(args.homeserver, args.room, args.token, args.user_id, args.why)

View File

@ -0,0 +1,23 @@
version: 1
# In systemd's journal, loglevel is implicitly stored, so let's omit it
# from the message text.
formatters:
journal_fmt:
format: '%(name)s: [%(request)s] %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
journal:
class: systemd.journal.JournalHandler
formatter: journal_fmt
filters: [context]
SYSLOG_IDENTIFIER: synapse
root:
level: INFO
handlers: [journal]

View File

@ -0,0 +1,16 @@
# This assumes that Synapse has been installed as a system package
# (e.g. https://aur.archlinux.org/packages/matrix-synapse/ for ArchLinux)
# rather than in a user home directory or similar under virtualenv.
[Unit]
Description=Synapse Matrix homeserver
[Service]
Type=simple
User=synapse
Group=synapse
WorkingDirectory=/var/lib/synapse
ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml
[Install]
WantedBy=multi-user.target

View File

@ -7,6 +7,9 @@ matrix:
matrix-bot: matrix-bot:
user_id: '@vertobot:matrix.org' user_id: '@vertobot:matrix.org'
password: '' password: ''
domain: 'matrix.org"
as_url: 'http://localhost:8009'
as_token: 'vertobot123'
verto-bot: verto-bot:
host: webrtc.freeswitch.org host: webrtc.freeswitch.org

View File

@ -16,29 +16,30 @@ if [ $# -eq 1 ]; then
fi fi
fi fi
export PYTHONPATH=$(readlink -f $(pwd))
echo $PYTHONPATH
for port in 8080 8081 8082; do for port in 8080 8081 8082; do
echo "Starting server on port $port... " echo "Starting server on port $port... "
https_port=$((port + 400)) https_port=$((port + 400))
mkdir -p demo/$port
pushd demo/$port
#rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--config-path "demo/etc/$port.config" \
-p "$https_port" \
--unsecure-port "$port" \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
-f "$DIR/$port.log" \ --config-path "$DIR/etc/$port.config" \
-d "$DIR/$port.db" \
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
--media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "demo/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
-D \
-vv \ -vv \
popd
done done
cd "$CWD" cd "$CWD"

31
docs/CAPTCHA_SETUP Normal file
View File

@ -0,0 +1,31 @@
Captcha can be enabled for this home server. This file explains how to do that.
The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google.
Getting keys
------------
Requires a public/private key pair from:
https://developers.google.com/recaptcha/
Setting ReCaptcha Keys
----------------------
The keys are a config option on the home server config. If they are not
visible, you can generate them via --generate-config. Set the following value:
recaptcha_public_key: YOUR_PUBLIC_KEY
recaptcha_private_key: YOUR_PRIVATE_KEY
In addition, you MUST enable captchas via:
enable_registration_captcha: true
Configuring IP used for auth
----------------------------
The ReCaptcha API requires that the IP address of the user who solved the
captcha is sent. If the client is connecting through a proxy or load balancer,
it may be required to use the X-Forwarded-For (XFF) header instead of the origin
IP address. This can be configured as an option on the home server like so:
captcha_ip_origin_is_x_forwarded: true

View File

@ -0,0 +1,36 @@
Registering an Application Service
==================================
The registration of new application services depends on the homeserver used.
In synapse, you need to create a new configuration file for your AS and add it
to the list specified under the ``app_service_config_files`` config
option in your synapse config.
For example:
.. code-block:: yaml
app_service_config_files:
- /home/matrix/.synapse/<your-AS>.yaml
The format of the AS configuration file is as follows:
.. code-block:: yaml
url: <base url of AS>
as_token: <token AS will add to requests to HS>
hs_token: <token HS will ad to requests to AS>
sender_localpart: <localpart of AS user>
namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
See the spec_ for further details on how application services work.
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api

50
docs/metrics-howto.rst Normal file
View File

@ -0,0 +1,50 @@
How to monitor Synapse metrics using Prometheus
===============================================
1: Install prometheus:
Follow instructions at http://prometheus.io/docs/introduction/install/
2: Enable synapse metrics:
Simply setting a (local) port number will enable it. Pick a port.
prometheus itself defaults to 9090, so starting just above that for
locally monitored services seems reasonable. E.g. 9092:
Add to homeserver.yaml
metrics_port: 9092
Restart synapse
3: Check out synapse-prometheus-config
https://github.com/matrix-org/synapse-prometheus-config
4: Add ``synapse.html`` and ``synapse.rules``
The ``.html`` file needs to appear in prometheus's ``consoles`` directory,
and the ``.rules`` file needs to be invoked somewhere in the main config
file. A symlink to each from the git checkout into the prometheus directory
might be easiest to ensure ``git pull`` keeps it updated.
5: Add a prometheus target for synapse
This is easiest if prometheus runs on the same machine as synapse, as it can
then just use localhost::
global: {
rule_file: "synapse.rules"
}
job: {
name: "synapse"
target_group: {
target: "http://localhost:9092/"
}
}
6: Start prometheus::
./prometheus -config.file=prometheus.conf
7: Wait a few seconds for it to start and perform the first scrape,
then visit the console:
http://server-where-prometheus-runs:9090/consoles/synapse.html

110
docs/postgres.rst Normal file
View File

@ -0,0 +1,110 @@
Using Postgres
--------------
Set up database
===============
The PostgreSQL database used *must* have the correct encoding set, otherwise
would not be able to store UTF8 strings. To create a database with the correct
encoding use, e.g.::
CREATE DATABASE synapse
ENCODING 'UTF8'
LC_COLLATE='C'
LC_CTYPE='C'
template=template0
OWNER synapse_user;
This would create an appropriate database named ``synapse`` owned by the
``synapse_user`` user (which must already exist).
Set up client
=============
Postgres support depends on the postgres python connector ``psycopg2``. In the
virtual env::
sudo apt-get install libpq-dev
pip install psycopg2
Synapse config
==============
When you are ready to start using PostgreSQL, add the following line to your
config file::
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
function, except keys beginning with ``cp_``, which are consumed by the twisted
adbapi connection pool.
Porting from SQLite
===================
Overview
~~~~~~~~
The script ``port_from_sqlite_to_postgres.py`` allows porting an existing
synapse server backed by SQLite to using PostgreSQL. This is done in as a two
phase process:
1. Copy the existing SQLite database to a separate location (while the server
is down) and running the port script against that offline database.
2. Shut down the server. Rerun the port script to port any data that has come
in since taking the first snapshot. Restart server against the PostgreSQL
database.
The port script is designed to be run repeatedly against newer snapshots of the
SQLite database file. This makes it safe to repeat step 1 if there was a delay
between taking the previous snapshot and being ready to do step 2.
It is safe to at any time kill the port script and restart it.
Using the port script
~~~~~~~~~~~~~~~~~~~~~
Firstly, shut down the currently running synapse server and copy its database
file (typically ``homeserver.db``) to another location. Once the copy is
complete, restart synapse. For instance::
./synctl stop
cp homeserver.db homeserver.db.snapshot
./synctl start
Assuming your database config file (as described in the section *Synapse
config*) is named ``database_config.yaml`` and the SQLite snapshot is at
``homeserver.db.snapshot`` then simply run::
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db.snapshot \
--postgres-config database_config.yaml
The flag ``--curses`` displays a coloured curses progress UI.
If the script took a long time to complete, or time has otherwise passed since
the original snapshot was taken, repeat the previous steps with a newer
snapshot.
To complete the conversion shut down the synapse server and run the port
script one last time, e.g. if the SQLite database is at ``homeserver.db``
run::
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db \
--postgres-config database_config.yaml
Once that has completed, change the synapse config to point at the PostgreSQL
database configuration file using the ``database_config`` parameter (see
`Synapse Config`_) and restart synapse. Synapse should now be running against
PostgreSQL.

View File

@ -0,0 +1,759 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
import argparse
import curses
import logging
import sys
import time
import traceback
import yaml
logger = logging.getLogger("port_from_sqlite_to_postgres")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
}
APPEND_ONLY_TABLES = [
"event_content_hashes",
"event_reference_hashes",
"event_signatures",
"event_edge_hashes",
"events",
"event_json",
"state_events",
"room_memberships",
"feedback",
"topics",
"room_names",
"rooms",
"local_media_repository",
"local_media_repository_thumbnails",
"remote_media_cache",
"remote_media_cache_thumbnails",
"redactions",
"event_edges",
"event_auth",
"received_transactions",
"sent_transactions",
"transaction_id_to_pdu",
"users",
"state_groups",
"state_groups_state",
"event_to_state_groups",
"rejections",
]
end_error_exec_info = None
class Store(object):
"""This object is used to pull out some of the convenience API from the
Storage layer.
*All* database interactions should go through this object.
"""
def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
_simple_insert = SQLBaseStore.__dict__["_simple_insert"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", desc, e)
raise
return self.db_pool.runWithConnection(r)
def execute(self, f, *args, **kwargs):
return self.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
", ".join("%s" for _ in headers)
)
try:
txn.executemany(sql, rows)
except:
logger.exception(
"Failed to insert: %s",
table,
)
raise
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@defer.inlineCallbacks
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcol="rowid",
allow_none=True,
)
total_to_port = None
if next_chunk is None:
if table == "sent_transactions":
next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
else:
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 1}
)
next_chunk = 1
already_ported = 0
if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
else:
def delete_all(txn):
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
(table,)
)
txn.execute("TRUNCATE %s CASCADE" % (table,))
yield self.postgres_store.execute(delete_all)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 0}
)
next_chunk = 1
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
defer.returnValue((table, already_ported, total_to_port, next_chunk))
@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk):
if not table_size:
return
self.progress.add_table(table, postgres_size, table_size)
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
)
while True:
def r(txn):
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
next_chunk = rows[-1][0] + 1
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update(table, postgres_size)
else:
return
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
k: v for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
db_conn.commit()
@defer.inlineCallbacks
def run(self):
try:
sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"],
**self.sqlite_config["args"]
)
postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"],
**self.postgres_config["args"]
)
sqlite_engine = create_engine("sqlite3")
postgres_engine = create_engine("psycopg2")
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
yield self.postgres_store.execute(
postgres_engine.check_database
)
# Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3")
self.setup_db(sqlite_config, sqlite_engine)
self.progress.set_state("Preparing PostgreSQL")
self.setup_db(postgres_config, postgres_engine)
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
)
postgres_tables = yield self.postgres_store._simple_select_onecol(
table="information_schema.tables",
keyvalues={
"table_schema": "public",
},
retcol="distinct table_name",
)
tables = set(sqlite_tables) & set(postgres_tables)
self.progress.set_state("Creating tables")
logger.info("Found %d tables", len(tables))
def create_port_table(txn):
txn.execute(
"CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL"
")"
)
try:
yield self.postgres_store.runInteraction(
"create_port_table", create_port_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
self.progress.set_state("Setting up")
# Set up tables.
setup_res = yield defer.gatherResults(
[
self.setup_table(table)
for table in tables
if table not in ["schema_version", "applied_schema_deltas"]
and not table.startswith("sqlite_")
],
consumeErrors=True,
)
# Process tables.
yield defer.gatherResults(
[
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
)
self.progress.done()
except:
global end_error_exec_info
end_error_exec_info = sys.exc_info()
logger.exception("")
finally:
reactor.stop()
def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [
i for i, h in enumerate(headers) if h in bool_col_names
]
def conv(j, col):
if j in bool_cols:
return bool(col)
return col
for i, row in enumerate(rows):
rows[i] = tuple(
self.postgres_store.database_engine.encode_parameter(
conv(j, col)
)
for j, col in enumerate(row)
if j > 0
)
@defer.inlineCallbacks
def _setup_sent_transactions(self):
# Only save things from the last day
yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination
select = (
"SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
"SELECT max(rowid) FROM sent_transactions"
" GROUP BY destination"
")"
)
def r(txn):
txn.execute(select)
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
ts_ind = headers.index('ts')
return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction(
"select", r,
)
self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows)
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
yield self.postgres_store.execute(insert)
def get_start_id(txn):
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
(yesterday,)
)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
return 1
next_chunk = yield self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk}
)
def get_sent_table_size(txn):
txn.execute(
"SELECT count(*) FROM sent_transactions"
" WHERE ts >= ?",
(yesterday,)
)
size, = txn.fetchone()
return int(size)
remaining_count = yield self.sqlite_store.execute(
get_sent_table_size
)
total_count = remaining_count + inserted_rows
defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk):
rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk,
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,),
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults(
[
self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table),
],
consumeErrors=True,
)
remaining = int(remaining) if remaining else 0
done = int(done) if done else 0
defer.returnValue((done, remaining + done))
##############################################
###### The following is simply UI stuff ######
##############################################
class Progress(object):
"""Used to report progress of the port
"""
def __init__(self):
self.tables = {}
self.start_time = int(time.time())
def add_table(self, table, cur, size):
self.tables[table] = {
"start": cur,
"num_done": cur,
"total": size,
"perc": int(cur * 100 / size),
}
def update(self, table, num_done):
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
def done(self):
pass
class CursesProgress(Progress):
"""Reports progress to a curses window
"""
def __init__(self, stdscr):
self.stdscr = stdscr
curses.use_default_colors()
curses.curs_set(0)
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.finished = False
self.total_processed = 0
self.total_remaining = 0
super(CursesProgress, self).__init__()
def update(self, table, num_done):
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
for table, data in self.tables.items():
self.total_processed += data["num_done"] - data["start"]
self.total_remaining += data["total"] - data["num_done"]
self.render()
def render(self, force=False):
now = time.time()
if not force and now - self.last_update < 0.2:
# reactor.callLater(1, self.render)
return
self.stdscr.clear()
rows, cols = self.stdscr.getmaxyx()
duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,)
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
else:
if self.total_processed > 0:
left = float(self.total_remaining) / self.total_processed
est_remaining = (int(now) - self.start_time) * left
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = (
"Time spent: %s (est. remaining: %s)"
% (duration_str, est_remaining_str,)
)
self.stdscr.addstr(
0, 0,
status,
curses.A_BOLD,
)
max_len = max([len(t) for t in self.tables.keys()])
left_margin = 5
middle_space = 1
items = self.tables.items()
items.sort(
key=lambda i: (i[1]["perc"], i[0]),
)
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
break
perc = data["perc"]
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
i+2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
)
size = 20
progress = "[%s%s]" % (
"#" * int(perc*size/100),
" " * (size - int(perc*size/100)),
)
self.stdscr.addstr(
i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
rows-1, 0,
"Press any key to exit...",
)
self.stdscr.refresh()
self.last_update = time.time()
def done(self):
self.finished = True
self.render(True)
self.stdscr.getch()
def set_state(self, state):
self.stdscr.clear()
self.stdscr.addstr(
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.refresh()
class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""
def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
print "%s: %d%% (%d/%d)" % (
table, data["perc"],
data["num_done"], data["total"],
)
def set_state(self, state):
print state + "..."
##############################################
##############################################
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
)
parser.add_argument("-v", action='store_true')
parser.add_argument(
"--sqlite-database", required=True,
help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server"
)
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True,
help="The database config file for the PostgreSQL database"
)
parser.add_argument(
"--curses", action='store_true',
help="display a curses based progress UI"
)
parser.add_argument(
"--batch-size", type=int, default=1000,
help="The number of rows to select from the SQLite table each"
" iteration [default=1000]",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
sqlite_config = {
"name": "sqlite3",
"args": {
"database": args.sqlite_database,
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
},
}
postgres_config = yaml.safe_load(args.postgres_config)
if "name" not in postgres_config:
sys.stderr.write("Malformed database config: no 'name'")
sys.exit(2)
if postgres_config["name"] != "psycopg2":
sys.stderr.write("Database must use 'psycopg2' connector.")
sys.exit(3)
def start(stdscr=None):
if stdscr:
progress = CursesProgress(stdscr)
else:
progress = TerminalProgress()
porter = Porter(
sqlite_config=sqlite_config,
postgres_config=postgres_config,
progress=progress,
batch_size=args.batch_size,
)
reactor.callWhenRunning(porter.run)
reactor.run()
if args.curses:
curses.wrapper(start)
else:
start()
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)

View File

@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest() ).hexdigest()
data = { data = {
"user": user, "username": user,
"password": password, "password": password,
"mac": mac, "mac": mac,
"type": "org.matrix.login.shared_secret",
} }
server_location = server_location.rstrip("/") server_location = server_location.rstrip("/")
@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..." print "Sending registration request..."
req = urllib2.Request( req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,), "%s/_matrix/client/v2_alpha/register" % (server_location,),
data=json.dumps(data), data=json.dumps(data),
headers={'Content-Type': 'application/json'} headers={'Content-Type': 'application/json'}
) )

2
scripts/upgrade_db_to_v0.6.0.py Normal file → Executable file
View File

@ -1,4 +1,4 @@
#!/usr/bin/env python
from synapse.storage import SCHEMA_VERSION, read_schema from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore from synapse.storage.signatures import SignatureStore

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import glob
import os import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
@ -55,5 +56,5 @@ setup(
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
scripts=["synctl", "register_new_matrix_user"], scripts=["synctl"] + glob.glob("scripts/*"),
) )

View File

@ -37,9 +37,13 @@ textarea, input {
margin: auto margin: auto
} }
.g-recaptcha div {
margin: auto;
}
#registrationForm { #registrationForm {
text-align: left; text-align: left;
padding: 1em; padding: 5px;
margin-bottom: 40px; margin-bottom: 40px;
display: inline-block; display: inline-block;

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.8.1-r4" __version__ = "0.9.0"

View File

@ -18,9 +18,8 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo from synapse.types import UserID, ClientInfo
import logging import logging
@ -40,6 +39,7 @@ class Auth(object):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
@ -64,7 +64,10 @@ class Auth(object):
if event.type == EventTypes.Aliases: if event.type == EventTypes.Aliases:
return True return True
logger.debug("Auth events: %s", auth_events) logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed( allowed = self.is_membership_change_allowed(
@ -183,18 +186,10 @@ class Auth(object):
else: else:
join_rule = JoinRules.INVITE join_rule = JoinRules.INVITE
user_level = self._get_power_level_from_event_state( user_level = self._get_user_power_level(event.user_id, auth_events)
event,
event.user_id,
auth_events,
)
ban_level, kick_level, redact_level = ( # FIXME (erikj): What should we do here as the default?
self._get_ops_level_from_event_state( ban_level = self._get_named_level(auth_events, "ban", 50)
event,
auth_events,
)
)
logger.debug( logger.debug(
"is_membership_change_allowed: %s", "is_membership_change_allowed: %s",
@ -210,28 +205,33 @@ class Auth(object):
} }
) )
if ban_level: if Membership.JOIN != membership:
ban_level = int(ban_level) # JOIN is the only action you can perform if you're not in the room
else: if not caller_in_room: # caller isn't joined
ban_level = 50 # FIXME (erikj): What should we do here? raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership: if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently # TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules. # PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't. # Invites are valid iff caller is in the room and target isn't.
if not caller_in_room: # caller isn't joined if target_banned:
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
elif target_banned:
raise AuthError( raise AuthError(
403, "%s is banned from the room" % (target_user_id,) 403, "%s is banned from the room" % (target_user_id,)
) )
elif target_in_room: # the target is already in the room. elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % raise AuthError(403, "%s is already in the room." %
target_user_id) target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership: elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were: # Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation # invited: They are accepting the invitation
@ -251,21 +251,12 @@ class Auth(object):
raise AuthError(403, "You are not allowed to join this room") raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership: elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks. # TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
elif target_banned and user_level < ban_level:
raise AuthError( raise AuthError(
403, "You cannot unban user &s." % (target_user_id,) 403, "You cannot unban user &s." % (target_user_id,)
) )
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
if kick_level: kick_level = self._get_named_level(auth_events, "kick", 50)
kick_level = int(kick_level)
else:
kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level: if user_level < kick_level:
raise AuthError( raise AuthError(
@ -279,34 +270,42 @@ class Auth(object):
return True return True
def _get_power_level_from_event_state(self, event, user_id, auth_events): def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key) return auth_events.get(key)
level = None
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
if power_level_event: if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id) level = power_level_event.content.get("users", {}).get(user_id)
if not level: if not level:
level = power_level_event.content.get("users_default", 0) level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else: else:
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = auth_events.get(key) create_event = auth_events.get(key)
if (create_event is not None and if (create_event is not None and
create_event.content["creator"] == user_id): create_event.content["creator"] == user_id):
return 100 return 100
else:
return 0
return level def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
def _get_ops_level_from_event_state(self, event, auth_events): if not power_level_event:
key = (EventTypes.PowerLevels, "", ) return default
power_level_event = auth_events.get(key)
if power_level_event: level = power_level_event.content.get(name, None)
return ( if level is not None:
power_level_event.content.get("ban", 50), return int(level)
power_level_event.content.get("kick", 50), else:
power_level_event.content.get("redact", 50), return default
)
return None, None, None,
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_req(self, request): def get_user_by_req(self, request):
@ -363,7 +362,7 @@ class Auth(object):
default=[""] default=[""]
)[0] )[0]
if user and access_token and ip_addr: if user and access_token and ip_addr:
yield self.store.insert_client_ip( self.store.insert_client_ip(
user=user, user=user,
access_token=access_token, access_token=access_token,
device_id=user_info["device_id"], device_id=user_info["device_id"],
@ -373,7 +372,10 @@ class Auth(object):
defer.returnValue((user, ClientInfo(device_id, token_id))) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_user_by_token(self, token): def get_user_by_token(self, token):
@ -387,21 +389,20 @@ class Auth(object):
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
try: ret = yield self.store.get_user_by_token(token)
ret = yield self.store.get_user_by_token(token) if not ret:
if not ret: raise AuthError(
raise StoreError(400, "Unknown token") self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
user_info = { errcode=Codes.UNKNOWN_TOKEN
"admin": bool(ret.get("admin", False)), )
"device_id": ret.get("device_id"), user_info = {
"user": UserID.from_string(ret.get("name")), "admin": bool(ret.get("admin", False)),
"token_id": ret.get("token_id", None), "device_id": ret.get("device_id"),
} "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
defer.returnValue(user_info) defer.returnValue(user_info)
except StoreError:
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
@ -409,19 +410,22 @@ class Auth(object):
token = request.args["access_token"][0] token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token) service = yield self.store.get_app_service_by_token(token)
if not service: if not service:
raise AuthError(403, "Unrecognised access token.", raise AuthError(
errcode=Codes.UNKNOWN_TOKEN) self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
)
def is_server_admin(self, user): def is_server_admin(self, user):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context.current_state) auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes( auth_events_entries = yield self.store.add_event_hashes(
@ -486,7 +490,7 @@ class Auth(object):
send_level = send_level_event.content.get("events", {}).get( send_level = send_level_event.content.get("events", {}).get(
event.type event.type
) )
if not send_level: if send_level is None:
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
send_level = send_level_event.content.get( send_level = send_level_event.content.get(
"state_default", 50 "state_default", 50
@ -501,16 +505,7 @@ class Auth(object):
else: else:
send_level = 0 send_level = 0
user_level = self._get_power_level_from_event_state( user_level = self._get_user_power_level(event.user_id, auth_events)
event,
event.user_id,
auth_events,
)
if user_level:
user_level = int(user_level)
else:
user_level = 0
if user_level < send_level: if user_level < send_level:
raise AuthError( raise AuthError(
@ -542,16 +537,9 @@ class Auth(object):
return True return True
def _check_redaction(self, event, auth_events): def _check_redaction(self, event, auth_events):
user_level = self._get_power_level_from_event_state( user_level = self._get_user_power_level(event.user_id, auth_events)
event,
event.user_id,
auth_events,
)
_, _, redact_level = self._get_ops_level_from_event_state( redact_level = self._get_named_level(auth_events, "redact", 50)
event,
auth_events,
)
if user_level < redact_level: if user_level < redact_level:
raise AuthError( raise AuthError(
@ -579,11 +567,7 @@ class Auth(object):
if not current_state: if not current_state:
return return
user_level = self._get_power_level_from_event_state( user_level = self._get_user_power_level(event.user_id, auth_events)
event,
event.user_id,
auth_events,
)
# Check other levels: # Check other levels:
levels_to_check = [ levels_to_check = [
@ -592,6 +576,7 @@ class Auth(object):
("ban", []), ("ban", []),
("redact", []), ("redact", []),
("kick", []), ("kick", []),
("invite", []),
] ]
old_list = current_state.content.get("users") old_list = current_state.content.get("users")

View File

@ -59,6 +59,9 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
# Only for C/S API v1
APPLICATION_SERVICE = u"m.login.application_service" APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret" SHARED_SECRET = u"org.matrix.login.shared_secret"

View File

@ -31,13 +31,15 @@ class Codes(object):
BAD_PAGINATION = "M_BAD_PAGINATION" BAD_PAGINATION = "M_BAD_PAGINATION"
UNKNOWN = "M_UNKNOWN" UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND" NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN" UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM", MISSING_PARAM = "M_MISSING_PARAM"
TOO_LARGE = "M_TOO_LARGE", TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -22,5 +22,6 @@ STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1" APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -16,14 +16,18 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import check_requirements
if __name__ == '__main__':
check_requirements()
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import ( from synapse.storage import (
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException, are_all_users_on_domain, UpgradeDatabaseException,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.python_dependencies import check_requirements
from twisted.internet import reactor from twisted.internet import reactor
from twisted.application import service from twisted.application import service
@ -31,16 +35,17 @@ from twisted.enterprise import adbapi
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX,
STATIC_PREFIX SERVER_KEY_V2_PREFIX,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
@ -59,9 +64,9 @@ import os
import re import re
import resource import resource
import subprocess import subprocess
import sqlite3
logger = logging.getLogger(__name__)
logger = logging.getLogger("synapse.app.homeserver")
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
@ -78,9 +83,6 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
def build_resource_for_app_services(self):
return AppServiceRestResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
import syweb import syweb
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
@ -101,6 +103,9 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_server_key(self): def build_resource_for_server_key(self):
return LocalKey(self) return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self): def build_resource_for_metrics(self):
if self.get_config().enable_metrics: if self.get_config().enable_metrics:
return MetricsResource(self) return MetricsResource(self)
@ -108,13 +113,11 @@ class SynapseHomeServer(HomeServer):
return None return None
def build_db_pool(self): def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool( return adbapi.ConnectionPool(
"sqlite3", self.get_db_name(), name,
check_same_thread=False, **self.db_config.get("args", {})
cp_min=1,
cp_max=1,
cp_openfun=prepare_database, # Prepare the database for each conn
# so that :memory: sqlite works
) )
def create_resource_tree(self, redirect_root_to_web_client): def create_resource_tree(self, redirect_root_to_web_client):
@ -140,8 +143,8 @@ class SynapseHomeServer(HomeServer):
(FEDERATION_PREFIX, self.get_resource_for_federation()), (FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()), (STATIC_PREFIX, self.get_resource_for_static_content()),
] ]
@ -226,7 +229,11 @@ class SynapseHomeServer(HomeServer):
if not config.no_tls and config.bind_port is not None: if not config.no_tls and config.bind_port is not None:
reactor.listenSSL( reactor.listenSSL(
config.bind_port, config.bind_port,
Site(self.root_resource), SynapseSite(
"synapse.access.https",
config,
self.root_resource,
),
self.tls_context_factory, self.tls_context_factory,
interface=config.bind_host interface=config.bind_host
) )
@ -235,7 +242,11 @@ class SynapseHomeServer(HomeServer):
if config.unsecure_port is not None: if config.unsecure_port is not None:
reactor.listenTCP( reactor.listenTCP(
config.unsecure_port, config.unsecure_port,
Site(self.root_resource), SynapseSite(
"synapse.access.http",
config,
self.root_resource,
),
interface=config.bind_host interface=config.bind_host
) )
logger.info("Synapse now listening on port %d", config.unsecure_port) logger.info("Synapse now listening on port %d", config.unsecure_port)
@ -243,10 +254,43 @@ class SynapseHomeServer(HomeServer):
metrics_resource = self.get_resource_for_metrics() metrics_resource = self.get_resource_for_metrics()
if metrics_resource and config.metrics_port is not None: if metrics_resource and config.metrics_port is not None:
reactor.listenTCP( reactor.listenTCP(
config.metrics_port, Site(metrics_resource), interface="127.0.0.1", config.metrics_port,
SynapseSite(
"synapse.access.metrics",
config,
metrics_resource,
),
interface="127.0.0.1",
) )
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port) logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
db_conn.cursor(), database_engine, self.hostname
)
if not all_users_native:
quit_with_error(
"Found users in database not native to %s!\n"
"You cannot changed a synapse server_name after it's been configured"
% (self.hostname,)
)
try:
database_engine.check_database(db_conn.cursor())
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def quit_with_error(error_string):
message_lines = error_string.split("\n")
line_length = max([len(l) for l in message_lines]) + 2
sys.stderr.write("*" * line_length + '\n')
for line in message_lines:
if line.strip():
sys.stderr.write(" %s\n" % (line.strip(),))
sys.stderr.write("*" * line_length + '\n')
sys.exit(1)
def get_version_string(): def get_version_string():
try: try:
@ -358,29 +402,39 @@ def setup(config_options):
tls_context_factory = context_factory.ServerContextFactory(config) tls_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer( hs = SynapseHomeServer(
config.server_name, config.server_name,
domain_with_port=domain_with_port, domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"), upload_dir=os.path.abspath("uploads"),
db_name=config.database_path, db_config=config.database_config,
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string, version_string=version_string,
database_engine=database_engine,
) )
hs.create_resource_tree( hs.create_resource_tree(
redirect_root_to_web_client=True, redirect_root_to_web_client=True,
) )
db_name = hs.get_db_name() logger.info("Preparing database: %r...", config.database_config)
logger.info("Preparing database: %s...", db_name)
try: try:
with sqlite3.connect(db_name) as db_conn: db_conn = database_engine.module.connect(
prepare_sqlite3_database(db_conn) **{
prepare_database(db_conn) k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
"\nFailed to upgrade database.\n" "\nFailed to upgrade database.\n"
@ -389,7 +443,7 @@ def setup(config_options):
) )
sys.exit(1) sys.exit(1)
logger.info("Database prepared in %s.", db_name) logger.info("Database prepared in %r.", config.database_config)
if config.manhole: if config.manhole:
f = twisted.manhole.telnet.ShellFactory() f = twisted.manhole.telnet.ShellFactory()
@ -423,6 +477,24 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
if config.captcha_ip_origin_is_x_forwarded:
self._log_formatter = proxiedLogFormatter
else:
self._log_formatter = combinedLogFormatter
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
line = self._log_formatter(self._logDateTime, request)
self.access_logger.info(line)
def run(hs): def run(hs):
def in_thread(): def in_thread():

View File

@ -18,15 +18,18 @@ import sys
import os import os
import subprocess import subprocess
import signal import signal
import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml" CONFIGFILE = "homeserver.yaml"
PIDFILE = "homeserver.pid"
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m" NORMAL = "\x1b[m"
CONFIG = yaml.load(open(CONFIGFILE))
PIDFILE = CONFIG["pid_file"]
def start(): def start():
if not os.path.exists(CONFIGFILE): if not os.path.exists(CONFIGFILE):
@ -40,7 +43,7 @@ def start():
sys.exit(1) sys.exit(1)
print "Starting ...", print "Starting ...",
args = SYNAPSE args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE, "--pid-file", PIDFILE]) args.extend(["--daemonize", "-c", CONFIGFILE])
subprocess.check_call(args) subprocess.check_call(args)
print GREEN + "started" + NORMAL print GREEN + "started" + NORMAL

View File

@ -20,6 +20,50 @@ import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ApplicationServiceState(object):
DOWN = "down"
UP = "up"
class AppServiceTransaction(object):
"""Represents an application service transaction."""
def __init__(self, service, id, events):
self.service = service
self.id = id
self.events = events
def send(self, as_api):
"""Sends this transaction using the provided AS API interface.
Args:
as_api(ApplicationServiceApi): The API to use to send.
Returns:
A Deferred which resolves to True if the transaction was sent.
"""
return as_api.push_bulk(
service=self.service,
events=self.events,
txn_id=self.id
)
def complete(self, store):
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
return store.complete_appservice_txn(
service=self.service,
txn_id=self.id
)
class ApplicationService(object): class ApplicationService(object):
"""Defines an application service. This definition is mostly what is """Defines an application service. This definition is mostly what is
provided to the /register AS API. provided to the /register AS API.
@ -35,13 +79,13 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None, def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, txn_id=None): sender=None, id=None):
self.token = token self.token = token
self.url = url self.url = url
self.hs_token = hs_token self.hs_token = hs_token
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.txn_id = txn_id self.id = id
def _check_namespaces(self, namespaces): def _check_namespaces(self, namespaces):
# Sanity check that it is of the form: # Sanity check that it is of the form:
@ -51,7 +95,7 @@ class ApplicationService(object):
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...], # rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# } # }
if not namespaces: if not namespaces:
return None namespaces = {}
for ns in ApplicationService.NS_LIST: for ns in ApplicationService.NS_LIST:
if ns not in namespaces: if ns not in namespaces:
@ -155,7 +199,10 @@ class ApplicationService(object):
return self._matches_user(event, member_list) return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id): def is_interested_in_user(self, user_id):
return self._matches_regex(user_id, ApplicationService.NS_USERS) return (
self._matches_regex(user_id, ApplicationService.NS_USERS)
or user_id == self.sender
)
def is_interested_in_alias(self, alias): def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES) return self._matches_regex(alias, ApplicationService.NS_ALIASES)
@ -164,7 +211,10 @@ class ApplicationService(object):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS) return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id): def is_exclusive_user(self, user_id):
return self._is_exclusive(ApplicationService.NS_USERS, user_id) return (
self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender
)
def is_exclusive_alias(self, alias): def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias) return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@ -72,14 +72,19 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)
if txn_id is None:
logger.warning("push_bulk: Missing txn ID sending events to %s",
service.url)
txn_id = str(0)
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" % uri = service.url + ("/transactions/%s" %
urllib.quote(str(0))) # TODO txn_ids urllib.quote(txn_id))
response = None
try: try:
response = yield self.put_json( yield self.put_json(
uri=uri, uri=uri,
json_body={ json_body={
"events": events "events": events
@ -87,9 +92,8 @@ class ApplicationServiceApi(SimpleHttpClient):
args={ args={
"access_token": service.hs_token "access_token": service.hs_token
}) })
if response: # just an empty json object defer.returnValue(True)
# TODO: Mark txn as sent successfully return
defer.returnValue(True)
except CodeMessageException as e: except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code) logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex: except Exception as ex:
@ -97,8 +101,8 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False) defer.returnValue(False)
@defer.inlineCallbacks @defer.inlineCallbacks
def push(self, service, event): def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event]) response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response) defer.returnValue(response)
def _serialize(self, events): def _serialize(self, events):

View File

@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module controls the reliability for application service transactions.
The nominal flow through this module looks like:
__________
1---ASa[e]-->| Service |--> Queue ASa[f]
2----ASb[e]->| Queuer |
3--ASa[f]--->|__________|-----------+ ASa[e], ASb[e]
V
-````````- +------------+
|````````|<--StoreTxn-|Transaction |
|Database| | Controller |---> SEND TO AS
`--------` +------------+
What happens on SEND TO AS depends on the state of the Application Service:
- If the AS is marked as DOWN, do nothing.
- If the AS is marked as UP, send the transaction.
* SUCCESS : Increment where the AS is up to txn-wise and nuke the txn
contents from the db.
* FAILURE : Marked AS as DOWN and start Recoverer.
Recoverer attempts to recover ASes who have died. The flow for this looks like:
,--------------------- backoff++ --------------.
V |
START ---> Wait exp ------> Get oldest txn ID from ----> FAILURE
backoff DB and try to send it
^ |___________
Mark AS as | V
UP & quit +---------- YES SUCCESS
| | |
NO <--- Have more txns? <------ Mark txn success & nuke <-+
from db; incr AS pos.
Reset backoff.
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
from synapse.appservice import ApplicationServiceState
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class AppServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
"""
def __init__(self, clock, store, as_api):
self.clock = clock
self.store = store
self.as_api = as_api
def create_recoverer(service, callback):
return _Recoverer(clock, store, as_api, service, callback)
self.txn_ctrl = _TransactionController(
clock, store, as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks
def start(self):
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
recoverers = yield _Recoverer.start(
self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered
)
self.txn_ctrl.add_recoverers(recoverers)
def submit_event_for_as(self, service, event):
self.queuer.enqueue(service, event)
class _ServiceQueuer(object):
"""Queues events for the same application service together, sending
transactions as soon as possible. Once a transaction is sent successfully,
this schedules any other events in the queue to run.
"""
def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl
def enqueue(self, service, event):
# if this service isn't being sent something
if not self.pending_requests.get(service.id):
self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
def _send_request(self, service, events):
# send request and add callbacks
d = self.txn_ctrl.send(service, events)
d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
def _on_request_finish(self, service):
self.pending_requests[service.id] = None
# if there are queued events, then send them.
if (service.id in self.queued_events
and len(self.queued_events[service.id]) > 0):
self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
def _on_request_fail(self, err):
logger.error("AS request failed: %s", err)
class _TransactionController(object):
def __init__(self, clock, store, as_api, recoverer_fn):
self.clock = clock
self.store = store
self.as_api = as_api
self.recoverer_fn = recoverer_fn
# keep track of how many recoverers there are
self.recoverers = []
@defer.inlineCallbacks
def send(self, service, events):
try:
txn = yield self.store.create_appservice_txn(
service=service,
events=events
)
service_is_up = yield self._is_service_up(service)
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
txn.complete(self.store)
else:
self._start_recoverer(service)
except Exception as e:
logger.exception(e)
self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
self.recoverers.remove(recoverer)
logger.info("Successfully recovered application service AS ID %s",
recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state(
recoverer.service,
ApplicationServiceState.UP
)
def add_recoverers(self, recoverers):
for r in recoverers:
self.recoverers.append(r)
if len(recoverers) > 0:
logger.info("New active recoverers: %s", len(self.recoverers))
@defer.inlineCallbacks
def _start_recoverer(self, service):
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
logger.info(
"Application service falling behind. Starting recoverer. AS ID %s",
service.id
)
recoverer = self.recoverer_fn(service, self.on_recovered)
self.add_recoverers([recoverer])
recoverer.recover()
@defer.inlineCallbacks
def _is_service_up(self, service):
state = yield self.store.get_appservice_state(service)
defer.returnValue(state == ApplicationServiceState.UP or state is None)
class _Recoverer(object):
@staticmethod
@defer.inlineCallbacks
def start(clock, store, as_api, callback):
services = yield store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
recoverers = [
_Recoverer(clock, store, as_api, s, callback) for s in services
]
for r in recoverers:
logger.info("Starting recoverer for AS ID %s which was marked as "
"DOWN", r.service.id)
r.recover()
defer.returnValue(recoverers)
def __init__(self, clock, store, as_api, service, callback):
self.clock = clock
self.store = store
self.as_api = as_api
self.service = service
self.callback = callback
self.backoff_counter = 1
def recover(self):
self.clock.call_later((2 ** self.backoff_counter), self.retry)
def _backoff(self):
# cap the backoff to be around 18h => (2^16) = 65536 secs
if self.backoff_counter < 16:
self.backoff_counter += 1
self.recover()
@defer.inlineCallbacks
def retry(self):
try:
txn = yield self.store.get_oldest_unsent_txn(self.service)
if txn:
logger.info("Retrying transaction %s for AS ID %s",
txn.id, txn.service.id)
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)
# reset the backoff counter and retry immediately
self.backoff_counter = 1
yield self.retry()
else:
self._backoff()
else:
self._set_service_recovered()
except Exception as e:
logger.exception(e)
self._backoff()
def _set_service_recovered(self):
self.callback(self)

View File

@ -14,9 +14,10 @@
# limitations under the License. # limitations under the License.
import argparse import argparse
import sys
import os import os
import yaml import yaml
import sys
from textwrap import dedent
class ConfigError(Exception): class ConfigError(Exception):
@ -24,18 +25,35 @@ class ConfigError(Exception):
class Config(object): class Config(object):
def __init__(self, args):
pass
@staticmethod @staticmethod
def parse_size(string): def parse_size(value):
if isinstance(value, int) or isinstance(value, long):
return value
sizes = {"K": 1024, "M": 1024 * 1024} sizes = {"K": 1024, "M": 1024 * 1024}
size = 1 size = 1
suffix = string[-1] suffix = value[-1]
if suffix in sizes: if suffix in sizes:
string = string[:-1] value = value[:-1]
size = sizes[suffix] size = sizes[suffix]
return int(string) * size return int(value) * size
@staticmethod
def parse_duration(value):
if isinstance(value, int) or isinstance(value, long):
return value
second = 1000
hour = 60 * 60 * second
day = 24 * hour
week = 7 * day
year = 365 * day
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
size = 1
suffix = value[-1]
if suffix in sizes:
value = value[:-1]
size = sizes[suffix]
return int(value) * size
@staticmethod @staticmethod
def abspath(file_path): def abspath(file_path):
@ -86,83 +104,130 @@ class Config(object):
with open(file_path) as file_stream: with open(file_path) as file_stream:
return yaml.load(file_stream) return yaml.load(file_stream)
@classmethod def invoke_all(self, name, *args, **kargs):
def add_arguments(cls, parser): results = []
pass for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
@classmethod def generate_config(self, config_dir_path, server_name):
def generate_config(cls, args, config_dir_path): default_config = "# vim:ft=yaml\n"
pass
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name
))
config = yaml.load(default_config)
return default_config, config
@classmethod @classmethod
def load_config(cls, description, argv, generate_section=None): def load_config(cls, description, argv, generate_section=None):
obj = cls()
config_parser = argparse.ArgumentParser(add_help=False) config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument( config_parser.add_argument(
"-c", "--config-path", "-c", "--config-path",
action="append",
metavar="CONFIG_FILE", metavar="CONFIG_FILE",
help="Specify config file" help="Specify config file"
) )
config_parser.add_argument( config_parser.add_argument(
"--generate-config", "--generate-config",
action="store_true", action="store_true",
help="Generate config file" help="Generate a config file for the server name"
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
if config_args.generate_config: if config_args.generate_config:
if not config_args.config_path: if not config_args.config_path:
config_parser.error( config_parser.error(
"Must specify where to generate the config file" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
) )
config_dir_path = os.path.dirname(config_args.config_path)
if os.path.exists(config_args.config_path):
defaults = cls.read_config_file(config_args.config_path)
else:
defaults = {}
else:
if config_args.config_path:
defaults = cls.read_config_file(config_args.config_path)
else:
defaults = {}
parser = argparse.ArgumentParser( config_dir_path = os.path.dirname(config_args.config_path[0])
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
cls.add_arguments(parser)
parser.set_defaults(**defaults)
args = parser.parse_args(remaining_args)
if config_args.generate_config:
config_dir_path = os.path.dirname(config_args.config_path)
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Most specify a server_name to a generate config for."
sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path): if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path) os.makedirs(config_dir_path)
cls.generate_config(args, config_dir_path) if os.path.exists(config_path):
config = {} print "Config file %r already exists" % (config_path,)
for key, value in vars(args).items(): yaml_config = cls.read_config_file(config_path)
if (key not in set(["config_path", "generate_config"]) yaml_name = yaml_config["server_name"]
and value is not None): if server_name != yaml_name:
config[key] = value print (
with open(config_args.config_path, "w") as config_file: "Config file %r has a different server_name: "
# TODO(paul) it would be lovely if we wrote out vim- and emacs- " %r != %r" % (config_path, server_name, yaml_name)
# style mode markers into the file, to hint to people that )
# this is a YAML file. sys.exit(1)
yaml.dump(config, config_file, default_flow_style=False) config_bytes, config = obj.generate_config(
print ( config_dir_path, server_name
"A config file has been generated in %s for server name" )
" '%s' with corresponding SSL keys and self-signed" config.update(yaml_config)
" certificates. Please review this file and customise it to" print "Generating any missing keys for %r" % (server_name,)
" your needs." obj.invoke_all("generate_files", config)
) % ( sys.exit(0)
config_args.config_path, config['server_name'] with open(config_path, "wb") as config_file:
) config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
) % (config_path, server_name)
print ( print (
"If this server name is incorrect, you will need to regenerate" "If this server name is incorrect, you will need to regenerate"
" the SSL certificates" " the SSL certificates"
) )
sys.exit(0) sys.exit(0)
return cls(args) parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -h SERVER_NAME"
" -c CONFIG-FILE\""
)
config_dir_path = os.path.dirname(config_args.config_path[0])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
for config_path in config_args.config_path:
yaml_config = cls.read_config_file(config_path)
specified_config.update(yaml_config)
server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name)
config.pop("log_config")
config.update(specified_config)
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
return obj

View File

@ -0,0 +1,27 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
def default_config(cls, config_dir_path, server_name):
return """\
# A list of application service config file to use
app_service_config_files: []
"""

View File

@ -17,35 +17,35 @@ from ._base import Config
class CaptchaConfig(Config): class CaptchaConfig(Config):
def __init__(self, args): def read_config(self, config):
super(CaptchaConfig, self).__init__(args) self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_private_key = args.recaptcha_private_key self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = args.enable_registration_captcha self.enable_registration_captcha = config["enable_registration_captcha"]
# XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = ( self.captcha_ip_origin_is_x_forwarded = (
args.captcha_ip_origin_is_x_forwarded config["captcha_ip_origin_is_x_forwarded"]
) )
self.captcha_bypass_secret = args.captcha_bypass_secret self.captcha_bypass_secret = config.get("captcha_bypass_secret")
@classmethod def default_config(self, config_dir_path, server_name):
def add_arguments(cls, parser): return """\
super(CaptchaConfig, cls).add_arguments(parser) ## Captcha ##
group = parser.add_argument_group("recaptcha")
group.add_argument( # This Home Server's ReCAPTCHA public key.
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY", recaptcha_private_key: "YOUR_PUBLIC_KEY"
help="The matching private key for the web client's public key."
) # This Home Server's ReCAPTCHA private key.
group.add_argument( recaptcha_public_key: "YOUR_PRIVATE_KEY"
"--enable-registration-captcha", type=bool, default=False,
help="Enables ReCaptcha checks when registering, preventing signup" # Enables ReCaptcha checks when registering, preventing signup
+ " unless a captcha is answered. Requires a valid ReCaptcha " # unless a captcha is answered. Requires a valid ReCaptcha
+ "public/private key." # public/private key.
) enable_registration_captcha: False
group.add_argument(
"--captcha_ip_origin_is_x_forwarded", type=bool, default=False, # When checking captchas, use the X-Forwarded-For (XFF) header
help="When checking captchas, use the X-Forwarded-For (XFF) header" # as the client IP and not the actual client IP.
+ " as the client IP and not the actual client IP." captcha_ip_origin_is_x_forwarded: False
)
group.add_argument( # A secret key used to bypass the captcha test entirely.
"--captcha_bypass_secret", type=str, #captcha_bypass_secret: "YOUR_SECRET_HERE"
help="A secret key used to bypass the captcha test entirely." """
)

View File

@ -14,32 +14,66 @@
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config
import os
class DatabaseConfig(Config): class DatabaseConfig(Config):
def __init__(self, args):
super(DatabaseConfig, self).__init__(args)
if args.database_path == ":memory:":
self.database_path = ":memory:"
else:
self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size)
@classmethod def read_config(self, config):
def add_arguments(cls, parser): self.event_cache_size = self.parse_size(
super(DatabaseConfig, cls).add_arguments(parser) config.get("event_cache_size", "10K")
)
self.database_config = config.get("database")
if self.database_config is None:
self.database_config = {
"name": "sqlite3",
"args": {},
}
name = self.database_config.get("name", None)
if name == "psycopg2":
pass
elif name == "sqlite3":
self.database_config.setdefault("args", {}).update({
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
})
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path):
database_path = self.abspath("homeserver.db")
return """\
# Database configuration
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
# Number of events to cache in memory.
event_cache_size: "10K"
""" % locals()
def read_arguments(self, args):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
if database_path != ":memory:":
database_path = self.abspath(database_path)
if self.database_config.get("name", None) == "sqlite3":
if database_path is not None:
self.database_config["args"]["database"] = database_path
def add_arguments(self, parser):
db_group = parser.add_argument_group("database") db_group = parser.add_argument_group("database")
db_group.add_argument( db_group.add_argument(
"-d", "--database-path", default="homeserver.db", "-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
help="The database name." help="The path to a sqlite database to use."
) )
db_group.add_argument(
"--event-cache-size", default="100K",
help="Number of events to cache in memory."
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(DatabaseConfig, cls).generate_config(args, config_dir_path)
args.database_path = os.path.abspath(args.database_path)

View File

@ -1,42 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class EmailConfig(Config):
def __init__(self, args):
super(EmailConfig, self).__init__(args)
self.email_from_address = args.email_from_address
self.email_smtp_server = args.email_smtp_server
@classmethod
def add_arguments(cls, parser):
super(EmailConfig, cls).add_arguments(parser)
email_group = parser.add_argument_group("email")
email_group.add_argument(
"--email-from-address",
default="FROM@EXAMPLE.COM",
help="The address to send emails from (e.g. for password resets)."
)
email_group.add_argument(
"--email-smtp-server",
default="",
help=(
"The SMTP server to send emails from (e.g. for password"
" resets)."
)
)

View File

@ -20,19 +20,22 @@ from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig from .ratelimiting import RatelimitConfig
from .repository import ContentRepositoryConfig from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .appservice import AppServiceConfig
from .key import KeyConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig, RegistrationConfig, VoipConfig, RegistrationConfig,
MetricsConfig,): MetricsConfig, AppServiceConfig, KeyConfig,):
pass pass
if __name__ == '__main__': if __name__ == '__main__':
import sys import sys
HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer") sys.stdout.write(
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
)

133
synapse/config/key.py Normal file
View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from ._base import Config, ConfigError
import syutil.crypto.signing_key
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from syutil.base64util import decode_base64
from synapse.util.stringutils import random_string
class KeyConfig(Config):
def read_config(self, config):
self.signing_key = self.read_signing_key(config["signing_key_path"])
self.old_signing_keys = self.read_old_signing_keys(
config["old_signing_keys"]
)
self.key_refresh_interval = self.parse_duration(
config["key_refresh_interval"]
)
self.perspectives = self.read_perspectives(
config["perspectives"]
)
def default_config(self, config_dir_path, server_name):
base_key_name = os.path.join(config_dir_path, server_name)
return """\
## Signing Keys ##
# Path to the signing key to sign messages with
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
# to sign new messages. E.g. it has lost its private key
old_signing_keys: {}
# "ed25519:auto":
# # Base64 encoded public key
# key: "The public part of your old signing key."
# # Millisecond POSIX timestamp when the key expired.
# expired_ts: 123456789123
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
# Determines how quickly servers will query to check which keys
# are still valid.
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
""" % locals()
def read_perspectives(self, perspectives_config):
servers = {}
for server_name, server_config in perspectives_config["servers"].items():
for key_id, key_data in server_config["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
servers.setdefault(server_name, {})[key_id] = verify_key
return servers
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
def read_old_signing_keys(self, old_signing_keys):
keys = {}
for key_id, key_data in old_signing_keys.items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired_ts = key_data["expired_ts"]
keys[key_id] = verify_key
else:
raise ConfigError(
"Unsupported signing algorithm for old key: %r" % (key_id,)
)
return keys
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
if not os.path.exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_signing_key(key_id),),
)
else:
signing_keys = self.read_file(signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key_id = "a_" + random_string(4)
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
key_id,
signing_keys.split("\n")[0]
)
with open(signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)

View File

@ -19,25 +19,88 @@ from twisted.python.log import PythonLoggingObserver
import logging import logging
import logging.config import logging.config
import yaml import yaml
from string import Template
import os
DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
- %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
level: INFO
console:
class: logging.StreamHandler
formatter: precise
loggers:
synapse:
level: INFO
synapse.storage.SQL:
level: INFO
root:
level: INFO
handlers: [file, console]
""")
class LoggingConfig(Config): class LoggingConfig(Config):
def __init__(self, args):
super(LoggingConfig, self).__init__(args)
self.verbosity = int(args.verbose) if args.verbose else None
self.log_config = self.abspath(args.log_config)
self.log_file = self.abspath(args.log_file)
@classmethod def read_config(self, config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name):
log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
# Logging verbosity level.
verbose: 0
# File to write logging to
log_file: "%(log_file)s"
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
def read_arguments(self, args):
if args.verbose is not None:
self.verbosity = args.verbose
if args.log_config is not None:
self.log_config = args.log_config
if args.log_file is not None:
self.log_file = args.log_file
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(LoggingConfig, cls).add_arguments(parser)
logging_group = parser.add_argument_group("logging") logging_group = parser.add_argument_group("logging")
logging_group.add_argument( logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count', '-v', '--verbose', dest="verbose", action='count',
help="The verbosity level." help="The verbosity level."
) )
logging_group.add_argument( logging_group.add_argument(
'-f', '--log-file', dest="log_file", default="homeserver.log", '-f', '--log-file', dest="log_file",
help="File to log to." help="File to log to."
) )
logging_group.add_argument( logging_group.add_argument(
@ -45,6 +108,14 @@ class LoggingConfig(Config):
help="Python logging config file" help="Python logging config file"
) )
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
with open(log_config, "wb") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
)
def setup_logging(self): def setup_logging(self):
log_format = ( log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s" "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
@ -78,7 +149,6 @@ class LoggingConfig(Config):
handler.addFilter(LoggingContextFilter(request="")) handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler) logger.addHandler(handler)
logger.info("Test")
else: else:
with open(self.log_config, 'r') as f: with open(self.log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f)) logging.config.dictConfig(yaml.load(f))

View File

@ -17,20 +17,17 @@ from ._base import Config
class MetricsConfig(Config): class MetricsConfig(Config):
def __init__(self, args): def read_config(self, config):
super(MetricsConfig, self).__init__(args) self.enable_metrics = config["enable_metrics"]
self.enable_metrics = args.enable_metrics self.metrics_port = config.get("metrics_port")
self.metrics_port = args.metrics_port
@classmethod def default_config(self, config_dir_path, server_name):
def add_arguments(cls, parser): return """\
super(MetricsConfig, cls).add_arguments(parser) ## Metrics ###
metrics_group = parser.add_argument_group("metrics")
metrics_group.add_argument( # Enable collection and rendering of performance metrics
'--enable-metrics', dest="enable_metrics", action="store_true", enable_metrics: False
help="Enable collection and rendering of performance metrics"
) # Separate port to accept metrics requests on (on localhost)
metrics_group.add_argument( # metrics_port: 8081
'--metrics-port', metavar="PORT", type=int, """
help="Separate port to accept metrics requests on (on localhost)"
)

View File

@ -17,56 +17,42 @@ from ._base import Config
class RatelimitConfig(Config): class RatelimitConfig(Config):
def __init__(self, args): def read_config(self, config):
super(RatelimitConfig, self).__init__(args) self.rc_messages_per_second = config["rc_messages_per_second"]
self.rc_messages_per_second = args.rc_messages_per_second self.rc_message_burst_count = config["rc_message_burst_count"]
self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = args.federation_rc_window_size self.federation_rc_window_size = config["federation_rc_window_size"]
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
self.federation_rc_reject_limit = args.federation_rc_reject_limit self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = args.federation_rc_concurrent self.federation_rc_concurrent = config["federation_rc_concurrent"]
@classmethod def default_config(self, config_dir_path, server_name):
def add_arguments(cls, parser): return """\
super(RatelimitConfig, cls).add_arguments(parser) ## Ratelimiting ##
rc_group = parser.add_argument_group("ratelimiting")
rc_group.add_argument(
"--rc-messages-per-second", type=float, default=0.2,
help="number of messages a client can send per second"
)
rc_group.add_argument(
"--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled"
)
rc_group.add_argument( # Number of messages a client can send per second
"--federation-rc-window-size", type=int, default=10000, rc_messages_per_second: 0.2
help="The federation window size in milliseconds",
)
rc_group.add_argument( # Number of message a client can send before being throttled
"--federation-rc-sleep-limit", type=int, default=10, rc_message_burst_count: 10.0
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
rc_group.add_argument( # The federation window size in milliseconds
"--federation-rc-sleep-delay", type=int, default=500, federation_rc_window_size: 1000
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
rc_group.add_argument( # The number of federation requests from a single server in a window
"--federation-rc-reject-limit", type=int, default=50, # before the server will delay processing the request.
help="The maximum number of concurrent federation requests allowed" federation_rc_sleep_limit: 10
" from a single server",
)
rc_group.add_argument( # The duration in milliseconds to delay processing events from
"--federation-rc-concurrent", type=int, default=3, # remote servers by if they go over the sleep limit.
help="The number of federation requests to concurrently process" federation_rc_sleep_delay: 500
" from a single server",
) # The maximum number of concurrent federation requests allowed
# from a single server
federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
federation_rc_concurrent: 3
"""

View File

@ -17,44 +17,44 @@ from ._base import Config
from synapse.util.stringutils import random_string_with_symbols from synapse.util.stringutils import random_string_with_symbols
import distutils.util from distutils.util import strtobool
class RegistrationConfig(Config): class RegistrationConfig(Config):
def __init__(self, args): def read_config(self, config):
super(RegistrationConfig, self).__init__(args) self.disable_registration = not bool(
strtobool(str(config["enable_registration"]))
# `args.disable_registration` may either be a bool or a string depending
# on if the option was given a value (e.g. --disable-registration=false
# would set `args.disable_registration` to "false" not False.)
self.disable_registration = bool(
distutils.util.strtobool(str(args.disable_registration))
) )
self.registration_shared_secret = args.registration_shared_secret if "disable_registration" in config:
self.disable_registration = bool(
strtobool(str(config["disable_registration"]))
)
@classmethod self.registration_shared_secret = config.get("registration_shared_secret")
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser) def default_config(self, config_dir, server_name):
registration_shared_secret = random_string_with_symbols(50)
return """\
## Registration ##
# Enable registration for new users.
enable_registration: True
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
""" % locals()
def add_arguments(self, parser):
reg_group = parser.add_argument_group("registration") reg_group = parser.add_argument_group("registration")
reg_group.add_argument( reg_group.add_argument(
"--disable-registration", "--enable-registration", action="store_true", default=None,
const=True, help="Enable registration for new users."
default=True,
nargs='?',
help="Disable registration of new users.",
)
reg_group.add_argument(
"--registration-shared-secret", type=str,
help="If set, allows registration by anyone who also has the shared"
" secret, even if registration is otherwise disabled.",
) )
@classmethod def read_arguments(self, args):
def generate_config(cls, args, config_dir_path): if args.enable_registration is not None:
if args.disable_registration is None: self.disable_registration = not bool(
args.disable_registration = True strtobool(str(args.enable_registration))
)
if args.registration_shared_secret is None:
args.registration_shared_secret = random_string_with_symbols(50)

View File

@ -17,32 +17,20 @@ from ._base import Config
class ContentRepositoryConfig(Config): class ContentRepositoryConfig(Config):
def __init__(self, args): def read_config(self, config):
super(ContentRepositoryConfig, self).__init__(args) self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_upload_size = self.parse_size(args.max_upload_size) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_image_pixels = self.parse_size(args.max_image_pixels) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.media_store_path = self.ensure_directory(args.media_store_path)
def parse_size(self, string): def default_config(self, config_dir_path, server_name):
sizes = {"K": 1024, "M": 1024 * 1024} media_store = self.default_path("media_store")
size = 1 return """
suffix = string[-1] # Directory where uploaded images and attachments are stored.
if suffix in sizes: media_store_path: "%(media_store)s"
string = string[:-1]
size = sizes[suffix]
return int(string) * size
@classmethod # The largest allowed upload size in bytes
def add_arguments(cls, parser): max_upload_size: "10M"
super(ContentRepositoryConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("content_repository") # Maximum number of pixels that will be thumbnailed
db_group.add_argument( max_image_pixels: "32M"
"--max-upload-size", default="10M" """ % locals()
)
db_group.add_argument(
"--media-store-path", default=cls.default_path("media_store")
)
db_group.add_argument(
"--max-image-pixels", default="32M",
help="Maximum number of pixels that will be thumbnailed"
)

View File

@ -13,116 +13,92 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os from ._base import Config
from ._base import Config, ConfigError
import syutil.crypto.signing_key
class ServerConfig(Config): class ServerConfig(Config):
def __init__(self, args):
super(ServerConfig, self).__init__(args)
self.server_name = args.server_name
self.signing_key = self.read_signing_key(args.signing_key_path)
self.bind_port = args.bind_port
self.bind_host = args.bind_host
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
self.web_client = args.web_client
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
if not args.content_addr: def read_config(self, config):
host = args.server_name self.server_name = config["server_name"]
self.bind_port = config["bind_port"]
self.bind_host = config["bind_host"]
self.unsecure_port = config["unsecure_port"]
self.manhole = config.get("manhole")
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
host = self.server_name
if ':' not in host: if ':' not in host:
host = "%s:%d" % (host, args.unsecure_port) host = "%s:%d" % (host, self.unsecure_port)
else: else:
host = host.split(':')[0] host = host.split(':')[0]
host = "%s:%d" % (host, args.unsecure_port) host = "%s:%d" % (host, self.unsecure_port)
args.content_addr = "http://%s" % (host,) content_addr = "http://%s" % (host,)
self.content_addr = args.content_addr self.content_addr = content_addr
@classmethod def default_config(self, config_dir_path, server_name):
def add_arguments(cls, parser): if ":" in server_name:
super(ServerConfig, cls).add_arguments(parser) bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400
else:
bind_port = 8448
unsecure_port = 8008
pid_file = self.abspath("homeserver.pid")
return """\
## Server ##
# The domain name of the server, with optional explicit port.
# This is used by remote servers to connect to this server,
# e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s"
# The port to listen for HTTPS requests on.
# For when matrix traffic is sent directly to synapse.
bind_port: %(bind_port)s
# The port to listen for HTTP requests on.
# For when matrix traffic passes through loadbalancer that unwraps TLS.
unsecure_port: %(unsecure_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_host: ""
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
soft_file_limit: 0
# Turn on the twisted telnet manhole service on localhost on the given
# port.
#manhole: 9000
""" % locals()
def read_arguments(self, args):
if args.manhole is not None:
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
def add_arguments(self, parser):
server_group = parser.add_argument_group("server") server_group = parser.add_argument_group("server")
server_group.add_argument(
"-H", "--server-name", default="localhost",
help="The domain name of the server, with optional explicit port. "
"This is used by remote servers to connect to this server, "
"e.g. matrix.org, localhost:8080, etc."
)
server_group.add_argument("--signing-key-path",
help="The signing key to sign messages with")
server_group.add_argument("-p", "--bind-port", metavar="PORT",
type=int, help="https port to listen on",
default=8448)
server_group.add_argument("--unsecure-port", metavar="PORT",
type=int, help="http port to listen on",
default=8008)
server_group.add_argument("--bind-host", default="",
help="Local interface to listen on")
server_group.add_argument("-D", "--daemonize", action='store_true', server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server") help="Daemonize the home server")
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
server_group.add_argument('--web_client', default=True, type=bool,
help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole", server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"
" service on the given port.") " service on the given port.")
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(ServerConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
args.pid_file = os.path.abspath(args.pid_file)
if not args.signing_key_path:
args.signing_key_path = base_key_name + ".signing.key"
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_singing_key("auto"),),
)
else:
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
"auto",
signing_keys.split("\n")[0]
)
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)

View File

@ -23,37 +23,44 @@ GENERATE_DH_PARAMS = False
class TlsConfig(Config): class TlsConfig(Config):
def __init__(self, args): def read_config(self, config):
super(TlsConfig, self).__init__(args)
self.tls_certificate = self.read_tls_certificate( self.tls_certificate = self.read_tls_certificate(
args.tls_certificate_path config.get("tls_certificate_path")
) )
self.no_tls = args.no_tls self.no_tls = config.get("no_tls", False)
if self.no_tls: if self.no_tls:
self.tls_private_key = None self.tls_private_key = None
else: else:
self.tls_private_key = self.read_tls_private_key( self.tls_private_key = self.read_tls_private_key(
args.tls_private_key_path config.get("tls_private_key_path")
) )
self.tls_dh_params_path = self.check_file( self.tls_dh_params_path = self.check_file(
args.tls_dh_params_path, "tls_dh_params" config.get("tls_dh_params_path"), "tls_dh_params"
) )
@classmethod def default_config(self, config_dir_path, server_name):
def add_arguments(cls, parser): base_key_name = os.path.join(config_dir_path, server_name)
super(TlsConfig, cls).add_arguments(parser)
tls_group = parser.add_argument_group("tls") tls_certificate_path = base_key_name + ".tls.crt"
tls_group.add_argument("--tls-certificate-path", tls_private_key_path = base_key_name + ".tls.key"
help="PEM encoded X509 certificate for TLS") tls_dh_params_path = base_key_name + ".tls.dh"
tls_group.add_argument("--tls-private-key-path",
help="PEM encoded private key for TLS") return """\
tls_group.add_argument("--tls-dh-params-path", # PEM encoded X509 certificate for TLS
help="PEM dh parameters for ephemeral keys") tls_certificate_path: "%(tls_certificate_path)s"
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.") # PEM encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "%(tls_dh_params_path)s"
# Don't bind to the https port
no_tls: False
""" % locals()
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate") cert_pem = self.read_file(cert_path, "tls_certificate")
@ -63,22 +70,13 @@ class TlsConfig(Config):
private_key_pem = self.read_file(private_key_path, "tls_private_key") private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem) return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
@classmethod def generate_files(self, config):
def generate_config(cls, args, config_dir_path): tls_certificate_path = config["tls_certificate_path"]
super(TlsConfig, cls).generate_config(args, config_dir_path) tls_private_key_path = config["tls_private_key_path"]
base_key_name = os.path.join(config_dir_path, args.server_name) tls_dh_params_path = config["tls_dh_params_path"]
if args.tls_certificate_path is None: if not os.path.exists(tls_private_key_path):
args.tls_certificate_path = base_key_name + ".tls.crt" with open(tls_private_key_path, "w") as private_key_file:
if args.tls_private_key_path is None:
args.tls_private_key_path = base_key_name + ".tls.key"
if args.tls_dh_params_path is None:
args.tls_dh_params_path = base_key_name + ".tls.dh"
if not os.path.exists(args.tls_private_key_path):
with open(args.tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey() tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048) tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey( private_key_pem = crypto.dump_privatekey(
@ -86,17 +84,17 @@ class TlsConfig(Config):
) )
private_key_file.write(private_key_pem) private_key_file.write(private_key_pem)
else: else:
with open(args.tls_private_key_path) as private_key_file: with open(tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read() private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey( tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem crypto.FILETYPE_PEM, private_key_pem
) )
if not os.path.exists(args.tls_certificate_path): if not os.path.exists(tls_certificate_path):
with open(args.tls_certificate_path, "w") as certifcate_file: with open(tls_certificate_path, "w") as certifcate_file:
cert = crypto.X509() cert = crypto.X509()
subject = cert.get_subject() subject = cert.get_subject()
subject.CN = args.server_name subject.CN = config["server_name"]
cert.set_serial_number(1000) cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notBefore(0)
@ -110,16 +108,16 @@ class TlsConfig(Config):
certifcate_file.write(cert_pem) certifcate_file.write(cert_pem)
if not os.path.exists(args.tls_dh_params_path): if not os.path.exists(tls_dh_params_path):
if GENERATE_DH_PARAMS: if GENERATE_DH_PARAMS:
subprocess.check_call([ subprocess.check_call([
"openssl", "dhparam", "openssl", "dhparam",
"-outform", "PEM", "-outform", "PEM",
"-out", args.tls_dh_params_path, "-out", tls_dh_params_path,
"2048" "2048"
]) ])
else: else:
with open(args.tls_dh_params_path, "w") as dh_params_file: with open(tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write( dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n" "2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n" "-----BEGIN DH PARAMETERS-----\n"

View File

@ -17,28 +17,21 @@ from ._base import Config
class VoipConfig(Config): class VoipConfig(Config):
def __init__(self, args): def read_config(self, config):
super(VoipConfig, self).__init__(args) self.turn_uris = config.get("turn_uris", [])
self.turn_uris = args.turn_uris self.turn_shared_secret = config["turn_shared_secret"]
self.turn_shared_secret = args.turn_shared_secret self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
self.turn_user_lifetime = args.turn_user_lifetime
@classmethod def default_config(self, config_dir_path, server_name):
def add_arguments(cls, parser): return """\
super(VoipConfig, cls).add_arguments(parser) ## Turn ##
group = parser.add_argument_group("voip")
group.add_argument( # The public URIs of the TURN server to give to clients
"--turn-uris", type=str, default=None, action='append', turn_uris: []
help="The public URIs of the TURN server to give to clients"
) # The shared secret used to compute passwords for the TURN server
group.add_argument( turn_shared_secret: "YOUR_SHARED_SECRET"
"--turn-shared-secret", type=str, default=None,
help=( # How long generated TURN credentials last
"The shared secret used to compute passwords for the TURN" turn_user_lifetime: "1h"
" server" """
)
)
group.add_argument(
"--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
help="How long generated TURN credentials last, in ms"
)

View File

@ -25,12 +25,15 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory): def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
"""Fetch the keys for a remote server.""" """Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory() factory = SynapseKeyClientFactory()
factory.path = path
endpoint = matrix_federation_endpoint( endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30 reactor, server_name, ssl_context_factory, timeout=30
) )
@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory):
server_response, server_certificate = yield protocol.remote_key server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate)) defer.returnValue((server_response, server_certificate))
return return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
except Exception as e: except Exception as e:
logger.exception(e) logger.exception(e)
raise IOError("Cannot get key for %s" % server_name) raise IOError("Cannot get key for %r" % server_name)
class SynapseKeyClientError(Exception): class SynapseKeyClientError(Exception):
"""The key wasn't retrieved from the remote server.""" """The key wasn't retrieved from the remote server."""
status = None
pass pass
@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient):
def connectionMade(self): def connectionMade(self):
self.host = self.transport.getHost() self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host) logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", b"/_matrix/key/v1/") self.sendCommand(b"GET", self.path)
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
self.timeout, self.timeout,
self.on_timeout self.on_timeout
) )
def errback(self, error):
if not self.remote_key.called:
self.remote_key.errback(error)
def callback(self, result):
if not self.remote_key.called:
self.remote_key.callback(result)
def handleStatus(self, version, status, message): def handleStatus(self, version, status, message):
if status != b"200": if status != b"200":
# logger.info("Non-200 response from %s: %s %s", # logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message) # self.transport.getHost(), status, message)
error = SynapseKeyClientError(
"Non-200 response %r from %r" % (status, self.host)
)
error.status = status
self.errback(error)
self.transport.abortConnection() self.transport.abortConnection()
def handleResponse(self, response_body_bytes): def handleResponse(self, response_body_bytes):
@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient):
return return
certificate = self.transport.getPeerCertificate() certificate = self.transport.getPeerCertificate()
self.remote_key.callback((json_response, certificate)) self.callback((json_response, certificate))
self.transport.abortConnection() self.transport.abortConnection()
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host) logger.debug("Timeout waiting for response from %s", self.host)
self.remote_key.errback(IOError("Timeout waiting for response")) self.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()
class SynapseKeyClientFactory(Factory): class SynapseKeyClientFactory(Factory):
protocol = SynapseKeyClientProtocol def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
return protocol

View File

@ -15,7 +15,9 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from twisted.internet import defer from twisted.internet import defer
from syutil.crypto.jsonsign import verify_signed_json, signature_ids from syutil.crypto.jsonsign import (
verify_signed_json, signature_ids, sign_json, encode_canonical_json
)
from syutil.crypto.signing_key import ( from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes is_signing_algorithm_supported, decode_verify_key_bytes
) )
@ -24,8 +26,12 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import create_observer
from OpenSSL import crypto from OpenSSL import crypto
import urllib
import hashlib
import logging import logging
@ -36,8 +42,13 @@ class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.config = hs.get_config()
self.perspective_servers = self.config.perspectives
self.hs = hs self.hs = hs
self.key_downloads = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
logger.debug("Verifying for %s", server_name) logger.debug("Verifying for %s", server_name)
@ -85,19 +96,56 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids): def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids. """Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args: Args:
server_name (str): The name of the server to fetch a key for. server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for. keys_ids (list of str): The key_ids to check for.
""" """
# Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids) cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached: if cached:
defer.returnValue(cached[0]) defer.returnValue(cached[0])
return return
# Try to fetch the key from the remote server. download = self.key_downloads.get(server_name)
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
self.key_downloads[server_name] = download
@download.addBoth
def callback(ret):
del self.key_downloads[server_name]
return ret
r = yield create_observer(download)
defer.returnValue(r)
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None
perspective_results = []
for perspective_name, perspective_keys in self.perspective_servers.items():
@defer.inlineCallbacks
def get_key():
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except:
logging.info(
"Unable to getting key %r for %r from %r",
key_ids, server_name, perspective_name,
)
perspective_results.append(get_key())
perspective_results = yield defer.gatherResults(perspective_results)
for results in perspective_results:
if results is not None:
keys = results
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
server_name, server_name,
@ -106,10 +154,234 @@ class Keyring(object):
) )
with limiter: with limiter:
(response, tls_certificate) = yield fetch_server_key( if keys is None:
server_name, self.hs.tls_context_factory try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except:
pass
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
) )
for key_id in key_ids:
if key_id in keys:
defer.returnValue(keys[key_id])
return
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
perspective_name, self.clock, self.store
)
with limiter:
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
responses = yield self.client.post_json(
destination=perspective_name,
path=b"/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
key_id: {
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
}
},
)
keys = {}
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
raise ValueError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
verified = False
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(
response,
perspective_name,
perspective_keys[key_id]
)
verified = True
if not verified:
logging.info(
"Response from perspective server %r not signed with a"
" known key, signed with: %r, known keys: %r",
perspective_name,
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
raise ValueError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
response_keys = yield self.process_v2_response(
server_name, perspective_name, response
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
for requested_key_id in key_ids:
if requested_key_id in keys:
continue
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory,
path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id),
)).encode("ascii"),
)
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
response_sha256_fingerprints = set()
for fingerprint in response[u"tls_fingerprints"]:
if u"sha256" in fingerprint:
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name,
requested_id=requested_key_id,
response_json=response,
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json,
requested_id=None):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
old_verify_keys = {}
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired = key_data["expired_ts"]
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
for key_id in response_json["signatures"][server_name]:
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response_json,
server_name,
verify_keys[key_id]
)
signed_key_json = sign_json(
response_json,
self.config.server_name,
self.config.signing_key[0],
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set()
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
for key_id in updated_key_ids:
yield self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
defer.returnValue(response_keys)
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Args:
server_name (str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
# Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
)
# Check the response. # Check the response.
x509_certificate_bytes = crypto.dump_certificate( x509_certificate_bytes = crypto.dump_certificate(
@ -128,11 +400,16 @@ class Keyring(object):
if encode_base64(x509_certificate_bytes) != tls_certificate_b64: if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match") raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
verify_keys = {} verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items(): for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id): if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64) key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes) verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]: for key_id in response["signatures"][server_name]:
@ -148,10 +425,6 @@ class Keyring(object):
verify_keys[key_id] verify_keys[key_id]
) )
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
yield self.store.store_server_certificate( yield self.store.store_server_certificate(
server_name, server_name,
server_name, server_name,
@ -159,14 +432,26 @@ class Keyring(object):
tls_certificate, tls_certificate,
) )
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=verify_keys,
)
defer.returnValue(verify_keys)
@defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
for key_id, key in verify_keys.items(): for key_id, key in verify_keys.items():
# TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key( yield self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key server_name, server_name, key.time_added, key
) )
for key_id in key_ids:
if key_id in verify_keys:
defer.returnValue(verify_keys[key_id])
return
raise ValueError("No verification key found for given key ids")

View File

@ -46,9 +46,10 @@ def _event_dict_property(key):
class EventBase(object): class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={}, def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}): internal_metadata_dict={}, rejected_reason=None):
self.signatures = signatures self.signatures = signatures
self.unsigned = unsigned self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._event_dict = event_dict self._event_dict = event_dict
@ -109,7 +110,7 @@ class EventBase(object):
class FrozenEvent(EventBase): class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}): def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
event_dict = dict(event_dict) event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a # Signatures is a dict of dicts, and this is faster than doing a
@ -128,6 +129,7 @@ class FrozenEvent(EventBase):
signatures=signatures, signatures=signatures,
unsigned=unsigned, unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict, internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
) )
@staticmethod @staticmethod

View File

@ -491,7 +491,7 @@ class FederationClient(FederationBase):
] ]
signed_events = yield self._check_sigs_and_hash_and_fetch( signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=True destination, events, outlier=False
) )
have_gotten_all_from_destination = True have_gotten_all_from_destination = True

View File

@ -417,13 +417,13 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth: elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen: if get_missing and prevs - seen:
latest_tuples = yield self.store.get_latest_events_in_room( latest = yield self.store.get_latest_event_ids_in_room(
pdu.room_id pdu.room_id
) )
# We add the prev events that we have seen to the latest # We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us # list to ensure the remote server doesn't give them to us
latest = set(e_id for e_id, _, _ in latest_tuples) latest = set(latest)
latest |= seen latest |= seen
missing_events = yield self.get_missing_events( missing_events = yield self.get_missing_events(

View File

@ -23,8 +23,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from syutil.jsonutil import encode_canonical_json
import logging import logging
@ -71,7 +69,7 @@ class TransactionActions(object):
transaction.transaction_id, transaction.transaction_id,
transaction.origin, transaction.origin,
code, code,
encode_canonical_json(response) response,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -101,5 +99,5 @@ class TransactionActions(object):
transaction.transaction_id, transaction.transaction_id,
transaction.destination, transaction.destination,
response_code, response_code,
encode_canonical_json(response_dict) response_dict,
) )

View File

@ -104,7 +104,6 @@ class TransactionQueue(object):
return not destination.startswith("localhost") return not destination.startswith("localhost")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order): def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have # We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus # a transaction in progress. If we do, stick it in the pending_pdus

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
@ -29,6 +30,8 @@ from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
class Handlers(object): class Handlers(object):
@ -54,7 +57,14 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler( self.appservice_handler = ApplicationServicesHandler(
hs, ApplicationServiceApi(hs) hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
) )
self.sync_handler = SyncHandler(hs) self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)

View File

@ -16,7 +16,6 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID
@ -58,8 +57,6 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder): def _create_new_client_event(self, builder):
yield run_on_reactor()
latest_ret = yield self.store.get_latest_events_in_room( latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id, builder.room_id,
) )
@ -101,8 +98,6 @@ class BaseHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[], def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False): extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth: if not suppress_auth:
@ -143,7 +138,9 @@ class BaseHandler(object):
) )
# Don't block waiting on waking up all the listeners. # Don't block waiting on waking up all the listeners.
d = self.notifier.on_new_room_event(event, extra_users=extra_users) notify_d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f): def log_failure(f):
logger.warn( logger.warn(
@ -151,8 +148,8 @@ class BaseHandler(object):
event.event_id, f.value event.event_id, f.value
) )
d.addErrback(log_failure) notify_d.addErrback(log_failure)
yield federation_handler.handle_new_event( federation_handler.handle_new_event(
event, destinations=destinations, event, destinations=destinations,
) )

View File

@ -16,57 +16,36 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.types import UserID from synapse.types import UserID
import synapse.util.stringutils as stringutils
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log_failure(failure):
logger.error(
"Application Services Failure",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject()
)
)
# NB: Purposefully not inheriting BaseHandler since that contains way too much # NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot # setup code which this handler does not need or use. This makes testing a lot
# easier. # easier.
class ApplicationServicesHandler(object): class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api): def __init__(self, hs, appservice_api, appservice_scheduler):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.hs = hs self.hs = hs
self.appservice_api = appservice_api self.appservice_api = appservice_api
self.scheduler = appservice_scheduler
@defer.inlineCallbacks self.started_scheduler = False
def register(self, app_service):
logger.info("Register -> %s", app_service)
# check the token is recognised
try:
stored_service = yield self.store.get_app_service_by_token(
app_service.token
)
if not stored_service:
raise StoreError(404, "Application service not found")
except StoreError:
raise SynapseError(
403, "Unrecognised application services token. "
"Consult the home server admin.",
errcode=Codes.FORBIDDEN
)
app_service.hs_token = self._generate_hs_token()
# create a sender for this application service which is used when
# creating rooms, etc..
account = yield self.hs.get_handlers().registration_handler.register()
app_service.sender = account[0]
yield self.store.update_app_service(app_service)
defer.returnValue(app_service)
@defer.inlineCallbacks
def unregister(self, token):
logger.info("Unregister as_token=%s", token)
yield self.store.unregister_app_service(token)
@defer.inlineCallbacks @defer.inlineCallbacks
def notify_interested_services(self, event): def notify_interested_services(self, event):
@ -90,9 +69,13 @@ class ApplicationServicesHandler(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key) yield self._check_user_exists(event.state_key)
# Fork off pushes to these services - XXX First cut, best effort if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
for service in services: for service in services:
self.appservice_api.push(service, event) self.scheduler.submit_event_for_as(service, event)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user_exists(self, user_id): def query_user_exists(self, user_id):
@ -197,7 +180,14 @@ class ApplicationServicesHandler(object):
return return
user_info = yield self.store.get_user_by_id(user_id) user_info = yield self.store.get_user_by_id(user_id)
defer.returnValue(len(user_info) == 0) if not user_info:
defer.returnValue(False)
return
# user not found; could be the AS though, so check.
services = yield self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_user_exists(self, user_id): def _check_user_exists(self, user_id):
@ -206,6 +196,3 @@ class ApplicationServicesHandler(object):
exists = yield self.query_user_exists(user_id) exists = yield self.query_user_exists(user_id)
defer.returnValue(exists) defer.returnValue(exists)
defer.returnValue(True) defer.returnValue(True)
def _generate_hs_token(self):
return stringutils.random_string(24)

277
synapse/handlers/auth.py Normal file
View File

@ -0,0 +1,277 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
import logging
import bcrypt
import simplejson
import synapse.util.stringutils as stringutils
logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler):
def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth,
}
self.sessions = {}
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
Args:
flows: list of list of stages
authdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
Returns:
A tuple of authed, dict, dict where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
"""
authdict = None
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
sess = self._get_session_info(sid)
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a home server.
# sess['clientdict'] = clientdict
# self._save_session(sess)
pass
elif 'clientdict' in sess:
clientdict = sess['clientdict']
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict)
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
# check auth type currently being presented
if 'type' in authdict:
if authdict['type'] not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip)
if result:
creds[authdict['type']] = result
self._save_session(sess)
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess)
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
if 'session' not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(
authdict['session']
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
result = yield self.checkers[stagetype](authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"]
password = authdict["password"]
if not user.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s",
user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = SimpleHttpClient(self.hs)
data = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify",
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
'remoteip': clientip,
}
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _):
yield run_on_reactor()
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
def _auth_dict_for_flows(self, flows, session):
public_flows = []
for f in flows:
public_flows.append(f)
get_params = {
LoginType.RECAPTCHA: self._get_params_recaptcha,
}
params = {}
for f in public_flows:
for stage in f:
if stage in get_params and stage not in params:
params[stage] = get_params[stage]()
return {
"session": session['id'],
"flows": [{"stages": f} for f in public_flows],
"params": params
}
def _get_session_info(self, session_id):
if session_id not in self.sessions:
session_id = None
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
}
return self.sessions[session_id]
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]

View File

@ -73,8 +73,6 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up # When joining a room we need to queue any events for that room up
self.room_queues = {} self.room_queues = {}
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, destinations): def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already """ Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any been authed and handled by the state module, and sends it to any
@ -89,9 +87,7 @@ class FederationHandler(BaseHandler):
processing. processing.
""" """
yield run_on_reactor() return self.replication_layer.send_pdu(event, destinations)
self.replication_layer.send_pdu(event, destinations)
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
@ -179,7 +175,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state # it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap) # for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin) retry_timings = yield self.store.get_destination_retry_timings(origin)
if (retry_timings and retry_timings.retry_last_ts): if retry_timings and retry_timings["retry_last_ts"]:
self.store.set_destination_retry_timings(origin, 0, 0) self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id) room = yield self.store.get_room(event.room_id)
@ -201,10 +197,18 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -427,10 +431,18 @@ class FederationHandler(BaseHandler):
auth_events=auth_events, auth_events=auth_events,
) )
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
new_event, extra_users=[joinee] new_event, extra_users=[joinee]
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
new_event.event_id, f.value
)
d.addErrback(log_failure)
logger.debug("Finished joining %s to %s", joinee, room_id) logger.debug("Finished joining %s to %s", joinee, room_id)
finally: finally:
room_queue = self.room_queues[room_id] room_queue = self.room_queues[room_id]
@ -500,10 +512,18 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=extra_users event, extra_users=extra_users
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
@ -574,10 +594,18 @@ class FederationHandler(BaseHandler):
) )
target_user = UserID.from_string(event.state_key) target_user = UserID.from_string(event.state_key)
yield self.notifier.on_new_room_event( d = self.notifier.on_new_room_event(
event, extra_users=[target_user], event, extra_users=[target_user],
) )
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
from twisted.internet import defer
from synapse.api.errors import (
CodeMessageException
)
from ._base import BaseHandler
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError
import json
import logging
logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org']
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
else:
raise SynapseError(400, "No client_secret in creds")
if id_server not in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'client_secret': client_secret}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
http_client = SimpleHttpClient(self.hs)
data = None
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
else:
raise SynapseError(400, "No client_secret in creds")
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
id_server, "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'client_secret': client_secret,
'mxid': mxid,
}
)
logger.debug("bound threepid %r to %s", creds, mxid)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@ -16,13 +16,9 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes, CodeMessageException from synapse.api.errors import LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -57,7 +53,7 @@ class LoginHandler(BaseHandler):
logger.warn("Attempted to login as %s but they do not exist", user) logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"] stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash): if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it. # generate an access token and store it.
token = self.reg_handler._generate_token(user) token = self.reg_handler._generate_token(user)
@ -69,48 +65,19 @@ class LoginHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN) raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks @defer.inlineCallbacks
def reset_password(self, user_id, email): def set_password(self, user_id, newpassword, token_id=None):
is_valid = yield self._check_valid_association(user_id, email) password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
is_valid) yield self.store.user_set_password_hash(user_id, password_hash)
if is_valid: yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
try: yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
# send an email out user_id, token_id
emailutils.send_email( )
smtp_server=self.hs.config.email_smtp_server, yield self.store.flush_user(user_id)
from_addr=self.hs.config.email_from_address,
to_addr=email,
subject="Password Reset",
body="TODO."
)
except EmailException as e:
logger.exception(e)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_valid_association(self, user_id, email): def add_threepid(self, user_id, medium, address, validated_at):
identity = yield self._query_email(email) yield self.store.user_add_threepid(
if identity and "mxid" in identity: user_id, medium, address, validated_at,
if identity["mxid"] == user_id: self.hs.get_clock().time_msec()
defer.returnValue(True) )
return
defer.returnValue(False)
@defer.inlineCallbacks
def _query_email(self, email):
http_client = SimpleHttpClient(self.hs)
try:
data = yield http_client.get_json(
# TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS
"http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
)
defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@ -267,14 +267,14 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None user, pagination_config.get_source_config("presence"), None
) )
public_rooms = yield self.store.get_rooms(is_public=True) public_room_ids = yield self.store.get_public_room_ids()
public_room_ids = [r["room_id"] for r in public_rooms]
limit = pagin_config.limit limit = pagin_config.limit
if limit is None: if limit is None:
limit = 10 limit = 10
for event in room_list: @defer.inlineCallbacks
def handle_room(event):
d = { d = {
"room_id": event.room_id, "room_id": event.room_id,
"membership": event.membership, "membership": event.membership,
@ -290,12 +290,19 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d) rooms_ret.append(d)
if event.membership != Membership.JOIN: if event.membership != Membership.JOIN:
continue return
try: try:
messages, token = yield self.store.get_recent_events_for_room( (messages, token), current_state = yield defer.gatherResults(
event.room_id, [
limit=limit, self.store.get_recent_events_for_room(
end_token=now_token.room_key, event.room_id,
limit=limit,
end_token=now_token.room_key,
),
self.state_handler.get_current_state(
event.room_id
),
]
) )
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
@ -311,9 +318,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
} }
current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [ d["state"] = [
serialize_event(c, time_now, as_client_event) serialize_event(c, time_now, as_client_event)
for c in current_state.values() for c in current_state.values()
@ -321,6 +325,11 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
yield defer.gatherResults(
[handle_room(e) for e in room_list],
consumeErrors=True
)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,

View File

@ -33,6 +33,13 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__) metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000
# Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000
# TODO(paul): Maybe there's one of these I can steal from somewhere # TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func): def partition(l, func):
"""Partition the list by the result of func applied to each element.""" """Partition the list by the result of func applied to each element."""
@ -131,6 +138,9 @@ class PresenceHandler(BaseHandler):
self._remote_sendmap = {} self._remote_sendmap = {}
# map remote users to sets of local users who're interested in them # map remote users to sets of local users who're interested in them
self._remote_recvmap = {} self._remote_recvmap = {}
# list of (serial, set of(userids)) tuples, ordered by serial, latest
# first
self._remote_offline_serials = []
# map any user to a UserPresenceCache # map any user to a UserPresenceCache
self._user_cachemap = {} self._user_cachemap = {}
@ -282,6 +292,10 @@ class PresenceHandler(BaseHandler):
if now is None: if now is None:
now = self.clock.time_msec() now = self.clock.time_msec()
prev_state = self._get_or_make_usercache(user)
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
self.changed_presencelike_data(user, {"last_active": now}) self.changed_presencelike_data(user, {"last_active": now})
def changed_presencelike_data(self, user, state): def changed_presencelike_data(self, user, state):
@ -706,8 +720,24 @@ class PresenceHandler(BaseHandler):
statuscache=statuscache, statuscache=statuscache,
) )
user_id = user.to_string()
if state["presence"] == PresenceState.OFFLINE: if state["presence"] == PresenceState.OFFLINE:
self._remote_offline_serials.insert(
0,
(self._user_cachemap_latest_serial, set([user_id]))
)
while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
self._remote_offline_serials.pop() # remove the oldest
del self._user_cachemap[user] del self._user_cachemap[user]
else:
# Remove the user from remote_offline_serials now that they're
# no longer offline
for idx, elem in enumerate(self._remote_offline_serials):
(_, user_ids) = elem
user_ids.discard(user_id)
if not user_ids:
self._remote_offline_serials.pop(idx)
for poll in content.get("poll", []): for poll in content.get("poll", []):
user = UserID.from_string(poll) user = UserID.from_string(poll)
@ -829,26 +859,47 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
max_serial = presence._user_cachemap_latest_serial
clock = self.clock
latest_serial = 0
updates = [] updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency. # TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys(): for observed_user in cachemap.keys():
cached = cachemap[observed_user] cached = cachemap[observed_user]
if cached.serial <= from_key: if cached.serial <= from_key or cached.serial > max_serial:
continue continue
if (yield self.is_visible(observer_user, observed_user)): if not (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cached)) continue
latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock))
# TODO(paul): limit # TODO(paul): limit
for serial, user_ids in presence._remote_offline_serials:
if serial <= from_key:
break
if serial > max_serial:
continue
latest_serial = max(latest_serial, serial)
for u in user_ids:
updates.append({
"type": "m.presence",
"content": {"user_id": u, "presence": PresenceState.OFFLINE},
})
# TODO(paul): For the v2 API we want to tell the client their from_key
# is too old if we fell off the end of the _remote_offline_serials
# list, and get them to invalidate+resync. In v1 we have no such
# concept so this is a best-effort result.
if updates: if updates:
clock = self.clock defer.returnValue((updates, latest_serial))
latest_serial = max([x[1].serial for x in updates])
data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
defer.returnValue((data, latest_serial))
else: else:
defer.returnValue(([], presence._user_cachemap_latest_serial)) defer.returnValue(([], presence._user_cachemap_latest_serial))

View File

@ -18,18 +18,15 @@ from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError, AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
CodeMessageException
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
import base64 import base64
import bcrypt import bcrypt
import json
import logging import logging
import urllib import urllib
@ -44,6 +41,30 @@ class RegistrationHandler(BaseHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
@defer.inlineCallbacks
def check_username(self, localpart):
yield run_on_reactor()
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id)
if u:
raise SynapseError(
400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def register(self, localpart=None, password=None): def register(self, localpart=None, password=None):
"""Registers a new client on the server. """Registers a new client on the server.
@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt()) password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart: if localpart:
if localpart and urllib.quote(localpart) != localpart: yield self.check_username(localpart)
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response): def check_recaptcha(self, ip, private_key, challenge, response):
"""Checks a recaptcha is correct.""" """
Checks a recaptcha is correct.
Used only by c/s api v1
"""
captcha_response = yield self._validate_captcha( captcha_response = yield self._validate_captcha(
ip, ip,
@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def register_email(self, threepidCreds): def register_email(self, threepidCreds):
"""Registers emails with an identity server.""" """
Registers emails with an identity server.
Used only by c/s api v1
"""
for c in threepidCreds: for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s", logger.info("validating theeepidcred sid %s on id server %s",
c['sid'], c['idServer']) c['sid'], c['idServer'])
try: try:
threepid = yield self._threepid_from_creds(c) identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
except: except:
logger.exception("Couldn't validate 3pid") logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid") raise RegistrationError(400, "Couldn't validate 3pid")
@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds): def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.""" """Links emails with a user ID and informs an identity server.
Used only by c/s api v1
"""
# Now we have a matrix ID, bind it to the threepids we were given # Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds: for c in threepidCreds:
identity_handler = self.hs.get_handlers().identity_handler
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id) yield identity_handler.bind_threepid(c, user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_user_id_is_valid(self, user_id): def check_user_id_is_valid(self, user_id):
@ -226,62 +253,12 @@ class RegistrationHandler(BaseHandler):
def _generate_user_id(self): def _generate_user_id(self):
return "-" + stringutils.random_string(18) return "-" + stringutils.random_string(18)
@defer.inlineCallbacks
def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
# XXX: This should be HTTPS
"http://%s%s" % (
creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def _bind_threepid(self, creds, mxid):
yield
logger.debug("binding threepid")
http_client = SimpleHttpClient(self.hs)
data = None
try:
data = yield http_client.post_urlencoded_get_json(
# XXX: Change when ID servers are all HTTPS
"http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'clientSecret': creds['clientSecret'],
'mxid': mxid,
}
)
logger.debug("bound threepid")
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response): def _validate_captcha(self, ip_addr, private_key, challenge, response):
"""Validates the captcha provided. """Validates the captcha provided.
Used only by c/s api v1
Returns: Returns:
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid. dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response): def _submit_captcha(self, ip_addr, private_key, challenge, response):
"""
Used only by c/s api v1
"""
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
client = CaptchaServerHttpClient(self.hs) client = CaptchaServerHttpClient(self.hs)

View File

@ -124,7 +124,7 @@ class RoomCreationHandler(BaseHandler):
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
for event in creation_events: for event in creation_events:
yield msg_handler.create_and_send_event(event) yield msg_handler.create_and_send_event(event, ratelimit=False)
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
@ -134,7 +134,7 @@ class RoomCreationHandler(BaseHandler):
"sender": user_id, "sender": user_id,
"state_key": "", "state_key": "",
"content": {"name": name}, "content": {"name": name},
}) }, ratelimit=False)
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
@ -144,7 +144,7 @@ class RoomCreationHandler(BaseHandler):
"sender": user_id, "sender": user_id,
"state_key": "", "state_key": "",
"content": {"topic": topic}, "content": {"topic": topic},
}) }, ratelimit=False)
for invitee in invite_list: for invitee in invite_list:
yield msg_handler.create_and_send_event({ yield msg_handler.create_and_send_event({
@ -153,7 +153,7 @@ class RoomCreationHandler(BaseHandler):
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": user_id,
"content": {"membership": Membership.INVITE}, "content": {"membership": Membership.INVITE},
}) }, ratelimit=False)
result = {"room_id": room_id} result = {"room_id": room_id}
@ -213,7 +213,8 @@ class RoomCreationHandler(BaseHandler):
"state_default": 50, "state_default": 50,
"ban": 50, "ban": 50,
"kick": 50, "kick": 50,
"redact": 50 "redact": 50,
"invite": 0,
}, },
) )
@ -310,25 +311,6 @@ class RoomMemberHandler(BaseHandler):
# paginating # paginating
defer.returnValue(chunk_data) defer.returnValue(chunk_data)
@defer.inlineCallbacks
def get_room_member(self, room_id, member_user_id, auth_user_id):
"""Retrieve a room member from a room.
Args:
room_id : The room the member is in.
member_user_id : The member's user ID
auth_user_id : The user ID of the user making this request.
Returns:
The room member, or None if this member does not exist.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, auth_user_id)
member = yield self.store.get_room_member(user_id=member_user_id,
room_id=room_id)
defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True): def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room. """ Change the membership status of a user in a room.
@ -547,11 +529,19 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_public_room_list(self): def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True) chunk = yield self.store.get_rooms(is_public=True)
for room in chunk: results = yield defer.gatherResults(
joined_users = yield self.store.get_users_in_room( [
room_id=room["room_id"], self.store.get_users_in_room(
) room_id=room["room_id"],
room["num_joined_members"] = len(joined_users) )
for room in chunk
],
consumeErrors=True,
)
for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i])
# FIXME (erikj): START is no longer a valid value # FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk}) defer.returnValue({"start": "START", "end": "END", "chunk": chunk})

View File

@ -223,6 +223,7 @@ class TypingNotificationEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self._handler = None self._handler = None
self._room_member_handler = None
def handler(self): def handler(self):
# Avoid cyclic dependency in handler setup # Avoid cyclic dependency in handler setup
@ -230,6 +231,11 @@ class TypingNotificationEventSource(object):
self._handler = self.hs.get_handlers().typing_notification_handler self._handler = self.hs.get_handlers().typing_notification_handler
return self._handler return self._handler
def room_member_handler(self):
if not self._room_member_handler:
self._room_member_handler = self.hs.get_handlers().room_member_handler
return self._room_member_handler
def _make_event_for(self, room_id): def _make_event_for(self, room_id):
typing = self.handler()._room_typing[room_id] typing = self.handler()._room_typing[room_id]
return { return {
@ -240,19 +246,25 @@ class TypingNotificationEventSource(object):
}, },
} }
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = [] events = []
for room_id in handler._room_serials: for room_id in handler._room_serials:
if room_id not in joined_room_ids:
continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
continue continue
# TODO: check if user is in room
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
return (events, handler._latest_room_serial) defer.returnValue((events, handler._latest_room_serial))
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View File

@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient):
""" """
Separate HTTP client for talking to google's captcha servers Separate HTTP client for talking to google's captcha servers
Only slightly special because accepts partial download responses Only slightly special because accepts partial download responses
used only by c/s api v1
""" """
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -24,7 +24,7 @@ from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json encode_canonical_json, encode_pretty_printed_json
) )
from twisted.internet import defer, reactor from twisted.internet import defer
from twisted.web import server, resource from twisted.web import server, resource
from twisted.web.server import NOT_DONE_YET from twisted.web.server import NOT_DONE_YET
from twisted.web.util import redirectTo from twisted.web.util import redirectTo
@ -51,16 +51,90 @@ response_timer = metrics.register_distribution(
labels=["method", "servlet"] labels=["method", "servlet"]
) )
_next_request_id = 0
def request_handler(request_handler):
"""Wraps a method that acts as a request handler with the necessary logging
and exception handling.
The method must have a signature of "handle_foo(self, request)". The
argument "self" must have "version_string" and "clock" attributes. The
argument "request" must be a twisted HTTP request.
The method must return a deferred. If the deferred succeeds we assume that
a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
We insert a unique request-id into the logging context for this request and
log the response and duration for this request.
"""
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
global _next_request_id
request_id = "%s-%s" % (request.method, _next_request_id)
_next_request_id += 1
with LoggingContext(request_id) as request_context:
request_context.request = request_id
code = None
start = self.clock.time_msec()
try:
logger.info(
"Received request: %s %s",
request.method, request.path
)
yield request_handler(self, request)
code = request.code
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
code = 500
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
return wrapped_request_handler
class HttpServer(object): class HttpServer(object):
""" Interface for registering callbacks on a HTTP server """ Interface for registering callbacks on a HTTP server
""" """
def register_path(self, method, path_pattern, callback): def register_path(self, method, path_pattern, callback):
""" Register a callback that get's fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
If the regex contains groups these get's passed to the calback via If the regex contains groups these gets passed to the calback via
an unpacked tuple. an unpacked tuple.
Args: Args:
@ -79,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource):
Resources. Resources.
Register callbacks via register_path() Register callbacks via register_path()
Callbacks can return a tuple of status code and a dict in which case the
the dict will automatically be sent to the client as a JSON object.
The JsonResource is primarily intended for returning JSON, but callbacks
may send something other than JSON, they may do so by using the methods
on the request object and instead returning None.
""" """
isLeaf = True isLeaf = True
@ -98,119 +179,61 @@ class JsonResource(HttpServer, resource.Resource):
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )
def start_listening(self, port):
""" Registers the http server with the twisted reactor.
Args:
port (int): The port to listen on.
"""
reactor.listenTCP(
port,
server.Site(self),
interface=self.hs.config.bind_host
)
# Gets called by twisted
def render(self, request): def render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This gets called by twisted every time someone sends us a request.
""" """
self._async_render_with_logging_context(request) self._async_render(request)
return server.NOT_DONE_YET return server.NOT_DONE_YET
_request_id = 0 @request_handler
@defer.inlineCallbacks
def _async_render_with_logging_context(self, request):
request_id = "%s-%s" % (request.method, JsonResource._request_id)
JsonResource._request_id += 1
with LoggingContext(request_id) as request_context:
request_context.request = request_id
yield self._async_render(request)
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render(self, request): def _async_render(self, request):
""" This get's called by twisted every time someone sends us a request. """ This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
code = None
start = self.clock.time_msec() start = self.clock.time_msec()
try: if request.method == "OPTIONS":
# Just say yes to OPTIONS. self._send_response(request, 200, {})
if request.method == "OPTIONS": return
self._send_response(request, 200, {}) # Loop through all the registered callbacks to check if the method
return # and path regex match
for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path)
if not m:
continue
# Loop through all the registered callbacks to check if the method # We found a match! Trigger callback and then return the
# and path regex match # returned response. We pass both the request and any
for path_entry in self.path_regexs.get(request.method, []): # matched groups from the regex to the callback.
m = path_entry.pattern.match(request.path)
if not m:
continue
# We found a match! Trigger callback and then return the callback = path_entry.callback
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
callback = path_entry.callback servlet_instance = getattr(callback, "__self__", None)
if servlet_instance is not None:
servlet_instance = getattr(callback, "__self__", None) servlet_classname = servlet_instance.__class__.__name__
if servlet_instance is not None:
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
logger.info(
"Received request: %s %s",
request.method, request.path
)
code, response = yield callback(request, *args)
self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
)
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError()
except CodeMessageException as e:
if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
else: else:
logger.exception(e) servlet_classname = "%r" % callback
incoming_requests_counter.inc(request.method, servlet_classname)
code = e.code args = [
self._send_response( urllib.unquote(u).decode("UTF-8") for u in m.groups()
request, ]
code,
cs_exception(e),
response_code_message=e.response_code_message
)
except Exception as e:
logger.exception(e)
self._send_response(
request,
500,
{"error": "Internal server error"}
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec() callback_return = yield callback(request, *args)
logger.info( if callback_return is not None:
"Processed request: %dms %s %s %s", code, response = callback_return
end-start, code, request.method, request.path self._send_response(request, code, response)
response_timer.inc_by(
self.clock.time_msec() - start, request.method, servlet_classname
) )
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
raise UnrecognizedRequestError()
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
# could alternatively use request.notifyFinish() and flip a flag when # could alternatively use request.notifyFinish() and flip a flag when
@ -229,20 +252,10 @@ class JsonResource(HttpServer, resource.Resource):
request, code, response_json_object, request, code, response_json_object,
send_cors=True, send_cors=True,
response_code_message=response_code_message, response_code_message=response_code_message,
pretty_print=self._request_user_agent_is_curl, pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string, version_string=self.version_string,
) )
@staticmethod
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
"User-Agent", default=[]
)
for user_agent in user_agents:
if "curl" in user_agent:
return True
return False
class RootRedirect(resource.Resource): class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path.""" """Redirects the root '/' path to another path."""
@ -263,8 +276,8 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False, response_code_message=None, pretty_print=False,
version_string=""): version_string=""):
if not pretty_print: if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) json_bytes = encode_pretty_printed_json(json_object) + "\n"
else: else:
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
@ -304,3 +317,13 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.write(json_bytes) request.write(json_bytes)
request.finish() request.finish()
return NOT_DONE_YET return NOT_DONE_YET
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
"User-Agent", default=[]
)
for user_agent in user_agents:
if "curl" in user_agent:
return True
return False

View File

@ -23,6 +23,61 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default
class RestServlet(object): class RestServlet(object):
""" A Synapse REST Servlet. """ A Synapse REST Servlet.
@ -56,58 +111,3 @@ class RestServlet(object):
http_server.register_path(method, pattern, method_handler) http_server.register_path(method, pattern, method_handler)
else: else:
raise NotImplementedError("RestServlet must register something.") raise NotImplementedError("RestServlet must register something.")
@staticmethod
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
import logging import logging
from resource import getrusage, getpagesize, RUSAGE_SELF from resource import getrusage, getpagesize, RUSAGE_SELF
import os
import stat
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -109,3 +111,36 @@ resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# pages # pages
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE) resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
TYPES = {
stat.S_IFSOCK: "SOCK",
stat.S_IFLNK: "LNK",
stat.S_IFREG: "REG",
stat.S_IFBLK: "BLK",
stat.S_IFDIR: "DIR",
stat.S_IFCHR: "CHR",
stat.S_IFIFO: "FIFO",
}
def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
fmt = stat.S_IFMT(s.st_mode)
if fmt in TYPES:
t = TYPES[fmt]
else:
t = "other"
counts[(t,)] += 1
except OSError:
# the dirh itself used by listdir() is usually missing by now
pass
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])

View File

@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -59,10 +58,11 @@ class _NotificationListener(object):
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
self.deferred = deferred self.deferred = deferred
self.rooms = rooms self.rooms = rooms
self.timer = None
self.pending_notifications = [] def notified(self):
return self.deferred.called
def notify(self, notifier, events, start_token, end_token): def notify(self, notifier, events, start_token, end_token):
""" Inform whoever is listening about the new events. This will """ Inform whoever is listening about the new events. This will
@ -78,16 +78,27 @@ class _NotificationListener(object):
except defer.AlreadyCalledError: except defer.AlreadyCalledError:
pass pass
# Should the following be done be using intrusively linked lists?
# -- erikj
for room in self.rooms: for room in self.rooms:
lst = notifier.room_to_listeners.get(room, set()) lst = notifier.room_to_listeners.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice: if self.appservice:
notifier.appservice_to_listeners.get( notifier.appservice_to_listeners.get(
self.appservice, set() self.appservice, set()
).discard(self) ).discard(self)
# Cancel the timeout for this notifer if one exists.
if self.timer is not None:
try:
notifier.clock.cancel_call_later(self.timer)
except:
logger.warn("Failed to cancel notifier timer")
class Notifier(object): class Notifier(object):
""" This class is responsible for notifying any listeners when there are """ This class is responsible for notifying any listeners when there are
@ -150,8 +161,6 @@ class Notifier(object):
listening to the room, and any listeners for the users in the listening to the room, and any listeners for the users in the
`extra_users` param. `extra_users` param.
""" """
yield run_on_reactor()
# poke any interested application service. # poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services( self.hs.get_handlers().appservice_handler.notify_interested_services(
event event
@ -161,10 +170,18 @@ class Notifier(object):
room_source = self.event_sources.sources["room"] room_source = self.event_sources.sources["room"]
listeners = self.room_to_listeners.get(room_id, set()).copy() room_listeners = self.room_to_listeners.get(room_id, set())
_discard_if_notified(room_listeners)
listeners = room_listeners.copy()
for user in extra_users: for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy() user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for appservice in self.appservice_to_listeners: for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
@ -173,9 +190,13 @@ class Notifier(object):
# receive *invites* for users they are interested in. Does this # receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete? # make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event): if appservice.is_interested(event):
listeners |= self.appservice_to_listeners.get( app_listeners = self.appservice_to_listeners.get(
appservice, set() appservice, set()
).copy() )
_discard_if_notified(app_listeners)
listeners |= app_listeners
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
@ -216,8 +237,6 @@ class Notifier(object):
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
yield run_on_reactor()
# TODO(paul): This is horrible, having to manually list every event # TODO(paul): This is horrible, having to manually list every event
# source here individually # source here individually
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
@ -226,10 +245,18 @@ class Notifier(object):
listeners = set() listeners = set()
for user in users: for user in users:
listeners |= self.user_to_listeners.get(user, set()).copy() user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for room in rooms: for room in rooms:
listeners |= self.room_to_listeners.get(room, set()).copy() room_listeners = self.room_to_listeners.get(room, set())
_discard_if_notified(room_listeners)
listeners |= room_listeners
@defer.inlineCallbacks @defer.inlineCallbacks
def notify(listener): def notify(listener):
@ -300,14 +327,20 @@ class Notifier(object):
self._register_with_keys(listener[0]) self._register_with_keys(listener[0])
result = yield callback() result = yield callback()
timer = [None]
if timeout: if timeout:
timed_out = [False] timed_out = [False]
def _timeout_listener(): def _timeout_listener():
timed_out[0] = True timed_out[0] = True
timer[0] = None
listener[0].notify(self, [], from_token, from_token) listener[0].notify(self, [], from_token, from_token)
self.clock.call_later(timeout/1000., _timeout_listener) # We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]: while not result and not timed_out[0]:
yield deferred yield deferred
deferred = defer.Deferred() deferred = defer.Deferred()
@ -322,6 +355,12 @@ class Notifier(object):
self._register_with_keys(listener[0]) self._register_with_keys(listener[0])
result = yield callback() result = yield callback()
if timer[0] is not None:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
defer.returnValue(result) defer.returnValue(result)
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout):
@ -360,6 +399,8 @@ class Notifier(object):
def _timeout_listener(): def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current # TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token. # max rather than reusing from_token.
# Remove the timer from the listener so we don't try to cancel it.
listener.timer = None
listener.notify( listener.notify(
self, self,
[], [],
@ -375,8 +416,11 @@ class Notifier(object):
if not timeout: if not timeout:
_timeout_listener() _timeout_listener()
else: else:
self.clock.call_later(timeout/1000.0, _timeout_listener) # Only add the timer if the listener hasn't been notified
if not listener.notified():
listener.timer = self.clock.call_later(
timeout/1000.0, _timeout_listener
)
return return
@log_function @log_function
@ -427,3 +471,17 @@ class Notifier(object):
listeners = self.room_to_listeners.setdefault(room_id, set()) listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners listeners |= new_listeners
for l in new_listeners:
l.rooms.add(room_id)
def _discard_if_notified(listener_set):
"""Remove any 'stale' listeners from the given set.
"""
to_discard = set()
for l in listener_set:
if l.notified():
to_discard.add(l)
listener_set -= to_discard

View File

@ -253,7 +253,8 @@ class Pusher(object):
self.user_name, config, timeout=0) self.user_name, config, timeout=0)
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.last_token) self.app_id, self.pushkey, self.user_name, self.last_token
)
logger.info("Pusher %s for user %s starting from token %s", logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token) self.pushkey, self.user_name, self.last_token)
@ -314,7 +315,7 @@ class Pusher(object):
pk pk
) )
yield self.hs.get_pusherpool().remove_pusher( yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk self.app_id, pk, self.user_name
) )
if not self.alive: if not self.alive:
@ -326,6 +327,7 @@ class Pusher(object):
self.store.update_pusher_last_token_and_success( self.store.update_pusher_last_token_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.last_token, self.last_token,
self.clock.time_msec() self.clock.time_msec()
) )
@ -334,6 +336,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since) self.failing_since)
else: else:
if not self.failing_since: if not self.failing_since:
@ -341,6 +344,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since self.failing_since
) )
@ -358,6 +362,7 @@ class Pusher(object):
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.last_token self.last_token
) )
@ -365,6 +370,7 @@ class Pusher(object):
self.store.update_pusher_failing_since( self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name,
self.failing_since self.failing_since
) )
else: else:

View File

@ -1,3 +1,17 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
@ -112,7 +126,25 @@ def make_base_prepend_override_rules():
def make_base_append_override_rules(): def make_base_append_override_rules():
return [ return [
{ {
'rule_id': 'global/override/.m.rule.call', 'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
}
],
'actions': [
'dont_notify',
]
}
]
def make_base_append_underride_rules(user):
return [
{
'rule_id': 'global/underride/.m.rule.call',
'conditions': [ 'conditions': [
{ {
'kind': 'event_match', 'kind': 'event_match',
@ -131,19 +163,6 @@ def make_base_append_override_rules():
} }
] ]
}, },
{
'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
}
],
'actions': [
'dont_notify',
]
},
{ {
'rule_id': 'global/override/.m.rule.contains_display_name', 'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [ 'conditions': [
@ -162,7 +181,7 @@ def make_base_append_override_rules():
] ]
}, },
{ {
'rule_id': 'global/override/.m.rule.room_one_to_one', 'rule_id': 'global/underride/.m.rule.room_one_to_one',
'conditions': [ 'conditions': [
{ {
'kind': 'room_member_count', 'kind': 'room_member_count',
@ -179,12 +198,7 @@ def make_base_append_override_rules():
'value': False 'value': False
} }
] ]
} },
]
def make_base_append_underride_rules(user):
return [
{ {
'rule_id': 'global/underride/.m.rule.invite_for_me', 'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [ 'conditions': [

View File

@ -19,10 +19,7 @@ from twisted.internet import defer
from httppusher import HttpPusher from httppusher import HttpPusher
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from syutil.jsonutil import encode_canonical_json
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -52,12 +49,10 @@ class PusherPool:
@defer.inlineCallbacks @defer.inlineCallbacks
def start(self): def start(self):
pushers = yield self.store.get_all_pushers() pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers) self._start_pushers(pushers)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id, def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data): app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it # we try to create the pusher just to validate the config: it
# will then get pulled out of the database, # will then get pulled out of the database,
@ -71,7 +66,7 @@ class PusherPool:
"app_display_name": app_display_name, "app_display_name": app_display_name,
"device_display_name": device_display_name, "device_display_name": device_display_name,
"pushkey": pushkey, "pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(), "ts": self.hs.get_clock().time_msec(),
"lang": lang, "lang": lang,
"data": data, "data": data,
"last_token": None, "last_token": None,
@ -79,17 +74,50 @@ class PusherPool:
"failing_since": None "failing_since": None
}) })
yield self._add_pusher_to_store( yield self._add_pusher_to_store(
user_name, profile_tag, kind, app_id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, app_display_name, device_display_name,
pushkey, lang, data pushkey, lang, data
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id, def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
app_display_name, device_display_name, not_user_id):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey
)
for p in to_remove:
if p['user_name'] != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
app_id, pushkey, p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s except access token %s",
user_id, not_access_token_id
)
for p in all:
if (
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']
)
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, lang, data): pushkey, lang, data):
yield self.store.add_pusher( yield self.store.add_pusher(
user_name=user_name, user_name=user_name,
access_token=access_token,
profile_tag=profile_tag, profile_tag=profile_tag,
kind=kind, kind=kind,
app_id=app_id, app_id=app_id,
@ -98,9 +126,9 @@ class PusherPool:
pushkey=pushkey, pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(), pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang, lang=lang,
data=encode_canonical_json(data).decode("UTF-8"), data=data,
) )
self._refresh_pusher((app_id, pushkey)) self._refresh_pusher(app_id, pushkey, user_name)
def _create_pusher(self, pusherdict): def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http': if pusherdict['kind'] == 'http':
@ -112,7 +140,7 @@ class PusherPool:
app_display_name=pusherdict['app_display_name'], app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'], device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'], pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'], pushkey_ts=pusherdict['ts'],
data=pusherdict['data'], data=pusherdict['data'],
last_token=pusherdict['last_token'], last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'], last_success=pusherdict['last_success'],
@ -125,30 +153,48 @@ class PusherPool:
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey): def _refresh_pusher(self, app_id, pushkey, user_name):
p = yield self.store.get_pushers_by_app_id_and_pushkey( resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey app_id, pushkey
) )
p['data'] = json.loads(p['data'])
self._start_pushers([p]) p = None
for r in resultlist:
if r['user_name'] == user_name:
p = r
if p:
self._start_pushers([p])
def _start_pushers(self, pushers): def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers)) logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers: for pusherdict in pushers:
p = self._create_pusher(pusherdict) try:
p = self._create_pusher(pusherdict)
except PusherConfigException:
logger.exception("Couldn't start a pusher: caught PusherConfigException")
continue
if p: if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) fullid = "%s:%s:%s" % (
pusherdict['app_id'],
pusherdict['pushkey'],
pusherdict['user_name']
)
if fullid in self.pushers: if fullid in self.pushers:
self.pushers[fullid].stop() self.pushers[fullid].stop()
self.pushers[fullid] = p self.pushers[fullid] = p
p.start() p.start()
logger.info("Started pushers")
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey): def remove_pusher(self, app_id, pushkey, user_name):
fullid = "%s:%s" % (app_id, pushkey) fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
if fullid in self.pushers: if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid) logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop() self.pushers[fullid].stop()
del self.pushers[fullid] del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) yield self.store.delete_pusher_by_app_id_pushkey_user_name(
app_id, pushkey, user_name
)

View File

@ -1,3 +1,17 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
PRIORITY_CLASS_MAP = { PRIORITY_CLASS_MAP = {
'underride': 1, 'underride': 1,
'sender': 2, 'sender': 2,

View File

@ -1,10 +1,24 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging import logging
from distutils.version import LooseVersion from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"], "syutil>=0.0.6": ["syutil>=0.0.6"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
@ -19,7 +33,7 @@ REQUIREMENTS = {
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"], "matrix_angular_sdk>=0.6.6": ["syweb>=0.6.6"],
} }
} }
@ -43,13 +57,13 @@ DEPENDENCY_LINKS = [
), ),
github_link( github_link(
project="matrix-org/syutil", project="matrix-org/syutil",
version="v0.0.3", version="v0.0.6",
egg="syutil-0.0.3", egg="syutil-0.0.6",
), ),
github_link( github_link(
project="matrix-org/matrix-angular-sdk", project="matrix-org/matrix-angular-sdk",
version="v0.6.5", version="v0.6.6",
egg="matrix_angular_sdk-0.6.5", egg="matrix_angular_sdk-0.6.6",
), ),
] ]

View File

@ -1,48 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import APP_SERVICE_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def as_path_pattern(path_regex):
"""Creates a regex compiled appservice path with the correct path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
class AppServiceRestServlet(RestServlet):
"""A base Synapse REST Servlet for the application services version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handler = hs.get_handlers().appservice_handler

View File

@ -1,99 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from base import AppServiceRestServlet, as_path_pattern
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.storage.appservice import ApplicationService
import json
import logging
logger = logging.getLogger(__name__)
class RegisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/register$")
@defer.inlineCallbacks
def on_POST(self, request):
params = _parse_json(request)
# sanity check required params
try:
as_token = params["as_token"]
as_url = params["url"]
if (not isinstance(as_token, basestring) or
not isinstance(as_url, basestring)):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(
400, "Missed required keys: as_token(str) / url(str)."
)
try:
app_service = ApplicationService(
as_token, as_url, params["namespaces"]
)
except ValueError as e:
raise SynapseError(400, e.message)
app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token
defer.returnValue((200, {
"hs_token": hs_token
}))
class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/unregister$")
def on_POST(self, request):
params = _parse_json(request)
try:
as_token = params["as_token"]
if not isinstance(as_token, basestring):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(400, "Missing required key: as_token(str)")
yield self.handler.unregister(as_token)
raise CodeMessageException(500, "Not implemented")
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError as e:
logger.warn(e)
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
UnregisterRestServlet(hs).register(http_server)

View File

@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs self.hs = hs
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory() self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth() self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore() self.txns = HttpTransactionStore()

View File

@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
user, _ = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet):
and 'kind' in content and and 'kind' in content and
content['kind'] is None): content['kind'] is None):
yield pusher_pool.remove_pusher( yield pusher_pool.remove_pusher(
content['app_id'], content['pushkey'] content['app_id'], content['pushkey'], user_name=user.to_string()
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
append = False
if 'append' in content:
append = content['append']
if not append:
yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content['app_id'],
pushkey=content['pushkey'],
not_user_id=user.to_string()
)
try: try:
yield pusher_pool.add_pusher( yield pusher_pool.add_pusher(
user_name=user.to_string(), user_name=user.to_string(),
access_token=client.token_id,
profile_tag=content['profile_tag'], profile_tag=content['profile_tag'],
kind=content['kind'], kind=content['kind'],
app_id=content['app_id'], app_id=content['app_id'],

View File

@ -15,7 +15,10 @@
from . import ( from . import (
sync, sync,
filter filter,
account,
register,
auth
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource):
def register_servlets(client_resource, hs): def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource) sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource) filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)

View File

@ -17,9 +17,11 @@
""" """
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.api.errors import SynapseError
import re import re
import logging import logging
import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,3 +38,23 @@ def client_v2_pattern(path_regex):
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)
def parse_request_allow_empty(request):
content = request.content.read()
if content is None or content == '':
return None
try:
return simplejson.loads(content)
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")
def parse_json_dict_from_request(request):
try:
content = simplejson.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except simplejson.JSONDecodeError:
raise SynapseError(400, "Content not JSON.")

View File

@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet
from synapse.util.async import run_on_reactor
from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
logger = logging.getLogger(__name__)
class PasswordRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/password")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
], body)
if not authed:
defer.returnValue((401, result))
user_id = None
if LoginType.PASSWORD in result:
# if using password, they should also be logged in
auth_user, client = yield self.auth.get_user_by_req(request)
if auth_user.to_string() != result[LoginType.PASSWORD]:
raise LoginError(400, "", Codes.UNKNOWN)
user_id = auth_user.to_string()
elif LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address']
)
if not threepid_user_id:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
user_id = threepid_user_id
else:
logger.error("Auth succeeded but no known type!", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
if 'new_password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
yield self.login_handler.set_password(
user_id, new_password, None
)
defer.returnValue((200, {}))
def on_OPTIONS(self, _):
return 200, {}
class ThreepidRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/3pid")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
yield run_on_reactor()
auth_user, _ = yield self.auth.get_user_by_req(request)
threepids = yield self.hs.get_datastore().user_get_threepids(
auth_user.to_string()
)
defer.returnValue((200, {'threepids': threepids}))
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_dict_from_request(request)
if 'threePidCreds' not in body:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
threePidCreds = body['threePidCreds']
auth_user, client = yield self.auth.get_user_by_req(request)
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
if not threepid:
raise SynapseError(
400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
)
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid(
auth_user.to_string(),
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if 'bind' in body and body['bind']:
logger.debug(
"Binding emails %s to %s",
threepid, auth_user.to_string()
)
yield self.identity_handler.bind_threepid(
threePidCreds, auth_user.to_string()
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
PasswordRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)

View File

@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
RECAPTCHA_TEMPLATE = """
<html>
<head>
<title>Authentication</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<script src="https://www.google.com/recaptcha/api.js"
async defer></script>
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
function captchaDone() {
$('#registrationForm').submit();
}
</script>
</head>
<body>
<form id="registrationForm" method="post" action="%(myurl)s">
<div>
<p>
Hello! We need to prevent computer programs and other automated
things from creating accounts on this server.
</p>
<p>
Please verify that you're not a robot.
</p>
<input type="hidden" name="session" value="%(session)s" />
<div class="g-recaptcha"
data-sitekey="%(sitekey)s"
data-callback="captchaDone">
</div>
<noscript>
<input type="submit" value="All Done" />
</noscript>
</div>
</div>
</form>
</body>
</html>
"""
SUCCESS_TEMPLATE = """
<html>
<head>
<title>Success!</title>
<meta name='viewport' content='width=device-width, initial-scale=1,
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
<script>
if (window.onAuthDone != undefined) {
window.onAuthDone();
}
</script>
</head>
<body>
<div>
<p>Thank you</p>
<p>You may now close this window and return to the application</p>
</div>
</body>
</html>
"""
class AuthRestServlet(RestServlet):
"""
Handles Client / Server API authentication in any situations where it
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
@defer.inlineCallbacks
def on_GET(self, request, stagetype):
yield
if stagetype == LoginType.RECAPTCHA:
if ('session' not in request.args or
len(request.args['session']) == 0):
raise SynapseError(400, "No session supplied")
session = request.args["session"][0]
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.finish()
defer.returnValue(None)
else:
raise SynapseError(404, "Unknown auth stage type")
@defer.inlineCallbacks
def on_POST(self, request, stagetype):
yield
if stagetype == "m.login.recaptcha":
if ('g-recaptcha-response' not in request.args or
len(request.args['g-recaptcha-response'])) == 0:
raise SynapseError(400, "No captcha response supplied")
if ('session' not in request.args or
len(request.args['session'])) == 0:
raise SynapseError(400, "No session supplied")
session = request.args['session'][0]
authdict = {
'response': request.args['g-recaptcha-response'][0],
'session': session,
}
success = yield self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA,
authdict,
self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = RECAPTCHA_TEMPLATE % {
'session': session,
'myurl': "%s/auth/%s/fallback/web" % (
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
request.finish()
defer.returnValue(None)
else:
raise SynapseError(404, "Unknown auth stage type")
def on_OPTIONS(self, _):
return 200, {}
def register_servlets(hs, http_server):
AuthRestServlet(hs).register(http_server)

View File

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty
import logging
import hmac
from hashlib import sha1
from synapse.util.async import run_on_reactor
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
# because the timing attack is so obscured by all the other code here it's
# unlikely to make much difference
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
compare_digest = lambda a, b: a == b
logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register")
def __init__(self, hs):
super(RegisterRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_request_allow_empty(request)
if 'password' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
if 'username' in body:
desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False
is_application_server = False
service = None
if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request)
if self.hs.config.enable_registration_captcha:
flows = [
[LoginType.RECAPTCHA],
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
]
else:
flows = [
[LoginType.DUMMY],
[LoginType.EMAIL_IDENTITY]
]
if service:
is_application_server = True
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
)
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password']
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
yield self.login_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
threepid['validated_at'],
)
if 'bind_email' in params and params['bind_email']:
logger.info("bind_email specified: binding")
emailThreepid = result[LoginType.EMAIL_IDENTITY]
threepid_creds = emailThreepid['threepid_creds']
logger.debug("Binding emails %s to %s" % (
emailThreepid, user_id
))
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
else:
logger.info("bind_email not specified: not binding email")
result = {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
def _check_shared_secret_auth(self, username, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
user = username.encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(mac)
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret,
msg=user,
digestmod=sha1,
).hexdigest()
if compare_digest(want_mac, got_mac):
return True
else:
raise SynapseError(
403, "HMAC incorrect",
)
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)

View File

@ -15,7 +15,9 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.http.servlet import RestServlet from synapse.http.servlet import (
RestServlet, parse_string, parse_integer, parse_boolean
)
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events.utils import ( from synapse.events.utils import (
@ -87,20 +89,20 @@ class SyncRestServlet(RestServlet):
def on_GET(self, request): def on_GET(self, request):
user, client = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
timeout = self.parse_integer(request, "timeout", default=0) timeout = parse_integer(request, "timeout", default=0)
limit = self.parse_integer(request, "limit", required=True) limit = parse_integer(request, "limit", required=True)
gap = self.parse_boolean(request, "gap", default=True) gap = parse_boolean(request, "gap", default=True)
sort = self.parse_string( sort = parse_string(
request, "sort", default="timeline,asc", request, "sort", default="timeline,asc",
allowed_values=self.ALLOWED_SORT allowed_values=self.ALLOWED_SORT
) )
since = self.parse_string(request, "since") since = parse_string(request, "since")
set_presence = self.parse_string( set_presence = parse_string(
request, "set_presence", default="online", request, "set_presence", default="online",
allowed_values=self.ALLOWED_PRESENCE allowed_values=self.ALLOWED_PRESENCE
) )
backfill = self.parse_boolean(request, "backfill", default=False) backfill = parse_boolean(request, "backfill", default=False)
filter_id = self.parse_string(request, "filter", default=None) filter_id = parse_string(request, "filter", default=None)
logger.info( logger.info(
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r," "/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -12,18 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import register
from synapse.http.server import JsonResource from twisted.web.resource import Resource
from .local_key_resource import LocalKey
from .remote_key_resource import RemoteKey
class AppServiceRestResource(JsonResource): class KeyApiV2Resource(Resource):
"""A resource for version 1 of the matrix application service API."""
def __init__(self, hs): def __init__(self, hs):
JsonResource.__init__(self, hs) Resource.__init__(self)
self.register_servlets(self, hs) self.putChild("server", LocalKey(hs))
self.putChild("query", RemoteKey(hs))
@staticmethod
def register_servlets(appservice_resource, hs):
register.register_servlets(hs, appservice_resource)

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from synapse.http.server import respond_with_json_bytes
from syutil.crypto.jsonsign import sign_json
from syutil.base64util import encode_base64
from syutil.jsonutil import encode_canonical_json
from hashlib import sha256
from OpenSSL import crypto
import logging
logger = logging.getLogger(__name__)
class LocalKey(Resource):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
GET /_matrix/key/v2/server/a.key.id HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"valid_until_ts": # integer posix timestamp when this result expires.
"server_name": "this.server.example.com"
"verify_keys": {
"algorithm:version": {
"key": # base64 encoded NACL verification key.
}
},
"old_verify_keys": {
"algorithm:version": {
"expired_ts": # integer posix timestamp when the key expired.
"key": # base64 encoded NACL verification key.
}
}
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
"signatures": {
"this.server.example.com": {
"algorithm:version": # NACL signature for this server
}
}
}
"""
isLeaf = True
def __init__(self, hs):
self.version_string = hs.version_string
self.config = hs.config
self.clock = hs.clock
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
def update_response_body(self, time_now_msec):
refresh_interval = self.config.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
self.response_body = encode_canonical_json(self.response_json_object())
def response_json_object(self):
verify_keys = {}
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
verify_keys[key_id] = {
u"key": encode_base64(verify_key_bytes)
}
old_verify_keys = {}
for key in self.config.old_signing_keys:
key_id = "%s:%s" % (key.alg, key.version)
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
u"key": encode_base64(verify_key_bytes),
u"expired_ts": key.expired,
}
x509_certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1,
self.config.tls_certificate
)
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
json_object = {
u"valid_until_ts": self.valid_until_ts,
u"server_name": self.config.server_name,
u"verify_keys": verify_keys,
u"old_verify_keys": old_verify_keys,
u"tls_fingerprints": [{
u"sha256": encode_base64(sha256_fingerprint),
}]
}
for key in self.config.signing_key:
json_object = sign_json(
json_object,
self.config.server_name,
key,
)
return json_object
def render_GET(self, request):
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
return respond_with_json_bytes(
request, 200, self.response_body,
version_string=self.version_string
)

View File

@ -0,0 +1,242 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer
from synapse.api.errors import SynapseError, Codes
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from io import BytesIO
import json
import logging
logger = logging.getLogger(__name__)
class RemoteKey(Resource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
JSON signed by both the remote server and by this server.
Supports individual GET APIs and a bulk query POST API.
Requsts:
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
POST /_matrix/v2/query HTTP/1.1
Content-Type: application/json
{
"server_keys": {
"remote.server.example.com": {
"a.key.id": {
"minimum_valid_until_ts": 1234567890123
}
}
}
}
Response:
HTTP/1.1 200 OK
Content-Type: application/json
{
"server_keys": [
{
"server_name": "remote.server.example.com"
"valid_until_ts": # posix timestamp
"verify_keys": {
"a.key.id": { # The identifier for a key.
key: "" # base64 encoded verification key.
}
}
"old_verify_keys": {
"an.old.key.id": { # The identifier for an old key.
key: "", # base64 encoded key
"expired_ts": 0, # when the key stop being used.
}
}
"tls_fingerprints": [
{ "sha256": # fingerprint }
]
"signatures": {
"remote.server.example.com": {...}
"this.server.example.com": {...}
}
}
]
}
"""
isLeaf = True
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self.async_render_GET(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
query = {server: {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(
request, "minimum_valid_until_ts"
)
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
query = {server: {key_id: arguments}}
else:
raise SynapseError(
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
)
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
def render_POST(self, request):
self.async_render_POST(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise ValueError()
except ValueError:
raise SynapseError(
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
)
query = content["server_keys"]
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
@defer.inlineCallbacks
def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
cached = yield self.store.get_server_keys_json(store_queries)
json_results = set()
time_now_ms = self.clock.time_msec()
cache_misses = dict()
for (server_name, key_id, from_server), results in cached.items():
results = [
(result["ts_added_ms"], result) for result in results
]
if not results and key_id is not None:
cache_misses.setdefault(server_name, set()).add(key_id)
continue
if key_id is not None:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
miss = False
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
"Cached response for %r/%r is older than requested"
": valid_until (%r) < minimum_valid_until (%r)",
server_name, key_id,
ts_valid_until_ms, req_valid_until
)
miss = True
else:
logger.debug(
"Cached response for %r/%r is newer than requested"
": valid_until (%r) >= minimum_valid_until (%r)",
server_name, key_id,
ts_valid_until_ms, req_valid_until
)
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
logger.debug(
"Cached response for %r/%r is too old"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
server_name, key_id,
ts_added_ms, ts_valid_until_ms, time_now_ms
)
# We more than half way through the lifetime of the
# response. We should fetch a fresh copy.
miss = True
else:
logger.debug(
"Cached response for %r/%r is still valid"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
server_name, key_id,
ts_added_ms, ts_valid_until_ms, time_now_ms
)
if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
json_results.add(bytes(most_recent_result["key_json"]))
else:
for ts_added, result in results:
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items():
try:
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
except:
logger.exception("Failed to get key for %r", server_name)
pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
else:
result_io = BytesIO()
result_io.write(b"{\"server_keys\":")
sep = b"["
for json_bytes in json_results:
result_io.write(sep)
result_io.write(json_bytes)
sep = b","
if sep == b"[":
result_io.write(sep)
result_io.write(b"]}")
respond_with_json_bytes(
request, 200, result_io.getvalue(),
version_string=self.version_string
)

Some files were not shown because too many files have changed in this diff Show More