Merge branch 'release-v0.10.0'

This commit is contained in:
Erik Johnston 2015-09-03 09:54:08 +01:00
commit efeeff29f6
110 changed files with 5473 additions and 1915 deletions

View File

@ -42,3 +42,6 @@ Ivan Shapovalov <intelfx100 at gmail.com>
Eric Myhre <hash at exultant.us> Eric Myhre <hash at exultant.us>
* Fix bug where ``media_store_path`` config option was ignored by v0 content * Fix bug where ``media_store_path`` config option was ignored by v0 content
repository API. repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and logins.

View File

@ -1,3 +1,130 @@
Changes in synapse v0.10.0 (2015-09-03)
=======================================
No change from release candidate.
Changes in synapse v0.10.0-rc6 (2015-09-02)
===========================================
* Remove some of the old database upgrade scripts.
* Fix database port script to work with newly created sqlite databases.
Changes in synapse v0.10.0-rc5 (2015-08-27)
===========================================
* Fix bug that broke downloading files with ascii filenames across federation.
Changes in synapse v0.10.0-rc4 (2015-08-27)
===========================================
* Allow UTF-8 filenames for upload. (PR #259)
Changes in synapse v0.10.0-rc3 (2015-08-25)
===========================================
* Add ``--keys-directory`` config option to specify where files such as
certs and signing keys should be stored in, when using ``--generate-config``
or ``--generate-keys``. (PR #250)
* Allow ``--config-path`` to specify a directory, causing synapse to use all
\*.yaml files in the directory as config files. (PR #249)
* Add ``web_client_location`` config option to specify static files to be
hosted by synapse under ``/_matrix/client``. (PR #245)
* Add helper utility to synapse to read and parse the config files and extract
the value of a given key. For example::
$ python -m synapse.config read server_name -c homeserver.yaml
localhost
(PR #246)
Changes in synapse v0.10.0-rc2 (2015-08-24)
===========================================
* Fix bug where we incorrectly populated the ``event_forward_extremities``
table, resulting in problems joining large remote rooms (e.g.
``#matrix:matrix.org``)
* Reduce the number of times we wake up pushers by not listening for presence
or typing events, reducing the CPU cost of each pusher.
Changes in synapse v0.10.0-rc1 (2015-08-21)
===========================================
Also see v0.9.4-rc1 changelog, which has been amalgamated into this release.
General:
* Upgrade to Twisted 15 (PR #173)
* Add support for serving and fetching encryption keys over federation.
(PR #208)
* Add support for logging in with email address (PR #234)
* Add support for new ``m.room.canonical_alias`` event. (PR #233)
* Change synapse to treat user IDs case insensitively during registration and
login. (If two users already exist with case insensitive matching user ids,
synapse will continue to require them to specify their user ids exactly.)
* Error if a user tries to register with an email already in use. (PR #211)
* Add extra and improve existing caches (PR #212, #219, #226, #228)
* Batch various storage request (PR #226, #228)
* Fix bug where we didn't correctly log the entity that triggered the request
if the request came in via an application service (PR #230)
* Fix bug where we needlessly regenerated the full list of rooms an AS is
interested in. (PR #232)
* Add support for AS's to use v2_alpha registration API (PR #210)
Configuration:
* Add ``--generate-keys`` that will generate any missing cert and key files in
the configuration files. This is equivalent to running ``--generate-config``
on an existing configuration file. (PR #220)
* ``--generate-config`` now no longer requires a ``--server-name`` parameter
when used on existing configuration files. (PR #220)
* Add ``--print-pidfile`` flag that controls the printing of the pid to stdout
of the demonised process. (PR #213)
Media Repository:
* Fix bug where we picked a lower resolution image than requested. (PR #205)
* Add support for specifying if a the media repository should dynamically
thumbnail images or not. (PR #206)
Metrics:
* Add statistics from the reactor to the metrics API. (PR #224, #225)
Demo Homeservers:
* Fix starting the demo homeservers without rate-limiting enabled. (PR #182)
* Fix enabling registration on demo homeservers (PR #223)
Changes in synapse v0.9.4-rc1 (2015-07-21)
==========================================
General:
* Add basic implementation of receipts. (SPEC-99)
* Add support for configuration presets in room creation API. (PR #203)
* Add auth event that limits the visibility of history for new users.
(SPEC-134)
* Add SAML2 login/registration support. (PR #201. Thanks Muthu Subramanian!)
* Add client side key management APIs for end to end encryption. (PR #198)
* Change power level semantics so that you cannot kick, ban or change power
levels of users that have equal or greater power level than you. (SYN-192)
* Improve performance by bulk inserting events where possible. (PR #193)
* Improve performance by bulk verifying signatures where possible. (PR #194)
Configuration:
* Add support for including TLS certificate chains.
Media Repository:
* Add Content-Disposition headers to content repository responses. (SYN-150)
Changes in synapse v0.9.3 (2015-07-01) Changes in synapse v0.9.3 (2015-07-01)
====================================== ======================================

View File

@ -94,6 +94,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements: System requirements:
- POSIX-compliant system (tested on Linux & OS X) - POSIX-compliant system (tested on Linux & OS X)
- Python 2.7 - Python 2.7
- At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the C. So before we can install synapse itself we need a working C compiler and the
@ -101,25 +102,26 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian:: Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux:: Installing prerequisites on ArchLinux::
$ sudo pacman -S base-devel python2 python-pip \ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3 python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
$ xcode-select --install xcode-select --install
$ sudo pip install virtualenv sudo easy_install pip
sudo pip install virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
$ virtualenv -p python2.7 ~/.synapse virtualenv -p python2.7 ~/.synapse
$ source ~/.synapse/bin/activate source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. Feel free to pick a different directory environment under ``~/.synapse``. Feel free to pick a different directory
@ -132,8 +134,8 @@ above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
$ cd ~/.synapse cd ~/.synapse
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
@ -173,12 +175,12 @@ traditionally used for convenience and simplicity.
The advantages of Postgres include: The advantages of Postgres include:
* significant performance improvements due to the superior threading and * significant performance improvements due to the superior threading and
caching model, smarter query optimiser caching model, smarter query optimiser
* allowing the DB to be run on separate hardware * allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse * allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in pointing at the same DB master, as well as enabling DB replication in
synapse itself. synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite. may have a few regressions relative to SQLite.
@ -189,12 +191,12 @@ For information on how to install and use PostgreSQL, please see
Running Synapse Running Synapse
=============== ===============
To actually run your new homeserver, pick a working directory for Synapse to run To actually run your new homeserver, pick a working directory for Synapse to
(e.g. ``~/.synapse``), and:: run (e.g. ``~/.synapse``), and::
$ cd ~/.synapse cd ~/.synapse
$ source ./bin/activate source ./bin/activate
$ synctl start synctl start
Platform Specific Instructions Platform Specific Instructions
============================== ==============================
@ -212,12 +214,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 ):: pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
$ sudo pip2.7 install --upgrade pip sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install You also may need to explicitly specify python 2.7 again during the install
request:: request::
$ pip2.7 install --process-dependency-links \ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class: If you encounter an error with lib bcrypt causing an Wrong ELF Class:
@ -225,13 +227,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if compile it under the right architecture. (This should not be needed if
installing under virtualenv):: installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt sudo pip2.7 install py-bcrypt
During setup of Synapse you need to call python2.7 directly again:: During setup of Synapse you need to call python2.7 directly again::
$ cd ~/.synapse cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
@ -242,18 +244,20 @@ Windows Install
--------------- ---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages: Synapse can be installed on Cygwin. It requires the following Cygwin packages:
- gcc - gcc
- git - git
- libffi-devel - libffi-devel
- openssl (and openssl-devel, python-openssl) - openssl (and openssl-devel, python-openssl)
- python - python
- python-setuptools - python-setuptools
The content repository requires additional packages and will be unable to process The content repository requires additional packages and will be unable to process
uploads without them: uploads without them:
- libjpeg8
- libjpeg8-devel - libjpeg8
- zlib - libjpeg8-devel
- zlib
If you choose to install Synapse without these packages, you will need to reinstall If you choose to install Synapse without these packages, you will need to reinstall
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install ``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
pillow --user`` pillow --user``
@ -279,22 +283,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it:: may need to manually upgrade it::
$ sudo pip install --upgrade pip sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it refuse to run until you remove the temporary installation directory it
created. To reset the installation:: created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.:: failing, e.g.::
$ pip install twisted pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments. will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running Troubleshooting Running
@ -310,10 +314,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl):: (https://github.com/pyca/pynacl)::
$ # Install from PyPI # Install from PyPI
$ pip install --user --upgrade --force pynacl pip install --user --upgrade --force pynacl
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master # Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux ArchLinux
~~~~~~~~~ ~~~~~~~~~
@ -321,7 +326,7 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1', If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as:: you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
...or by editing synctl with the correct python executable. ...or by editing synctl with the correct python executable.
@ -331,16 +336,16 @@ Synapse Development
To check out a synapse for development, clone the git repo into a working To check out a synapse for development, clone the git repo into a working
directory of your choice:: directory of your choice::
$ git clone https://github.com/matrix-org/synapse.git git clone https://github.com/matrix-org/synapse.git
$ cd synapse cd synapse
Synapse has a number of external dependencies, that are easiest Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv:: to install using pip and a virtualenv::
$ virtualenv env virtualenv env
$ source env/bin/activate source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock pip install setuptools_trial mock
This will run a process of downloading and installing all the needed This will run a process of downloading and installing all the needed
dependencies into a virtual env. dependencies into a virtual env.
@ -348,7 +353,7 @@ dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
$ python setup.py test python setup.py test
This should end with a 'PASSED' result:: This should end with a 'PASSED' result::
@ -360,14 +365,11 @@ This should end with a 'PASSED' result::
Upgrading an existing Synapse Upgrading an existing Synapse
============================= =============================
IMPORTANT: Before upgrading an existing synapse to a new version, please The instructions for upgrading synapse are in `UPGRADE.rst`_.
refer to UPGRADE.rst for any additional instructions. Please check these instructions as upgrading may require extra steps for some
versions of synapse.
Otherwise, simply re-install the new codebase over the current one - e.g.
by ``pip install --process-dependency-links
https://github.com/matrix-org/synapse/tarball/master``
if using pip, or by ``git pull`` if running off a git working copy.
.. _UPGRADE.rst: UPGRADE.rst
Setting up Federation Setting up Federation
===================== =====================
@ -389,11 +391,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter:: --server-name parameter::
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process. Alternatively, you can run ``synctl start`` to guide you through the process.
@ -410,11 +412,11 @@ record would then look something like::
At this point, you should then run the homeserver with the hostname of this At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have:: SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \ --server-name YOURDOMAIN \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to
@ -428,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at ``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run:: http://localhost:8080. Simply run::
$ demo/start.sh demo/start.sh
This is mainly useful just for development purposes. This is mainly useful just for development purposes.
@ -502,10 +504,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and Before building internal API documentation install sphinx and
sphinxcontrib-napoleon:: sphinxcontrib-napoleon::
$ pip install sphinx pip install sphinx
$ pip install sphinxcontrib-napoleon pip install sphinxcontrib-napoleon
Building internal API documentation:: Building internal API documentation::
$ python setup.py build_sphinx python setup.py build_sphinx

View File

@ -1,3 +1,36 @@
Upgrading Synapse
=================
Before upgrading check if any special steps are required to upgrade from the
what you currently have installed to current version of synapse. The extra
instructions that may be required are listed later in this document.
If synapse was installed in a virtualenv then active that virtualenv before
upgrading. If synapse is installed in a virtualenv in ``~/.synapse/`` then run:
.. code:: bash
source ~/.synapse/bin/activate
If synapse was installed using pip then upgrade to the latest version by
running:
.. code:: bash
pip install --upgrade --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
If synapse was installed using git then upgrade to the latest version by
running:
.. code:: bash
# Pull the latest version of the master branch.
git pull
# Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.9.0 Upgrading to v0.9.0
=================== ===================

View File

@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
exit 1 exit 1
fi fi
find "$DIR" -name "*.log" -delete for port in 8080 8081 8082; do
find "$DIR" -name "*.db" -delete rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
rm -rf $DIR/etc rm -rf $DIR/etc

View File

@ -8,14 +8,6 @@ cd "$DIR/.."
mkdir -p demo/etc mkdir -p demo/etc
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
export PYTHONPATH=$(readlink -f $(pwd)) export PYTHONPATH=$(readlink -f $(pwd))
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
#rm $DIR/etc/$port.config #rm $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--generate-config \ --generate-config \
--enable_registration \
-H "localhost:$https_port" \ -H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \ --config-path "$DIR/etc/$port.config" \
-D \ -D \

View File

@ -55,9 +55,8 @@ Porting from SQLite
Overview Overview
~~~~~~~~ ~~~~~~~~
The script ``port_from_sqlite_to_postgres.py`` allows porting an existing The script ``synapse_port_db`` allows porting an existing synapse server
synapse server backed by SQLite to using PostgreSQL. This is done in as a two backed by SQLite to using PostgreSQL. This is done in as a two phase process:
phase process:
1. Copy the existing SQLite database to a separate location (while the server 1. Copy the existing SQLite database to a separate location (while the server
is down) and running the port script against that offline database. is down) and running the port script against that offline database.
@ -86,8 +85,7 @@ Assuming your new config file (as described in the section *Synapse config*)
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
``homeserver.db.snapshot`` then simply run:: ``homeserver.db.snapshot`` then simply run::
python scripts/port_from_sqlite_to_postgres.py \ synapse_port_db --sqlite-database homeserver.db.snapshot \
--sqlite-database homeserver.db.snapshot \
--postgres-config homeserver-postgres.yaml --postgres-config homeserver-postgres.yaml
The flag ``--curses`` displays a coloured curses progress UI. The flag ``--curses`` displays a coloured curses progress UI.
@ -100,8 +98,7 @@ To complete the conversion shut down the synapse server and run the port
script one last time, e.g. if the SQLite database is at ``homeserver.db`` script one last time, e.g. if the SQLite database is at ``homeserver.db``
run:: run::
python scripts/port_from_sqlite_to_postgres.py \ synapse_port_db --sqlite-database homeserver.db \
--sqlite-database homeserver.db \
--postgres-config database_config.yaml --postgres-config database_config.yaml
Once that has completed, change the synapse config to point at the PostgreSQL Once that has completed, change the synapse config to point at the PostgreSQL

View File

@ -1,21 +0,0 @@
#!/bin/bash
# This is will prepare a synapse database for running with v0.0.1 of synapse.
# It will store all the user information, but will *delete* all messages and
# room data.
set -e
cp "$1" "$1.bak"
DUMP=$(sqlite3 "$1" << 'EOF'
.dump users
.dump access_tokens
.dump presence
.dump profiles
EOF
)
rm "$1"
sqlite3 "$1" <<< "$DUMP"

View File

@ -1,21 +0,0 @@
#!/bin/bash
# This is will prepare a synapse database for running with v0.5.0 of synapse.
# It will store all the user information, but will *delete* all messages and
# room data.
set -e
cp "$1" "$1.bak"
DUMP=$(sqlite3 "$1" << 'EOF'
.dump users
.dump access_tokens
.dump presence
.dump profiles
EOF
)
rm "$1"
sqlite3 "$1" <<< "$DUMP"

View File

@ -412,14 +412,17 @@ class Porter(object):
self._convert_rows("sent_transactions", headers, rows) self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows) inserted_rows = len(rows)
max_inserted_rowid = max(r[0] for r in rows) if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn): def insert(txn):
self.postgres_store.insert_many_txn( self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows txn, "sent_transactions", headers[1:], rows
) )
yield self.postgres_store.execute(insert) yield self.postgres_store.execute(insert)
else:
max_inserted_rowid = 0
def get_start_id(txn): def get_start_id(txn):
txn.execute( txn.execute(

View File

@ -1,331 +0,0 @@
#!/usr/bin/env python
from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore
from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import (
verify_signed_json, SignatureVerifyException,
)
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
# import dns.resolver
import hashlib
import httplib
import json
import sqlite3
import syutil
import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
PRAGMA user_version = 10;
"""
class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows):
events = []
for row in rows:
d = dict(row)
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
if "origin_server_ts" not in d:
d["origin_server_ts"] = d.pop("ts", 0)
else:
d.pop("ts", 0)
d.pop("prev_state", None)
d.update(json.loads(d.pop("unrecognized_keys")))
d["sender"] = d.pop("user_id")
d["content"] = json.loads(d["content"])
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
outlier = d.pop("outlier", False)
# d.pop("membership", None)
d.pop("state_hash", None)
d.pop("replaces_state", None)
b = EventBuilder(d)
b.internal_metadata.outlier = outlier
events.append(b)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
hashes = dict(ev.auth_events)
for e_id, hash in ev.prev_events:
if e_id in hashes and not hash:
hash.update(hashes[e_id])
#
# if hasattr(ev, "state_key"):
# ev.prev_state = [
# (e_id, h)
# for e_id, h, is_state in prevs
# if is_state == 1
# ]
return [e.build() for e in events]
store = Store()
# def get_key(server_name):
# print "Getting keys for: %s" % (server_name,)
# targets = []
# if ":" in server_name:
# target, port = server_name.split(":")
# targets.append((target, int(port)))
# try:
# answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
# for srv in answers:
# targets.append((srv.target, srv.port))
# except dns.resolver.NXDOMAIN:
# targets.append((server_name, 8448))
# except:
# print "Failed to lookup keys for %s" % (server_name,)
# return {}
#
# for target, port in targets:
# url = "https://%s:%i/_matrix/key/v1" % (target, port)
# try:
# keys = json.load(urllib2.urlopen(url, timeout=2))
# verify_keys = {}
# for key_id, key_base64 in keys["verify_keys"].items():
# verify_key = decode_verify_key_bytes(
# key_id, decode_base64(key_base64)
# )
# verify_signed_json(keys, server_name, verify_key)
# verify_keys[key_id] = verify_key
# print "Got keys for: %s" % (server_name,)
# return verify_keys
# except urllib2.URLError:
# pass
# except urllib2.HTTPError:
# pass
# except httplib.HTTPException:
# pass
#
# print "Failed to get keys for %s" % (server_name,)
# return {}
def reinsert_events(cursor, server_name, signing_key):
print "Running delta: v10"
cursor.executescript(delta_sql)
cursor.execute(
"SELECT * FROM events ORDER BY rowid ASC"
)
print "Getting events..."
rows = store.cursor_to_dict(cursor)
events = store._generate_event_json(cursor, rows)
print "Got events from DB."
algorithms = {
"sha256": hashlib.sha256,
}
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
i = 0
N = len(events)
for event in events:
if i % 100 == 0:
print "Processed: %d/%d events" % (i,N,)
i += 1
# for alg_name in event.hashes:
# if check_event_content_hash(event, algorithms[alg_name]):
# pass
# else:
# pass
# print "FAIL content hash %s %s" % (alg_name, event.event_id, )
have_own_correctly_signed = False
for host, sigs in event.signatures.items():
pruned = prune_event(event)
for key_id in sigs:
if host not in server_keys:
server_keys[host] = {} # get_key(host)
if key_id in server_keys[host]:
try:
verify_signed_json(
pruned.get_pdu_json(),
host,
server_keys[host][key_id]
)
if host == server_name:
have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor()
# Do other deltas:
cursor.execute("PRAGMA user_version")
row = cursor.fetchone()
if row and row[0]:
user_version = row[0]
# Run every version since after the current version.
for v in range(user_version + 1, 10):
print "Running delta: %d" % (v,)
sql_script = read_schema("delta/v%d" % (v,))
cursor.executescript(sql_script)
reinsert_events(cursor, server_name, signing_key)
conn.commit()
print "Success!"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])

View File

@ -16,3 +16,6 @@ ignore =
docs/* docs/*
pylint.cfg pylint.cfg
tox.ini tox.ini
[flake8]
max-line-length = 90

View File

@ -48,7 +48,7 @@ setup(
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=dependencies['requirements'](include_conditional=True).keys(), install_requires=dependencies['requirements'](include_conditional=True).keys(),
setup_requires=[ setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial", "setuptools_trial",
"mock" "mock"
], ],

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.9.3" __version__ = "0.10.0"

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
AuthEventTypes = ( AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels, EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules, EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
) )
@ -44,6 +44,11 @@ class Auth(object):
def check(self, event, auth_events): def check(self, event, auth_events):
""" Checks if this event is correctly authed. """ Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns: Returns:
True if the auth checks pass. True if the auth checks pass.
""" """
@ -187,6 +192,9 @@ class Auth(object):
join_rule = JoinRules.INVITE join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events) user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default? # FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50) ban_level = self._get_named_level(auth_events, "ban", 50)
@ -258,12 +266,12 @@ class Auth(object):
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50) kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level: if user_level < kick_level or user_level <= target_level:
raise AuthError( raise AuthError(
403, "You cannot kick user %s." % target_user_id 403, "You cannot kick user %s." % target_user_id
) )
elif Membership.BAN == membership: elif Membership.BAN == membership:
if user_level < ban_level: if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban") raise AuthError(403, "You don't have permission to ban")
else: else:
raise AuthError(500, "Unknown membership %s" % membership) raise AuthError(500, "Unknown membership %s" % membership)
@ -316,7 +324,7 @@ class Auth(object):
Returns: Returns:
tuple : of UserID and device string: tuple : of UserID and device string:
User ID object of the user making the request User ID object of the user making the request
Client ID object of the client instance the user is using ClientInfo object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
@ -344,12 +352,14 @@ class Auth(object):
if not user_id: if not user_id:
raise KeyError raise KeyError
request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", "")) (UserID.from_string(user_id), ClientInfo("", ""))
) )
return return
except KeyError: except KeyError:
pass # normal users won't have this query parameter set pass # normal users won't have the user_id query parameter set.
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
@ -417,6 +427,7 @@ class Auth(object):
"Unrecognised access token.", "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender
defer.returnValue(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(
@ -518,23 +529,22 @@ class Auth(object):
# Check state_key # Check state_key
if hasattr(event, "state_key"): if hasattr(event, "state_key"):
if not event.state_key.startswith("_"): if event.state_key.startswith("@"):
if event.state_key.startswith("@"): if event.state_key != event.user_id:
if event.state_key != event.user_id: raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError( raise AuthError(
403, 403,
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True return True
@ -573,25 +583,26 @@ class Auth(object):
# Check other levels: # Check other levels:
levels_to_check = [ levels_to_check = [
("users_default", []), ("users_default", None),
("events_default", []), ("events_default", None),
("ban", []), ("state_default", None),
("redact", []), ("ban", None),
("kick", []), ("redact", None),
("invite", []), ("kick", None),
("invite", None),
] ]
old_list = current_state.content.get("users") old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()): for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append( levels_to_check.append(
(user, ["users"]) (user, "users")
) )
old_list = current_state.content.get("events") old_list = current_state.content.get("events")
new_list = event.content.get("events") new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()): for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append( levels_to_check.append(
(ev_id, ["events"]) (ev_id, "events")
) )
old_state = current_state.content old_state = current_state.content
@ -599,12 +610,10 @@ class Auth(object):
for level_to_check, dir in levels_to_check: for level_to_check, dir in levels_to_check:
old_loc = old_state old_loc = old_state
for d in dir:
old_loc = old_loc.get(d, {})
new_loc = new_state new_loc = new_state
for d in dir: if dir:
new_loc = new_loc.get(d, {}) old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc: if level_to_check in old_loc:
old_level = int(old_loc[level_to_check]) old_level = int(old_loc[level_to_check])
@ -620,6 +629,14 @@ class Auth(object):
if new_level == old_level: if new_level == old_level:
continue continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level: if old_level > user_level or new_level > user_level:
raise AuthError( raise AuthError(
403, 403,

View File

@ -75,6 +75,10 @@ class EventTypes(object):
Redaction = "m.room.redaction" Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback" Feedback = "m.room.message.feedback"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar"
# These are used for validation # These are used for validation
Message = "m.room.message" Message = "m.room.message"
Topic = "m.room.topic" Topic = "m.room.topic"
@ -85,3 +89,8 @@ class RejectedReason(object):
AUTH_ERROR = "auth_error" AUTH_ERROR = "auth_error"
REPLACED = "replaced" REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor" NOT_ANCESTOR = "not_ancestor"
class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat"

View File

@ -40,6 +40,7 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE" TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE" EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE"
class CodeMessageException(RuntimeError): class CodeMessageException(RuntimeError):

View File

@ -16,7 +16,7 @@
import sys import sys
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import check_requirements from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS
if __name__ == '__main__': if __name__ == '__main__':
check_requirements() check_requirements()
@ -97,9 +97,25 @@ class SynapseHomeServer(HomeServer):
return JsonResource(self) return JsonResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
import syweb webclient_path = self.get_config().web_client_location
syweb_path = os.path.dirname(syweb.__file__) if not webclient_path:
webclient_path = os.path.join(syweb_path, "webclient") try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to # GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678 # https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call # (It can stay enabled for the API resources: they call
@ -259,11 +275,10 @@ class SynapseHomeServer(HomeServer):
def quit_with_error(error_string): def quit_with_error(error_string):
message_lines = error_string.split("\n") message_lines = error_string.split("\n")
line_length = max([len(l) for l in message_lines]) + 2 line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
sys.stderr.write("*" * line_length + '\n') sys.stderr.write("*" * line_length + '\n')
for line in message_lines: for line in message_lines:
if line.strip(): sys.stderr.write(" %s\n" % (line.rstrip(),))
sys.stderr.write(" %s\n" % (line.strip(),))
sys.stderr.write("*" * line_length + '\n') sys.stderr.write("*" * line_length + '\n')
sys.exit(1) sys.exit(1)
@ -326,7 +341,7 @@ def get_version_string():
) )
).encode("ascii") ).encode("ascii")
except Exception as e: except Exception as e:
logger.warn("Failed to check for git repository: %s", e) logger.info("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
@ -657,7 +672,8 @@ def run(hs):
if hs.config.daemonize: if hs.config.daemonize:
print hs.config.pid_file if hs.config.print_pidfile:
print hs.config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",

View File

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if __name__ == "__main__":
import sys
from homeserver import HomeServerConfig
action = sys.argv[1]
if action == "read":
key = sys.argv[2]
config = HomeServerConfig.load_config("", sys.argv[3:])
print getattr(config, key)
sys.exit(0)
else:
sys.stderr.write("Unknown command %r\n" % (action,))
sys.exit(1)

View File

@ -131,71 +131,107 @@ class Config(object):
"-c", "--config-path", "-c", "--config-path",
action="append", action="append",
metavar="CONFIG_FILE", metavar="CONFIG_FILE",
help="Specify config file" help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
) )
config_parser.add_argument( config_parser.add_argument(
"--generate-config", "--generate-config",
action="store_true", action="store_true",
help="Generate a config file for the server name" help="Generate a config file for the server name"
) )
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly"
" specified in the config."
)
config_parser.add_argument( config_parser.add_argument(
"-H", "--server-name", "-H", "--server-name",
help="The server name to generate a config file for" help="The server name to generate a config file for"
) )
config_args, remaining_args = config_parser.parse_known_args(argv) config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
config_files = []
if config_args.config_path:
for config_path in config_args.config_path:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
if config_args.generate_config: if config_args.generate_config:
if not config_args.config_path: if not config_files:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
(config_path,) = config_files
if not os.path.exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path)
config_dir_path = os.path.dirname(config_args.config_path[0]) server_name = config_args.server_name
config_dir_path = os.path.abspath(config_dir_path) if not server_name:
print "Must specify a server_name to a generate config for."
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
(config_path,) = config_args.config_path
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
if os.path.exists(config_path):
print "Config file %r already exists" % (config_path,)
yaml_config = cls.read_config_file(config_path)
yaml_name = yaml_config["server_name"]
if server_name != yaml_name:
print (
"Config file %r has a different server_name: "
" %r != %r" % (config_path, server_name, yaml_name)
)
sys.exit(1) sys.exit(1)
config_bytes, config = obj.generate_config( if not os.path.exists(config_dir_path):
config_dir_path, server_name os.makedirs(config_dir_path)
) with open(config_path, "wb") as config_file:
config.update(yaml_config) config_bytes, config = obj.generate_config(
print "Generating any missing keys for %r" % (server_name,) config_dir_path, server_name
obj.invoke_all("generate_files", config) )
sys.exit(0) obj.invoke_all("generate_files", config)
with open(config_path, "wb") as config_file: config_file.write(config_bytes)
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print ( print (
"A config file has been generated in %s for server name" "A config file has been generated in %r for server name"
" '%s' with corresponding SSL keys and self-signed" " %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to" " certificates. Please review this file and customise it"
" your needs." " to your needs."
) % (config_path, server_name) ) % (config_path, server_name)
print ( print (
"If this server name is incorrect, you will need to regenerate" "If this server name is incorrect, you will need to"
" the SSL certificates" " regenerate the SSL certificates"
) )
sys.exit(0) sys.exit(0)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
parents=[config_parser], parents=[config_parser],
@ -206,19 +242,22 @@ class Config(object):
obj.invoke_all("add_arguments", parser) obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args) args = parser.parse_args(remaining_args)
if not config_args.config_path: if not config_files:
config_parser.error( config_parser.error(
"Must supply a config file.\nA config file can be automatically" "Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME" " generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\"" " -c CONFIG-FILE\""
) )
config_dir_path = os.path.dirname(config_args.config_path[0]) if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path) config_dir_path = os.path.abspath(config_dir_path)
specified_config = {} specified_config = {}
for config_path in config_args.config_path: for config_file in config_files:
yaml_config = cls.read_config_file(config_path) yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config) specified_config.update(yaml_config)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
@ -226,6 +265,10 @@ class Config(object):
config.pop("log_config") config.pop("log_config")
config.update(specified_config) config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config) obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args) obj.invoke_all("read_arguments", args)

View File

@ -29,10 +29,10 @@ class CaptchaConfig(Config):
## Captcha ## ## Captcha ##
# This Home Server's ReCAPTCHA public key. # This Home Server's ReCAPTCHA public key.
recaptcha_private_key: "YOUR_PUBLIC_KEY" recaptcha_private_key: "YOUR_PRIVATE_KEY"
# This Home Server's ReCAPTCHA private key. # This Home Server's ReCAPTCHA private key.
recaptcha_public_key: "YOUR_PRIVATE_KEY" recaptcha_public_key: "YOUR_PUBLIC_KEY"
# Enables ReCaptcha checks when registering, preventing signup # Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha # unless a captcha is answered. Requires a valid ReCaptcha

View File

@ -25,12 +25,13 @@ from .registration import RegistrationConfig
from .metrics import MetricsConfig from .metrics import MetricsConfig
from .appservice import AppServiceConfig from .appservice import AppServiceConfig
from .key import KeyConfig from .key import KeyConfig
from .saml2 import SAML2Config
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, VoipConfig, RegistrationConfig, MetricsConfig,
MetricsConfig, AppServiceConfig, KeyConfig,): AppServiceConfig, KeyConfig, SAML2Config, ):
pass pass

View File

@ -14,6 +14,39 @@
# limitations under the License. # limitations under the License.
from ._base import Config from ._base import Config
from collections import namedtuple
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumnailing
method, and thumbnail media type to precalculate
Args:
thumbnail_sizes(list): List of dicts with "width", "height", and
"method" keys
Returns:
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements = {}
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
method = size["method"]
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
media_type: tuple(thumbnails)
for media_type, thumbnails in requirements.items()
}
class ContentRepositoryConfig(Config): class ContentRepositoryConfig(Config):
@ -22,6 +55,10 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"]) self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"]) self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"]
)
def default_config(self, config_dir_path, server_name): def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store") media_store = self.default_path("media_store")
@ -38,4 +75,26 @@ class ContentRepositoryConfig(Config):
# Maximum number of pixels that will be thumbnailed # Maximum number of pixels that will be thumbnailed
max_image_pixels: "32M" max_image_pixels: "32M"
# Whether to generate new thumbnails on the fly to precisely match
# the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail
# from a precalcualted list.
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
thumbnail_sizes:
- width: 32
height: 32
method: crop
- width: 96
height: 96
method: crop
- width: 320
height: 240
method: scale
- width: 640
height: 480
method: scale
""" % locals() """ % locals()

54
synapse/config/saml2.py Normal file
View File

@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Ericsson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class SAML2Config(Config):
"""SAML2 Configuration
Synapse uses pysaml2 libraries for providing SAML2 support
config_path: Path to the sp_conf.py configuration file
idp_redirect_url: Identity provider URL which will redirect
the user back to /login/saml2 with proper info.
sp_conf.py file is something like:
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
More information: https://pythonhosted.org/pysaml2/howto/config.html
"""
def read_config(self, config):
saml2_config = config.get("saml2_config", None)
if saml2_config:
self.saml2_enabled = True
self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else:
self.saml2_enabled = False
self.saml2_config_path = None
self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name):
return """
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file
# idp_redirect_url: Identity provider URL which will redirect
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name)

View File

@ -22,8 +22,10 @@ class ServerConfig(Config):
self.server_name = config["server_name"] self.server_name = config["server_name"]
self.pid_file = self.abspath(config.get("pid_file")) self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"] self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None)
self.soft_file_limit = config["soft_file_limit"] self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize") self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True) self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", []) self.listeners = config.get("listeners", [])
@ -208,12 +210,18 @@ class ServerConfig(Config):
self.manhole = args.manhole self.manhole = args.manhole
if args.daemonize is not None: if args.daemonize is not None:
self.daemonize = args.daemonize self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser): def add_arguments(self, parser):
server_group = parser.add_argument_group("server") server_group = parser.add_argument_group("server")
server_group.add_argument("-D", "--daemonize", action='store_true', server_group.add_argument("-D", "--daemonize", action='store_true',
default=None, default=None,
help="Daemonize the home server") help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole", server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int, type=int,
help="Turn on the twisted telnet manhole" help="Turn on the twisted telnet manhole"

View File

@ -27,6 +27,7 @@ class TlsConfig(Config):
self.tls_certificate = self.read_tls_certificate( self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path") config.get("tls_certificate_path")
) )
self.tls_certificate_file = config.get("tls_certificate_path")
self.no_tls = config.get("no_tls", False) self.no_tls = config.get("no_tls", False)
@ -49,7 +50,11 @@ class TlsConfig(Config):
tls_dh_params_path = base_key_name + ".tls.dh" tls_dh_params_path = base_key_name + ".tls.dh"
return """\ return """\
# PEM encoded X509 certificate for TLS # PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s" tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS # PEM encoded private key for TLS

View File

@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
_ecCurve = _OpenSSLECCurve(_defaultCurveName) _ecCurve = _OpenSSLECCurve(_defaultCurveName)
_ecCurve.addECKeyToContext(context) _ecCurve.addECKeyToContext(context)
except: except:
logger.exception("Failed to enable eliptic curve for TLS") logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(config.tls_certificate) context.use_certificate_chain_file(config.tls_certificate_file)
if not config.no_tls: if not config.no_tls:
context.use_privatekey(config.tls_private_key) context.use_privatekey(config.tls_private_key)

View File

@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from OpenSSL import crypto from OpenSSL import crypto
from collections import namedtuple
import urllib import urllib
import hashlib import hashlib
import logging import logging
@ -38,6 +40,9 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object): class Keyring(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@ -49,141 +54,325 @@ class Keyring(object):
self.key_downloads = {} self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
logger.debug("Verifying for %s", server_name) return self.verify_json_objects_for_server(
key_ids = signature_ids(json_object, server_name) [(server_name, json_object)]
if not key_ids: )[0]
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
try: def verify_json_objects_for_server(self, server_and_json):
verify_signed_json(json_object, server_name, verify_key) """Bulk verfies signatures of json objects, bulk fetching keys as
except: necessary.
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
@defer.inlineCallbacks
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args: Args:
server_name(str): The name of the server to fetch a key for. server_and_json (list): List of pairs of (server_name, json_object)
keys_ids (list of str): The key_ids to check for.
Returns:
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
""" """
cached = yield self.store.get_server_verify_keys(server_name, key_ids) group_id_to_json = {}
group_id_to_group = {}
group_ids = []
if cached: next_group_id = 0
defer.returnValue(cached[0]) deferreds = {}
return
download = self.key_downloads.get(server_name) for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
if download is None: key_ids = signature_ids(json_object, server_name)
download = self._get_server_verify_key_impl(server_name, key_ids) if not key_ids:
download = ObservableDeferred( deferreds[group_id] = defer.fail(SynapseError(
download, 400,
consumeErrors=True "Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
server_to_deferred = {
server_name: defer.Deferred()
for server_name, _ in server_and_json
}
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
server_to_deferred.pop(server_name).callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
) )
self.key_downloads[server_name] = download for g_id in group_ids
]
@download.addBoth
def callback(ret):
del self.key_downloads[server_name]
return ret
r = yield download.observe()
defer.returnValue(r)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids): def wait_for_previous_lookups(self, server_names, server_to_deferred):
keys = None """Waits for any previous key lookups for the given servers to finish.
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server
"""
while True:
wait_on = [
self.key_downloads[server_name]
for server_name in server_names
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred:
self.key_downloads[server_name] = ObservableDeferred(deferred)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
missing_keys = {
group.server_name: key_id
for group in group_id_to_group.values()
for key_id in group.key_ids
}
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(zip(
[server_name for server_name, _ in server_name_and_key_ids],
res
)))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_key(perspective_name, perspective_keys): def get_key(perspective_name, perspective_keys):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys server_name_and_key_ids, perspective_name, perspective_keys
) )
defer.returnValue(result) defer.returnValue(result)
except Exception as e: except Exception as e:
logging.info( logger.exception(
"Unable to getting key %r for %r from %r: %s %s", "Unable to get key from %r: %s %s",
key_ids, server_name, perspective_name, perspective_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
defer.returnValue({})
perspective_results = yield defer.gatherResults([ results = yield defer.gatherResults(
get_key(p_name, p_keys) [
for p_name, p_keys in self.perspective_servers.items() get_key(p_name, p_keys)
]) for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for results in perspective_results: union_of_keys = {}
if results is not None: for result in results:
keys = results for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
limiter = yield get_retry_limiter( defer.returnValue(union_of_keys)
server_name,
self.clock,
self.store,
)
with limiter: @defer.inlineCallbacks
if not keys: def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None
try: try:
keys = yield self.get_server_verify_key_v2_direct( keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids server_name, key_ids
) )
except Exception as e: except Exception as e:
logging.info( logger.info(
"Unable to getting key %r for %r directly: %s %s", "Unable to getting key %r for %r directly: %s %s",
key_ids, server_name, key_ids, server_name,
type(e).__name__, str(e.message), type(e).__name__, str(e.message),
) )
if not keys: if not keys:
keys = yield self.get_server_verify_key_v1_direct( keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids server_name, key_ids
) )
for key_id in key_ids: keys = {server_name: keys}
if key_id in keys:
defer.returnValue(keys[key_id]) defer.returnValue(keys)
return
raise ValueError("No verification key found for given key ids") results = yield defer.gatherResults(
[
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_name, key_ids, def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name, perspective_name,
perspective_keys): perspective_keys):
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
@ -204,6 +393,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0 u"minimum_valid_until_ts": 0
} for key_id in key_ids } for key_id in key_ids
} }
for server_name, key_ids in server_names_and_key_ids
} }
}, },
) )
@ -243,23 +433,29 @@ class Keyring(object):
" server %r" % (perspective_name,) " server %r" % (perspective_name,)
) )
response_keys = yield self.process_v2_response( processed_response = yield self.process_v2_response(
server_name, perspective_name, response perspective_name, response
) )
keys.update(response_keys) for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield self.store_keys( yield defer.gatherResults(
server_name=server_name, [
from_server=perspective_name, self.store_keys(
verify_keys=keys, server_name=server_name,
) from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} keys = {}
for requested_key_id in key_ids: for requested_key_id in key_ids:
@ -295,25 +491,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints") raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name, from_server=server_name,
requested_id=requested_key_id, requested_ids=[requested_key_id],
response_json=response, response_json=response,
) )
keys.update(response_keys) keys.update(response_keys)
yield self.store_keys( yield defer.gatherResults(
server_name=server_name, [
from_server=server_name, self.store_keys(
verify_keys=keys, server_name=key_server_name,
) from_server=server_name,
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys) defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, server_name, from_server, response_json, def process_v2_response(self, from_server, response_json,
requested_id=None): requested_ids=[]):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -335,6 +536,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise ValueError( raise ValueError(
@ -357,28 +560,31 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json) signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"] ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set() updated_key_ids = set(requested_ids)
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys) updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys) updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
for key_id in updated_key_ids: yield defer.gatherResults(
yield self.store.store_server_keys_json( [
server_name=server_name, self.store.store_server_keys_json(
key_id=key_id, server_name=server_name,
from_server=server_name, key_id=key_id,
ts_now_ms=time_now_ms, from_server=server_name,
ts_expires_ms=ts_valid_until_ms, ts_now_ms=time_now_ms,
key_json_bytes=signed_key_json_bytes, ts_expires_ms=ts_valid_until_ms,
) key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(response_keys) results[server_name] = response_keys
raise ValueError("No verification key found for given key ids") defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids): def get_server_verify_key_v1_direct(self, server_name, key_ids):
@ -462,8 +668,13 @@ class Keyring(object):
Returns: Returns:
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
for key_id, key in verify_keys.items(): # TODO(markjh): Store whether the keys have expired.
# TODO(markjh): Store whether the keys have expired. yield defer.gatherResults(
yield self.store.store_server_verify_key( [
server_name, server_name, key.time_added, key self.store.store_server_verify_key(
) server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)

View File

@ -90,7 +90,7 @@ class EventBase(object):
d = dict(self._event_dict) d = dict(self._event_dict)
d.update({ d.update({
"signatures": self.signatures, "signatures": self.signatures,
"unsigned": self.unsigned, "unsigned": dict(self.unsigned),
}) })
return d return d
@ -109,6 +109,9 @@ class EventBase(object):
pdu_json.setdefault("unsigned", {})["age"] = int(age) pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["unsigned"]["age_ts"] del pdu_json["unsigned"]["age_ts"]
# This may be a frozen event
pdu_json["unsigned"].pop("redacted_because", None)
return pdu_json return pdu_json
def __set__(self, instance, value): def __set__(self, instance, value):

View File

@ -74,6 +74,8 @@ def prune_event(event):
) )
elif event_type == EventTypes.Aliases: elif event_type == EventTypes.Aliases:
add_fields("aliases") add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
allowed_fields = { allowed_fields = {
k: v k: v

View File

@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object): class FederationBase(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False): def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each """Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of the database and if not then request if from the originating server of
@ -50,84 +51,108 @@ class FederationBase(object):
Returns: Returns:
Deferred : A list of PDUs that have valid signatures and hashes. Deferred : A list of PDUs that have valid signatures and hashes.
""" """
deferreds = self._check_sigs_and_hashes(pdus)
signed_pdus = [] def callback(pdu):
return pdu
@defer.inlineCallbacks def errback(failure, pdu):
def do(pdu): failure.trap(SynapseError)
try: return None
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
def try_local_db(res, pdu):
if not res:
# Check local db. # Check local db.
new_pdu = yield self.store.get_event( return self.store.get_event(
pdu.event_id, pdu.event_id,
allow_rejected=True, allow_rejected=True,
allow_none=True, allow_none=True,
) )
if new_pdu: return res
signed_pdus.append(new_pdu)
return
# Check pdu.origin def try_remote(res, pdu):
if pdu.origin != origin: if not res and pdu.origin != origin:
try: return self.get_pdu(
new_pdu = yield self.get_pdu( destinations=[pdu.origin],
destinations=[pdu.origin], event_id=pdu.event_id,
event_id=pdu.event_id, outlier=outlier,
outlier=outlier, timeout=10000,
timeout=10000, ).addErrback(lambda e: None)
) return res
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
def warn(res, pdu):
if not res:
logger.warn( logger.warn(
"Failed to find copy of %s with valid signature", "Failed to find copy of %s with valid signature",
pdu.event_id, pdu.event_id,
) )
return res
yield defer.gatherResults( for pdu, deferred in zip(pdus, deferreds):
[do(pdu) for pdu in pdus], deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
valid_pdus = yield defer.gatherResults(
deferreds,
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
defer.returnValue(signed_pdus) if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu): def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
signatures. signatures.
Returns: Returns:
FrozenEvent: Either the given event or it redacted if it failed the FrozenEvent: Either the given event or it redacted if it failed the
content hash check. content hash check.
""" """
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try: redacted_pdus = [
yield self.keyring.verify_json_for_server( prune_event(pdu)
pdu.origin, redacted_pdu_json for pdu in pdus
) ]
except SynapseError:
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
logger.warn( logger.warn(
"Signature check failed for %s", "Signature check failed for %s",
pdu.event_id, pdu.event_id,
) )
raise return failure
if not check_event_content_hash(pdu): for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
logger.warn( deferred.addCallbacks(
"Event content has been tampered, redacting.", callback, errback,
pdu.event_id, callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
) )
defer.returnValue(redacted_event)
defer.returnValue(pdu) return deferreds

View File

@ -23,13 +23,14 @@ from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError, CodeMessageException, HttpResponseException, SynapseError,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
import synapse.metrics import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import copy
import itertools import itertools
import logging import logging
import random import random
@ -133,6 +134,36 @@ class FederationClient(FederationBase):
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
) )
@log_function
def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content)
@log_function
def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def backfill(self, dest, context, limit, extremities): def backfill(self, dest, context, limit, extremities):
@ -167,7 +198,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully. # FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults( pdus[:] = yield defer.gatherResults(
[self._check_sigs_and_hash(pdu) for pdu in pdus], self._check_sigs_and_hashes(pdus),
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
@ -230,7 +261,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0] pdu = pdu_list[0]
# Check signatures are correct. # Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu) pdu = yield self._check_sigs_and_hashes([pdu])[0]
break break
@ -327,6 +358,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id): def make_join(self, destinations, room_id, user_id):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
ret = yield self.transport_layer.make_join( ret = yield self.transport_layer.make_join(
destination, room_id, user_id destination, room_id, user_id
@ -353,6 +387,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_join(self, destinations, pdu): def send_join(self, destinations, pdu):
for destination in destinations: for destination in destinations:
if destination == self.server_name:
continue
try: try:
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join( _, content = yield self.transport_layer.send_join(
@ -374,17 +411,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", []) for p in content.get("auth_chain", [])
] ]
signed_state, signed_auth = yield defer.gatherResults( pdus = {
[ p.event_id: p
self._check_sigs_and_hash_and_fetch( for p in itertools.chain(state, auth_chain)
destination, state, outlier=True }
),
self._check_sigs_and_hash_and_fetch( valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True destination, pdus.values(),
) outlier=True,
], )
consumeErrors=True
).addErrback(unwrapFirstError) valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -396,7 +455,7 @@ class FederationClient(FederationBase):
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
logger.warn( logger.exception(
"Failed to send_join via %s: %s", "Failed to send_join via %s: %s",
destination, e.message destination, e.message
) )

View File

@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature from synapse.crypto.event_signing import compute_event_signature
import simplejson as json
import logging import logging
@ -312,6 +313,48 @@ class FederationServer(FederationBase):
(200, send_content) (200, send_content)
) )
@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_get_missing_events(self, origin, room_id, earliest_events, def on_get_missing_events(self, origin, room_id, earliest_events,

View File

@ -222,6 +222,76 @@ class TransportLayerClient(object):
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote
server.
Request:
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
} } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the one-time keys.
"""
path = PREFIX + "/user/keys/claim"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_missing_events(self, destination, room_id, earliest_events, def get_missing_events(self, destination, room_id, earliest_events,

View File

@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_claim_client_keys(origin, content)
defer.returnValue((200, response))
class FederationQueryAuthServlet(BaseFederationServlet): class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)" PATH = "/query_auth/([^/]*)/([^/]*)"
@ -373,4 +391,6 @@ SERVLET_CLASSES = (
FederationQueryAuthServlet, FederationQueryAuthServlet,
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
) )

View File

@ -22,7 +22,6 @@ from .room import (
from .message import MessageHandler from .message import MessageHandler
from .events import EventStreamHandler, EventHandler from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler from .federation import FederationHandler
from .login import LoginHandler
from .profile import ProfileHandler from .profile import ProfileHandler
from .presence import PresenceHandler from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
@ -32,6 +31,7 @@ from .appservice import ApplicationServicesHandler
from .sync import SyncHandler from .sync import SyncHandler
from .auth import AuthHandler from .auth import AuthHandler
from .identity import IdentityHandler from .identity import IdentityHandler
from .receipts import ReceiptsHandler
class Handlers(object): class Handlers(object):
@ -53,10 +53,10 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs) self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs) self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs) self.room_list_handler = RoomListHandler(hs)
self.login_handler = LoginHandler(hs)
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs) asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler( self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler( hs, asapi, AppServiceScheduler(

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError from synapse.api.errors import LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID from synapse.types import UserID, RoomAlias
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
@ -107,6 +107,22 @@ class BaseHandler(object):
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.current_state) self.auth.check(event, auth_events=context.current_state)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
(event_stream_id, max_stream_id) = yield self.store.persist_event( (event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context event, context=context
) )

View File

@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
self.sessions = {} self.sessions = {}
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip=None): def check_auth(self, flows, clientdict, clientip):
""" """
Takes a dictionary sent by the client in the login / registration Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow. protocol and handles the login flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
Args: Args:
flows: list of list of stages flows (list): A list of login flows. Each flow is an ordered list of
authdict: The dictionary from the client root level, not the strings representing auth-types. At least one full
'auth' key: this method prompts for auth if none is sent. flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns: Returns:
A tuple of authed, dict, dict where authed is true if the client A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage. dict contains the authenticated credentials of each stage.
@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
del clientdict['auth'] del clientdict['auth']
if 'session' in authdict: if 'session' in authdict:
sid = authdict['session'] sid = authdict['session']
sess = self._get_session_info(sid) session = self._get_session_info(sid)
if len(clientdict) > 0: if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters # This was designed to allow the client to omit the parameters
@ -85,20 +92,21 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse # email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects # because it lets unauthenticated clients store arbitrary objects
# on a home server. # on a home server.
# sess['clientdict'] = clientdict # Revisit: Assumimg the REST APIs do sensible validation, the data
# self._save_session(sess) # isn't arbintrary.
pass session['clientdict'] = clientdict
elif 'clientdict' in sess: self._save_session(session)
clientdict = sess['clientdict'] elif 'clientdict' in session:
clientdict = session['clientdict']
if not authdict: if not authdict:
defer.returnValue( defer.returnValue(
(False, self._auth_dict_for_flows(flows, sess), clientdict) (False, self._auth_dict_for_flows(flows, session), clientdict)
) )
if 'creds' not in sess: if 'creds' not in session:
sess['creds'] = {} session['creds'] = {}
creds = sess['creds'] creds = session['creds']
# check auth type currently being presented # check auth type currently being presented
if 'type' in authdict: if 'type' in authdict:
@ -107,15 +115,15 @@ class AuthHandler(BaseHandler):
result = yield self.checkers[authdict['type']](authdict, clientip) result = yield self.checkers[authdict['type']](authdict, clientip)
if result: if result:
creds[authdict['type']] = result creds[authdict['type']] = result
self._save_session(sess) self._save_session(session)
for f in flows: for f in flows:
if len(set(f) - set(creds.keys())) == 0: if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds) logger.info("Auth completed with creds: %r", creds)
self._remove_session(sess) self._remove_session(session)
defer.returnValue((True, creds, clientdict)) defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, sess) ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys() ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict)) defer.returnValue((False, ret, clientdict))
@ -149,22 +157,14 @@ class AuthHandler(BaseHandler):
if "user" not in authdict or "password" not in authdict: if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM) raise LoginError(400, "", Codes.MISSING_PARAM)
user = authdict["user"] user_id = authdict["user"]
password = authdict["password"] password = authdict["password"]
if not user.startswith('@'): if not user_id.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string() user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_info = yield self.store.get_user_by_id(user_id=user) user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
if not user_info: self._check_password(user_id, password, password_hash)
logger.warn("Attempted to login as %s but they do not exist", user) defer.returnValue(user_id)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks @defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip): def _check_recaptcha(self, authdict, clientip):
@ -268,6 +268,79 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id] return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
password (str): Password
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id)
logger.info("Logging in user %s", user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue((user_id, access_token))
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches.
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
if len(user_infos) > 1:
if user_id not in user_infos:
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not bcrypt.checkpw(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)
def _save_session(self, session): def _save_session(self, session):
# TODO: Persistent storage # TODO: Persistent storage
logger.debug("Saving session %s", session) logger.debug("Saving session %s", session)

View File

@ -49,7 +49,12 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True): as_client_event=True, affect_presence=True,
only_room_events=False):
"""Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned.
"""
auth_user = UserID.from_string(auth_user_id) auth_user = UserID.from_string(auth_user_id)
try: try:
@ -70,7 +75,15 @@ class EventStreamHandler(BaseHandler):
self._streams_per_user[auth_user] += 1 self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout: if timeout:
# If they've set a timeout set a minimum limit. # If they've set a timeout set a minimum limit.
@ -81,7 +94,8 @@ class EventStreamHandler(BaseHandler):
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout,
only_room_events=only_room_events
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View File

@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
) )
from synapse.types import UserID from synapse.types import UserID
from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer from twisted.internet import defer
@ -138,26 +140,29 @@ class FederationHandler(BaseHandler):
if state and auth_chain is not None: if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication # If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.) # layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state): for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids: if e.event_id in seen_ids:
continue continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: auth_ids = [e_id for e_id, _ in e.auth_events]
auth_ids = [e_id for e_id, _ in e.auth_events] auth = {
auth = { (e.type, e.state_key): e for e in auth_chain
(e.type, e.state_key): e for e in auth_chain if e.event_id in auth_ids
if e.event_id in auth_ids }
} event_infos.append({
yield self._handle_new_event( "event": e,
origin, e, auth_events=auth "auth_events": auth,
) })
seen_ids.add(e.event_id) seen_ids.add(e.event_id)
except:
logger.exception( yield self._handle_new_events(
"Failed to handle state event %s", origin,
e.event_id, event_infos,
) outliers=True
)
try: try:
_, event_stream_id, max_stream_id = yield self._handle_new_event( _, event_stream_id, max_stream_id = yield self._handle_new_event(
@ -222,6 +227,55 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
def redact_disallowed(event, state):
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.values():
if ev.type != EventTypes.Member:
continue
try:
domain = UserID.from_string(ev.state_key).domain
except:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
return prune_event(event)
return event
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]): def backfill(self, dest, room_id, limit, extremities=[]):
@ -292,38 +346,29 @@ class FederationHandler(BaseHandler):
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results}) auth_events.update({a.event_id: a for a in results})
yield defer.gatherResults( ev_infos = []
[ for a in auth_events.values():
self._handle_new_event( if a.event_id in seen_events:
dest, a, continue
auth_events={ ev_infos.append({
(auth_events[a_id].type, auth_events[a_id].state_key): "event": a,
auth_events[a_id] "auth_events": {
for a_id, _ in a.auth_events (auth_events[a_id].type, auth_events[a_id].state_key):
}, auth_events[a_id]
) for a_id, _ in a.auth_events
for a in auth_events.values() }
if a.event_id not in seen_events })
],
consumeErrors=True,
).addErrback(unwrapFirstError)
yield defer.gatherResults( for e_id in events_to_state:
[ ev_infos.append({
self._handle_new_event( "event": event_map[e_id],
dest, event_map[e_id], "state": events_to_state[e_id],
state=events_to_state[e_id], "auth_events": {
backfilled=True, (auth_events[a_id].type, auth_events[a_id].state_key):
auth_events={ auth_events[a_id]
(auth_events[a_id].type, auth_events[a_id].state_key): for a_id, _ in event_map[e_id].auth_events
auth_events[a_id] }
for a_id, _ in event_map[e_id].auth_events })
},
)
for e_id in events_to_state
],
consumeErrors=True
).addErrback(unwrapFirstError)
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -331,10 +376,14 @@ class FederationHandler(BaseHandler):
if event in events_to_state: if event in events_to_state:
continue continue
yield self._handle_new_event( ev_infos.append({
dest, event, "event": event,
backfilled=True, })
)
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
defer.returnValue(events) defer.returnValue(events)
@ -453,7 +502,7 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys()) event_ids = list(extremities.keys())
states = yield defer.gatherResults([ states = yield defer.gatherResults([
self.state_handler.resolve_state_groups([e]) self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids for e in event_ids
]) ])
states = dict(zip(event_ids, [s[1] for s in states])) states = dict(zip(event_ids, [s[1] for s in states]))
@ -600,32 +649,22 @@ class FederationHandler(BaseHandler):
# FIXME # FIXME
pass pass
yield self._handle_auth_events( ev_infos = []
origin, [e for e in auth_chain if e.event_id != event.event_id] for e in itertools.chain(state, auth_chain):
)
@defer.inlineCallbacks
def handle_state(e):
if e.event_id == event.event_id: if e.event_id == event.event_id:
return continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: auth_ids = [e_id for e_id, _ in e.auth_events]
auth_ids = [e_id for e_id, _ in e.auth_events] ev_infos.append({
auth = { "event": e,
"auth_events": {
(e.type, e.state_key): e for e in auth_chain (e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids if e.event_id in auth_ids
} }
yield self._handle_new_event( })
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
yield defer.DeferredList([handle_state(e) for e in state]) yield self._handle_new_events(origin, ev_infos, outliers=True)
auth_ids = [e_id for e_id, _ in event.auth_events] auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = { auth_events = {
@ -835,7 +874,7 @@ class FederationHandler(BaseHandler):
raise AuthError(403, "Host not in room.") raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
[event_id] room_id, [event_id]
) )
if state_groups: if state_groups:
@ -882,6 +921,8 @@ class FederationHandler(BaseHandler):
limit limit
) )
events = yield self._filter_events_for_server(origin, room_id, events)
defer.returnValue(events) defer.returnValue(events)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -940,11 +981,54 @@ class FederationHandler(BaseHandler):
def _handle_new_event(self, origin, event, state=None, backfilled=False, def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None): current_state=None, auth_events=None):
logger.debug( outlier = event.internal_metadata.is_outlier()
"_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures, context = yield self._prep_event(
origin, event,
state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events,
) )
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(not outlier and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
contexts = yield defer.gatherResults(
[
self._prep_event(
origin,
ev_info["event"],
state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
]
)
yield self.store.persist_events(
[
(ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier() outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context( context = yield self.state_handler.compute_event_context(
@ -954,13 +1038,6 @@ class FederationHandler(BaseHandler):
if not auth_events: if not auth_events:
auth_events = context.current_state auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
event.event_id, auth_events,
)
is_new_state = not outlier
# This is a hack to fix some old rooms where the initial join event # This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events. # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_events:
@ -984,26 +1061,7 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't defer.returnValue(context)
# seen all the auth events.
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
@ -1066,14 +1124,24 @@ class FederationHandler(BaseHandler):
@log_function @log_function
def do_auth(self, origin, event, context, auth_events): def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events. # Check if we have all the auth events.
have_events = yield self.store.have_events( current_state = set(e.event_id for e in auth_events.values())
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events) event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
)
else:
have_events = {}
have_events.update({
e.event_id: ""
for e in auth_events.values()
})
seen_events = set(have_events.keys()) seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events missing_auth = event_auth_events - seen_events - current_state
if missing_auth: if missing_auth:
logger.info("Missing auth: %s", missing_auth) logger.info("Missing auth: %s", missing_auth)

View File

@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090'] # trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org'] trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds: if 'id_server' in creds:
id_server = creds['id_server'] id_server = creds['id_server']
@ -117,3 +117,28 @@ class IdentityHandler(BaseHandler):
except CodeMessageException as e: except CodeMessageException as e:
data = json.loads(e.msg) data = json.loads(e.msg)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor()
http_client = SimpleHttpClient(self.hs)
params = {
'email': email,
'client_secret': client_secret,
'send_attempt': send_attempt,
}
params.update(kwargs)
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/email/requestToken"
),
params
)
defer.returnValue(data)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

View File

@ -1,83 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes
import bcrypt
import logging
logger = logging.getLogger(__name__)
class LoginHandler(BaseHandler):
def __init__(self, hs):
super(LoginHandler, self).__init__(hs)
self.hs = hs
@defer.inlineCallbacks
def login(self, user, password):
"""Login as the specified user with the specified password.
Args:
user (str): The user ID.
password (str): The password.
Returns:
The newly allocated access token.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
# TODO do this better, it can't go in __init__ else it cyclic loops
if not hasattr(self, "reg_handler"):
self.reg_handler = self.hs.get_handlers().registration_handler
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
logger.info("Adding token %s for user %s", token, user)
yield self.store.add_access_token_to_user(user, token)
defer.returnValue(token)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, token_id=None):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
user_id, token_id
)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)

View File

@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
"room_key", next_key "room_key", next_key
) )
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield self._filter_events_for_client(user_id, room_id, events)
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": [ "chunk": [
serialize_event(e, time_now, as_client_event) for e in events serialize_event(e, time_now, as_client_event)
for e in events
], ],
"start": pagin_config.from_token.to_string(), "start": pagin_config.from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
@ -125,6 +135,52 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True, def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None): client=None, txn_id=None):
@ -278,6 +334,11 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None user, pagination_config.get_source_config("presence"), None
) )
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -316,6 +377,10 @@ class MessageHandler(BaseHandler):
] ]
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, event.room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
@ -336,15 +401,20 @@ class MessageHandler(BaseHandler):
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
yield defer.gatherResults( # Only do N rooms at once
[handle_room(e) for e in room_list], n = 5
consumeErrors=True d_list = [handle_room(e) for e in room_list]
).addErrback(unwrapFirstError) for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,
"end": now_token.to_string() "receipts": receipt,
"end": now_token.to_string(),
} }
defer.returnValue(ret) defer.returnValue(ret)
@ -390,24 +460,21 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
presence_defs = yield defer.DeferredList( states = yield presence_handler.get_states(
[ target_users=[UserID.from_string(m.user_id) for m in room_members],
presence_handler.get_state( auth_user=auth_user,
target_user=UserID.from_string(m.user_id), as_event=True,
auth_user=auth_user, check_auth=False,
as_event=True,
check_auth=False,
)
for m in room_members
],
consumeErrors=True,
) )
defer.returnValue([p for success, p in presence_defs if success]) defer.returnValue(states.values())
presence, (messages, token) = yield defer.gatherResults( receipts_handler = self.hs.get_handlers().receipts_handler
presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
get_presence(), get_presence(),
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
self.store.get_recent_events_for_room( self.store.get_recent_events_for_room(
room_id, room_id,
limit=limit, limit=limit,
@ -417,6 +484,10 @@ class MessageHandler(BaseHandler):
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
@ -431,5 +502,6 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(), "end": end_token.to_string(),
}, },
"state": state, "state": state,
"presence": presence "presence": presence,
"receipts": receipts,
}) })

View File

@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False, check_auth=True): def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
"""Get the current presence state of the given user.
Args:
target_user (UserID): The user whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_user`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: The presence state of the `target_user`, whose format depends
on the `as_event` argument.
"""
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
if check_auth: if check_auth:
visible = yield self.is_presence_visible( visible = yield self.is_presence_visible(
@ -232,6 +246,81 @@ class PresenceHandler(BaseHandler):
else: else:
defer.returnValue(state) defer.returnValue(state)
@defer.inlineCallbacks
def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
"""A batched version of the `get_state` method that accepts a list of
`target_users`
Args:
target_users (list): The list of UserID's whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_users`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: A mapping from user -> presence_state
"""
local_users, remote_users = partitionbool(
target_users,
lambda u: self.hs.is_mine(u)
)
if check_auth:
for user in local_users:
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=user
)
if not visible:
raise SynapseError(404, "Presence information not visible")
results = {}
if local_users:
for user in local_users:
if user in self._user_cachemap:
results[user] = self._user_cachemap[user].get_state()
local_to_user = {u.localpart: u for u in local_users}
states = yield self.store.get_presence_states(
[u.localpart for u in local_users if u not in results]
)
for local_part, state in states.items():
if state is None:
continue
res = {"presence": state["state"]}
if "status_msg" in state and state["status_msg"]:
res["status_msg"] = state["status_msg"]
results[local_to_user[local_part]] = res
for user in remote_users:
# TODO(paul): Have remote server send us permissions set
results[user] = self._get_or_offline_usercache(user).get_state()
for state in results.values():
if "last_active" in state:
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
if as_event:
for user, state in results.items():
content = state
content["user_id"] = user.to_string()
if "last_active" in content:
content["last_active_ago"] = int(
self._clock.time_msec() - content.pop("last_active")
)
results[user] = {"type": "m.presence", "content": content}
defer.returnValue(results)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def set_state(self, target_user, auth_user, state): def set_state(self, target_user, auth_user, state):
@ -992,7 +1081,7 @@ class PresenceHandler(BaseHandler):
room_ids([str]): List of room_ids to notify. room_ids([str]): List of room_ids to notify.
""" """
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_event(
"presence_key", "presence_key",
self._user_cachemap_latest_serial, self._user_cachemap_latest_serial,
users_to_push, users_to_push,

View File

@ -0,0 +1,210 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
is_new = yield self._handle_new_receipts([receipt])
if is_new:
self._push_remotes([receipt])
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = [
{
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": user_values["event_ids"],
"data": user_values.get("data", {}),
}
for room_id, room_values in content.items()
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
yield self._handle_new_receipts(receipts)
@defer.inlineCallbacks
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
res = yield self.store.insert_receipt(
room_id, receipt_type, user_id, event_ids, data
)
if not res:
# res will be None if this read receipt is 'old'
defer.returnValue(False)
stream_id, max_persisted_id = res
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_persisted_id, rooms=[room_id]
)
defer.returnValue(True)
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
# TODO: Some of this stuff should be coallesced.
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
}
},
},
)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
room_id,
to_key=to_key,
)
if not result:
defer.returnValue([])
event = {
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
to_key = yield self.get_current_key()
if from_key == to_key:
defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))
def get_current_key(self, direction='f'):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
from_key = int(config.to_key)
else:
from_key = None
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))

View File

@ -57,8 +57,8 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
u = yield self.store.get_user_by_id(user_id) users = yield self.store.get_users_by_id_case_insensitive(user_id)
if u: if users:
raise SynapseError( raise SynapseError(
400, 400,
"User ID already taken.", "User ID already taken.",
@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None, localpart : The local part of the user ID to register. If None,
one will be randomly generated. one will be randomly generated.
password (str) : The password to assign to this user so they can password (str) : The password to assign to this user so they can
login again. login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
Returns: Returns:
A tuple of (user_id, access_token). A tuple of (user_id, access_token).
Raises: Raises:
@ -90,7 +91,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -110,7 +111,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id) yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -160,7 +161,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.", 400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
token = self._generate_token(user_id) token = self.generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
token=token, token=token,
@ -192,6 +193,35 @@ class RegistrationHandler(BaseHandler):
else: else:
logger.info("Valid captcha entered from %s", ip) logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
def register_saml2(self, localpart):
"""
Registers email_id as SAML2 Based Auth.
"""
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self.generate_token(user_id)
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield self.distributor.fire("registered_user", user)
except Exception, e:
yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors
logger.exception(e)
defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
def register_email(self, threepidCreds): def register_email(self, threepidCreds):
""" """
@ -243,7 +273,7 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE errcode=Codes.EXCLUSIVE
) )
def _generate_token(self, user_id): def generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace # urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as # all =s with .s so http clients don't quote =s when it is used as
# query params. # query params.

View File

@ -19,12 +19,15 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError from synapse.util import stringutils, unwrapFirstError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from collections import OrderedDict
import logging import logging
import string import string
@ -33,6 +36,19 @@ logger = logging.getLogger(__name__)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "invited",
"original_invitees_have_ops": False,
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
}
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, user_id, room_id, config): def create_room(self, user_id, room_id, config):
""" Creates a new room. """ Creates a new room.
@ -121,9 +137,25 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname], servers=[self.hs.hostname],
) )
preset_config = config.get(
"preset",
RoomCreationPreset.PUBLIC_CHAT
if is_public
else RoomCreationPreset.PRIVATE_CHAT
)
raw_initial_state = config.get("initial_state", [])
initial_state = OrderedDict()
for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room( creation_events = self._create_events_for_new_room(
user, room_id, is_public=is_public user, room_id,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
@ -170,7 +202,10 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result) defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, is_public=False): def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state):
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.to_string() creator_id = creator.to_string()
event_keys = { event_keys = {
@ -203,16 +238,20 @@ class RoomCreationHandler(BaseHandler):
}, },
) )
power_levels_event = create( returned_events = [creation_event, join_event]
etype=EventTypes.PowerLevels,
content={ if (EventTypes.PowerLevels, '') not in initial_state:
power_level_content = {
"users": { "users": {
creator.to_string(): 100, creator.to_string(): 100,
}, },
"users_default": 0, "users_default": 0,
"events": { "events": {
EventTypes.Name: 100, EventTypes.Name: 50,
EventTypes.PowerLevels: 100, EventTypes.PowerLevels: 100,
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
}, },
"events_default": 0, "events_default": 0,
"state_default": 50, "state_default": 50,
@ -220,21 +259,43 @@ class RoomCreationHandler(BaseHandler):
"kick": 50, "kick": 50,
"redact": 50, "redact": 50,
"invite": 0, "invite": 0,
}, }
)
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE if config["original_invitees_have_ops"]:
join_rules_event = create( for invitee in invite_list:
etype=EventTypes.JoinRules, power_level_content["users"][invitee] = 100
content={"join_rule": join_rule},
)
return [ power_levels_event = create(
creation_event, etype=EventTypes.PowerLevels,
join_event, content=power_level_content,
power_levels_event, )
join_rules_event,
] returned_events.append(power_levels_event)
if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create(
etype=EventTypes.JoinRules,
content={"join_rule": config["join_rules"]},
)
returned_events.append(join_rules_event)
if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
history_event = create(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]}
)
returned_events.append(history_event)
for (etype, state_key), content in initial_state.items():
returned_events.append(create(
etype=etype,
state_key=state_key,
content=content,
))
return returned_events
class RoomMemberHandler(BaseHandler): class RoomMemberHandler(BaseHandler):
@ -498,15 +559,9 @@ class RoomMemberHandler(BaseHandler):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
membership states in.""" membership states in."""
app_service = yield self.store.get_app_service_by_user_id( rooms = yield self.store.get_rooms_for_user(
user.to_string() user.to_string(),
) )
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates # For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should # TODO(paul): work out why because I really don't think it should

View File

@ -96,9 +96,18 @@ class SyncHandler(BaseHandler):
return self.current_sync_for_user(sync_config, since_token) return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
) )
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events( result = yield self.notifier.wait_for_events(
sync_config.user, room_ids, sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback sync_config.filter, timeout, current_sync_callback
@ -229,7 +238,16 @@ class SyncHandler(BaseHandler):
logger.debug("Typing %r", typing_by_room) logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user) app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
# TODO (mjark): Does public mean "published"? # TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True) published_rooms = yield self.store.get_rooms(is_public=True)
@ -292,6 +310,52 @@ class SyncHandler(BaseHandler):
next_batch=now_token, next_batch=now_token,
)) ))
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks @defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token, def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None): since_token=None):
@ -313,6 +377,9 @@ class SyncHandler(BaseHandler):
(room_key, _) = keys (room_key, _) = keys
end_key = "s" + room_key.split('-')[-1] end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events) loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), room_id, loaded_recents,
)
loaded_recents.extend(recents) loaded_recents.extend(recents)
recents = loaded_recents recents = loaded_recents
if len(events) <= load_limit: if len(events) <= load_limit:

View File

@ -204,21 +204,17 @@ class TypingNotificationHandler(BaseHandler):
) )
def _push_update_local(self, room_id, user, typing): def _push_update_local(self, room_id, user, typing):
if room_id not in self._room_serials: room_set = self._room_typing.setdefault(room_id, set())
self._room_serials[room_id] = 0
self._room_typing[room_id] = set()
room_set = self._room_typing[room_id]
if typing: if typing:
room_set.add(user) room_set.add(user)
elif user in room_set: else:
room_set.remove(user) room_set.discard(user)
self._latest_room_serial += 1 self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id] "typing_key", self._latest_room_serial, rooms=[room_id]
) )
@ -260,8 +256,8 @@ class TypingNotificationEventSource(object):
) )
events = [] events = []
for room_id in handler._room_serials: for room_id in joined_room_ids:
if room_id not in joined_room_ids: if room_id not in handler._room_serials:
continue continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
continue continue

View File

@ -16,7 +16,7 @@
from twisted.internet import defer, reactor, protocol from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
) )
class MatrixFederationHttpAgent(_AgentBase): class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
def __init__(self, reactor, pool=None): def endpointForURI(self, uri):
_AgentBase.__init__(self, reactor, pool) destination = uri.netloc
def request(self, destination, endpoint, method, path, params, query, return matrix_federation_endpoint(
headers, body_producer): reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
outgoing_requests_counter.inc(method) )
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
return failure
d.addCallbacks(_cb, _eb)
return d
class MatrixFederationHttpClient(object): class MatrixFederationHttpClient(object):
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
self.server_name = hs.hostname self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor) pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10 pool.maxPersistentPerHost = 10
self.agent = MatrixFederationHttpAgent(reactor, pool=pool) self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.version_string = hs.version_string self.version_string = hs.version_string
self._next_id = 1 self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse( url_bytes = self._create_url(
("", "", path_bytes, param_bytes, query_bytes, "",) destination, path_bytes, param_bytes, query_bytes
) )
txn_id = "%s-O-%s" % (method, self._next_id) txn_id = "%s-O-%s" % (method, self._next_id)
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = 5 retries_left = 5
endpoint = preserve_context_over_fn( http_url_bytes = urlparse.urlunparse(
self._getEndpoint, reactor, destination ("", "", path_bytes, param_bytes, query_bytes, "")
) )
log_result = None log_result = None
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
while True: while True:
producer = None producer = None
if body_callback: if body_callback:
producer = body_callback(method, url_bytes, headers_dict) producer = body_callback(method, http_url_bytes, headers_dict)
try: try:
def send_request(): def send_request():
request_deferred = self.agent.request( request_deferred = preserve_context_over_fn(
destination, self.agent.request,
endpoint,
method, method,
path_bytes, url_bytes,
param_bytes,
query_bytes,
Headers(headers_dict), Headers(headers_dict),
producer producer
) )
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
defer.returnValue((length, headers)) defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class _ReadBodyToFileProtocol(protocol.Protocol): class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size): def __init__(self, stream, deferred, max_size):

View File

@ -207,7 +207,7 @@ class JsonResource(HttpServer, resource.Resource):
incoming_requests_counter.inc(request.method, servlet_classname) incoming_requests_counter.inc(request.method, servlet_classname)
args = [ args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups() urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
] ]
callback_return = yield callback(request, *args) callback_return = yield callback(request, *args)

View File

@ -18,8 +18,12 @@ from __future__ import absolute_import
import logging import logging
from resource import getrusage, getpagesize, RUSAGE_SELF from resource import getrusage, getpagesize, RUSAGE_SELF
import functools
import os import os
import stat import stat
import time
from twisted.internet import reactor
from .metric import ( from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@ -144,3 +148,50 @@ def _process_fds():
return counts return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"]) get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
num_pending = 0
# _newTimedCalls is one long list of *all* pending calls. Below loop
# is based off of impl of reactor.runUntilCurrent
for delayed_call in reactor._newTimedCalls:
if delayed_call.time > now:
break
if delayed_call.delayed_time > 0:
continue
num_pending += 1
num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)
return ret
return f
try:
# Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent
reactor._newTimedCalls
reactor.threadCallQueue
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
except AttributeError:
pass

View File

@ -221,16 +221,7 @@ class Notifier(object):
event event
) )
room_id = event.room_id app_streams = set()
room_user_streams = self.room_to_user_streams.get(room_id, set())
user_streams = room_user_streams.copy()
for user in extra_users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
for appservice in self.appservice_to_user_streams: for appservice in self.appservice_to_user_streams:
# TODO (kegan): Redundant appservice listener checks? # TODO (kegan): Redundant appservice listener checks?
@ -242,24 +233,20 @@ class Notifier(object):
app_user_streams = self.appservice_to_user_streams.get( app_user_streams = self.appservice_to_user_streams.get(
appservice, set() appservice, set()
) )
user_streams |= app_user_streams app_streams |= app_user_streams
logger.debug("on_new_room_event listeners %s", user_streams) self.on_new_event(
"room_key", room_stream_id,
time_now_ms = self.clock.time_msec() users=extra_users,
for user_stream in user_streams: rooms=[event.room_id],
try: extra_streams=app_streams,
user_stream.notify( )
"room_key", "s%d" % (room_stream_id,), time_now_ms
)
except:
logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]): def on_new_event(self, stream_key, new_token, users=[], rooms=[],
""" Used to inform listeners that something has happend extra_streams=set()):
presence/user event wise. """ Used to inform listeners that something has happend event wise.
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
@ -283,7 +270,7 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback, def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
@ -341,10 +328,13 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout,
only_room_events=False):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
If `only_room_events` is `True` only room events will be returned.
""" """
from_token = pagination_config.from_token from_token = pagination_config.from_token
if not from_token: if not from_token:
@ -365,10 +355,12 @@ class Notifier(object):
after_id = getattr(after_token, keyname) after_id = getattr(after_token, keyname)
if before_id == after_id: if before_id == after_id:
continue continue
stuff, new_key = yield source.get_new_events_for_user( if only_room_events and name != "room":
continue
new_events, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit, user, getattr(from_token, keyname), limit,
) )
events.extend(stuff) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
if events: if events:

View File

@ -249,7 +249,9 @@ class Pusher(object):
# we fail to dispatch the push) # we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1') config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0) self.user_name, config, timeout=0, affect_presence=False,
only_room_events=True
)
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_name, self.last_token self.app_id, self.pushkey, self.user_name, self.last_token
@ -280,8 +282,8 @@ class Pusher(object):
config = PaginationConfig(from_token=from_tok, limit='1') config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000 timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream( chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, self.user_name, config, timeout=timeout, affect_presence=False,
timeout=timeout, affect_presence=False only_room_events=True
) )
# limiting to 1 may get 1 event plus 1 presence event, so # limiting to 1 may get 1 event plus 1 presence event, so
@ -294,6 +296,12 @@ class Pusher(object):
if not single_event: if not single_event:
self.last_token = chunk['end'] self.last_token = chunk['end']
logger.debug("Event stream timeout for pushkey %s", self.pushkey) logger.debug("Event stream timeout for pushkey %s", self.pushkey)
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
return return
if not self.alive: if not self.alive:
@ -345,7 +353,7 @@ class Pusher(object):
if processed: if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success( yield self.store.update_pusher_last_token_and_success(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_name,
@ -354,7 +362,7 @@ class Pusher(object):
) )
if self.failing_since: if self.failing_since:
self.failing_since = None self.failing_since = None
self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_name,
@ -362,7 +370,7 @@ class Pusher(object):
else: else:
if not self.failing_since: if not self.failing_since:
self.failing_since = self.clock.time_msec() self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_name,
@ -380,7 +388,7 @@ class Pusher(object):
self.user_name, self.pushkey) self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end'] self.last_token = chunk['end']
self.store.update_pusher_last_token( yield self.store.update_pusher_last_token(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_name,
@ -388,7 +396,7 @@ class Pusher(object):
) )
self.failing_since = None self.failing_since = None
self.store.update_pusher_failing_since( yield self.store.update_pusher_failing_since(
self.app_id, self.app_id,
self.pushkey, self.pushkey,
self.user_name, self.user_name,

View File

@ -164,7 +164,7 @@ def make_base_append_underride_rules(user):
] ]
}, },
{ {
'rule_id': 'global/override/.m.rule.contains_display_name', 'rule_id': 'global/underride/.m.rule.contains_display_name',
'conditions': [ 'conditions': [
{ {
'kind': 'contains_display_name' 'kind': 'contains_display_name'

View File

@ -94,17 +94,14 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name']) self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks @defer.inlineCallbacks
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id): def remove_pushers_by_user(self, user_id):
all = yield self.store.get_all_pushers() all = yield self.store.get_all_pushers()
logger.info( logger.info(
"Removing all pushers for user %s except access token %s", "Removing all pushers for user %s",
user_id, not_access_token_id user_id,
) )
for p in all: for p in all:
if ( if p['user_name'] == user_id:
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info( logger.info(
"Removing pusher for app id %s, pushkey %s, user %s", "Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name'] p['app_id'], p['pushkey'], p['user_name']

View File

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil>=0.0.7": ["syutil>=0.0.7"], "syutil>=0.0.7": ["syutil>=0.0.7"],
"Twisted==14.0.2": ["twisted==14.0.2"], "Twisted>=15.1.0": ["twisted>=15.1.0"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
@ -31,6 +31,8 @@ REQUIREMENTS = {
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"], "pydenticon": ["pydenticon"],
"ujson": ["ujson"], "ujson": ["ujson"],
"blist": ["blist"],
"pysaml2": ["saml2"],
} }
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"web_client": { "web_client": {
@ -41,8 +43,8 @@ CONDITIONAL_REQUIREMENTS = {
def requirements(config=None, include_conditional=False): def requirements(config=None, include_conditional=False):
reqs = REQUIREMENTS.copy() reqs = REQUIREMENTS.copy()
for key, req in CONDITIONAL_REQUIREMENTS.items(): if include_conditional:
if (config and getattr(config, key)) or include_conditional: for _, req in CONDITIONAL_REQUIREMENTS.items():
reqs.update(req) reqs.update(req)
return reqs return reqs
@ -50,18 +52,18 @@ def requirements(config=None, include_conditional=False):
def github_link(project, version, egg): def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg) return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = [ DEPENDENCY_LINKS = {
github_link( "syutil": github_link(
project="matrix-org/syutil", project="matrix-org/syutil",
version="v0.0.7", version="v0.0.7",
egg="syutil-0.0.7", egg="syutil-0.0.7",
), ),
github_link( "matrix-angular-sdk": github_link(
project="matrix-org/matrix-angular-sdk", project="matrix-org/matrix-angular-sdk",
version="v0.6.6", version="v0.6.6",
egg="matrix_angular_sdk-0.6.6", egg="matrix_angular_sdk-0.6.6",
), ),
] }
class MissingRequirementError(Exception): class MissingRequirementError(Exception):
@ -129,7 +131,7 @@ def check_requirements(config=None):
def list_requirements(): def list_requirements():
result = [] result = []
linked = [] linked = []
for link in DEPENDENCY_LINKS: for link in DEPENDENCY_LINKS.values():
egg = link.split("#egg=")[1] egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0]) linked.append(egg.split('-')[0])
result.append(link) result.append(link)

View File

@ -20,14 +20,32 @@ from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
import simplejson as json import simplejson as json
import urllib
import logging
from saml2 import BINDING_HTTP_POST
from saml2 import config
from saml2.client import Saml2Client
logger = logging.getLogger(__name__)
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.saml2_enabled = hs.config.saml2_enabled
def on_GET(self, request): def on_GET(self, request):
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]}) flows = [{"type": LoginRestServlet.PASS_TYPE}]
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
return (200, {}) return (200, {})
@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
if login_submission["type"] == LoginRestServlet.PASS_TYPE: if login_submission["type"] == LoginRestServlet.PASS_TYPE:
result = yield self.do_password_login(login_submission) result = yield self.do_password_login(login_submission)
defer.returnValue(result) defer.returnValue(result)
elif self.saml2_enabled and (login_submission["type"] ==
LoginRestServlet.SAML2_TYPE):
relay_state = ""
if "relay_state" in login_submission:
relay_state = "&RelayState="+urllib.quote(
login_submission["relay_state"])
result = {
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -46,17 +74,24 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def do_password_login(self, login_submission): def do_password_login(self, login_submission):
if not login_submission["user"].startswith('@'): if 'medium' in login_submission and 'address' in login_submission:
login_submission["user"] = UserID.create( user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission["user"], self.hs.hostname).to_string() login_submission['medium'], login_submission['address']
)
else:
user_id = login_submission['user']
handler = self.handlers.login_handler if not user_id.startswith('@'):
token = yield handler.login( user_id = UserID.create(
user=login_submission["user"], user_id, self.hs.hostname
).to_string()
user_id, token = yield self.handlers.auth_handler.login_with_password(
user_id=user_id,
password=login_submission["password"]) password=login_submission["password"])
result = { result = {
"user_id": login_submission["user"], # may have changed "user_id": user_id, # may have changed
"access_token": token, "access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
@ -94,6 +129,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
) )
class SAML2RestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/saml2")
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
@defer.inlineCallbacks
def on_POST(self, request):
saml2_auth = None
try:
conf = config.SPConfig()
conf.load_file(self.sp_config)
SP = Saml2Client(conf)
saml2_auth = SP.parse_authn_request_response(
request.args['SAMLResponse'][0], BINDING_HTTP_POST)
except Exception, e: # Not authenticated
logger.exception(e)
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
username = saml2_auth.name_id.text
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register_saml2(username)
# Forward to the RelayState callback along with ava
if 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.args['RelayState'][0]) +
'?status=authenticated&access_token=' +
token + '&user_id=' + user_id + '&ava=' +
urllib.quote(json.dumps(saml2_auth.ava)))
request.finish()
defer.returnValue(None)
defer.returnValue((200, {"status": "authenticated",
"user_id": user_id, "token": token,
"ava": saml2_auth.ava}))
elif 'RelayState' in request.args:
request.redirect(urllib.unquote(
request.args['RelayState'][0]) +
'?status=not_authenticated')
request.finish()
defer.returnValue(None)
defer.returnValue((200, {"status": "not_authenticated"}))
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -106,4 +184,6 @@ def _parse_json(request):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server) LoginRestServlet(hs).register(http_server)
if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)

View File

@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if "user_id" not in content: if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.") raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"] state_key = content["user_id"]
# make sure it looks like a user ID; it'll throw if it's invalid.
UserID.from_string(state_key)
if membership_action == "kick": if membership_action == "kick":
membership_action = "leave" membership_action = "leave"

View File

@ -18,7 +18,9 @@ from . import (
filter, filter,
account, account,
register, register,
auth auth,
receipts,
keys,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
account.register_servlets(hs, client_resource) account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)

View File

@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_handlers().auth_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params = yield self.auth_handler.check_auth([ authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD], [LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
], body) ], body, self.hs.get_ip_from_request(request))
if not authed: if not authed:
defer.returnValue((401, result)) defer.returnValue((401, result))
@ -79,8 +78,8 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password'] new_password = params['new_password']
yield self.login_handler.set_password( yield self.auth_handler.set_password(
user_id, new_password, None user_id, new_password
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()
self.hs = hs self.hs = hs
self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth() self.auth = hs.get_auth()
@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
logger.warn("Couldn't add 3pid: invalid response from ID sevrer") logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server") raise SynapseError(500, "Invalid response from ID Server")
yield self.login_handler.add_threepid( yield self.auth_handler.add_threepid(
auth_user.to_string(), auth_user.to_string(),
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],

View File

@ -0,0 +1,316 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.types import UserID
from syutil.jsonutil import encode_canonical_json
from ._base import client_v2_pattern
import simplejson as json
import logging
logger = logging.getLogger(__name__)
class KeyUploadServlet(RestServlet):
"""
POST /keys/upload/<device_id> HTTP/1.1
Content-Type: application/json
{
"device_keys": {
"user_id": "<user_id>",
"device_id": "<device_id>",
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [
"m.olm.curve25519-aes-sha256",
]
"keys": {
"<algorithm>:<device_id>": "<key_base64>",
},
"signatures:" {
"<user_id>" {
"<algorithm>:<device_id>": "<signature_base64>"
} } },
"one_time_keys": {
"<algorithm>:<key_id>": "<key_base64>"
},
}
"""
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
def __init__(self, hs):
super(KeyUploadServlet, self).__init__()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
# TODO: Check that the device_id matches that in the authentication
# or derive the device_id from the authentication instead.
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
device_keys = body.get("device_keys", None)
if device_keys:
logger.info(
"Updating device_keys for device %r for user %r at %d",
device_id, auth_user, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
)
one_time_keys = body.get("one_time_keys", None)
if one_time_keys:
logger.info(
"Adding %d one_time_keys for device %r for user %r at %d",
len(one_time_keys), device_id, user_id, time_now
)
key_list = []
for key_id, key_json in one_time_keys.items():
algorithm, key_id = key_id.split(":")
key_list.append((
algorithm, key_id, encode_canonical_json(key_json)
))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, key_list
)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
@defer.inlineCallbacks
def on_GET(self, request, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string()
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue((200, {"one_time_key_counts": result}))
class KeyQueryServlet(RestServlet):
"""
GET /keys/query/<user_id> HTTP/1.1
GET /keys/query/<user_id>/<device_id> HTTP/1.1
POST /keys/query HTTP/1.1
Content-Type: application/json
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
HTTP/1.1 200 OK
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"user_id": "<user_id>", // Duplicated to be signed
"device_id": "<device_id>", // Duplicated to be signed
"valid_until_ts": <millisecond_timestamp>,
"algorithms": [ // List of supported algorithms
"m.olm.curve25519-aes-sha256",
],
"keys": { // Must include a ed25519 signing key
"<algorithm>:<key_id>": "<key_base64>",
},
"signatures:" {
// Must be signed with device's ed25519 key
"<user_id>/<device_id>": {
"<algorithm>:<key_id>": "<signature_base64>"
}
// Must be signed by this server.
"<server_name>": {
"<algorithm>:<key_id>": "<signature_base64>"
} } } } } }
"""
PATTERN = client_v2_pattern(
"/keys/query(?:"
"/(?P<user_id>[^/]*)(?:"
"/(?P<device_id>[^/]*)"
")?"
")?"
)
def __init__(self, hs):
super(KeyQueryServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id):
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
result = yield self.handle_request(body)
defer.returnValue(result)
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request)
auth_user_id = auth_user.to_string()
user_id = user_id if user_id else auth_user_id
device_ids = [device_id] if device_id else []
result = yield self.handle_request(
{"device_keys": {user_id: device_ids}}
)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_ids in body.get("device_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
if not device_ids:
local_query.append((user_id, None))
else:
for device_id in device_ids:
local_query.append((user_id, device_id))
else:
remote_queries.setdefault(user.domain, {})[user_id] = list(
device_ids
)
results = yield self.store.get_e2e_device_keys(local_query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.query_client_keys(
destination, {"device_keys": device_keys}
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"device_keys": json_result}))
class OneTimeKeyServlet(RestServlet):
"""
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
POST /keys/claim HTTP/1.1
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
HTTP/1.1 200 OK
{
"one_time_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
"""
PATTERN = client_v2_pattern(
"/keys/claim(?:/?|(?:/"
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
")?)"
)
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.federation = hs.get_replication_layer()
self.is_mine = hs.is_mine
@defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
result = yield self.handle_request(
{"one_time_keys": {user_id: {device_id: algorithm}}}
)
defer.returnValue(result)
@defer.inlineCallbacks
def on_POST(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request)
try:
body = json.loads(request.content.read())
except:
raise SynapseError(400, "Invalid key JSON")
result = yield self.handle_request(body)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_request(self, body):
local_query = []
remote_queries = {}
for user_id, device_keys in body.get("one_time_keys", {}).items():
user = UserID.from_string(user_id)
if self.is_mine(user):
for device_id, algorithm in device_keys.items():
local_query.append((user_id, device_id, algorithm))
else:
remote_queries.setdefault(user.domain, {})[user_id] = (
device_keys
)
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
for destination, device_keys in remote_queries.items():
remote_result = yield self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
defer.returnValue((200, {"one_time_keys": json_result}))
def register_servlets(hs, http_server):
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)

View File

@ -0,0 +1,55 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern
import logging
logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet):
PATTERN = client_v2_pattern(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(ReceiptRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.receipts_handler = hs.get_handlers().receipts_handler
@defer.inlineCallbacks
def on_POST(self, request, room_id, receipt_type, event_id):
user, client = yield self.auth.get_user_by_req(request)
yield self.receipts_handler.received_client_receipt(
room_id,
receipt_type,
user_id=user.to_string(),
event_id=event_id
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReceiptRestServlet(hs).register(http_server)

View File

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_request_allow_empty from ._base import client_v2_pattern, parse_json_dict_from_request
import logging import logging
import hmac import hmac
@ -50,26 +50,64 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_handlers().auth_handler self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
yield run_on_reactor() yield run_on_reactor()
body = parse_request_allow_empty(request) if '/register/email/requestToken' in request.path:
if 'password' not in body: ret = yield self.onEmailTokenRequest(request)
raise SynapseError(400, "", Codes.MISSING_PARAM) defer.returnValue(ret)
body = parse_json_dict_from_request(request)
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
if (not isinstance(body['password'], basestring) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body: if 'username' in body:
if (not isinstance(body['username'], basestring) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username'] desired_username = body['username']
yield self.registration_handler.check_username(desired_username)
is_using_shared_secret = False appservice = None
is_application_server = False
service = None
if 'access_token' in request.args: if 'access_token' in request.args:
service = yield self.auth.get_appservice_by_req(request) appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
# have completely different registration flows to normal users
# == Application Service Registration ==
if appservice:
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Shared Secret Registration == (e.g. create new user scripts)
if 'mac' in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body["mac"]
)
defer.returnValue((200, result)) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
if self.hs.config.disable_registration:
raise SynapseError(403, "Registration has been disabled")
if desired_username is not None:
yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha: if self.hs.config.enable_registration_captcha:
flows = [ flows = [
@ -82,39 +120,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
result = None authed, result, params = yield self.auth_handler.check_auth(
if service: flows, body, self.hs.get_ip_from_request(request)
is_application_server = True
params = body
elif 'mac' in body:
# Check registration-specific shared secret auth
if 'username' not in body:
raise SynapseError(400, "", Codes.MISSING_PARAM)
self._check_shared_secret_auth(
body['username'], body['mac']
)
is_using_shared_secret = True
params = body
else:
authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
if not authed:
defer.returnValue((401, result))
can_register = (
not self.hs.config.disable_registration
or is_application_server
or is_using_shared_secret
) )
if not can_register:
raise SynapseError(403, "Registration has been disabled")
if not authed:
defer.returnValue((401, result))
return
# NB: This may be from the auth handler and NOT from the POST
if 'password' not in params: if 'password' not in params:
raise SynapseError(400, "", Codes.MISSING_PARAM) raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
desired_username = params['username'] if 'username' in params else None
new_password = params['password'] desired_username = params.get("username", None)
new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register( (user_id, token) = yield self.registration_handler.register(
localpart=desired_username, localpart=desired_username,
@ -128,7 +147,7 @@ class RegisterRestServlet(RestServlet):
if reqd not in threepid: if reqd not in threepid:
logger.info("Can't add incomplete 3pid") logger.info("Can't add incomplete 3pid")
else: else:
yield self.login_handler.add_threepid( yield self.auth_handler.add_threepid(
user_id, user_id,
threepid['medium'], threepid['medium'],
threepid['address'], threepid['address'],
@ -147,18 +166,21 @@ class RegisterRestServlet(RestServlet):
else: else:
logger.info("bind_email not specified: not binding email") logger.info("bind_email not specified: not binding email")
result = { result = self._create_registration_details(user_id, token)
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result)) defer.returnValue((200, result))
def on_OPTIONS(self, _): def on_OPTIONS(self, _):
return 200, {} return 200, {}
def _check_shared_secret_auth(self, username, mac): @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token):
(user_id, token) = yield self.registration_handler.appservice_register(
username, as_token
)
defer.returnValue(self._create_registration_details(user_id, token))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret: if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled") raise SynapseError(400, "Shared secret registration is not enabled")
@ -174,13 +196,46 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1, digestmod=sha1,
).hexdigest() ).hexdigest()
if compare_digest(want_mac, got_mac): if not compare_digest(want_mac, got_mac):
return True
else:
raise SynapseError( raise SynapseError(
403, "HMAC incorrect", 403, "HMAC incorrect",
) )
(user_id, token) = yield self.registration_handler.register(
localpart=username, password=password
)
defer.returnValue(self._create_registration_details(user_id, token))
def _create_registration_details(self, user_id, token):
return {
"user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname,
}
@defer.inlineCallbacks
def onEmailTokenRequest(self, request):
body = parse_json_dict_from_request(request)
required = ['id_server', 'client_secret', 'email', 'send_attempt']
absent = []
for k in required:
if k not in body:
absent.append(k)
if len(absent) > 0:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
if existingUid is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
defer.returnValue((200, ret))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server)

View File

@ -27,18 +27,30 @@ from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
import os import os
import cgi
import logging import logging
import urllib
import urlparse
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def parse_media_id(request): def parse_media_id(request):
try: try:
server_name, media_id = request.postpath # This allows users to append e.g. /test.png to the URL. Useful for
return (server_name, media_id) # clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
file_name = None
if len(request.postpath) > 2:
try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except: except:
raise SynapseError( raise SynapseError(
404, 404,
@ -62,6 +74,8 @@ class BaseMediaResource(Resource):
self.filepaths = filepaths self.filepaths = filepaths
self.version_string = hs.version_string self.version_string = hs.version_string
self.downloads = {} self.downloads = {}
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
def _respond_404(self, request): def _respond_404(self, request):
respond_with_json( respond_with_json(
@ -128,12 +142,38 @@ class BaseMediaResource(Resource):
media_type = headers["Content-Type"][0] media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
yield self.store.store_cached_remote_media( yield self.store.store_cached_remote_media(
origin=server_name, origin=server_name,
media_id=media_id, media_id=media_id,
media_type=media_type, media_type=media_type,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
upload_name=None, upload_name=upload_name,
media_length=length, media_length=length,
filesystem_id=file_id, filesystem_id=file_id,
) )
@ -144,7 +184,7 @@ class BaseMediaResource(Resource):
media_info = { media_info = {
"media_type": media_type, "media_type": media_type,
"media_length": length, "media_length": length,
"upload_name": None, "upload_name": upload_name,
"created_ts": time_now_ms, "created_ts": time_now_ms,
"filesystem_id": file_id, "filesystem_id": file_id,
} }
@ -157,11 +197,26 @@ class BaseMediaResource(Resource):
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path, def _respond_with_file(self, request, media_type, file_path,
file_size=None): file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path) logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path): if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8")) request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day. # cache for at least a day.
# XXX: we might want to turn this off for data we don't want to # XXX: we might want to turn this off for data we don't want to
@ -187,22 +242,74 @@ class BaseMediaResource(Resource):
self._respond_404(request) self._respond_404(request)
def _get_thumbnail_requirements(self, media_type): def _get_thumbnail_requirements(self, media_type):
if media_type == "image/jpeg": return self.thumbnail_requirements.get(media_type, ())
return (
(32, 32, "crop", "image/jpeg"), def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
(96, 96, "crop", "image/jpeg"), t_method, t_type):
(320, 240, "scale", "image/jpeg"), thumbnailer = Thumbnailer(input_path)
(640, 480, "scale", "image/jpeg"), m_width = thumbnailer.width
) m_height = thumbnailer.height
elif (media_type == "image/png") or (media_type == "image/gif"):
return ( if m_width * m_height >= self.max_image_pixels:
(32, 32, "crop", "image/png"), logger.info(
(96, 96, "crop", "image/png"), "Image too large to thumbnail %r x %r > %r",
(320, 240, "scale", "image/png"), m_width, m_height, self.max_image_pixels
(640, 480, "scale", "image/png"),
) )
return
if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale":
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else: else:
return () t_len = None
return t_len
@defer.inlineCallbacks
def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield threads.deferToThread(
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield threads.deferToThread(
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info): def _generate_local_thumbnails(self, media_id, media_info):
@ -223,43 +330,52 @@ class BaseMediaResource(Resource):
) )
return return
scales = set() local_thumbnails = []
crops = set()
for r_width, r_height, r_method, r_type in requirements: def generate_thumbnails():
if r_method == "scale": scales = set()
t_width, t_height = thumbnailer.aspect(r_width, r_height) crops = set()
scales.add(( for r_width, r_height, r_method, r_type in requirements:
min(m_width, t_width), min(m_height, t_height), r_type, if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
)) ))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales: for t_width, t_height, t_type in crops:
t_method = "scale" if (t_width, t_height, t_type) in scales:
t_path = self.filepaths.local_media_thumbnail( # If the aspect ratio of the cropped thumbnail matches a purely
media_id, t_width, t_height, t_type, t_method # scaled one then there is no point in calculating a separate
) # thumbnail.
self._makedirs(t_path) continue
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_method = "crop"
yield self.store.store_local_thumbnail( t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len media_id, t_width, t_height, t_type, t_method
) )
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops: yield threads.deferToThread(generate_thumbnails)
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely for l in local_thumbnails:
# scaled one then there is no point in calculating a separate yield self.store.store_local_thumbnail(*l)
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue({ defer.returnValue({
"width": m_width, "width": m_width,

View File

@ -32,14 +32,16 @@ class DownloadResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
server_name, media_id = parse_media_id(request) server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_file(request, media_id) yield self._respond_local_file(request, media_id, name)
else: else:
yield self._respond_remote_file(request, server_name, media_id) yield self._respond_remote_file(
request, server_name, media_id, name
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_local_file(self, request, media_id): def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id) media_info = yield self.store.get_local_media(media_id)
if not media_info: if not media_info:
self._respond_404(request) self._respond_404(request)
@ -47,24 +49,28 @@ class DownloadResource(BaseMediaResource):
media_type = media_info["media_type"] media_type = media_info["media_type"]
media_length = media_info["media_length"] media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.local_media_filepath(media_id) file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file( yield self._respond_with_file(
request, media_type, file_path, media_length request, media_type, file_path, media_length,
upload_name=upload_name,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id): def _respond_remote_file(self, request, server_name, media_id, name):
media_info = yield self._get_remote_media(server_name, media_id) media_info = yield self._get_remote_media(server_name, media_id)
media_type = media_info["media_type"] media_type = media_info["media_type"]
media_length = media_info["media_length"] media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"] filesystem_id = media_info["filesystem_id"]
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath( file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id server_name, filesystem_id
) )
yield self._respond_with_file( yield self._respond_with_file(
request, media_type, file_path, media_length request, media_type, file_path, media_length,
upload_name=upload_name,
) )

View File

@ -36,21 +36,32 @@ class ThumbnailResource(BaseMediaResource):
@request_handler @request_handler
@defer.inlineCallbacks @defer.inlineCallbacks
def _async_render_GET(self, request): def _async_render_GET(self, request):
server_name, media_id = parse_media_id(request) server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width") width = parse_integer(request, "width")
height = parse_integer(request, "height") height = parse_integer(request, "height")
method = parse_string(request, "method", "scale") method = parse_string(request, "method", "scale")
m_type = parse_string(request, "type", "image/png") m_type = parse_string(request, "type", "image/png")
if server_name == self.server_name: if server_name == self.server_name:
yield self._respond_local_thumbnail( if self.dynamic_thumbnails:
request, media_id, width, height, method, m_type yield self._select_or_generate_local_thumbnail(
) request, media_id, width, height, method, m_type
)
else:
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
else: else:
yield self._respond_remote_thumbnail( if self.dynamic_thumbnails:
request, server_name, media_id, yield self._select_or_generate_remote_thumbnail(
width, height, method, m_type request, server_name, media_id,
) width, height, method, m_type
)
else:
yield self._respond_remote_thumbnail(
request, server_name, media_id,
width, height, method, m_type
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height, def _respond_local_thumbnail(self, request, media_id, width, height,
@ -82,6 +93,87 @@ class ThumbnailResource(BaseMediaResource):
request, media_info, width, height, method, m_type, request, media_info, width, height, method, m_type,
) )
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
desired_height, desired_method,
desired_type):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
t_method = info["thumbnail_method"] == desired_method
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
file_path = self.filepaths.local_media_thumbnail(
media_id, desired_width, desired_height, desired_type, desired_method,
)
yield self._respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self._generate_local_exact_thumbnail(
media_id, desired_width, desired_height, desired_method, desired_type
)
if file_path:
yield self._respond_with_file(request, desired_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height,
desired_method, desired_type,
)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
media_info = yield self._get_remote_media(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
)
file_id = media_info["filesystem_id"]
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
t_h = info["thumbnail_height"] == desired_height
t_method = info["thumbnail_method"] == desired_method
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, desired_width, desired_height,
desired_type, desired_method,
)
yield self._respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self._generate_remote_exact_thumbnail(
server_name, file_id, media_id, desired_width,
desired_height, desired_method, desired_type
)
if file_path:
yield self._respond_with_file(request, desired_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height,
desired_method, desired_type,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width, def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type): height, method, m_type):
@ -162,11 +254,12 @@ class ThumbnailResource(BaseMediaResource):
t_method = info["thumbnail_method"] t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop": if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w) aspect_quality = abs(d_w * t_h - d_h * t_w)
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h)) size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"] type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"] length_quality = info["thumbnail_length"]
info_list.append(( info_list.append((
aspect_quality, size_quality, type_quality, aspect_quality, min_quality, size_quality, type_quality,
length_quality, info length_quality, info
)) ))
if info_list: if info_list:

View File

@ -82,7 +82,7 @@ class Thumbnailer(object):
def save_image(self, output_image, output_type, output_path): def save_image(self, output_image, output_type, output_path):
output_bytes_io = BytesIO() output_bytes_io = BytesIO()
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70) output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
output_bytes = output_bytes_io.getvalue() output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file: with open(output_path, "wb") as output_file:
output_file.write(output_bytes) output_file.write(output_bytes)

View File

@ -84,6 +84,16 @@ class UploadResource(BaseMediaResource):
code=413, code=413,
) )
upload_name = request.args.get("filename", None)
if upload_name:
try:
upload_name = upload_name[0].decode('UTF-8')
except UnicodeDecodeError:
raise SynapseError(
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
code=400,
)
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader("Content-Type"): if headers.hasHeader("Content-Type"):
@ -99,7 +109,7 @@ class UploadResource(BaseMediaResource):
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
content_uri = yield self.create_content( content_uri = yield self.create_content(
media_type, None, request.content.read(), media_type, upload_name, request.content.read(),
content_length, auth_user content_length, auth_user
) )

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes from synapse.api.auth import AuthEventTypes
@ -96,7 +96,7 @@ class StateHandler(object):
cache.ts = self.clock.time_msec() cache.ts = self.clock.time_msec()
state = cache.state state = cache.state
else: else:
res = yield self.resolve_state_groups(event_ids) res = yield self.resolve_state_groups(room_id, event_ids)
state = res[1] state = res[1]
if event_type: if event_type:
@ -155,13 +155,13 @@ class StateHandler(object):
if event.is_state(): if event.is_state():
ret = yield self.resolve_state_groups( ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
event_type=event.type, event_type=event.type,
state_key=event.state_key, state_key=event.state_key,
) )
else: else:
ret = yield self.resolve_state_groups( ret = yield self.resolve_state_groups(
[e for e, _ in event.prev_events], event.room_id, [e for e, _ in event.prev_events],
) )
group, curr_state, prev_state = ret group, curr_state, prev_state = ret
@ -180,7 +180,7 @@ class StateHandler(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups(self, event_ids, event_type=None, state_key=""): def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -205,7 +205,7 @@ class StateHandler(object):
) )
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
event_ids room_id, event_ids
) )
logger.debug( logger.debug(

View File

@ -37,6 +37,9 @@ from .rejections import RejectionsStore
from .state import StateStore from .state import StateStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .filtering import FilteringStore from .filtering import FilteringStore
from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore
import fnmatch import fnmatch
@ -51,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 20 SCHEMA_VERSION = 22
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))
@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore,
PushRuleStore, PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventsStore, EventsStore,
ReceiptsStore,
EndToEndKeyStore,
): ):
def __init__(self, hs): def __init__(self, hs):
@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
key = (user.to_string(), access_token, device_id, ip) key = (user.to_string(), access_token, device_id, ip)
try: try:
last_seen = self.client_ip_last_seen.get(*key) last_seen = self.client_ip_last_seen.get(key)
except KeyError: except KeyError:
last_seen = None last_seen = None
@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
defer.returnValue(None) defer.returnValue(None)
self.client_ip_last_seen.prefill(*key + (now,)) self.client_ip_last_seen.prefill(key, now)
# It's safe not to lock here: a) no unique constraint, # It's safe not to lock here: a) no unique constraint,
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
@ -349,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
) )
logger.debug("Running script %s", relative_path) logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine) module.run_upgrade(cur, database_engine)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
elif ext == ".sql": elif ext == ".sql":
# A plain old .sql file, just read and execute it # A plain old .sql file, just read and execute it
logger.debug("Applying schema %s", relative_path) logger.debug("Applying schema %s", relative_path)

View File

@ -17,21 +17,20 @@ import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
import synapse.metrics import synapse.metrics
from util.id_generators import IdGenerator, StreamIdGenerator from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple
import functools
import sys import sys
import time import time
import threading import threading
DEBUG_CACHES = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"]) sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"]) sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, *keyargs):
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if keyargs in self.cache:
cache_counter.inc_hits(self.name)
return self.cache[keyargs]
cache_counter.inc_misses(self.name)
raise KeyError()
def update(self, sequence, *args):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(*args)
def prefill(self, *args): # because I can't *keyargs, value
keyargs = args[:-1]
value = args[-1]
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[keyargs] = value
def invalidate(self, *keyargs):
self.check_thread()
if len(keyargs) != self.keylen:
raise ValueError("Expected a key to have %d items", self.keylen)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(keyargs, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
self.orig = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
def __get__(self, obj, objtype=None):
cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
@functools.wraps(self.orig)
@defer.inlineCallbacks
def wrapped(*keyargs):
try:
cached_result = cache.get(*keyargs[:self.num_args])
if DEBUG_CACHES:
actual_result = yield self.orig(obj, *keyargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, keyargs,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = yield self.orig(obj, *keyargs)
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
class LoggingTransaction(object): class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object """An object that almost-transparently proxies for the 'txn' object
@ -321,6 +167,8 @@ class SQLBaseStore(object):
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True, self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
max_entries=hs.config.event_cache_size) max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000)
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
self._event_fetch_list = [] self._event_fetch_list = []
self._event_fetch_ongoing = 0 self._event_fetch_ongoing = 0
@ -329,13 +177,14 @@ class SQLBaseStore(object):
self.database_engine = hs.database_engine self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator() self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
self._pushers_id_gen = IdGenerator("pushers", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self)
self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
def start_profiling(self): def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec() self._previous_loop_ts = self._clock.time_msec()

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -104,7 +105,7 @@ class DirectoryStore(SQLBaseStore):
}, },
desc="create_room_alias_association", desc="create_room_alias_association",
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_room_alias(self, room_alias): def delete_room_alias(self, room_alias):
@ -114,7 +115,7 @@ class DirectoryStore(SQLBaseStore):
room_alias, room_alias,
) )
self.get_aliases_for_room.invalidate(room_id) self.get_aliases_for_room.invalidate((room_id,))
defer.returnValue(room_id) defer.returnValue(room_id)
def _delete_room_alias_txn(self, txn, room_alias): def _delete_room_alias_txn(self, txn, room_alias):

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from _base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
return self._simple_upsert(
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
)
def get_e2e_device_keys(self, query_list):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key json byte strings.
"""
def _get_e2e_device_keys(txn):
result = {}
for user_id, device_id in query_list:
user_result = result.setdefault(user_id, {})
keyvalues = {"user_id": user_id}
if device_id:
keyvalues["device_id"] = device_id
rows = self._simple_select_list_txn(
txn, table="e2e_device_keys_json",
keyvalues=keyvalues,
retcols=["device_id", "key_json"]
)
for row in rows:
user_result[row["device_id"]] = row["key_json"]
return result
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
def _add_e2e_one_time_keys(txn):
for (algorithm, key_id, json_bytes) in key_list:
self._simple_upsert_txn(
txn, table="e2e_one_time_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
)
return self.runInteraction(
"add_e2e_one_time_keys", _add_e2e_one_time_keys
)
def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ?"
" GROUP BY algorithm"
)
txn.execute(sql, (user_id, device_id))
result = {}
for algorithm, key_count in txn.fetchall():
result[algorithm] = key_count
return result
return self.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):
sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" LIMIT 1"
)
result = {}
delete = []
for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall():
device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id))
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" AND key_id = ?"
)
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)

View File

@ -15,7 +15,8 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
import logging import logging
@ -49,14 +50,22 @@ class EventFederationStore(SQLBaseStore):
results = set() results = set()
base_sql = ( base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id = ?" "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
) )
front = set(event_ids) front = set(event_ids)
while front: while front:
new_front = set() new_front = set()
for f in front: front_list = list(front)
txn.execute(base_sql, (f,)) chunks = [
front_list[x:x+100]
for x in xrange(0, len(front), 100)
]
for chunk in chunks:
txn.execute(
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
new_front.update([r[0] for r in txn.fetchall()]) new_front.update([r[0] for r in txn.fetchall()])
new_front -= results new_front -= results
@ -274,8 +283,7 @@ class EventFederationStore(SQLBaseStore):
}, },
) )
def _handle_prev_events(self, txn, outlier, event_id, prev_events, def _handle_mult_prev_events(self, txn, events):
room_id):
""" """
For the given event, update the event edges table and forward and For the given event, update the event edges table and forward and
backward extremities tables. backward extremities tables.
@ -285,70 +293,83 @@ class EventFederationStore(SQLBaseStore):
table="event_edges", table="event_edges",
values=[ values=[
{ {
"event_id": event_id, "event_id": ev.event_id,
"prev_event_id": e_id, "prev_event_id": e_id,
"room_id": room_id, "room_id": ev.room_id,
"is_state": False, "is_state": False,
} }
for e_id, _ in prev_events for ev in events
for e_id, _ in ev.prev_events
], ],
) )
# Update the extremities table if this is not an outlier. events_by_room = {}
if not outlier: for ev in events:
for e_id, _ in prev_events: events_by_room.setdefault(ev.room_id, []).append(ev)
# TODO (erikj): This could be done as a bulk insert
self._simple_delete_txn( for room_id, room_events in events_by_room.items():
txn, prevs = [
table="event_forward_extremities", e_id for ev in room_events for e_id, _ in ev.prev_events
keyvalues={ if not ev.internal_metadata.is_outlier()
"event_id": e_id, ]
"room_id": room_id, if prevs:
} txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
) )
# We only insert as a forward extremity the new event if there are query = (
# no other events that reference it as a prev event "INSERT INTO event_forward_extremities (event_id, room_id)"
query = ( " SELECT ?, ? WHERE NOT EXISTS ("
"SELECT 1 FROM event_edges WHERE prev_event_id = ?" " SELECT 1 FROM event_edges WHERE prev_event_id = ?"
) " )"
)
txn.execute(query, (event_id,)) txn.executemany(
query,
[
(ev.event_id, ev.room_id, ev.event_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
if not txn.fetchone(): query = (
query = ( "INSERT INTO event_backward_extremities (event_id, room_id)"
"INSERT INTO event_forward_extremities" " SELECT ?, ? WHERE NOT EXISTS ("
" (event_id, room_id)" " SELECT 1 FROM event_backward_extremities"
" VALUES (?, ?)" " WHERE event_id = ? AND room_id = ?"
) " )"
" AND NOT EXISTS ("
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
" AND outlier = ?"
" )"
)
txn.execute(query, (event_id, room_id)) txn.executemany(query, [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
])
query = ( query = (
"INSERT INTO event_backward_extremities (event_id, room_id)" "DELETE FROM event_backward_extremities"
" SELECT ?, ? WHERE NOT EXISTS (" " WHERE event_id = ? AND room_id = ?"
" SELECT 1 FROM event_backward_extremities" )
" WHERE event_id = ? AND room_id = ?" txn.executemany(
" )" query,
" AND NOT EXISTS (" [
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " (ev.event_id, ev.room_id) for ev in events
" AND outlier = ?" if not ev.internal_metadata.is_outlier()
" )" ]
) )
txn.executemany(query, [
(e_id, room_id, e_id, room_id, e_id, room_id, False)
for e_id, _ in prev_events
])
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
txn.execute(query, (event_id, room_id))
for room_id in events_by_room:
txn.call_after( txn.call_after(
self.get_latest_event_ids_in_room.invalidate, room_id self.get_latest_event_ids_in_room.invalidate, (room_id,)
) )
def get_backfill_events(self, room_id, event_list, limit): def get_backfill_events(self, room_id, event_list, limit):
@ -400,10 +421,12 @@ class EventFederationStore(SQLBaseStore):
keyvalues={ keyvalues={
"event_id": event_id, "event_id": event_id,
}, },
retcol="depth" retcol="depth",
allow_none=True,
) )
queue.put((-depth, event_id)) if depth:
queue.put((-depth, event_id))
while not queue.empty() and len(event_results) < limit: while not queue.empty() and len(event_results) < limit:
try: try:
@ -489,4 +512,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?" query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,)) txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))

View File

@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
from synapse.util.logcontext import preserve_context_over_deferred from synapse.util.logcontext import preserve_context_over_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_json from syutil.jsonutil import encode_json
from contextlib import contextmanager from contextlib import contextmanager
@ -46,6 +44,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True):
if not events_and_contexts:
return
if backfilled:
if not self.min_token_deferred.called:
yield self.min_token_deferred
start = self.min_token - 1
self.min_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_token, -1)
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
else:
stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
self, len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
chunks = [
events_and_contexts[x:x+100]
for x in xrange(0, len(events_and_contexts), 100)
]
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
is_new_state=is_new_state,
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, backfilled=False, def persist_event(self, event, context, backfilled=False,
@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
try: try:
with stream_ordering_manager as stream_ordering: with stream_ordering_manager as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction( yield self.runInteraction(
"persist_event", "persist_event",
self._persist_event_txn, self._persist_event_txn,
event=event, event=event,
context=context, context=context,
backfilled=backfilled, backfilled=backfilled,
stream_ordering=stream_ordering,
is_new_state=is_new_state, is_new_state=is_new_state,
current_state=current_state, current_state=current_state,
) )
@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore):
@log_function @log_function
def _persist_event_txn(self, txn, event, context, backfilled, def _persist_event_txn(self, txn, event, context, backfilled,
stream_ordering=None, is_new_state=True, is_new_state=True, current_state=None):
current_state=None):
# Remove the any existing cache entries for the event_id
txn.call_after(self._invalidate_get_event_cache, event.event_id)
# We purposefully do this first since if we include a `current_state` # We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table # key, we *want* to update the `current_state_events` table
if current_state: if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all) txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all) txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id) txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn( self._simple_delete_txn(
@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore):
} }
) )
outlier = event.internal_metadata.is_outlier() return self._persist_events_txn(
if not outlier:
self._update_min_depth_for_room_txn(
txn,
event.room_id,
event.depth
)
have_persisted = self._simple_select_one_txn(
txn, txn,
table="events", [(event, context)],
keyvalues={"event_id": event.event_id}, backfilled=backfilled,
retcols=["event_id", "outlier"], is_new_state=is_new_state,
allow_none=True,
) )
metadata_json = encode_json( @log_function
event.internal_metadata.get_dict(), def _persist_events_txn(self, txn, events_and_contexts, backfilled,
using_frozen_dicts=USE_FROZEN_DICTS is_new_state=True):
).decode("UTF-8")
# If we have already persisted this event, we don't need to do any # Remove the any existing cache entries for the event_ids
# more processing. for event, _ in events_and_contexts:
# The processing above must be done on every call to persist event, txn.call_after(self._invalidate_get_event_cache, event.event_id)
# since they might not have happened on previous calls. For example,
# if we are persisting an event that we had persisted as an outlier, depth_updates = {}
# but is no longer one. for event, _ in events_and_contexts:
if have_persisted: if event.internal_metadata.is_outlier():
if not outlier and have_persisted["outlier"]: continue
self._store_state_groups_txn(txn, event, context) depth_updates[event.room_id] = max(
event.depth, depth_updates.get(event.room_id, event.depth)
)
for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth)
txn.execute(
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
",".join(["?"] * len(events_and_contexts)),
),
[event.event_id for event, _ in events_and_contexts]
)
have_persisted = {
event_id: outlier
for event_id, outlier in txn.fetchall()
}
event_map = {}
to_remove = set()
for event, context in events_and_contexts:
# Handle the case of the list including the same event multiple
# times. The tricky thing here is when they differ by whether
# they are an outlier.
if event.event_id in event_map:
other = event_map[event.event_id]
if not other.internal_metadata.is_outlier():
to_remove.add(event)
continue
elif not event.internal_metadata.is_outlier():
to_remove.add(event)
continue
else:
to_remove.add(other)
event_map[event.event_id] = event
if event.event_id not in have_persisted:
continue
to_remove.add(event)
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
self._store_state_groups_txn(
txn, event, context,
)
metadata_json = encode_json(
event.internal_metadata.get_dict(),
using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = ( sql = (
"UPDATE event_json SET internal_metadata = ?" "UPDATE event_json SET internal_metadata = ?"
@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore):
sql, sql,
(False, event.event_id,) (False, event.event_id,)
) )
events_and_contexts = filter(
lambda ec: ec[0] not in to_remove,
events_and_contexts
)
if not events_and_contexts:
return return
if not outlier: self._store_mult_state_groups_txn(txn, [
self._store_state_groups_txn(txn, event, context) (event, context)
for event, context in events_and_contexts
if not event.internal_metadata.is_outlier()
])
self._handle_prev_events( self._handle_mult_prev_events(
txn, txn,
outlier=outlier, events=[event for event, _ in events_and_contexts],
event_id=event.event_id,
prev_events=event.prev_events,
room_id=event.room_id,
) )
if event.type == EventTypes.Member: for event, _ in events_and_contexts:
self._store_room_member_txn(txn, event) if event.type == EventTypes.Name:
elif event.type == EventTypes.Name: self._store_room_name_txn(txn, event)
self._store_room_name_txn(txn, event) elif event.type == EventTypes.Topic:
elif event.type == EventTypes.Topic: self._store_room_topic_txn(txn, event)
self._store_room_topic_txn(txn, event) elif event.type == EventTypes.Redaction:
elif event.type == EventTypes.Redaction: self._store_redaction(txn, event)
self._store_redaction(txn, event)
event_dict = { self._store_room_members_txn(
k: v txn,
for k, v in event.get_dict().items() [
if k not in [ event
"redacted", for event, _ in events_and_contexts
"redacted_because", if event.type == EventTypes.Member
] ]
} )
self._simple_insert_txn( def event_dict(event):
return {
k: v
for k, v in event.get_dict().items()
if k not in [
"redacted",
"redacted_because",
]
}
self._simple_insert_many_txn(
txn, txn,
table="event_json", table="event_json",
values={ values=[
"event_id": event.event_id, {
"room_id": event.room_id, "event_id": event.event_id,
"internal_metadata": metadata_json, "room_id": event.room_id,
"json": encode_json( "internal_metadata": encode_json(
event_dict, using_frozen_dicts=USE_FROZEN_DICTS event.internal_metadata.get_dict(),
).decode("UTF-8"), using_frozen_dicts=USE_FROZEN_DICTS
}, ).decode("UTF-8"),
"json": encode_json(
event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8"),
}
for event, _ in events_and_contexts
],
) )
content = encode_json( self._simple_insert_many_txn(
event.content, using_frozen_dicts=USE_FROZEN_DICTS txn,
).decode("UTF-8") table="events",
values=[
vals = { {
"topological_ordering": event.depth, "stream_ordering": event.internal_metadata.stream_ordering,
"event_id": event.event_id, "topological_ordering": event.depth,
"type": event.type, "depth": event.depth,
"room_id": event.room_id, "event_id": event.event_id,
"content": content, "room_id": event.room_id,
"processed": True, "type": event.type,
"outlier": outlier, "processed": True,
"depth": event.depth, "outlier": event.internal_metadata.is_outlier(),
} "content": encode_json(
event.content, using_frozen_dicts=USE_FROZEN_DICTS
unrec = { ).decode("UTF-8"),
k: v }
for k, v in event.get_dict().items() for event, _ in events_and_contexts
if k not in vals.keys() and k not in [ ],
"redacted",
"redacted_because",
"signatures",
"hashes",
"prev_events",
]
}
vals["unrecognized_keys"] = encode_json(
unrec, using_frozen_dicts=USE_FROZEN_DICTS
).decode("UTF-8")
sql = (
"INSERT INTO events"
" (stream_ordering, topological_ordering, event_id, type,"
" room_id, content, processed, outlier, depth)"
" VALUES (?,?,?,?,?,?,?,?,?)"
)
txn.execute(
sql,
(
stream_ordering, event.depth, event.event_id, event.type,
event.room_id, content, True, outlier, event.depth
)
) )
if context.rejected: if context.rejected:
@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore):
txn, event.event_id, context.rejected txn, event.event_id, context.rejected
) )
for hash_alg, hash_base64 in event.hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_event_content_hash_txn(
txn, event.event_id, hash_alg, hash_bytes,
)
for prev_event_id, prev_hashes in event.prev_events:
for alg, hash_base64 in prev_hashes.items():
hash_bytes = decode_base64(hash_base64)
self._store_prev_event_hash_txn(
txn, event.event_id, prev_event_id, alg,
hash_bytes
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_auth", table="event_auth",
@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore):
"room_id": event.room_id, "room_id": event.room_id,
"auth_id": auth_id, "auth_id": auth_id,
} }
for event, _ in events_and_contexts
for auth_id, _ in event.auth_events for auth_id, _ in event.auth_events
], ],
) )
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) self._store_event_reference_hashes_txn(
self._store_event_reference_hash_txn( txn, [event for event, _ in events_and_contexts]
txn, event.event_id, ref_alg, ref_hash_bytes
) )
if event.is_state(): state_events_and_contexts = filter(
lambda i: i[0].is_state(),
events_and_contexts,
)
state_values = []
for event, context in state_events_and_contexts:
vals = { vals = {
"event_id": event.event_id, "event_id": event.event_id,
"room_id": event.room_id, "room_id": event.room_id,
@ -337,51 +402,55 @@ class EventsStore(SQLBaseStore):
if hasattr(event, "replaces_state"): if hasattr(event, "replaces_state"):
vals["prev_state"] = event.replaces_state vals["prev_state"] = event.replaces_state
self._simple_insert_txn( state_values.append(vals)
txn,
"state_events",
vals,
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="event_edges", table="state_events",
values=[ values=state_values,
{ )
"event_id": event.event_id,
"prev_event_id": e_id,
"room_id": event.room_id,
"is_state": True,
}
for e_id, h in event.prev_state
],
)
if is_new_state and not context.rejected: self._simple_insert_many_txn(
txn.call_after( txn,
self.get_current_state_for_key.invalidate, table="event_edges",
event.room_id, event.type, event.state_key values=[
) {
"event_id": event.event_id,
"prev_event_id": prev_id,
"room_id": event.room_id,
"is_state": True,
}
for event, _ in state_events_and_contexts
for prev_id, _ in event.prev_state
],
)
if (event.type == EventTypes.Name if is_new_state:
or event.type == EventTypes.Aliases): for event, _ in state_events_and_contexts:
if not context.rejected:
txn.call_after( txn.call_after(
self.get_room_name_and_aliases.invalidate, self.get_current_state_for_key.invalidate,
event.room_id (event.room_id, event.type, event.state_key,)
) )
self._simple_upsert_txn( if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn, txn.call_after(
"current_state_events", self.get_room_name_and_aliases.invalidate,
keyvalues={ (event.room_id,)
"room_id": event.room_id, )
"type": event.type,
"state_key": event.state_key, self._simple_upsert_txn(
}, txn,
values={ "current_state_events",
"event_id": event.event_id, keyvalues={
} "room_id": event.room_id,
) "type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return return
@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore):
def _invalidate_get_event_cache(self, event_id): def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True): for check_redacted in (False, True):
for get_prev_content in (False, True): for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted, self._get_event_cache.invalidate(
get_prev_content) (event_id, check_redacted, get_prev_content)
)
def _get_event_txn(self, txn, event_id, check_redacted=True, def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False): get_prev_content=False, allow_rejected=False):
@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore):
for event_id in events: for event_id in events:
try: try:
ret = self._get_event_cache.get( ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content (event_id, check_redacted, get_prev_content,)
) )
if allow_rejected or not ret.rejected_reason: if allow_rejected or not ret.rejected_reason:
@ -741,6 +811,8 @@ class EventsStore(SQLBaseStore):
) )
if because: if because:
# It's fine to do add the event directly, since get_pdu_json
# will serialise this field correctly
ev.unsigned["redacted_because"] = because ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned: if get_prev_content and "replaces_state" in ev.unsigned:
@ -753,7 +825,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
defer.returnValue(ev) defer.returnValue(ev)
@ -810,7 +882,7 @@ class EventsStore(SQLBaseStore):
ev.unsigned["prev_content"] = prev.get_dict()["content"] ev.unsigned["prev_content"] = prev.get_dict()["content"]
self._get_event_cache.prefill( self._get_event_cache.prefill(
ev.event_id, check_redacted, get_prev_content, ev (ev.event_id, check_redacted, get_prev_content), ev
) )
return ev return ev

View File

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from _base import SQLBaseStore from _base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
@ -71,6 +72,24 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate", desc="store_server_certificate",
) )
@cachedInlineCallbacks()
def get_all_server_verify_keys(self, server_name):
rows = yield self._simple_select_list(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
},
retcols=["key_id", "verify_key"],
desc="get_all_server_verify_keys",
)
defer.returnValue({
row["key_id"]: decode_verify_key_bytes(
row["key_id"], str(row["verify_key"])
)
for row in rows
})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids): def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given """Retrieve the NACL verification key for a given server for the given
@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
Returns: Returns:
(list of VerifyKey): The verification keys. (list of VerifyKey): The verification keys.
""" """
sql = ( keys = yield self.get_all_server_verify_keys(server_name)
"SELECT key_id, verify_key FROM server_signature_keys" defer.returnValue({
" WHERE server_name = ?" k: keys[k]
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")" for k in key_ids
) if k in keys and keys[k]
})
rows = yield self._execute_and_decode(
"get_server_verify_keys", sql, server_name, *key_ids
)
keys = []
for row in rows:
key_id = row["key_id"]
key_bytes = row["verify_key"]
key = decode_verify_key_bytes(key_id, str(key_bytes))
keys.append(key)
defer.returnValue(keys)
@defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms, def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key): verify_key):
"""Stores a NACL verification key for the given server. """Stores a NACL verification key for the given server.
@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key. verification_key (VerifyKey): The NACL verify key.
""" """
return self._simple_upsert( yield self._simple_upsert(
table="server_signature_keys", table="server_signature_keys",
keyvalues={ keyvalues={
"server_name": server_name, "server_name": server_name,
@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key", desc="store_server_verify_key",
) )
self.get_all_server_verify_keys.invalidate((server_name,))
def store_server_keys_json(self, server_name, key_id, from_server, def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes): ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server """Stores the JSON bytes for a set of keys from a server
@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms, "ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes), "key_json": buffer(key_json_bytes),
}, },
desc="store_server_keys_json",
) )
def get_server_keys_json(self, server_keys): def get_server_keys_json(self, server_keys):

View File

@ -13,19 +13,23 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
from twisted.internet import defer from twisted.internet import defer
class PresenceStore(SQLBaseStore): class PresenceStore(SQLBaseStore):
def create_presence(self, user_localpart): def create_presence(self, user_localpart):
return self._simple_insert( res = self._simple_insert(
table="presence", table="presence",
values={"user_id": user_localpart}, values={"user_id": user_localpart},
desc="create_presence", desc="create_presence",
) )
self.get_presence_state.invalidate((user_localpart,))
return res
def has_presence_state(self, user_localpart): def has_presence_state(self, user_localpart):
return self._simple_select_one( return self._simple_select_one(
table="presence", table="presence",
@ -35,6 +39,7 @@ class PresenceStore(SQLBaseStore):
desc="has_presence_state", desc="has_presence_state",
) )
@cached(max_entries=2000)
def get_presence_state(self, user_localpart): def get_presence_state(self, user_localpart):
return self._simple_select_one( return self._simple_select_one(
table="presence", table="presence",
@ -43,8 +48,27 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_state", desc="get_presence_state",
) )
@cachedList(get_presence_state.cache, list_name="user_localparts")
def get_presence_states(self, user_localparts):
def f(txn):
results = {}
for user_localpart in user_localparts:
res = self._simple_select_one_txn(
txn,
table="presence",
keyvalues={"user_id": user_localpart},
retcols=["state", "status_msg", "mtime"],
allow_none=True,
)
if res:
results[user_localpart] = res
return results
return self.runInteraction("get_presence_states", f)
def set_presence_state(self, user_localpart, new_state): def set_presence_state(self, user_localpart, new_state):
return self._simple_update_one( res = self._simple_update_one(
table="presence", table="presence",
keyvalues={"user_id": user_localpart}, keyvalues={"user_id": user_localpart},
updatevalues={"state": new_state["state"], updatevalues={"state": new_state["state"],
@ -53,6 +77,9 @@ class PresenceStore(SQLBaseStore):
desc="set_presence_state", desc="set_presence_state",
) )
self.get_presence_state.invalidate((user_localpart,))
return res
def allow_presence_visible(self, observed_localpart, observer_userid): def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert( return self._simple_insert(
table="presence_allow_inbound", table="presence_allow_inbound",
@ -98,7 +125,7 @@ class PresenceStore(SQLBaseStore):
updatevalues={"accepted": True}, updatevalues={"accepted": True},
desc="set_presence_list_accepted", desc="set_presence_list_accepted",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))
defer.returnValue(result) defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None): def get_presence_list(self, observer_localpart, accepted=None):
@ -133,4 +160,4 @@ class PresenceStore(SQLBaseStore):
"observed_user_id": observed_userid}, "observed_user_id": observed_userid},
desc="del_presence_list", desc="del_presence_list",
) )
self.get_presence_list_accepted.invalidate(observer_localpart) self.get_presence_list_accepted.invalidate((observer_localpart,))

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer from twisted.internet import defer
import logging import logging
@ -23,8 +24,7 @@ logger = logging.getLogger(__name__)
class PushRuleStore(SQLBaseStore): class PushRuleStore(SQLBaseStore):
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_for_user(self, user_name): def get_push_rules_for_user(self, user_name):
rows = yield self._simple_select_list( rows = yield self._simple_select_list(
table=PushRuleTable.table_name, table=PushRuleTable.table_name,
@ -41,8 +41,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows) defer.returnValue(rows)
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name): def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list( results = yield self._simple_select_list(
table=PushRuleEnableTable.table_name, table=PushRuleEnableTable.table_name,
@ -153,11 +152,11 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority)) txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -189,10 +188,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority'] = new_prio new_rule['priority'] = new_prio
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )
self._simple_insert_txn( self._simple_insert_txn(
@ -218,8 +217,8 @@ class PushRuleStore(SQLBaseStore):
desc="delete_push_rule", desc="delete_push_rule",
) )
self.get_push_rules_for_user.invalidate(user_name) self.get_push_rules_for_user.invalidate((user_name,))
self.get_push_rules_enabled_for_user.invalidate(user_name) self.get_push_rules_enabled_for_user.invalidate((user_name,))
@defer.inlineCallbacks @defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled): def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -240,10 +239,10 @@ class PushRuleStore(SQLBaseStore):
{'id': new_id}, {'id': new_id},
) )
txn.call_after( txn.call_after(
self.get_push_rules_for_user.invalidate, user_name self.get_push_rules_for_user.invalidate, (user_name,)
) )
txn.call_after( txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name self.get_push_rules_enabled_for_user.invalidate, (user_name,)
) )

406
synapse/storage/receipts.py Normal file
View File

@ -0,0 +1,406 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches import cache_counter, caches_by_name
from twisted.internet import defer
from blist import sorteddict
import logging
import ujson as json
logger = logging.getLogger(__name__)
class ReceiptsStore(SQLBaseStore):
def __init__(self, hs):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = _RoomStreamChangeCache()
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
"""Get receipts for multiple rooms for sending to clients.
Args:
room_ids (list): List of room_ids.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
"""
room_ids = set(room_ids)
if from_key:
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
self, room_ids, from_key
)
results = yield self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key
)
defer.returnValue([ev for res in results.values() for ev in res])
@cachedInlineCallbacks(num_args=3, max_entries=5000)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
Args:
room_ids (str): The room id.
to_key (int): Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches
from the start.
Returns:
list: A list of receipts.
"""
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, from_key, to_key)
)
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, to_key)
)
rows = self.cursor_to_dict(txn)
return rows
rows = yield self.runInteraction(
"get_linearized_receipts_for_room", f
)
if not rows:
defer.returnValue([])
content = {}
for row in rows:
content.setdefault(
row["event_id"], {}
).setdefault(
row["receipt_type"], {}
)[row["user_id"]] = json.loads(row["data"])
defer.returnValue([{
"type": "m.receipt",
"room_id": room_id,
"content": content,
}])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
args = list(room_ids)
args.extend([from_key, to_key])
txn.execute(sql, args)
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
args = list(room_ids)
args.append(to_key)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
txn_results = yield self.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(row["room_id"], {
"type": "m.receipt",
"room_id": row["room_id"],
"content": {},
})
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = json.loads(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
defer.returnValue(results)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self)
@cachedInlineCallbacks()
def get_graph_receipts_for_room(self, room_id):
"""Get receipts for sending to remote servers.
"""
rows = yield self._simple_select_list(
table="receipts_graph",
keyvalues={"room_id": room_id},
retcols=["receipt_type", "user_id", "event_id"],
desc="get_linearized_receipts_for_room",
)
result = {}
for row in rows:
result.setdefault(
row["user_id"], {}
).setdefault(
row["receipt_type"], []
).append(row["event_id"])
defer.returnValue(result)
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
sql = (
"SELECT topological_ordering, stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
)
txn.execute(sql, (room_id, receipt_type, user_id))
results = txn.fetchall()
if results:
res = self._simple_select_one_txn(
txn,
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
)
topological_ordering = int(res["topological_ordering"])
stream_ordering = int(res["stream_ordering"])
for to, so, _ in results:
if int(to) > topological_ordering:
return False
elif int(to) == topological_ordering and int(so) >= stream_ordering:
return False
self._simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
)
self._simple_insert_txn(
txn,
table="receipts_linearized",
values={
"stream_id": stream_id,
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
}
)
return True
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
representations.
"""
if not event_ids:
return
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
query = (
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
")"
) % (",".join(["?"] * len(event_ids)))
txn.execute(query, [room_id] + event_ids)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = yield self._receipts_id_gen.get_next(self)
with stream_id_manager as stream_id:
yield self._receipts_stream_cache.room_has_changed(
self, room_id, stream_id
)
have_persisted = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
data,
stream_id=stream_id,
)
if not have_persisted:
defer.returnValue(None)
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
data):
return self.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id, receipt_type, user_id, event_ids, data
)
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data):
self._simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
)
self._simple_insert_txn(
txn,
table="receipts_graph",
values={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": json.dumps(event_ids),
"data": json.dumps(data),
}
)
class _RoomStreamChangeCache(object):
"""Keeps track of the stream_id of the latest change in rooms.
Given a list of rooms and stream key, it will give a subset of rooms that
may have changed since that key. If the key is too old then the cache
will simply return all rooms.
"""
def __init__(self, size_of_cache=10000):
self._size_of_cache = size_of_cache
self._room_to_key = {}
self._cache = sorteddict()
self._earliest_key = None
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
@defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key):
"""Returns subset of room ids that have had new receipts since the
given key. If the key is too old it will just return the given list.
"""
if key > (yield self._get_earliest_key(store)):
keys = self._cache.keys()
i = keys.bisect_right(key)
result = set(
self._cache[k] for k in keys[i:]
).intersection(room_ids)
cache_counter.inc_hits(self.name)
else:
result = room_ids
cache_counter.inc_misses(self.name)
defer.returnValue(result)
@defer.inlineCallbacks
def room_has_changed(self, store, room_id, key):
"""Informs the cache that the room has been changed at the given key.
"""
if key > (yield self._get_earliest_key(store)):
old_key = self._room_to_key.get(room_id, None)
if old_key:
key = max(key, old_key)
self._cache.pop(old_key, None)
self._cache[key] = room_id
while len(self._cache) > self._size_of_cache:
k, r = self._cache.popitem()
self._earliest_key = max(k, self._earliest_key)
self._room_to_key.pop(r, None)
@defer.inlineCallbacks
def _get_earliest_key(self, store):
if self._earliest_key is None:
self._earliest_key = yield store.get_max_receipt_stream_id()
self._earliest_key = int(self._earliest_key)
defer.returnValue(self._earliest_key)

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
class RegistrationStore(SQLBaseStore): class RegistrationStore(SQLBaseStore):
@ -97,6 +98,20 @@ class RegistrationStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn.fetchall())
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks @defer.inlineCallbacks
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """
@ -111,16 +126,16 @@ class RegistrationStore(SQLBaseStore):
}) })
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens_apart_from(self, user_id, token_id): def user_delete_access_tokens(self, user_id):
yield self.runInteraction( yield self.runInteraction(
"user_delete_access_tokens_apart_from", "user_delete_access_tokens",
self._user_delete_access_tokens_apart_from, user_id, token_id self._user_delete_access_tokens, user_id
) )
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id): def _user_delete_access_tokens(self, txn, user_id):
txn.execute( txn.execute(
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?", "DELETE FROM access_tokens WHERE user_id = ?",
(user_id, token_id) (user_id, )
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -131,7 +146,7 @@ class RegistrationStore(SQLBaseStore):
user_id user_id
) )
for r in rows: for r in rows:
self.get_user_by_token.invalidate(r) self.get_user_by_token.invalidate((r,))
@cached() @cached()
def get_user_by_token(self, token): def get_user_by_token(self, token):

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
import collections import collections
import logging import logging
@ -186,8 +187,7 @@ class RoomStore(SQLBaseStore):
} }
) )
@cached() @cachedInlineCallbacks()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id): def get_room_name_and_aliases(self, room_id):
def f(txn): def f(txn):
sql = ( sql = (

View File

@ -17,7 +17,8 @@ from twisted.internet import defer
from collections import namedtuple from collections import namedtuple
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import UserID from synapse.types import UserID
@ -35,38 +36,28 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore): class RoomMemberStore(SQLBaseStore):
def _store_room_member_txn(self, txn, event): def _store_room_members_txn(self, txn, events):
"""Store a room member in the database. """Store a room member in the database.
""" """
try: self._simple_insert_many_txn(
target_user_id = event.state_key
except:
logger.exception(
"Failed to parse target_user_id=%s", target_user_id
)
raise
logger.debug(
"_store_room_member_txn: target_user_id=%s, membership=%s",
target_user_id,
event.membership,
)
self._simple_insert_txn(
txn, txn,
"room_memberships", table="room_memberships",
{ values=[
"event_id": event.event_id, {
"user_id": target_user_id, "event_id": event.event_id,
"sender": event.user_id, "user_id": event.state_key,
"room_id": event.room_id, "sender": event.user_id,
"membership": event.membership, "room_id": event.room_id,
} "membership": event.membership,
}
for event in events
]
) )
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id) for event in events:
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id) txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_users_in_room.invalidate, event.room_id) txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
def get_room_member(self, user_id, room_id): def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member. """Retrieve the current state of a room member.
@ -88,7 +79,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None lambda events: events[0] if events else None
) )
@cached() @cached(max_entries=5000)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):
@ -164,7 +155,7 @@ class RoomMemberStore(SQLBaseStore):
RoomsForUser(**r) for r in self.cursor_to_dict(txn) RoomsForUser(**r) for r in self.cursor_to_dict(txn)
] ]
@cached() @cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id): def get_joined_hosts_for_room(self, room_id):
return self.runInteraction( return self.runInteraction(
"get_joined_hosts_for_room", "get_joined_hosts_for_room",

View File

@ -0,0 +1,34 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
user_id TEXT NOT NULL, -- The user these keys are for.
device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
);
CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
user_id TEXT NOT NULL, -- The user this one-time key is for.
device_id TEXT NOT NULL, -- The device this one-time key is for.
algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
key_json TEXT NOT NULL, -- The key as a JSON blob.
CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
);

View File

@ -0,0 +1,38 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS receipts_graph(
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_ids TEXT NOT NULL,
data TEXT NOT NULL,
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
);
CREATE TABLE IF NOT EXISTS receipts_linearized (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
data TEXT NOT NULL,
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
);
CREATE INDEX receipts_linearized_id ON receipts_linearized(
stream_id
);

View File

@ -0,0 +1,18 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized(
room_id, stream_id
);

View File

@ -0,0 +1,19 @@
CREATE TABLE IF NOT EXISTS user_threepids2 (
user_id TEXT NOT NULL,
medium TEXT NOT NULL,
address TEXT NOT NULL,
validated_at BIGINT NOT NULL,
added_at BIGINT NOT NULL,
CONSTRAINT medium_address UNIQUE (medium, address)
);
INSERT INTO user_threepids2
SELECT * FROM user_threepids WHERE added_at IN (
SELECT max(added_at) FROM user_threepids GROUP BY medium, address
)
;
DROP TABLE user_threepids;
ALTER TABLE user_threepids2 RENAME TO user_threepids;
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from _base import SQLBaseStore from _base import SQLBaseStore
from syutil.base64util import encode_base64 from syutil.base64util import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
txn.execute(query, (event_id, )) txn.execute(query, (event_id, ))
return {k: v for k, v in txn.fetchall()} return {k: v for k, v in txn.fetchall()}
def _store_event_reference_hash_txn(self, txn, event_id, algorithm, def _store_event_reference_hashes_txn(self, txn, events):
hash_bytes):
"""Store a hash for a PDU """Store a hash for a PDU
Args: Args:
txn (cursor): txn (cursor):
event_id (str): Id for the Event. events (list): list of Events.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
""" """
self._simple_insert_txn(
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
vals.append({
"event_id": event.event_id,
"algorithm": ref_alg,
"hash": buffer(ref_hash_bytes),
})
self._simple_insert_many_txn(
txn, txn,
"event_reference_hashes", table="event_reference_hashes",
{ values=vals,
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
) )
def _get_event_signatures_txn(self, txn, event_id): def _get_event_signatures_txn(self, txn, event_id):

View File

@ -13,7 +13,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import (
cached, cachedInlineCallbacks, cachedList
)
from twisted.internet import defer from twisted.internet import defer
@ -44,72 +47,44 @@ class StateStore(SQLBaseStore):
""" """
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_groups(self, event_ids): def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
The return value is a dict mapping group names to lists of events. The return value is a dict mapping group names to lists of events.
""" """
if not event_ids:
defer.returnValue({})
def f(txn): event_to_groups = yield self._get_state_group_for_events(
groups = set() room_id, event_ids,
for event_id in event_ids:
group = self._simple_select_one_onecol_txn(
txn,
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
)
if group:
groups.add(group)
res = {}
for group in groups:
state_ids = self._simple_select_onecol_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": group},
retcol="event_id",
)
res[group] = state_ids
return res
states = yield self.runInteraction(
"get_state_groups",
f,
) )
state_list = yield defer.gatherResults( groups = set(event_to_groups.values())
[ group_to_state = yield self._get_state_for_groups(groups)
self._fetch_events_for_group(group, vals)
for group, vals in states.items()
],
consumeErrors=True,
)
defer.returnValue(dict(state_list)) defer.returnValue({
group: state_map.values()
@cached(num_args=1) for group, state_map in group_to_state.items()
def _fetch_events_for_group(self, state_group, events): })
return self._get_events(
events, get_prev_content=False
).addCallback(
lambda evs: (state_group, evs)
)
def _store_state_groups_txn(self, txn, event, context): def _store_state_groups_txn(self, txn, event, context):
if context.current_state is None: return self._store_mult_state_groups_txn(txn, [(event, context)])
return
state_events = dict(context.current_state) def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if context.current_state is None:
continue
if event.is_state(): if context.state_group is not None:
state_events[(event.type, event.state_key)] = event state_groups[event.event_id] = context.state_group
continue
state_events = dict(context.current_state)
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = context.state_group
if not state_group:
state_group = self._state_groups_id_gen.get_next_txn(txn) state_group = self._state_groups_id_gen.get_next_txn(txn)
self._simple_insert_txn( self._simple_insert_txn(
txn, txn,
@ -135,14 +110,19 @@ class StateStore(SQLBaseStore):
for state in state_events.values() for state in state_events.values()
], ],
) )
state_groups[event.event_id] = state_group
self._simple_insert_txn( self._simple_insert_many_txn(
txn, txn,
table="event_to_state_groups", table="event_to_state_groups",
values={ values=[
"state_group": state_group, {
"event_id": event.event_id, "state_group": state_groups[event.event_id],
}, "event_id": event.event_id,
}
for event, context in events_and_contexts
if context.current_state is not None
],
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -177,8 +157,7 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
@cached(num_args=3) @cachedInlineCallbacks(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key): def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn): def f(txn):
sql = ( sql = (
@ -194,6 +173,262 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False) events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events) defer.returnValue(events)
def _get_state_groups_from_groups(self, groups_and_types):
"""Returns dictionary state_group -> state event ids
Args:
groups_and_types (list): list of 2-tuple (`group`, `types`)
"""
def f(txn):
results = {}
for group, types in groups_and_types:
if types is not None:
where_clause = "AND (%s)" % (
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
)
else:
where_clause = ""
sql = (
"SELECT event_id FROM state_groups_state WHERE"
" state_group = ? %s"
) % (where_clause,)
args = [group]
if types is not None:
args.extend([i for typ in types for i in typ])
txn.execute(sql, args)
results[group] = [r[0] for r in txn.fetchall()]
return results
return self.runInteraction(
"_get_state_groups_from_groups",
f,
)
@defer.inlineCallbacks
def get_state_for_events(self, room_id, event_ids, types):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event. The state dicts will only have the type/state_keys
that are in the `types` list.
Args:
room_id (str)
event_ids (list)
types (list): List of (type, state_key) tuples which are used to
filter the state fetched. `state_key` may be None, which matches
any `state_key`
Returns:
deferred: A list of dicts corresponding to the event_ids given.
The dicts are mappings from (type, state_key) -> state_events
"""
event_to_groups = yield self._get_state_group_for_events(
room_id, event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
for event_id, group in event_to_groups.items()
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@cached(num_args=2, lru=True, max_entries=10000)
def _get_state_group_for_event(self, room_id, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
retcol="state_group",
allow_none=True,
desc="_get_state_group_for_event",
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=2)
def _get_state_group_for_events(self, room_id, event_ids):
"""Returns mapping event_id -> state_group
"""
def f(txn):
results = {}
for event_id in event_ids:
results[event_id] = self._simple_select_one_onecol_txn(
txn,
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
retcol="state_group",
allow_none=True,
)
return results
return self.runInteraction("_get_state_group_for_events", f)
def _get_some_state_from_cache(self, group, types):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
`missing_types` is the list of types that aren't in the cache for that
group. `got_all` is a bool indicating if we successfully retrieved all
requests state from the cache, if False we need to query the DB for the
missing state.
Args:
group: The state group to lookup
types (list): List of 2-tuples of the form (`type`, `state_key`),
where a `state_key` of `None` matches all state_keys for the
`type`.
"""
is_all, state_dict = self._state_group_cache.get(group)
type_to_key = {}
missing_types = set()
for typ, state_key in types:
if state_key is None:
type_to_key[typ] = None
missing_types.add((typ, state_key))
else:
if type_to_key.get(typ, object()) is not None:
type_to_key.setdefault(typ, set()).add(state_key)
if (typ, state_key) not in state_dict:
missing_types.add((typ, state_key))
sentinel = object()
def include(typ, state_key):
valid_state_keys = type_to_key.get(typ, sentinel)
if valid_state_keys is sentinel:
return False
if valid_state_keys is None:
return True
if state_key in valid_state_keys:
return True
return False
got_all = not (missing_types or types is None)
return {
k: v for k, v in state_dict.items()
if include(k[0], k[1])
}, missing_types, got_all
def _get_all_state_from_cache(self, group):
"""Checks if group is in cache. See `_get_state_for_groups`
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
indicating if we successfully retrieved all requests state from the
cache, if False we need to query the DB for the missing state.
Args:
group: The state group to lookup
"""
is_all, state_dict = self._state_group_cache.get(group)
return state_dict, is_all
@defer.inlineCallbacks
def _get_state_for_groups(self, groups, types=None):
"""Given list of groups returns dict of group -> list of state events
with matching types. `types` is a list of `(type, state_key)`, where
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned.
"""
results = {}
missing_groups_and_types = []
if types is not None:
for group in set(groups):
state_dict, missing_types, got_all = self._get_some_state_from_cache(
group, types
)
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, missing_types))
else:
for group in set(groups):
state_dict, got_all = self._get_all_state_from_cache(
group
)
results[group] = state_dict
if not got_all:
missing_groups_and_types.append((group, None))
if not missing_groups_and_types:
defer.returnValue({
group: {
type_tuple: event
for type_tuple, event in state.items()
if event
}
for group, state in results.items()
})
# Okay, so we have some missing_types, lets fetch them.
cache_seq_num = self._state_group_cache.sequence
group_state_dict = yield self._get_state_groups_from_groups(
missing_groups_and_types
)
state_events = yield self._get_events(
[e_id for l in group_state_dict.values() for e_id in l],
get_prev_content=False
)
state_events = {e.event_id: e for e in state_events}
# Now we want to update the cache with all the things we fetched
# from the database.
for group, state_ids in group_state_dict.items():
if types:
# We delibrately put key -> None mappings into the cache to
# cache absence of the key, on the assumption that if we've
# explicitly asked for some types then we will probably ask
# for them again.
state_dict = {key: None for key in types}
state_dict.update(results[group])
results[group] = state_dict
else:
state_dict = results[group]
for event_id in state_ids:
try:
state_event = state_events[event_id]
state_dict[(state_event.type, state_event.state_key)] = state_event
except KeyError:
# Hmm. So we do don't have that state event? Interesting.
logger.warn(
"Can't find state event %r for state group %r",
event_id, group,
)
self._state_group_cache.update(
cache_seq_num,
key=group,
value=state_dict,
full=(types is None),
)
# Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = {
key: event for key, event in state_dict.items() if event
}
defer.returnValue(results)
def _make_group_id(clock): def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5) return str(int(clock.time_msec())) + random_string(5)

View File

@ -36,6 +36,7 @@ what sort order was used:
from twisted.internet import defer from twisted.internet import defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
@ -299,9 +300,8 @@ class StreamStore(SQLBaseStore):
defer.returnValue((events, token)) defer.returnValue((events, token))
@defer.inlineCallbacks @cachedInlineCallbacks(num_args=4)
def get_recent_events_for_room(self, room_id, limit, end_token, def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback # TODO (erikj): Handle compressed feedback
end_token = RoomStreamToken.parse_stream_token(end_token) end_token = RoomStreamToken.parse_stream_token(end_token)

View File

@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, cached from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from collections import namedtuple from collections import namedtuple

View File

@ -72,7 +72,10 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next_txn(txn) as stream_id: with stream_id_gen.get_next_txn(txn) as stream_id:
# ... persist event ... # ... persist event ...
""" """
def __init__(self): def __init__(self, table, column):
self.table = table
self.column = column
self._lock = threading.Lock() self._lock = threading.Lock()
self._current_max = None self._current_max = None
@ -107,6 +110,37 @@ class StreamIdGenerator(object):
defer.returnValue(manager()) defer.returnValue(manager())
@defer.inlineCallbacks
def get_next_mult(self, store, n):
"""
Usage:
with yield stream_id_gen.get_next(store, n) as stream_ids:
# ... persist events ...
"""
if not self._current_max:
yield store.runInteraction(
"_compute_current_max",
self._get_or_compute_current_max,
)
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_ids
finally:
with self._lock:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
defer.returnValue(manager())
@defer.inlineCallbacks @defer.inlineCallbacks
def get_max_token(self, store): def get_max_token(self, store):
"""Returns the maximum stream id such that all stream ids less than or """Returns the maximum stream id such that all stream ids less than or
@ -126,7 +160,7 @@ class StreamIdGenerator(object):
def _get_or_compute_current_max(self, txn): def _get_or_compute_current_max(self, txn):
with self._lock: with self._lock:
txn.execute("SELECT MAX(stream_ordering) FROM events") txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
rows = txn.fetchall() rows = txn.fetchall()
val, = rows[0] val, = rows[0]

View File

@ -20,6 +20,7 @@ from synapse.types import StreamToken
from synapse.handlers.presence import PresenceEventSource from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object): class NullSource(object):
@ -43,6 +44,7 @@ class EventSources(object):
"room": RoomEventSource, "room": RoomEventSource,
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -62,7 +64,10 @@ class EventSources(object):
), ),
typing_key=( typing_key=(
yield self.sources["typing"].get_current_key() yield self.sources["typing"].get_current_key()
) ),
receipt_key=(
yield self.sources["receipt"].get_current_key()
),
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
class StreamToken( class StreamToken(
namedtuple( namedtuple(
"Token", "Token",
("room_key", "presence_key", "typing_key") ("room_key", "presence_key", "typing_key", "receipt_key")
) )
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -109,6 +109,9 @@ class StreamToken(
def from_string(cls, string): def from_string(cls, string):
try: try:
keys = string.split(cls._SEPARATOR) keys = string.split(cls._SEPARATOR)
if len(keys) == len(cls._fields) - 1:
# i.e. old token from before receipt_key
keys.append("0")
return cls(*keys) return cls(*keys)
except: except:
raise SynapseError(400, "Invalid Token") raise SynapseError(400, "Invalid Token")
@ -131,6 +134,7 @@ class StreamToken(
(other_token.room_stream_id < self.room_stream_id) (other_token.room_stream_id < self.room_stream_id)
or (int(other_token.presence_key) < int(self.presence_key)) or (int(other_token.presence_key) < int(self.presence_key))
or (int(other_token.typing_key) < int(self.typing_key)) or (int(other_token.typing_key) < int(self.typing_key))
or (int(other_token.receipt_key) < int(self.receipt_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):
@ -174,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
Live tokens start with an "s" followed by the "stream_ordering" id of the Live tokens start with an "s" followed by the "stream_ordering" id of the
event it comes after. Historic tokens start with a "t" followed by the event it comes after. Historic tokens start with a "t" followed by the
"topological_ordering" id of the event it comes after, follewed by "-", "topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = []
@ -207,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# token_id is the primary key ID of the access token, not the access token itself.
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))

View File

@ -51,7 +51,7 @@ class ObservableDeferred(object):
object.__setattr__(self, "_observers", set()) object.__setattr__(self, "_observers", set())
def callback(r): def callback(r):
self._result = (True, r) object.__setattr__(self, "_result", (True, r))
while self._observers: while self._observers:
try: try:
self._observers.pop().callback(r) self._observers.pop().callback(r)
@ -60,7 +60,7 @@ class ObservableDeferred(object):
return r return r
def errback(f): def errback(f):
self._result = (False, f) object.__setattr__(self, "_result", (False, f))
while self._observers: while self._observers:
try: try:
self._observers.pop().errback(f) self._observers.pop().errback(f)
@ -97,3 +97,8 @@ class ObservableDeferred(object):
def __setattr__(self, name, value): def __setattr__(self, name, value):
setattr(self._deferred, name, value) setattr(self._deferred, name, value)
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self), self._result, self._deferred,
)

View File

@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.metrics
DEBUG_CACHES = False
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
caches_by_name = {}
cache_counter = metrics.register_cache(
"cache",
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
labels=["name"],
)

View File

@ -0,0 +1,377 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from . import caches_by_name, DEBUG_CACHES, cache_counter
from twisted.internet import defer
from collections import OrderedDict
import functools
import inspect
import threading
logger = logging.getLogger(__name__)
_CacheSentinel = object()
class Cache(object):
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
if lru:
self.cache = LruCache(max_size=max_entries)
self.max_entries = None
else:
self.cache = OrderedDict()
self.max_entries = max_entries
self.name = name
self.keylen = keylen
self.sequence = 0
self.thread = None
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, default=_CacheSentinel):
val = self.cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
cache_counter.inc_hits(self.name)
return val
cache_counter.inc_misses(self.name)
if default is _CacheSentinel:
raise KeyError()
else:
return default
def update(self, sequence, key, value):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value)
def prefill(self, key, value):
if self.max_entries is not None:
while len(self.cache) >= self.max_entries:
self.cache.popitem(last=False)
self.cache[key] = value
def invalidate(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
class CacheDescriptor(object):
""" A method decorator that applies a memoizing cache around the function.
This caches deferreds, rather than the results themselves. Deferreds that
fail are removed from the cache.
The function is presumed to take zero or more arguments, which are used in
a tuple as the key for the cache. Hits are served directly from the cache;
misses use the function body to generate the value.
The wrapped function has an additional member, a callable called
"invalidate". This can be used to remove individual entries from the cache.
The wrapped function has another additional callable, called "prefill",
which can be used to insert values into the cache specifically, without
calling the calculation function.
"""
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
inlineCallbacks=False):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.max_entries = max_entries
self.num_args = num_args
self.lru = lru
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
self.cache = Cache(
name=self.orig.__name__,
max_entries=self.max_entries,
keylen=self.num_args,
lru=self.lru,
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
try:
cached_result_d = self.cache.get(cache_key)
observer = cached_result_d.observe()
if DEBUG_CACHES:
@defer.inlineCallbacks
def check_result(cached_result):
actual_result = yield self.function_to_call(obj, *args, **kwargs)
if actual_result != cached_result:
logger.error(
"Stale cache entry %s%r: cached: %r, actual %r",
self.orig.__name__, cache_key,
cached_result, actual_result,
)
raise ValueError("Stale cache entry")
defer.returnValue(cached_result)
observer.addCallback(check_result)
return observer
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = self.cache.sequence
ret = defer.maybeDeferred(
self.function_to_call,
obj, *args, **kwargs
)
def onErr(f):
self.cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
self.cache.update(sequence, cache_key, ret)
return ret.observe()
wrapped.invalidate = self.cache.invalidate
wrapped.invalidate_all = self.cache.invalidate_all
wrapped.prefill = self.cache.prefill
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
class CacheListDescriptor(object):
"""Wraps an existing cache to support bulk fetching of keys.
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped fucntion.
"""
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
"""
Args:
orig (function)
cache (Cache)
list_name (str): Name of the argument which is the bulk lookup list
num_args (int)
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
self.num_args = num_args
self.list_name = list_name
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
self.list_pos = self.arg_names.index(self.list_name)
self.cache = cache
self.sentinel = object()
if len(self.arg_names) < self.num_args:
raise Exception(
"Not enough explicit positional arguments to key off of for %r."
" (@cached cannot key off of *args or **kwars)"
% (orig.__name__,)
)
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
% (self.list_name, cache.name,)
)
def __get__(self, obj, objtype=None):
@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]
# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
cached = {}
missing = []
for arg in list_args:
key = list(keyargs)
key[self.list_pos] = arg
try:
res = self.cache.get(tuple(key)).observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
except KeyError:
missing.append(arg)
if missing:
sequence = self.cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
ret_d = defer.maybeDeferred(
self.function_to_call,
**args_to_call
)
ret_d = ObservableDeferred(ret_d)
# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
observer = ObservableDeferred(observer)
key = list(keyargs)
key[self.list_pos] = arg
self.cache.update(sequence, tuple(key), observer)
def invalidate(f, key):
self.cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))
res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached[arg] = res
return defer.gatherResults(
cached.values(),
consumeErrors=True,
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
def cached(max_entries=1000, num_args=1, lru=True):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
lru=lru,
inlineCallbacks=True,
)
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
is specified as a list that is iterated through to lookup keys in the
original cache. A new list consisting of the keys that weren't in the cache
get passed to the original function, the result of which is stored in the
cache.
Args:
cache (Cache): The underlying cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache.
inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`?
Example:
class Example(object):
@cached(num_args=2)
def do_something(self, first_arg):
...
@cachedList(do_something.cache, list_name="second_args", num_args=2)
def batch_do_something(self, first_arg, second_args):
...
"""
return lambda orig: CacheListDescriptor(
orig,
cache=cache,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)

View File

@ -0,0 +1,103 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches.lrucache import LruCache
from collections import namedtuple
from . import caches_by_name, cache_counter
import threading
import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryCache(object):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.name = name
self.sequence = 0
self.thread = None
# caches_by_name[name] = self.cache
class Sentinel(object):
__slots__ = []
self.sentinel = Sentinel()
caches_by_name[name] = self.cache
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
self.thread = threading.current_thread()
else:
if expected_thread is not threading.current_thread():
raise ValueError(
"Cache objects can only be accessed from the main thread"
)
def get(self, key, dict_keys=None):
entry = self.cache.get(key, self.sentinel)
if entry is not self.sentinel:
cache_counter.inc_hits(self.name)
if dict_keys is None:
return DictionaryEntry(entry.full, dict(entry.value))
else:
return DictionaryEntry(entry.full, {
k: entry.value[k]
for k in dict_keys
if k in entry.value
})
cache_counter.inc_misses(self.name)
return DictionaryEntry(False, {})
def invalidate(self, key):
self.check_thread()
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
def update(self, sequence, key, value, full=False):
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if full:
self._insert(key, value)
else:
self._update_or_insert(key, value)
def _update_or_insert(self, key, value):
entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
entry.value.update(value)
def _insert(self, key, value):
self.cache[key] = DictionaryEntry(True, value)

Some files were not shown because too many files have changed in this diff Show More