mirror of
https://git.anonymousland.org/anonymousland/synapse-product.git
synced 2024-12-12 08:24:19 -05:00
Merge branch 'release-v0.10.0'
This commit is contained in:
commit
efeeff29f6
@ -42,3 +42,6 @@ Ivan Shapovalov <intelfx100 at gmail.com>
|
|||||||
Eric Myhre <hash at exultant.us>
|
Eric Myhre <hash at exultant.us>
|
||||||
* Fix bug where ``media_store_path`` config option was ignored by v0 content
|
* Fix bug where ``media_store_path`` config option was ignored by v0 content
|
||||||
repository API.
|
repository API.
|
||||||
|
|
||||||
|
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
|
||||||
|
* Add SAML2 support for registration and logins.
|
||||||
|
127
CHANGES.rst
127
CHANGES.rst
@ -1,3 +1,130 @@
|
|||||||
|
Changes in synapse v0.10.0 (2015-09-03)
|
||||||
|
=======================================
|
||||||
|
|
||||||
|
No change from release candidate.
|
||||||
|
|
||||||
|
Changes in synapse v0.10.0-rc6 (2015-09-02)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
* Remove some of the old database upgrade scripts.
|
||||||
|
* Fix database port script to work with newly created sqlite databases.
|
||||||
|
|
||||||
|
Changes in synapse v0.10.0-rc5 (2015-08-27)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
* Fix bug that broke downloading files with ascii filenames across federation.
|
||||||
|
|
||||||
|
Changes in synapse v0.10.0-rc4 (2015-08-27)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
* Allow UTF-8 filenames for upload. (PR #259)
|
||||||
|
|
||||||
|
Changes in synapse v0.10.0-rc3 (2015-08-25)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
* Add ``--keys-directory`` config option to specify where files such as
|
||||||
|
certs and signing keys should be stored in, when using ``--generate-config``
|
||||||
|
or ``--generate-keys``. (PR #250)
|
||||||
|
* Allow ``--config-path`` to specify a directory, causing synapse to use all
|
||||||
|
\*.yaml files in the directory as config files. (PR #249)
|
||||||
|
* Add ``web_client_location`` config option to specify static files to be
|
||||||
|
hosted by synapse under ``/_matrix/client``. (PR #245)
|
||||||
|
* Add helper utility to synapse to read and parse the config files and extract
|
||||||
|
the value of a given key. For example::
|
||||||
|
|
||||||
|
$ python -m synapse.config read server_name -c homeserver.yaml
|
||||||
|
localhost
|
||||||
|
|
||||||
|
(PR #246)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.10.0-rc2 (2015-08-24)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
* Fix bug where we incorrectly populated the ``event_forward_extremities``
|
||||||
|
table, resulting in problems joining large remote rooms (e.g.
|
||||||
|
``#matrix:matrix.org``)
|
||||||
|
* Reduce the number of times we wake up pushers by not listening for presence
|
||||||
|
or typing events, reducing the CPU cost of each pusher.
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.10.0-rc1 (2015-08-21)
|
||||||
|
===========================================
|
||||||
|
|
||||||
|
Also see v0.9.4-rc1 changelog, which has been amalgamated into this release.
|
||||||
|
|
||||||
|
General:
|
||||||
|
|
||||||
|
* Upgrade to Twisted 15 (PR #173)
|
||||||
|
* Add support for serving and fetching encryption keys over federation.
|
||||||
|
(PR #208)
|
||||||
|
* Add support for logging in with email address (PR #234)
|
||||||
|
* Add support for new ``m.room.canonical_alias`` event. (PR #233)
|
||||||
|
* Change synapse to treat user IDs case insensitively during registration and
|
||||||
|
login. (If two users already exist with case insensitive matching user ids,
|
||||||
|
synapse will continue to require them to specify their user ids exactly.)
|
||||||
|
* Error if a user tries to register with an email already in use. (PR #211)
|
||||||
|
* Add extra and improve existing caches (PR #212, #219, #226, #228)
|
||||||
|
* Batch various storage request (PR #226, #228)
|
||||||
|
* Fix bug where we didn't correctly log the entity that triggered the request
|
||||||
|
if the request came in via an application service (PR #230)
|
||||||
|
* Fix bug where we needlessly regenerated the full list of rooms an AS is
|
||||||
|
interested in. (PR #232)
|
||||||
|
* Add support for AS's to use v2_alpha registration API (PR #210)
|
||||||
|
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
|
||||||
|
* Add ``--generate-keys`` that will generate any missing cert and key files in
|
||||||
|
the configuration files. This is equivalent to running ``--generate-config``
|
||||||
|
on an existing configuration file. (PR #220)
|
||||||
|
* ``--generate-config`` now no longer requires a ``--server-name`` parameter
|
||||||
|
when used on existing configuration files. (PR #220)
|
||||||
|
* Add ``--print-pidfile`` flag that controls the printing of the pid to stdout
|
||||||
|
of the demonised process. (PR #213)
|
||||||
|
|
||||||
|
Media Repository:
|
||||||
|
|
||||||
|
* Fix bug where we picked a lower resolution image than requested. (PR #205)
|
||||||
|
* Add support for specifying if a the media repository should dynamically
|
||||||
|
thumbnail images or not. (PR #206)
|
||||||
|
|
||||||
|
Metrics:
|
||||||
|
|
||||||
|
* Add statistics from the reactor to the metrics API. (PR #224, #225)
|
||||||
|
|
||||||
|
Demo Homeservers:
|
||||||
|
|
||||||
|
* Fix starting the demo homeservers without rate-limiting enabled. (PR #182)
|
||||||
|
* Fix enabling registration on demo homeservers (PR #223)
|
||||||
|
|
||||||
|
|
||||||
|
Changes in synapse v0.9.4-rc1 (2015-07-21)
|
||||||
|
==========================================
|
||||||
|
|
||||||
|
General:
|
||||||
|
|
||||||
|
* Add basic implementation of receipts. (SPEC-99)
|
||||||
|
* Add support for configuration presets in room creation API. (PR #203)
|
||||||
|
* Add auth event that limits the visibility of history for new users.
|
||||||
|
(SPEC-134)
|
||||||
|
* Add SAML2 login/registration support. (PR #201. Thanks Muthu Subramanian!)
|
||||||
|
* Add client side key management APIs for end to end encryption. (PR #198)
|
||||||
|
* Change power level semantics so that you cannot kick, ban or change power
|
||||||
|
levels of users that have equal or greater power level than you. (SYN-192)
|
||||||
|
* Improve performance by bulk inserting events where possible. (PR #193)
|
||||||
|
* Improve performance by bulk verifying signatures where possible. (PR #194)
|
||||||
|
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
|
||||||
|
* Add support for including TLS certificate chains.
|
||||||
|
|
||||||
|
Media Repository:
|
||||||
|
|
||||||
|
* Add Content-Disposition headers to content repository responses. (SYN-150)
|
||||||
|
|
||||||
|
|
||||||
Changes in synapse v0.9.3 (2015-07-01)
|
Changes in synapse v0.9.3 (2015-07-01)
|
||||||
======================================
|
======================================
|
||||||
|
|
||||||
|
140
README.rst
140
README.rst
@ -94,6 +94,7 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
|
|||||||
System requirements:
|
System requirements:
|
||||||
- POSIX-compliant system (tested on Linux & OS X)
|
- POSIX-compliant system (tested on Linux & OS X)
|
||||||
- Python 2.7
|
- Python 2.7
|
||||||
|
- At least 512 MB RAM.
|
||||||
|
|
||||||
Synapse is written in python but some of the libraries is uses are written in
|
Synapse is written in python but some of the libraries is uses are written in
|
||||||
C. So before we can install synapse itself we need a working C compiler and the
|
C. So before we can install synapse itself we need a working C compiler and the
|
||||||
@ -101,25 +102,26 @@ header files for python C extensions.
|
|||||||
|
|
||||||
Installing prerequisites on Ubuntu or Debian::
|
Installing prerequisites on Ubuntu or Debian::
|
||||||
|
|
||||||
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
|
sudo apt-get install build-essential python2.7-dev libffi-dev \
|
||||||
python-pip python-setuptools sqlite3 \
|
python-pip python-setuptools sqlite3 \
|
||||||
libssl-dev python-virtualenv libjpeg-dev
|
libssl-dev python-virtualenv libjpeg-dev
|
||||||
|
|
||||||
Installing prerequisites on ArchLinux::
|
Installing prerequisites on ArchLinux::
|
||||||
|
|
||||||
$ sudo pacman -S base-devel python2 python-pip \
|
sudo pacman -S base-devel python2 python-pip \
|
||||||
python-setuptools python-virtualenv sqlite3
|
python-setuptools python-virtualenv sqlite3
|
||||||
|
|
||||||
Installing prerequisites on Mac OS X::
|
Installing prerequisites on Mac OS X::
|
||||||
|
|
||||||
$ xcode-select --install
|
xcode-select --install
|
||||||
$ sudo pip install virtualenv
|
sudo easy_install pip
|
||||||
|
sudo pip install virtualenv
|
||||||
|
|
||||||
To install the synapse homeserver run::
|
To install the synapse homeserver run::
|
||||||
|
|
||||||
$ virtualenv -p python2.7 ~/.synapse
|
virtualenv -p python2.7 ~/.synapse
|
||||||
$ source ~/.synapse/bin/activate
|
source ~/.synapse/bin/activate
|
||||||
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
This installs synapse, along with the libraries it uses, into a virtual
|
This installs synapse, along with the libraries it uses, into a virtual
|
||||||
environment under ``~/.synapse``. Feel free to pick a different directory
|
environment under ``~/.synapse``. Feel free to pick a different directory
|
||||||
@ -132,8 +134,8 @@ above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
|
|||||||
|
|
||||||
To set up your homeserver, run (in your virtualenv, as before)::
|
To set up your homeserver, run (in your virtualenv, as before)::
|
||||||
|
|
||||||
$ cd ~/.synapse
|
cd ~/.synapse
|
||||||
$ python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--server-name machine.my.domain.name \
|
--server-name machine.my.domain.name \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
@ -173,12 +175,12 @@ traditionally used for convenience and simplicity.
|
|||||||
|
|
||||||
The advantages of Postgres include:
|
The advantages of Postgres include:
|
||||||
|
|
||||||
* significant performance improvements due to the superior threading and
|
* significant performance improvements due to the superior threading and
|
||||||
caching model, smarter query optimiser
|
caching model, smarter query optimiser
|
||||||
* allowing the DB to be run on separate hardware
|
* allowing the DB to be run on separate hardware
|
||||||
* allowing basic active/backup high-availability with a "hot spare" synapse
|
* allowing basic active/backup high-availability with a "hot spare" synapse
|
||||||
pointing at the same DB master, as well as enabling DB replication in
|
pointing at the same DB master, as well as enabling DB replication in
|
||||||
synapse itself.
|
synapse itself.
|
||||||
|
|
||||||
The only disadvantage is that the code is relatively new as of April 2015 and
|
The only disadvantage is that the code is relatively new as of April 2015 and
|
||||||
may have a few regressions relative to SQLite.
|
may have a few regressions relative to SQLite.
|
||||||
@ -189,12 +191,12 @@ For information on how to install and use PostgreSQL, please see
|
|||||||
Running Synapse
|
Running Synapse
|
||||||
===============
|
===============
|
||||||
|
|
||||||
To actually run your new homeserver, pick a working directory for Synapse to run
|
To actually run your new homeserver, pick a working directory for Synapse to
|
||||||
(e.g. ``~/.synapse``), and::
|
run (e.g. ``~/.synapse``), and::
|
||||||
|
|
||||||
$ cd ~/.synapse
|
cd ~/.synapse
|
||||||
$ source ./bin/activate
|
source ./bin/activate
|
||||||
$ synctl start
|
synctl start
|
||||||
|
|
||||||
Platform Specific Instructions
|
Platform Specific Instructions
|
||||||
==============================
|
==============================
|
||||||
@ -212,12 +214,12 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
|
|||||||
|
|
||||||
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
||||||
|
|
||||||
$ sudo pip2.7 install --upgrade pip
|
sudo pip2.7 install --upgrade pip
|
||||||
|
|
||||||
You also may need to explicitly specify python 2.7 again during the install
|
You also may need to explicitly specify python 2.7 again during the install
|
||||||
request::
|
request::
|
||||||
|
|
||||||
$ pip2.7 install --process-dependency-links \
|
pip2.7 install --process-dependency-links \
|
||||||
https://github.com/matrix-org/synapse/tarball/master
|
https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
|
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
|
||||||
@ -225,13 +227,13 @@ ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
|
|||||||
compile it under the right architecture. (This should not be needed if
|
compile it under the right architecture. (This should not be needed if
|
||||||
installing under virtualenv)::
|
installing under virtualenv)::
|
||||||
|
|
||||||
$ sudo pip2.7 uninstall py-bcrypt
|
sudo pip2.7 uninstall py-bcrypt
|
||||||
$ sudo pip2.7 install py-bcrypt
|
sudo pip2.7 install py-bcrypt
|
||||||
|
|
||||||
During setup of Synapse you need to call python2.7 directly again::
|
During setup of Synapse you need to call python2.7 directly again::
|
||||||
|
|
||||||
$ cd ~/.synapse
|
cd ~/.synapse
|
||||||
$ python2.7 -m synapse.app.homeserver \
|
python2.7 -m synapse.app.homeserver \
|
||||||
--server-name machine.my.domain.name \
|
--server-name machine.my.domain.name \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
@ -242,18 +244,20 @@ Windows Install
|
|||||||
---------------
|
---------------
|
||||||
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
|
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
|
||||||
|
|
||||||
- gcc
|
- gcc
|
||||||
- git
|
- git
|
||||||
- libffi-devel
|
- libffi-devel
|
||||||
- openssl (and openssl-devel, python-openssl)
|
- openssl (and openssl-devel, python-openssl)
|
||||||
- python
|
- python
|
||||||
- python-setuptools
|
- python-setuptools
|
||||||
|
|
||||||
The content repository requires additional packages and will be unable to process
|
The content repository requires additional packages and will be unable to process
|
||||||
uploads without them:
|
uploads without them:
|
||||||
- libjpeg8
|
|
||||||
- libjpeg8-devel
|
- libjpeg8
|
||||||
- zlib
|
- libjpeg8-devel
|
||||||
|
- zlib
|
||||||
|
|
||||||
If you choose to install Synapse without these packages, you will need to reinstall
|
If you choose to install Synapse without these packages, you will need to reinstall
|
||||||
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
|
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
|
||||||
pillow --user``
|
pillow --user``
|
||||||
@ -279,22 +283,22 @@ Synapse requires pip 1.7 or later, so if your OS provides too old a version and
|
|||||||
you get errors about ``error: no such option: --process-dependency-links`` you
|
you get errors about ``error: no such option: --process-dependency-links`` you
|
||||||
may need to manually upgrade it::
|
may need to manually upgrade it::
|
||||||
|
|
||||||
$ sudo pip install --upgrade pip
|
sudo pip install --upgrade pip
|
||||||
|
|
||||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||||
refuse to run until you remove the temporary installation directory it
|
refuse to run until you remove the temporary installation directory it
|
||||||
created. To reset the installation::
|
created. To reset the installation::
|
||||||
|
|
||||||
$ rm -rf /tmp/pip_install_matrix
|
rm -rf /tmp/pip_install_matrix
|
||||||
|
|
||||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||||
happens, you will have to individually install the dependencies which are
|
happens, you will have to individually install the dependencies which are
|
||||||
failing, e.g.::
|
failing, e.g.::
|
||||||
|
|
||||||
$ pip install twisted
|
pip install twisted
|
||||||
|
|
||||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||||
will need to export CFLAGS=-Qunused-arguments.
|
will need to export CFLAGS=-Qunused-arguments.
|
||||||
|
|
||||||
Troubleshooting Running
|
Troubleshooting Running
|
||||||
@ -310,10 +314,11 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
|
|||||||
fix try re-installing from PyPI or directly from
|
fix try re-installing from PyPI or directly from
|
||||||
(https://github.com/pyca/pynacl)::
|
(https://github.com/pyca/pynacl)::
|
||||||
|
|
||||||
$ # Install from PyPI
|
# Install from PyPI
|
||||||
$ pip install --user --upgrade --force pynacl
|
pip install --user --upgrade --force pynacl
|
||||||
$ # Install from github
|
|
||||||
$ pip install --user https://github.com/pyca/pynacl/tarball/master
|
# Install from github
|
||||||
|
pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||||
|
|
||||||
ArchLinux
|
ArchLinux
|
||||||
~~~~~~~~~
|
~~~~~~~~~
|
||||||
@ -321,7 +326,7 @@ ArchLinux
|
|||||||
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
||||||
you will need to explicitly call Python2.7 - either running as::
|
you will need to explicitly call Python2.7 - either running as::
|
||||||
|
|
||||||
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
|
||||||
|
|
||||||
...or by editing synctl with the correct python executable.
|
...or by editing synctl with the correct python executable.
|
||||||
|
|
||||||
@ -331,16 +336,16 @@ Synapse Development
|
|||||||
To check out a synapse for development, clone the git repo into a working
|
To check out a synapse for development, clone the git repo into a working
|
||||||
directory of your choice::
|
directory of your choice::
|
||||||
|
|
||||||
$ git clone https://github.com/matrix-org/synapse.git
|
git clone https://github.com/matrix-org/synapse.git
|
||||||
$ cd synapse
|
cd synapse
|
||||||
|
|
||||||
Synapse has a number of external dependencies, that are easiest
|
Synapse has a number of external dependencies, that are easiest
|
||||||
to install using pip and a virtualenv::
|
to install using pip and a virtualenv::
|
||||||
|
|
||||||
$ virtualenv env
|
virtualenv env
|
||||||
$ source env/bin/activate
|
source env/bin/activate
|
||||||
$ python synapse/python_dependencies.py | xargs -n1 pip install
|
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||||
$ pip install setuptools_trial mock
|
pip install setuptools_trial mock
|
||||||
|
|
||||||
This will run a process of downloading and installing all the needed
|
This will run a process of downloading and installing all the needed
|
||||||
dependencies into a virtual env.
|
dependencies into a virtual env.
|
||||||
@ -348,7 +353,7 @@ dependencies into a virtual env.
|
|||||||
Once this is done, you may wish to run Synapse's unit tests, to
|
Once this is done, you may wish to run Synapse's unit tests, to
|
||||||
check that everything is installed as it should be::
|
check that everything is installed as it should be::
|
||||||
|
|
||||||
$ python setup.py test
|
python setup.py test
|
||||||
|
|
||||||
This should end with a 'PASSED' result::
|
This should end with a 'PASSED' result::
|
||||||
|
|
||||||
@ -360,14 +365,11 @@ This should end with a 'PASSED' result::
|
|||||||
Upgrading an existing Synapse
|
Upgrading an existing Synapse
|
||||||
=============================
|
=============================
|
||||||
|
|
||||||
IMPORTANT: Before upgrading an existing synapse to a new version, please
|
The instructions for upgrading synapse are in `UPGRADE.rst`_.
|
||||||
refer to UPGRADE.rst for any additional instructions.
|
Please check these instructions as upgrading may require extra steps for some
|
||||||
|
versions of synapse.
|
||||||
Otherwise, simply re-install the new codebase over the current one - e.g.
|
|
||||||
by ``pip install --process-dependency-links
|
|
||||||
https://github.com/matrix-org/synapse/tarball/master``
|
|
||||||
if using pip, or by ``git pull`` if running off a git working copy.
|
|
||||||
|
|
||||||
|
.. _UPGRADE.rst: UPGRADE.rst
|
||||||
|
|
||||||
Setting up Federation
|
Setting up Federation
|
||||||
=====================
|
=====================
|
||||||
@ -389,11 +391,11 @@ IDs:
|
|||||||
For the first form, simply pass the required hostname (of the machine) as the
|
For the first form, simply pass the required hostname (of the machine) as the
|
||||||
--server-name parameter::
|
--server-name parameter::
|
||||||
|
|
||||||
$ python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--server-name machine.my.domain.name \
|
--server-name machine.my.domain.name \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||||
|
|
||||||
Alternatively, you can run ``synctl start`` to guide you through the process.
|
Alternatively, you can run ``synctl start`` to guide you through the process.
|
||||||
|
|
||||||
@ -410,11 +412,11 @@ record would then look something like::
|
|||||||
At this point, you should then run the homeserver with the hostname of this
|
At this point, you should then run the homeserver with the hostname of this
|
||||||
SRV record, as that is the name other machines will expect it to have::
|
SRV record, as that is the name other machines will expect it to have::
|
||||||
|
|
||||||
$ python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--server-name YOURDOMAIN \
|
--server-name YOURDOMAIN \
|
||||||
--config-path homeserver.yaml \
|
--config-path homeserver.yaml \
|
||||||
--generate-config
|
--generate-config
|
||||||
$ python -m synapse.app.homeserver --config-path homeserver.yaml
|
python -m synapse.app.homeserver --config-path homeserver.yaml
|
||||||
|
|
||||||
|
|
||||||
You may additionally want to pass one or more "-v" options, in order to
|
You may additionally want to pass one or more "-v" options, in order to
|
||||||
@ -428,7 +430,7 @@ private federation (``localhost:8080``, ``localhost:8081`` and
|
|||||||
``localhost:8082``) which you can then access through the webclient running at
|
``localhost:8082``) which you can then access through the webclient running at
|
||||||
http://localhost:8080. Simply run::
|
http://localhost:8080. Simply run::
|
||||||
|
|
||||||
$ demo/start.sh
|
demo/start.sh
|
||||||
|
|
||||||
This is mainly useful just for development purposes.
|
This is mainly useful just for development purposes.
|
||||||
|
|
||||||
@ -502,10 +504,10 @@ Building Internal API Documentation
|
|||||||
Before building internal API documentation install sphinx and
|
Before building internal API documentation install sphinx and
|
||||||
sphinxcontrib-napoleon::
|
sphinxcontrib-napoleon::
|
||||||
|
|
||||||
$ pip install sphinx
|
pip install sphinx
|
||||||
$ pip install sphinxcontrib-napoleon
|
pip install sphinxcontrib-napoleon
|
||||||
|
|
||||||
Building internal API documentation::
|
Building internal API documentation::
|
||||||
|
|
||||||
$ python setup.py build_sphinx
|
python setup.py build_sphinx
|
||||||
|
|
||||||
|
33
UPGRADE.rst
33
UPGRADE.rst
@ -1,3 +1,36 @@
|
|||||||
|
Upgrading Synapse
|
||||||
|
=================
|
||||||
|
|
||||||
|
Before upgrading check if any special steps are required to upgrade from the
|
||||||
|
what you currently have installed to current version of synapse. The extra
|
||||||
|
instructions that may be required are listed later in this document.
|
||||||
|
|
||||||
|
If synapse was installed in a virtualenv then active that virtualenv before
|
||||||
|
upgrading. If synapse is installed in a virtualenv in ``~/.synapse/`` then run:
|
||||||
|
|
||||||
|
.. code:: bash
|
||||||
|
|
||||||
|
source ~/.synapse/bin/activate
|
||||||
|
|
||||||
|
If synapse was installed using pip then upgrade to the latest version by
|
||||||
|
running:
|
||||||
|
|
||||||
|
.. code:: bash
|
||||||
|
|
||||||
|
pip install --upgrade --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
|
||||||
|
|
||||||
|
If synapse was installed using git then upgrade to the latest version by
|
||||||
|
running:
|
||||||
|
|
||||||
|
.. code:: bash
|
||||||
|
|
||||||
|
# Pull the latest version of the master branch.
|
||||||
|
git pull
|
||||||
|
# Update the versions of synapse's python dependencies.
|
||||||
|
python synapse/python_dependencies.py | xargs -n1 pip install
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Upgrading to v0.9.0
|
Upgrading to v0.9.0
|
||||||
===================
|
===================
|
||||||
|
|
||||||
|
@ -11,7 +11,9 @@ if [ -f $PID_FILE ]; then
|
|||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
find "$DIR" -name "*.log" -delete
|
for port in 8080 8081 8082; do
|
||||||
find "$DIR" -name "*.db" -delete
|
rm -rf $DIR/$port
|
||||||
|
rm -rf $DIR/media_store.$port
|
||||||
|
done
|
||||||
|
|
||||||
rm -rf $DIR/etc
|
rm -rf $DIR/etc
|
||||||
|
@ -8,14 +8,6 @@ cd "$DIR/.."
|
|||||||
|
|
||||||
mkdir -p demo/etc
|
mkdir -p demo/etc
|
||||||
|
|
||||||
# Check the --no-rate-limit param
|
|
||||||
PARAMS=""
|
|
||||||
if [ $# -eq 1 ]; then
|
|
||||||
if [ $1 = "--no-rate-limit" ]; then
|
|
||||||
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
export PYTHONPATH=$(readlink -f $(pwd))
|
export PYTHONPATH=$(readlink -f $(pwd))
|
||||||
|
|
||||||
|
|
||||||
@ -31,10 +23,20 @@ for port in 8080 8081 8082; do
|
|||||||
#rm $DIR/etc/$port.config
|
#rm $DIR/etc/$port.config
|
||||||
python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--generate-config \
|
--generate-config \
|
||||||
--enable_registration \
|
|
||||||
-H "localhost:$https_port" \
|
-H "localhost:$https_port" \
|
||||||
--config-path "$DIR/etc/$port.config" \
|
--config-path "$DIR/etc/$port.config" \
|
||||||
|
|
||||||
|
# Check script parameters
|
||||||
|
if [ $# -eq 1 ]; then
|
||||||
|
if [ $1 = "--no-rate-limit" ]; then
|
||||||
|
# Set high limits in config file to disable rate limiting
|
||||||
|
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
|
||||||
|
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
|
||||||
|
|
||||||
python -m synapse.app.homeserver \
|
python -m synapse.app.homeserver \
|
||||||
--config-path "$DIR/etc/$port.config" \
|
--config-path "$DIR/etc/$port.config" \
|
||||||
-D \
|
-D \
|
||||||
|
@ -55,9 +55,8 @@ Porting from SQLite
|
|||||||
Overview
|
Overview
|
||||||
~~~~~~~~
|
~~~~~~~~
|
||||||
|
|
||||||
The script ``port_from_sqlite_to_postgres.py`` allows porting an existing
|
The script ``synapse_port_db`` allows porting an existing synapse server
|
||||||
synapse server backed by SQLite to using PostgreSQL. This is done in as a two
|
backed by SQLite to using PostgreSQL. This is done in as a two phase process:
|
||||||
phase process:
|
|
||||||
|
|
||||||
1. Copy the existing SQLite database to a separate location (while the server
|
1. Copy the existing SQLite database to a separate location (while the server
|
||||||
is down) and running the port script against that offline database.
|
is down) and running the port script against that offline database.
|
||||||
@ -86,8 +85,7 @@ Assuming your new config file (as described in the section *Synapse config*)
|
|||||||
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
|
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
|
||||||
``homeserver.db.snapshot`` then simply run::
|
``homeserver.db.snapshot`` then simply run::
|
||||||
|
|
||||||
python scripts/port_from_sqlite_to_postgres.py \
|
synapse_port_db --sqlite-database homeserver.db.snapshot \
|
||||||
--sqlite-database homeserver.db.snapshot \
|
|
||||||
--postgres-config homeserver-postgres.yaml
|
--postgres-config homeserver-postgres.yaml
|
||||||
|
|
||||||
The flag ``--curses`` displays a coloured curses progress UI.
|
The flag ``--curses`` displays a coloured curses progress UI.
|
||||||
@ -100,8 +98,7 @@ To complete the conversion shut down the synapse server and run the port
|
|||||||
script one last time, e.g. if the SQLite database is at ``homeserver.db``
|
script one last time, e.g. if the SQLite database is at ``homeserver.db``
|
||||||
run::
|
run::
|
||||||
|
|
||||||
python scripts/port_from_sqlite_to_postgres.py \
|
synapse_port_db --sqlite-database homeserver.db \
|
||||||
--sqlite-database homeserver.db \
|
|
||||||
--postgres-config database_config.yaml
|
--postgres-config database_config.yaml
|
||||||
|
|
||||||
Once that has completed, change the synapse config to point at the PostgreSQL
|
Once that has completed, change the synapse config to point at the PostgreSQL
|
||||||
|
@ -1,21 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# This is will prepare a synapse database for running with v0.0.1 of synapse.
|
|
||||||
# It will store all the user information, but will *delete* all messages and
|
|
||||||
# room data.
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
cp "$1" "$1.bak"
|
|
||||||
|
|
||||||
DUMP=$(sqlite3 "$1" << 'EOF'
|
|
||||||
.dump users
|
|
||||||
.dump access_tokens
|
|
||||||
.dump presence
|
|
||||||
.dump profiles
|
|
||||||
EOF
|
|
||||||
)
|
|
||||||
|
|
||||||
rm "$1"
|
|
||||||
|
|
||||||
sqlite3 "$1" <<< "$DUMP"
|
|
@ -1,21 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# This is will prepare a synapse database for running with v0.5.0 of synapse.
|
|
||||||
# It will store all the user information, but will *delete* all messages and
|
|
||||||
# room data.
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
cp "$1" "$1.bak"
|
|
||||||
|
|
||||||
DUMP=$(sqlite3 "$1" << 'EOF'
|
|
||||||
.dump users
|
|
||||||
.dump access_tokens
|
|
||||||
.dump presence
|
|
||||||
.dump profiles
|
|
||||||
EOF
|
|
||||||
)
|
|
||||||
|
|
||||||
rm "$1"
|
|
||||||
|
|
||||||
sqlite3 "$1" <<< "$DUMP"
|
|
@ -412,14 +412,17 @@ class Porter(object):
|
|||||||
self._convert_rows("sent_transactions", headers, rows)
|
self._convert_rows("sent_transactions", headers, rows)
|
||||||
|
|
||||||
inserted_rows = len(rows)
|
inserted_rows = len(rows)
|
||||||
max_inserted_rowid = max(r[0] for r in rows)
|
if inserted_rows:
|
||||||
|
max_inserted_rowid = max(r[0] for r in rows)
|
||||||
|
|
||||||
def insert(txn):
|
def insert(txn):
|
||||||
self.postgres_store.insert_many_txn(
|
self.postgres_store.insert_many_txn(
|
||||||
txn, "sent_transactions", headers[1:], rows
|
txn, "sent_transactions", headers[1:], rows
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.postgres_store.execute(insert)
|
yield self.postgres_store.execute(insert)
|
||||||
|
else:
|
||||||
|
max_inserted_rowid = 0
|
||||||
|
|
||||||
def get_start_id(txn):
|
def get_start_id(txn):
|
||||||
txn.execute(
|
txn.execute(
|
@ -1,331 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
from synapse.storage import SCHEMA_VERSION, read_schema
|
|
||||||
from synapse.storage._base import SQLBaseStore
|
|
||||||
from synapse.storage.signatures import SignatureStore
|
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
|
||||||
|
|
||||||
from syutil.base64util import encode_base64, decode_base64
|
|
||||||
|
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
|
||||||
|
|
||||||
from synapse.events.builder import EventBuilder
|
|
||||||
from synapse.events.utils import prune_event
|
|
||||||
|
|
||||||
from synapse.crypto.event_signing import check_event_content_hash
|
|
||||||
|
|
||||||
from syutil.crypto.jsonsign import (
|
|
||||||
verify_signed_json, SignatureVerifyException,
|
|
||||||
)
|
|
||||||
from syutil.crypto.signing_key import decode_verify_key_bytes
|
|
||||||
|
|
||||||
from syutil.jsonutil import encode_canonical_json
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
# import dns.resolver
|
|
||||||
import hashlib
|
|
||||||
import httplib
|
|
||||||
import json
|
|
||||||
import sqlite3
|
|
||||||
import syutil
|
|
||||||
import urllib2
|
|
||||||
|
|
||||||
|
|
||||||
delta_sql = """
|
|
||||||
CREATE TABLE IF NOT EXISTS event_json(
|
|
||||||
event_id TEXT NOT NULL,
|
|
||||||
room_id TEXT NOT NULL,
|
|
||||||
internal_metadata NOT NULL,
|
|
||||||
json BLOB NOT NULL,
|
|
||||||
CONSTRAINT ev_j_uniq UNIQUE (event_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
|
|
||||||
|
|
||||||
PRAGMA user_version = 10;
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class Store(object):
|
|
||||||
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
|
|
||||||
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
|
|
||||||
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
|
|
||||||
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
|
|
||||||
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
|
|
||||||
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
|
|
||||||
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
|
|
||||||
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
|
||||||
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
|
|
||||||
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
|
|
||||||
|
|
||||||
def _generate_event_json(self, txn, rows):
|
|
||||||
events = []
|
|
||||||
for row in rows:
|
|
||||||
d = dict(row)
|
|
||||||
|
|
||||||
d.pop("stream_ordering", None)
|
|
||||||
d.pop("topological_ordering", None)
|
|
||||||
d.pop("processed", None)
|
|
||||||
|
|
||||||
if "origin_server_ts" not in d:
|
|
||||||
d["origin_server_ts"] = d.pop("ts", 0)
|
|
||||||
else:
|
|
||||||
d.pop("ts", 0)
|
|
||||||
|
|
||||||
d.pop("prev_state", None)
|
|
||||||
d.update(json.loads(d.pop("unrecognized_keys")))
|
|
||||||
|
|
||||||
d["sender"] = d.pop("user_id")
|
|
||||||
|
|
||||||
d["content"] = json.loads(d["content"])
|
|
||||||
|
|
||||||
if "age_ts" not in d:
|
|
||||||
# For compatibility
|
|
||||||
d["age_ts"] = d.get("origin_server_ts", 0)
|
|
||||||
|
|
||||||
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
|
|
||||||
|
|
||||||
outlier = d.pop("outlier", False)
|
|
||||||
|
|
||||||
# d.pop("membership", None)
|
|
||||||
|
|
||||||
d.pop("state_hash", None)
|
|
||||||
|
|
||||||
d.pop("replaces_state", None)
|
|
||||||
|
|
||||||
b = EventBuilder(d)
|
|
||||||
b.internal_metadata.outlier = outlier
|
|
||||||
|
|
||||||
events.append(b)
|
|
||||||
|
|
||||||
for i, ev in enumerate(events):
|
|
||||||
signatures = self._get_event_signatures_txn(
|
|
||||||
txn, ev.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
ev.signatures = {
|
|
||||||
n: {
|
|
||||||
k: encode_base64(v) for k, v in s.items()
|
|
||||||
}
|
|
||||||
for n, s in signatures.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
hashes = self._get_event_content_hashes_txn(
|
|
||||||
txn, ev.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
ev.hashes = {
|
|
||||||
k: encode_base64(v) for k, v in hashes.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
prevs = self._get_prev_events_and_state(txn, ev.event_id)
|
|
||||||
|
|
||||||
ev.prev_events = [
|
|
||||||
(e_id, h)
|
|
||||||
for e_id, h, is_state in prevs
|
|
||||||
if is_state == 0
|
|
||||||
]
|
|
||||||
|
|
||||||
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
|
|
||||||
|
|
||||||
hashes = dict(ev.auth_events)
|
|
||||||
|
|
||||||
for e_id, hash in ev.prev_events:
|
|
||||||
if e_id in hashes and not hash:
|
|
||||||
hash.update(hashes[e_id])
|
|
||||||
#
|
|
||||||
# if hasattr(ev, "state_key"):
|
|
||||||
# ev.prev_state = [
|
|
||||||
# (e_id, h)
|
|
||||||
# for e_id, h, is_state in prevs
|
|
||||||
# if is_state == 1
|
|
||||||
# ]
|
|
||||||
|
|
||||||
return [e.build() for e in events]
|
|
||||||
|
|
||||||
|
|
||||||
store = Store()
|
|
||||||
|
|
||||||
|
|
||||||
# def get_key(server_name):
|
|
||||||
# print "Getting keys for: %s" % (server_name,)
|
|
||||||
# targets = []
|
|
||||||
# if ":" in server_name:
|
|
||||||
# target, port = server_name.split(":")
|
|
||||||
# targets.append((target, int(port)))
|
|
||||||
# try:
|
|
||||||
# answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
|
|
||||||
# for srv in answers:
|
|
||||||
# targets.append((srv.target, srv.port))
|
|
||||||
# except dns.resolver.NXDOMAIN:
|
|
||||||
# targets.append((server_name, 8448))
|
|
||||||
# except:
|
|
||||||
# print "Failed to lookup keys for %s" % (server_name,)
|
|
||||||
# return {}
|
|
||||||
#
|
|
||||||
# for target, port in targets:
|
|
||||||
# url = "https://%s:%i/_matrix/key/v1" % (target, port)
|
|
||||||
# try:
|
|
||||||
# keys = json.load(urllib2.urlopen(url, timeout=2))
|
|
||||||
# verify_keys = {}
|
|
||||||
# for key_id, key_base64 in keys["verify_keys"].items():
|
|
||||||
# verify_key = decode_verify_key_bytes(
|
|
||||||
# key_id, decode_base64(key_base64)
|
|
||||||
# )
|
|
||||||
# verify_signed_json(keys, server_name, verify_key)
|
|
||||||
# verify_keys[key_id] = verify_key
|
|
||||||
# print "Got keys for: %s" % (server_name,)
|
|
||||||
# return verify_keys
|
|
||||||
# except urllib2.URLError:
|
|
||||||
# pass
|
|
||||||
# except urllib2.HTTPError:
|
|
||||||
# pass
|
|
||||||
# except httplib.HTTPException:
|
|
||||||
# pass
|
|
||||||
#
|
|
||||||
# print "Failed to get keys for %s" % (server_name,)
|
|
||||||
# return {}
|
|
||||||
|
|
||||||
|
|
||||||
def reinsert_events(cursor, server_name, signing_key):
|
|
||||||
print "Running delta: v10"
|
|
||||||
|
|
||||||
cursor.executescript(delta_sql)
|
|
||||||
|
|
||||||
cursor.execute(
|
|
||||||
"SELECT * FROM events ORDER BY rowid ASC"
|
|
||||||
)
|
|
||||||
|
|
||||||
print "Getting events..."
|
|
||||||
|
|
||||||
rows = store.cursor_to_dict(cursor)
|
|
||||||
|
|
||||||
events = store._generate_event_json(cursor, rows)
|
|
||||||
|
|
||||||
print "Got events from DB."
|
|
||||||
|
|
||||||
algorithms = {
|
|
||||||
"sha256": hashlib.sha256,
|
|
||||||
}
|
|
||||||
|
|
||||||
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
|
|
||||||
verify_key = signing_key.verify_key
|
|
||||||
verify_key.alg = signing_key.alg
|
|
||||||
verify_key.version = signing_key.version
|
|
||||||
|
|
||||||
server_keys = {
|
|
||||||
server_name: {
|
|
||||||
key_id: verify_key
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
N = len(events)
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
if i % 100 == 0:
|
|
||||||
print "Processed: %d/%d events" % (i,N,)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
# for alg_name in event.hashes:
|
|
||||||
# if check_event_content_hash(event, algorithms[alg_name]):
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# pass
|
|
||||||
# print "FAIL content hash %s %s" % (alg_name, event.event_id, )
|
|
||||||
|
|
||||||
have_own_correctly_signed = False
|
|
||||||
for host, sigs in event.signatures.items():
|
|
||||||
pruned = prune_event(event)
|
|
||||||
|
|
||||||
for key_id in sigs:
|
|
||||||
if host not in server_keys:
|
|
||||||
server_keys[host] = {} # get_key(host)
|
|
||||||
if key_id in server_keys[host]:
|
|
||||||
try:
|
|
||||||
verify_signed_json(
|
|
||||||
pruned.get_pdu_json(),
|
|
||||||
host,
|
|
||||||
server_keys[host][key_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
if host == server_name:
|
|
||||||
have_own_correctly_signed = True
|
|
||||||
except SignatureVerifyException:
|
|
||||||
print "FAIL signature check %s %s" % (
|
|
||||||
key_id, event.event_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: Re sign with our own server key
|
|
||||||
if not have_own_correctly_signed:
|
|
||||||
sigs = compute_event_signature(event, server_name, signing_key)
|
|
||||||
event.signatures.update(sigs)
|
|
||||||
|
|
||||||
pruned = prune_event(event)
|
|
||||||
|
|
||||||
for key_id in event.signatures[server_name]:
|
|
||||||
verify_signed_json(
|
|
||||||
pruned.get_pdu_json(),
|
|
||||||
server_name,
|
|
||||||
server_keys[server_name][key_id]
|
|
||||||
)
|
|
||||||
|
|
||||||
event_json = encode_canonical_json(
|
|
||||||
event.get_dict()
|
|
||||||
).decode("UTF-8")
|
|
||||||
|
|
||||||
metadata_json = encode_canonical_json(
|
|
||||||
event.internal_metadata.get_dict()
|
|
||||||
).decode("UTF-8")
|
|
||||||
|
|
||||||
store._simple_insert_txn(
|
|
||||||
cursor,
|
|
||||||
table="event_json",
|
|
||||||
values={
|
|
||||||
"event_id": event.event_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"internal_metadata": metadata_json,
|
|
||||||
"json": event_json,
|
|
||||||
},
|
|
||||||
or_replace=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(database, server_name, signing_key):
|
|
||||||
conn = sqlite3.connect(database)
|
|
||||||
cursor = conn.cursor()
|
|
||||||
|
|
||||||
# Do other deltas:
|
|
||||||
cursor.execute("PRAGMA user_version")
|
|
||||||
row = cursor.fetchone()
|
|
||||||
|
|
||||||
if row and row[0]:
|
|
||||||
user_version = row[0]
|
|
||||||
# Run every version since after the current version.
|
|
||||||
for v in range(user_version + 1, 10):
|
|
||||||
print "Running delta: %d" % (v,)
|
|
||||||
sql_script = read_schema("delta/v%d" % (v,))
|
|
||||||
cursor.executescript(sql_script)
|
|
||||||
|
|
||||||
reinsert_events(cursor, server_name, signing_key)
|
|
||||||
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
print "Success!"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
parser.add_argument("database")
|
|
||||||
parser.add_argument("server_name")
|
|
||||||
parser.add_argument(
|
|
||||||
"signing_key", type=argparse.FileType('r'),
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
signing_key = syutil.crypto.signing_key.read_signing_keys(
|
|
||||||
args.signing_key
|
|
||||||
)
|
|
||||||
|
|
||||||
main(args.database, args.server_name, signing_key[0])
|
|
@ -16,3 +16,6 @@ ignore =
|
|||||||
docs/*
|
docs/*
|
||||||
pylint.cfg
|
pylint.cfg
|
||||||
tox.ini
|
tox.ini
|
||||||
|
|
||||||
|
[flake8]
|
||||||
|
max-line-length = 90
|
||||||
|
2
setup.py
2
setup.py
@ -48,7 +48,7 @@ setup(
|
|||||||
description="Reference Synapse Home Server",
|
description="Reference Synapse Home Server",
|
||||||
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||||
setup_requires=[
|
setup_requires=[
|
||||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
"Twisted>=15.1.0", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||||
"setuptools_trial",
|
"setuptools_trial",
|
||||||
"mock"
|
"mock"
|
||||||
],
|
],
|
||||||
|
@ -16,4 +16,4 @@
|
|||||||
""" This is a reference implementation of a Matrix home server.
|
""" This is a reference implementation of a Matrix home server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__version__ = "0.9.3"
|
__version__ = "0.10.0"
|
||||||
|
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
AuthEventTypes = (
|
AuthEventTypes = (
|
||||||
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
|
||||||
EventTypes.JoinRules,
|
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -44,6 +44,11 @@ class Auth(object):
|
|||||||
def check(self, event, auth_events):
|
def check(self, event, auth_events):
|
||||||
""" Checks if this event is correctly authed.
|
""" Checks if this event is correctly authed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event: the event being checked.
|
||||||
|
auth_events (dict: event-key -> event): the existing room state.
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the auth checks pass.
|
True if the auth checks pass.
|
||||||
"""
|
"""
|
||||||
@ -187,6 +192,9 @@ class Auth(object):
|
|||||||
join_rule = JoinRules.INVITE
|
join_rule = JoinRules.INVITE
|
||||||
|
|
||||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||||
|
target_level = self._get_user_power_level(
|
||||||
|
target_user_id, auth_events
|
||||||
|
)
|
||||||
|
|
||||||
# FIXME (erikj): What should we do here as the default?
|
# FIXME (erikj): What should we do here as the default?
|
||||||
ban_level = self._get_named_level(auth_events, "ban", 50)
|
ban_level = self._get_named_level(auth_events, "ban", 50)
|
||||||
@ -258,12 +266,12 @@ class Auth(object):
|
|||||||
elif target_user_id != event.user_id:
|
elif target_user_id != event.user_id:
|
||||||
kick_level = self._get_named_level(auth_events, "kick", 50)
|
kick_level = self._get_named_level(auth_events, "kick", 50)
|
||||||
|
|
||||||
if user_level < kick_level:
|
if user_level < kick_level or user_level <= target_level:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "You cannot kick user %s." % target_user_id
|
403, "You cannot kick user %s." % target_user_id
|
||||||
)
|
)
|
||||||
elif Membership.BAN == membership:
|
elif Membership.BAN == membership:
|
||||||
if user_level < ban_level:
|
if user_level < ban_level or user_level <= target_level:
|
||||||
raise AuthError(403, "You don't have permission to ban")
|
raise AuthError(403, "You don't have permission to ban")
|
||||||
else:
|
else:
|
||||||
raise AuthError(500, "Unknown membership %s" % membership)
|
raise AuthError(500, "Unknown membership %s" % membership)
|
||||||
@ -316,7 +324,7 @@ class Auth(object):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple : of UserID and device string:
|
tuple : of UserID and device string:
|
||||||
User ID object of the user making the request
|
User ID object of the user making the request
|
||||||
Client ID object of the client instance the user is using
|
ClientInfo object of the client instance the user is using
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
@ -344,12 +352,14 @@ class Auth(object):
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise KeyError
|
raise KeyError
|
||||||
|
|
||||||
|
request.authenticated_entity = user_id
|
||||||
|
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(UserID.from_string(user_id), ClientInfo("", ""))
|
(UserID.from_string(user_id), ClientInfo("", ""))
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass # normal users won't have this query parameter set
|
pass # normal users won't have the user_id query parameter set.
|
||||||
|
|
||||||
user_info = yield self.get_user_by_token(access_token)
|
user_info = yield self.get_user_by_token(access_token)
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
@ -417,6 +427,7 @@ class Auth(object):
|
|||||||
"Unrecognised access token.",
|
"Unrecognised access token.",
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
|
request.authenticated_entity = service.sender
|
||||||
defer.returnValue(service)
|
defer.returnValue(service)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
@ -518,23 +529,22 @@ class Auth(object):
|
|||||||
|
|
||||||
# Check state_key
|
# Check state_key
|
||||||
if hasattr(event, "state_key"):
|
if hasattr(event, "state_key"):
|
||||||
if not event.state_key.startswith("_"):
|
if event.state_key.startswith("@"):
|
||||||
if event.state_key.startswith("@"):
|
if event.state_key != event.user_id:
|
||||||
if event.state_key != event.user_id:
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"You are not allowed to set others state"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sender_domain = UserID.from_string(
|
||||||
|
event.user_id
|
||||||
|
).domain
|
||||||
|
|
||||||
|
if sender_domain != event.state_key:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"You are not allowed to set others state"
|
"You are not allowed to set others state"
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
sender_domain = UserID.from_string(
|
|
||||||
event.user_id
|
|
||||||
).domain
|
|
||||||
|
|
||||||
if sender_domain != event.state_key:
|
|
||||||
raise AuthError(
|
|
||||||
403,
|
|
||||||
"You are not allowed to set others state"
|
|
||||||
)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -573,25 +583,26 @@ class Auth(object):
|
|||||||
|
|
||||||
# Check other levels:
|
# Check other levels:
|
||||||
levels_to_check = [
|
levels_to_check = [
|
||||||
("users_default", []),
|
("users_default", None),
|
||||||
("events_default", []),
|
("events_default", None),
|
||||||
("ban", []),
|
("state_default", None),
|
||||||
("redact", []),
|
("ban", None),
|
||||||
("kick", []),
|
("redact", None),
|
||||||
("invite", []),
|
("kick", None),
|
||||||
|
("invite", None),
|
||||||
]
|
]
|
||||||
|
|
||||||
old_list = current_state.content.get("users")
|
old_list = current_state.content.get("users")
|
||||||
for user in set(old_list.keys() + user_list.keys()):
|
for user in set(old_list.keys() + user_list.keys()):
|
||||||
levels_to_check.append(
|
levels_to_check.append(
|
||||||
(user, ["users"])
|
(user, "users")
|
||||||
)
|
)
|
||||||
|
|
||||||
old_list = current_state.content.get("events")
|
old_list = current_state.content.get("events")
|
||||||
new_list = event.content.get("events")
|
new_list = event.content.get("events")
|
||||||
for ev_id in set(old_list.keys() + new_list.keys()):
|
for ev_id in set(old_list.keys() + new_list.keys()):
|
||||||
levels_to_check.append(
|
levels_to_check.append(
|
||||||
(ev_id, ["events"])
|
(ev_id, "events")
|
||||||
)
|
)
|
||||||
|
|
||||||
old_state = current_state.content
|
old_state = current_state.content
|
||||||
@ -599,12 +610,10 @@ class Auth(object):
|
|||||||
|
|
||||||
for level_to_check, dir in levels_to_check:
|
for level_to_check, dir in levels_to_check:
|
||||||
old_loc = old_state
|
old_loc = old_state
|
||||||
for d in dir:
|
|
||||||
old_loc = old_loc.get(d, {})
|
|
||||||
|
|
||||||
new_loc = new_state
|
new_loc = new_state
|
||||||
for d in dir:
|
if dir:
|
||||||
new_loc = new_loc.get(d, {})
|
old_loc = old_loc.get(dir, {})
|
||||||
|
new_loc = new_loc.get(dir, {})
|
||||||
|
|
||||||
if level_to_check in old_loc:
|
if level_to_check in old_loc:
|
||||||
old_level = int(old_loc[level_to_check])
|
old_level = int(old_loc[level_to_check])
|
||||||
@ -620,6 +629,14 @@ class Auth(object):
|
|||||||
if new_level == old_level:
|
if new_level == old_level:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if dir == "users" and level_to_check != event.user_id:
|
||||||
|
if old_level == user_level:
|
||||||
|
raise AuthError(
|
||||||
|
403,
|
||||||
|
"You don't have permission to remove ops level equal "
|
||||||
|
"to your own"
|
||||||
|
)
|
||||||
|
|
||||||
if old_level > user_level or new_level > user_level:
|
if old_level > user_level or new_level > user_level:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
|
@ -75,6 +75,10 @@ class EventTypes(object):
|
|||||||
Redaction = "m.room.redaction"
|
Redaction = "m.room.redaction"
|
||||||
Feedback = "m.room.message.feedback"
|
Feedback = "m.room.message.feedback"
|
||||||
|
|
||||||
|
RoomHistoryVisibility = "m.room.history_visibility"
|
||||||
|
CanonicalAlias = "m.room.canonical_alias"
|
||||||
|
RoomAvatar = "m.room.avatar"
|
||||||
|
|
||||||
# These are used for validation
|
# These are used for validation
|
||||||
Message = "m.room.message"
|
Message = "m.room.message"
|
||||||
Topic = "m.room.topic"
|
Topic = "m.room.topic"
|
||||||
@ -85,3 +89,8 @@ class RejectedReason(object):
|
|||||||
AUTH_ERROR = "auth_error"
|
AUTH_ERROR = "auth_error"
|
||||||
REPLACED = "replaced"
|
REPLACED = "replaced"
|
||||||
NOT_ANCESTOR = "not_ancestor"
|
NOT_ANCESTOR = "not_ancestor"
|
||||||
|
|
||||||
|
|
||||||
|
class RoomCreationPreset(object):
|
||||||
|
PRIVATE_CHAT = "private_chat"
|
||||||
|
PUBLIC_CHAT = "public_chat"
|
||||||
|
@ -40,6 +40,7 @@ class Codes(object):
|
|||||||
TOO_LARGE = "M_TOO_LARGE"
|
TOO_LARGE = "M_TOO_LARGE"
|
||||||
EXCLUSIVE = "M_EXCLUSIVE"
|
EXCLUSIVE = "M_EXCLUSIVE"
|
||||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||||
|
THREEPID_IN_USE = "THREEPID_IN_USE"
|
||||||
|
|
||||||
|
|
||||||
class CodeMessageException(RuntimeError):
|
class CodeMessageException(RuntimeError):
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.dont_write_bytecode = True
|
sys.dont_write_bytecode = True
|
||||||
from synapse.python_dependencies import check_requirements
|
from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
check_requirements()
|
check_requirements()
|
||||||
@ -97,9 +97,25 @@ class SynapseHomeServer(HomeServer):
|
|||||||
return JsonResource(self)
|
return JsonResource(self)
|
||||||
|
|
||||||
def build_resource_for_web_client(self):
|
def build_resource_for_web_client(self):
|
||||||
import syweb
|
webclient_path = self.get_config().web_client_location
|
||||||
syweb_path = os.path.dirname(syweb.__file__)
|
if not webclient_path:
|
||||||
webclient_path = os.path.join(syweb_path, "webclient")
|
try:
|
||||||
|
import syweb
|
||||||
|
except ImportError:
|
||||||
|
quit_with_error(
|
||||||
|
"Could not find a webclient.\n\n"
|
||||||
|
"Please either install the matrix-angular-sdk or configure\n"
|
||||||
|
"the location of the source to serve via the configuration\n"
|
||||||
|
"option `web_client_location`\n\n"
|
||||||
|
"To install the `matrix-angular-sdk` via pip, run:\n\n"
|
||||||
|
" pip install '%(dep)s'\n"
|
||||||
|
"\n"
|
||||||
|
"You can also disable hosting of the webclient via the\n"
|
||||||
|
"configuration option `web_client`\n"
|
||||||
|
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
|
||||||
|
)
|
||||||
|
syweb_path = os.path.dirname(syweb.__file__)
|
||||||
|
webclient_path = os.path.join(syweb_path, "webclient")
|
||||||
# GZip is disabled here due to
|
# GZip is disabled here due to
|
||||||
# https://twistedmatrix.com/trac/ticket/7678
|
# https://twistedmatrix.com/trac/ticket/7678
|
||||||
# (It can stay enabled for the API resources: they call
|
# (It can stay enabled for the API resources: they call
|
||||||
@ -259,11 +275,10 @@ class SynapseHomeServer(HomeServer):
|
|||||||
|
|
||||||
def quit_with_error(error_string):
|
def quit_with_error(error_string):
|
||||||
message_lines = error_string.split("\n")
|
message_lines = error_string.split("\n")
|
||||||
line_length = max([len(l) for l in message_lines]) + 2
|
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
|
||||||
sys.stderr.write("*" * line_length + '\n')
|
sys.stderr.write("*" * line_length + '\n')
|
||||||
for line in message_lines:
|
for line in message_lines:
|
||||||
if line.strip():
|
sys.stderr.write(" %s\n" % (line.rstrip(),))
|
||||||
sys.stderr.write(" %s\n" % (line.strip(),))
|
|
||||||
sys.stderr.write("*" * line_length + '\n')
|
sys.stderr.write("*" * line_length + '\n')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
@ -326,7 +341,7 @@ def get_version_string():
|
|||||||
)
|
)
|
||||||
).encode("ascii")
|
).encode("ascii")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn("Failed to check for git repository: %s", e)
|
logger.info("Failed to check for git repository: %s", e)
|
||||||
|
|
||||||
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
|
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
|
||||||
|
|
||||||
@ -657,7 +672,8 @@ def run(hs):
|
|||||||
|
|
||||||
if hs.config.daemonize:
|
if hs.config.daemonize:
|
||||||
|
|
||||||
print hs.config.pid_file
|
if hs.config.print_pidfile:
|
||||||
|
print hs.config.pid_file
|
||||||
|
|
||||||
daemon = Daemonize(
|
daemon = Daemonize(
|
||||||
app="synapse-homeserver",
|
app="synapse-homeserver",
|
||||||
|
30
synapse/config/__main__.py
Normal file
30
synapse/config/__main__.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
from homeserver import HomeServerConfig
|
||||||
|
|
||||||
|
action = sys.argv[1]
|
||||||
|
|
||||||
|
if action == "read":
|
||||||
|
key = sys.argv[2]
|
||||||
|
config = HomeServerConfig.load_config("", sys.argv[3:])
|
||||||
|
|
||||||
|
print getattr(config, key)
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
sys.stderr.write("Unknown command %r\n" % (action,))
|
||||||
|
sys.exit(1)
|
@ -131,71 +131,107 @@ class Config(object):
|
|||||||
"-c", "--config-path",
|
"-c", "--config-path",
|
||||||
action="append",
|
action="append",
|
||||||
metavar="CONFIG_FILE",
|
metavar="CONFIG_FILE",
|
||||||
help="Specify config file"
|
help="Specify config file. Can be given multiple times and"
|
||||||
|
" may specify directories containing *.yaml files."
|
||||||
)
|
)
|
||||||
config_parser.add_argument(
|
config_parser.add_argument(
|
||||||
"--generate-config",
|
"--generate-config",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Generate a config file for the server name"
|
help="Generate a config file for the server name"
|
||||||
)
|
)
|
||||||
|
config_parser.add_argument(
|
||||||
|
"--generate-keys",
|
||||||
|
action="store_true",
|
||||||
|
help="Generate any missing key files then exit"
|
||||||
|
)
|
||||||
|
config_parser.add_argument(
|
||||||
|
"--keys-directory",
|
||||||
|
metavar="DIRECTORY",
|
||||||
|
help="Used with 'generate-*' options to specify where files such as"
|
||||||
|
" certs and signing keys should be stored in, unless explicitly"
|
||||||
|
" specified in the config."
|
||||||
|
)
|
||||||
config_parser.add_argument(
|
config_parser.add_argument(
|
||||||
"-H", "--server-name",
|
"-H", "--server-name",
|
||||||
help="The server name to generate a config file for"
|
help="The server name to generate a config file for"
|
||||||
)
|
)
|
||||||
config_args, remaining_args = config_parser.parse_known_args(argv)
|
config_args, remaining_args = config_parser.parse_known_args(argv)
|
||||||
|
|
||||||
|
generate_keys = config_args.generate_keys
|
||||||
|
|
||||||
|
config_files = []
|
||||||
|
if config_args.config_path:
|
||||||
|
for config_path in config_args.config_path:
|
||||||
|
if os.path.isdir(config_path):
|
||||||
|
# We accept specifying directories as config paths, we search
|
||||||
|
# inside that directory for all files matching *.yaml, and then
|
||||||
|
# we apply them in *sorted* order.
|
||||||
|
files = []
|
||||||
|
for entry in os.listdir(config_path):
|
||||||
|
entry_path = os.path.join(config_path, entry)
|
||||||
|
if not os.path.isfile(entry_path):
|
||||||
|
print (
|
||||||
|
"Found subdirectory in config directory: %r. IGNORING."
|
||||||
|
) % (entry_path, )
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not entry.endswith(".yaml"):
|
||||||
|
print (
|
||||||
|
"Found file in config directory that does not"
|
||||||
|
" end in '.yaml': %r. IGNORING."
|
||||||
|
) % (entry_path, )
|
||||||
|
continue
|
||||||
|
|
||||||
|
files.append(entry_path)
|
||||||
|
|
||||||
|
config_files.extend(sorted(files))
|
||||||
|
else:
|
||||||
|
config_files.append(config_path)
|
||||||
|
|
||||||
if config_args.generate_config:
|
if config_args.generate_config:
|
||||||
if not config_args.config_path:
|
if not config_files:
|
||||||
config_parser.error(
|
config_parser.error(
|
||||||
"Must supply a config file.\nA config file can be automatically"
|
"Must supply a config file.\nA config file can be automatically"
|
||||||
" generated using \"--generate-config -H SERVER_NAME"
|
" generated using \"--generate-config -H SERVER_NAME"
|
||||||
" -c CONFIG-FILE\""
|
" -c CONFIG-FILE\""
|
||||||
)
|
)
|
||||||
|
(config_path,) = config_files
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
if config_args.keys_directory:
|
||||||
|
config_dir_path = config_args.keys_directory
|
||||||
|
else:
|
||||||
|
config_dir_path = os.path.dirname(config_path)
|
||||||
|
config_dir_path = os.path.abspath(config_dir_path)
|
||||||
|
|
||||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
server_name = config_args.server_name
|
||||||
config_dir_path = os.path.abspath(config_dir_path)
|
if not server_name:
|
||||||
|
print "Must specify a server_name to a generate config for."
|
||||||
server_name = config_args.server_name
|
|
||||||
if not server_name:
|
|
||||||
print "Must specify a server_name to a generate config for."
|
|
||||||
sys.exit(1)
|
|
||||||
(config_path,) = config_args.config_path
|
|
||||||
if not os.path.exists(config_dir_path):
|
|
||||||
os.makedirs(config_dir_path)
|
|
||||||
if os.path.exists(config_path):
|
|
||||||
print "Config file %r already exists" % (config_path,)
|
|
||||||
yaml_config = cls.read_config_file(config_path)
|
|
||||||
yaml_name = yaml_config["server_name"]
|
|
||||||
if server_name != yaml_name:
|
|
||||||
print (
|
|
||||||
"Config file %r has a different server_name: "
|
|
||||||
" %r != %r" % (config_path, server_name, yaml_name)
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
config_bytes, config = obj.generate_config(
|
if not os.path.exists(config_dir_path):
|
||||||
config_dir_path, server_name
|
os.makedirs(config_dir_path)
|
||||||
)
|
with open(config_path, "wb") as config_file:
|
||||||
config.update(yaml_config)
|
config_bytes, config = obj.generate_config(
|
||||||
print "Generating any missing keys for %r" % (server_name,)
|
config_dir_path, server_name
|
||||||
obj.invoke_all("generate_files", config)
|
)
|
||||||
sys.exit(0)
|
obj.invoke_all("generate_files", config)
|
||||||
with open(config_path, "wb") as config_file:
|
config_file.write(config_bytes)
|
||||||
config_bytes, config = obj.generate_config(
|
|
||||||
config_dir_path, server_name
|
|
||||||
)
|
|
||||||
obj.invoke_all("generate_files", config)
|
|
||||||
config_file.write(config_bytes)
|
|
||||||
print (
|
print (
|
||||||
"A config file has been generated in %s for server name"
|
"A config file has been generated in %r for server name"
|
||||||
" '%s' with corresponding SSL keys and self-signed"
|
" %r with corresponding SSL keys and self-signed"
|
||||||
" certificates. Please review this file and customise it to"
|
" certificates. Please review this file and customise it"
|
||||||
" your needs."
|
" to your needs."
|
||||||
) % (config_path, server_name)
|
) % (config_path, server_name)
|
||||||
print (
|
print (
|
||||||
"If this server name is incorrect, you will need to regenerate"
|
"If this server name is incorrect, you will need to"
|
||||||
" the SSL certificates"
|
" regenerate the SSL certificates"
|
||||||
)
|
)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print (
|
||||||
|
"Config file %r already exists. Generating any missing key"
|
||||||
|
" files."
|
||||||
|
) % (config_path,)
|
||||||
|
generate_keys = True
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
parents=[config_parser],
|
parents=[config_parser],
|
||||||
@ -206,19 +242,22 @@ class Config(object):
|
|||||||
obj.invoke_all("add_arguments", parser)
|
obj.invoke_all("add_arguments", parser)
|
||||||
args = parser.parse_args(remaining_args)
|
args = parser.parse_args(remaining_args)
|
||||||
|
|
||||||
if not config_args.config_path:
|
if not config_files:
|
||||||
config_parser.error(
|
config_parser.error(
|
||||||
"Must supply a config file.\nA config file can be automatically"
|
"Must supply a config file.\nA config file can be automatically"
|
||||||
" generated using \"--generate-config -H SERVER_NAME"
|
" generated using \"--generate-config -H SERVER_NAME"
|
||||||
" -c CONFIG-FILE\""
|
" -c CONFIG-FILE\""
|
||||||
)
|
)
|
||||||
|
|
||||||
config_dir_path = os.path.dirname(config_args.config_path[0])
|
if config_args.keys_directory:
|
||||||
|
config_dir_path = config_args.keys_directory
|
||||||
|
else:
|
||||||
|
config_dir_path = os.path.dirname(config_args.config_path[-1])
|
||||||
config_dir_path = os.path.abspath(config_dir_path)
|
config_dir_path = os.path.abspath(config_dir_path)
|
||||||
|
|
||||||
specified_config = {}
|
specified_config = {}
|
||||||
for config_path in config_args.config_path:
|
for config_file in config_files:
|
||||||
yaml_config = cls.read_config_file(config_path)
|
yaml_config = cls.read_config_file(config_file)
|
||||||
specified_config.update(yaml_config)
|
specified_config.update(yaml_config)
|
||||||
|
|
||||||
server_name = specified_config["server_name"]
|
server_name = specified_config["server_name"]
|
||||||
@ -226,6 +265,10 @@ class Config(object):
|
|||||||
config.pop("log_config")
|
config.pop("log_config")
|
||||||
config.update(specified_config)
|
config.update(specified_config)
|
||||||
|
|
||||||
|
if generate_keys:
|
||||||
|
obj.invoke_all("generate_files", config)
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
obj.invoke_all("read_config", config)
|
obj.invoke_all("read_config", config)
|
||||||
|
|
||||||
obj.invoke_all("read_arguments", args)
|
obj.invoke_all("read_arguments", args)
|
||||||
|
@ -29,10 +29,10 @@ class CaptchaConfig(Config):
|
|||||||
## Captcha ##
|
## Captcha ##
|
||||||
|
|
||||||
# This Home Server's ReCAPTCHA public key.
|
# This Home Server's ReCAPTCHA public key.
|
||||||
recaptcha_private_key: "YOUR_PUBLIC_KEY"
|
recaptcha_private_key: "YOUR_PRIVATE_KEY"
|
||||||
|
|
||||||
# This Home Server's ReCAPTCHA private key.
|
# This Home Server's ReCAPTCHA private key.
|
||||||
recaptcha_public_key: "YOUR_PRIVATE_KEY"
|
recaptcha_public_key: "YOUR_PUBLIC_KEY"
|
||||||
|
|
||||||
# Enables ReCaptcha checks when registering, preventing signup
|
# Enables ReCaptcha checks when registering, preventing signup
|
||||||
# unless a captcha is answered. Requires a valid ReCaptcha
|
# unless a captcha is answered. Requires a valid ReCaptcha
|
||||||
|
@ -25,12 +25,13 @@ from .registration import RegistrationConfig
|
|||||||
from .metrics import MetricsConfig
|
from .metrics import MetricsConfig
|
||||||
from .appservice import AppServiceConfig
|
from .appservice import AppServiceConfig
|
||||||
from .key import KeyConfig
|
from .key import KeyConfig
|
||||||
|
from .saml2 import SAML2Config
|
||||||
|
|
||||||
|
|
||||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||||
VoipConfig, RegistrationConfig,
|
VoipConfig, RegistrationConfig, MetricsConfig,
|
||||||
MetricsConfig, AppServiceConfig, KeyConfig,):
|
AppServiceConfig, KeyConfig, SAML2Config, ):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -14,6 +14,39 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import Config
|
from ._base import Config
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
ThumbnailRequirement = namedtuple(
|
||||||
|
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_thumbnail_requirements(thumbnail_sizes):
|
||||||
|
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
||||||
|
and creates a map from image media types to the thumbnail size, thumnailing
|
||||||
|
method, and thumbnail media type to precalculate
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thumbnail_sizes(list): List of dicts with "width", "height", and
|
||||||
|
"method" keys
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping from media type string to list of
|
||||||
|
ThumbnailRequirement tuples.
|
||||||
|
"""
|
||||||
|
requirements = {}
|
||||||
|
for size in thumbnail_sizes:
|
||||||
|
width = size["width"]
|
||||||
|
height = size["height"]
|
||||||
|
method = size["method"]
|
||||||
|
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
|
||||||
|
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
|
||||||
|
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
|
||||||
|
requirements.setdefault("image/gif", []).append(png_thumbnail)
|
||||||
|
requirements.setdefault("image/png", []).append(png_thumbnail)
|
||||||
|
return {
|
||||||
|
media_type: tuple(thumbnails)
|
||||||
|
for media_type, thumbnails in requirements.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ContentRepositoryConfig(Config):
|
class ContentRepositoryConfig(Config):
|
||||||
@ -22,6 +55,10 @@ class ContentRepositoryConfig(Config):
|
|||||||
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
||||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||||
|
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||||
|
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||||
|
config["thumbnail_sizes"]
|
||||||
|
)
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name):
|
def default_config(self, config_dir_path, server_name):
|
||||||
media_store = self.default_path("media_store")
|
media_store = self.default_path("media_store")
|
||||||
@ -38,4 +75,26 @@ class ContentRepositoryConfig(Config):
|
|||||||
|
|
||||||
# Maximum number of pixels that will be thumbnailed
|
# Maximum number of pixels that will be thumbnailed
|
||||||
max_image_pixels: "32M"
|
max_image_pixels: "32M"
|
||||||
|
|
||||||
|
# Whether to generate new thumbnails on the fly to precisely match
|
||||||
|
# the resolution requested by the client. If true then whenever
|
||||||
|
# a new resolution is requested by the client the server will
|
||||||
|
# generate a new thumbnail. If false the server will pick a thumbnail
|
||||||
|
# from a precalcualted list.
|
||||||
|
dynamic_thumbnails: false
|
||||||
|
|
||||||
|
# List of thumbnail to precalculate when an image is uploaded.
|
||||||
|
thumbnail_sizes:
|
||||||
|
- width: 32
|
||||||
|
height: 32
|
||||||
|
method: crop
|
||||||
|
- width: 96
|
||||||
|
height: 96
|
||||||
|
method: crop
|
||||||
|
- width: 320
|
||||||
|
height: 240
|
||||||
|
method: scale
|
||||||
|
- width: 640
|
||||||
|
height: 480
|
||||||
|
method: scale
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
54
synapse/config/saml2.py
Normal file
54
synapse/config/saml2.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 Ericsson
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class SAML2Config(Config):
|
||||||
|
"""SAML2 Configuration
|
||||||
|
Synapse uses pysaml2 libraries for providing SAML2 support
|
||||||
|
|
||||||
|
config_path: Path to the sp_conf.py configuration file
|
||||||
|
idp_redirect_url: Identity provider URL which will redirect
|
||||||
|
the user back to /login/saml2 with proper info.
|
||||||
|
|
||||||
|
sp_conf.py file is something like:
|
||||||
|
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
|
||||||
|
|
||||||
|
More information: https://pythonhosted.org/pysaml2/howto/config.html
|
||||||
|
"""
|
||||||
|
|
||||||
|
def read_config(self, config):
|
||||||
|
saml2_config = config.get("saml2_config", None)
|
||||||
|
if saml2_config:
|
||||||
|
self.saml2_enabled = True
|
||||||
|
self.saml2_config_path = saml2_config["config_path"]
|
||||||
|
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
|
||||||
|
else:
|
||||||
|
self.saml2_enabled = False
|
||||||
|
self.saml2_config_path = None
|
||||||
|
self.saml2_idp_redirect_url = None
|
||||||
|
|
||||||
|
def default_config(self, config_dir_path, server_name):
|
||||||
|
return """
|
||||||
|
# Enable SAML2 for registration and login. Uses pysaml2
|
||||||
|
# config_path: Path to the sp_conf.py configuration file
|
||||||
|
# idp_redirect_url: Identity provider URL which will redirect
|
||||||
|
# the user back to /login/saml2 with proper info.
|
||||||
|
# See pysaml2 docs for format of config.
|
||||||
|
#saml2_config:
|
||||||
|
# config_path: "%s/sp_conf.py"
|
||||||
|
# idp_redirect_url: "http://%s/idp"
|
||||||
|
""" % (config_dir_path, server_name)
|
@ -22,8 +22,10 @@ class ServerConfig(Config):
|
|||||||
self.server_name = config["server_name"]
|
self.server_name = config["server_name"]
|
||||||
self.pid_file = self.abspath(config.get("pid_file"))
|
self.pid_file = self.abspath(config.get("pid_file"))
|
||||||
self.web_client = config["web_client"]
|
self.web_client = config["web_client"]
|
||||||
|
self.web_client_location = config.get("web_client_location", None)
|
||||||
self.soft_file_limit = config["soft_file_limit"]
|
self.soft_file_limit = config["soft_file_limit"]
|
||||||
self.daemonize = config.get("daemonize")
|
self.daemonize = config.get("daemonize")
|
||||||
|
self.print_pidfile = config.get("print_pidfile")
|
||||||
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
|
||||||
|
|
||||||
self.listeners = config.get("listeners", [])
|
self.listeners = config.get("listeners", [])
|
||||||
@ -208,12 +210,18 @@ class ServerConfig(Config):
|
|||||||
self.manhole = args.manhole
|
self.manhole = args.manhole
|
||||||
if args.daemonize is not None:
|
if args.daemonize is not None:
|
||||||
self.daemonize = args.daemonize
|
self.daemonize = args.daemonize
|
||||||
|
if args.print_pidfile is not None:
|
||||||
|
self.print_pidfile = args.print_pidfile
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
server_group = parser.add_argument_group("server")
|
server_group = parser.add_argument_group("server")
|
||||||
server_group.add_argument("-D", "--daemonize", action='store_true',
|
server_group.add_argument("-D", "--daemonize", action='store_true',
|
||||||
default=None,
|
default=None,
|
||||||
help="Daemonize the home server")
|
help="Daemonize the home server")
|
||||||
|
server_group.add_argument("--print-pidfile", action='store_true',
|
||||||
|
default=None,
|
||||||
|
help="Print the path to the pidfile just"
|
||||||
|
" before daemonizing")
|
||||||
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
|
||||||
type=int,
|
type=int,
|
||||||
help="Turn on the twisted telnet manhole"
|
help="Turn on the twisted telnet manhole"
|
||||||
|
@ -27,6 +27,7 @@ class TlsConfig(Config):
|
|||||||
self.tls_certificate = self.read_tls_certificate(
|
self.tls_certificate = self.read_tls_certificate(
|
||||||
config.get("tls_certificate_path")
|
config.get("tls_certificate_path")
|
||||||
)
|
)
|
||||||
|
self.tls_certificate_file = config.get("tls_certificate_path")
|
||||||
|
|
||||||
self.no_tls = config.get("no_tls", False)
|
self.no_tls = config.get("no_tls", False)
|
||||||
|
|
||||||
@ -49,7 +50,11 @@ class TlsConfig(Config):
|
|||||||
tls_dh_params_path = base_key_name + ".tls.dh"
|
tls_dh_params_path = base_key_name + ".tls.dh"
|
||||||
|
|
||||||
return """\
|
return """\
|
||||||
# PEM encoded X509 certificate for TLS
|
# PEM encoded X509 certificate for TLS.
|
||||||
|
# You can replace the self-signed certificate that synapse
|
||||||
|
# autogenerates on launch with your own SSL certificate + key pair
|
||||||
|
# if you like. Any required intermediary certificates can be
|
||||||
|
# appended after the primary certificate in hierarchical order.
|
||||||
tls_certificate_path: "%(tls_certificate_path)s"
|
tls_certificate_path: "%(tls_certificate_path)s"
|
||||||
|
|
||||||
# PEM encoded private key for TLS
|
# PEM encoded private key for TLS
|
||||||
|
@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
|
|||||||
_ecCurve = _OpenSSLECCurve(_defaultCurveName)
|
_ecCurve = _OpenSSLECCurve(_defaultCurveName)
|
||||||
_ecCurve.addECKeyToContext(context)
|
_ecCurve.addECKeyToContext(context)
|
||||||
except:
|
except:
|
||||||
logger.exception("Failed to enable eliptic curve for TLS")
|
logger.exception("Failed to enable elliptic curve for TLS")
|
||||||
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
|
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
|
||||||
context.use_certificate(config.tls_certificate)
|
context.use_certificate_chain_file(config.tls_certificate_file)
|
||||||
|
|
||||||
if not config.no_tls:
|
if not config.no_tls:
|
||||||
context.use_privatekey(config.tls_private_key)
|
context.use_privatekey(config.tls_private_key)
|
||||||
|
@ -25,11 +25,13 @@ from syutil.base64util import decode_base64, encode_base64
|
|||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
from synapse.util.retryutils import get_retry_limiter
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
import urllib
|
import urllib
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
@ -38,6 +40,9 @@ import logging
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
|
||||||
|
|
||||||
|
|
||||||
class Keyring(object):
|
class Keyring(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
@ -49,141 +54,325 @@ class Keyring(object):
|
|||||||
|
|
||||||
self.key_downloads = {}
|
self.key_downloads = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def verify_json_for_server(self, server_name, json_object):
|
def verify_json_for_server(self, server_name, json_object):
|
||||||
logger.debug("Verifying for %s", server_name)
|
return self.verify_json_objects_for_server(
|
||||||
key_ids = signature_ids(json_object, server_name)
|
[(server_name, json_object)]
|
||||||
if not key_ids:
|
)[0]
|
||||||
raise SynapseError(
|
|
||||||
400,
|
|
||||||
"Not signed with a supported algorithm",
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
verify_key = yield self.get_server_verify_key(server_name, key_ids)
|
|
||||||
except IOError as e:
|
|
||||||
logger.warn(
|
|
||||||
"Got IOError when downloading keys for %s: %s %s",
|
|
||||||
server_name, type(e).__name__, str(e.message),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
502,
|
|
||||||
"Error downloading keys for %s" % (server_name,),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warn(
|
|
||||||
"Got Exception when downloading keys for %s: %s %s",
|
|
||||||
server_name, type(e).__name__, str(e.message),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"No key for %s with id %s" % (server_name, key_ids),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
def verify_json_objects_for_server(self, server_and_json):
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
"""Bulk verfies signatures of json objects, bulk fetching keys as
|
||||||
except:
|
necessary.
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"Invalid signature for server %s with key %s:%s" % (
|
|
||||||
server_name, verify_key.alg, verify_key.version
|
|
||||||
),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_server_verify_key(self, server_name, key_ids):
|
|
||||||
"""Finds a verification key for the server with one of the key ids.
|
|
||||||
Trys to fetch the key from a trusted perspective server first.
|
|
||||||
Args:
|
Args:
|
||||||
server_name(str): The name of the server to fetch a key for.
|
server_and_json (list): List of pairs of (server_name, json_object)
|
||||||
keys_ids (list of str): The key_ids to check for.
|
|
||||||
|
Returns:
|
||||||
|
list of deferreds indicating success or failure to verify each
|
||||||
|
json object's signature for the given server_name.
|
||||||
"""
|
"""
|
||||||
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
|
group_id_to_json = {}
|
||||||
|
group_id_to_group = {}
|
||||||
|
group_ids = []
|
||||||
|
|
||||||
if cached:
|
next_group_id = 0
|
||||||
defer.returnValue(cached[0])
|
deferreds = {}
|
||||||
return
|
|
||||||
|
|
||||||
download = self.key_downloads.get(server_name)
|
for server_name, json_object in server_and_json:
|
||||||
|
logger.debug("Verifying for %s", server_name)
|
||||||
|
group_id = next_group_id
|
||||||
|
next_group_id += 1
|
||||||
|
group_ids.append(group_id)
|
||||||
|
|
||||||
if download is None:
|
key_ids = signature_ids(json_object, server_name)
|
||||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
if not key_ids:
|
||||||
download = ObservableDeferred(
|
deferreds[group_id] = defer.fail(SynapseError(
|
||||||
download,
|
400,
|
||||||
consumeErrors=True
|
"Not signed with a supported algorithm",
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
deferreds[group_id] = defer.Deferred()
|
||||||
|
|
||||||
|
group = KeyGroup(server_name, group_id, key_ids)
|
||||||
|
|
||||||
|
group_id_to_group[group_id] = group
|
||||||
|
group_id_to_json[group_id] = json_object
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_key_deferred(group, deferred):
|
||||||
|
server_name = group.server_name
|
||||||
|
try:
|
||||||
|
_, _, key_id, verify_key = yield deferred
|
||||||
|
except IOError as e:
|
||||||
|
logger.warn(
|
||||||
|
"Got IOError when downloading keys for %s: %s %s",
|
||||||
|
server_name, type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
raise SynapseError(
|
||||||
|
502,
|
||||||
|
"Error downloading keys for %s" % (server_name,),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Got Exception when downloading keys for %s: %s %s",
|
||||||
|
server_name, type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (server_name, key_ids),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
json_object = group_id_to_json[group.group_id]
|
||||||
|
|
||||||
|
try:
|
||||||
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
except:
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"Invalid signature for server %s with key %s:%s" % (
|
||||||
|
server_name, verify_key.alg, verify_key.version
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
server_to_deferred = {
|
||||||
|
server_name: defer.Deferred()
|
||||||
|
for server_name, _ in server_and_json
|
||||||
|
}
|
||||||
|
|
||||||
|
# We want to wait for any previous lookups to complete before
|
||||||
|
# proceeding.
|
||||||
|
wait_on_deferred = self.wait_for_previous_lookups(
|
||||||
|
[server_name for server_name, _ in server_and_json],
|
||||||
|
server_to_deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actually start fetching keys.
|
||||||
|
wait_on_deferred.addBoth(
|
||||||
|
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||||
|
)
|
||||||
|
|
||||||
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
|
# any lookups waiting will proceed.
|
||||||
|
server_to_gids = {}
|
||||||
|
|
||||||
|
def remove_deferreds(res, server_name, group_id):
|
||||||
|
server_to_gids[server_name].discard(group_id)
|
||||||
|
if not server_to_gids[server_name]:
|
||||||
|
server_to_deferred.pop(server_name).callback(None)
|
||||||
|
return res
|
||||||
|
|
||||||
|
for g_id, deferred in deferreds.items():
|
||||||
|
server_name = group_id_to_group[g_id].server_name
|
||||||
|
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||||
|
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||||
|
|
||||||
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
|
# signatures can be verified
|
||||||
|
return [
|
||||||
|
handle_key_deferred(
|
||||||
|
group_id_to_group[g_id],
|
||||||
|
deferreds[g_id],
|
||||||
)
|
)
|
||||||
self.key_downloads[server_name] = download
|
for g_id in group_ids
|
||||||
|
]
|
||||||
@download.addBoth
|
|
||||||
def callback(ret):
|
|
||||||
del self.key_downloads[server_name]
|
|
||||||
return ret
|
|
||||||
|
|
||||||
r = yield download.observe()
|
|
||||||
defer.returnValue(r)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_server_verify_key_impl(self, server_name, key_ids):
|
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||||
keys = None
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_names (list): list of server_names we want to lookup
|
||||||
|
server_to_deferred (dict): server_name to deferred which gets
|
||||||
|
resolved once we've finished looking up keys for that server
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
wait_on = [
|
||||||
|
self.key_downloads[server_name]
|
||||||
|
for server_name in server_names
|
||||||
|
if server_name in self.key_downloads
|
||||||
|
]
|
||||||
|
if wait_on:
|
||||||
|
yield defer.DeferredList(wait_on)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
for server_name, deferred in server_to_deferred:
|
||||||
|
self.key_downloads[server_name] = ObservableDeferred(deferred)
|
||||||
|
|
||||||
|
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||||
|
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||||
|
each group.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# These are functions that produce keys given a list of key ids
|
||||||
|
key_fetch_fns = (
|
||||||
|
self.get_keys_from_store, # First try the local store
|
||||||
|
self.get_keys_from_perspectives, # Then try via perspectives
|
||||||
|
self.get_keys_from_server, # Then try directly
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def do_iterations():
|
||||||
|
merged_results = {}
|
||||||
|
|
||||||
|
missing_keys = {
|
||||||
|
group.server_name: key_id
|
||||||
|
for group in group_id_to_group.values()
|
||||||
|
for key_id in group.key_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
for fn in key_fetch_fns:
|
||||||
|
results = yield fn(missing_keys.items())
|
||||||
|
merged_results.update(results)
|
||||||
|
|
||||||
|
# We now need to figure out which groups we have keys for
|
||||||
|
# and which we don't
|
||||||
|
missing_groups = {}
|
||||||
|
for group in group_id_to_group.values():
|
||||||
|
for key_id in group.key_ids:
|
||||||
|
if key_id in merged_results[group.server_name]:
|
||||||
|
group_id_to_deferred[group.group_id].callback((
|
||||||
|
group.group_id,
|
||||||
|
group.server_name,
|
||||||
|
key_id,
|
||||||
|
merged_results[group.server_name][key_id],
|
||||||
|
))
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
missing_groups.setdefault(
|
||||||
|
group.server_name, []
|
||||||
|
).append(group)
|
||||||
|
|
||||||
|
if not missing_groups:
|
||||||
|
break
|
||||||
|
|
||||||
|
missing_keys = {
|
||||||
|
server_name: set(
|
||||||
|
key_id for group in groups for key_id in group.key_ids
|
||||||
|
)
|
||||||
|
for server_name, groups in missing_groups.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
for group in missing_groups.values():
|
||||||
|
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (
|
||||||
|
group.server_name, group.key_ids,
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
))
|
||||||
|
|
||||||
|
def on_err(err):
|
||||||
|
for deferred in group_id_to_deferred.values():
|
||||||
|
if not deferred.called:
|
||||||
|
deferred.errback(err)
|
||||||
|
|
||||||
|
do_iterations().addErrback(on_err)
|
||||||
|
|
||||||
|
return group_id_to_deferred
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
|
res = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self.store.get_server_verify_keys(server_name, key_ids)
|
||||||
|
for server_name, key_ids in server_name_and_key_ids
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
defer.returnValue(dict(zip(
|
||||||
|
[server_name for server_name, _ in server_name_and_key_ids],
|
||||||
|
res
|
||||||
|
)))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_keys_from_perspectives(self, server_name_and_key_ids):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_key(perspective_name, perspective_keys):
|
def get_key(perspective_name, perspective_keys):
|
||||||
try:
|
try:
|
||||||
result = yield self.get_server_verify_key_v2_indirect(
|
result = yield self.get_server_verify_key_v2_indirect(
|
||||||
server_name, key_ids, perspective_name, perspective_keys
|
server_name_and_key_ids, perspective_name, perspective_keys
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logger.exception(
|
||||||
"Unable to getting key %r for %r from %r: %s %s",
|
"Unable to get key from %r: %s %s",
|
||||||
key_ids, server_name, perspective_name,
|
perspective_name,
|
||||||
type(e).__name__, str(e.message),
|
type(e).__name__, str(e.message),
|
||||||
)
|
)
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
perspective_results = yield defer.gatherResults([
|
results = yield defer.gatherResults(
|
||||||
get_key(p_name, p_keys)
|
[
|
||||||
for p_name, p_keys in self.perspective_servers.items()
|
get_key(p_name, p_keys)
|
||||||
])
|
for p_name, p_keys in self.perspective_servers.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
for results in perspective_results:
|
union_of_keys = {}
|
||||||
if results is not None:
|
for result in results:
|
||||||
keys = results
|
for server_name, keys in result.items():
|
||||||
|
union_of_keys.setdefault(server_name, {}).update(keys)
|
||||||
|
|
||||||
limiter = yield get_retry_limiter(
|
defer.returnValue(union_of_keys)
|
||||||
server_name,
|
|
||||||
self.clock,
|
|
||||||
self.store,
|
|
||||||
)
|
|
||||||
|
|
||||||
with limiter:
|
@defer.inlineCallbacks
|
||||||
if not keys:
|
def get_keys_from_server(self, server_name_and_key_ids):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_key(server_name, key_ids):
|
||||||
|
limiter = yield get_retry_limiter(
|
||||||
|
server_name,
|
||||||
|
self.clock,
|
||||||
|
self.store,
|
||||||
|
)
|
||||||
|
with limiter:
|
||||||
|
keys = None
|
||||||
try:
|
try:
|
||||||
keys = yield self.get_server_verify_key_v2_direct(
|
keys = yield self.get_server_verify_key_v2_direct(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(
|
logger.info(
|
||||||
"Unable to getting key %r for %r directly: %s %s",
|
"Unable to getting key %r for %r directly: %s %s",
|
||||||
key_ids, server_name,
|
key_ids, server_name,
|
||||||
type(e).__name__, str(e.message),
|
type(e).__name__, str(e.message),
|
||||||
)
|
)
|
||||||
|
|
||||||
if not keys:
|
if not keys:
|
||||||
keys = yield self.get_server_verify_key_v1_direct(
|
keys = yield self.get_server_verify_key_v1_direct(
|
||||||
server_name, key_ids
|
server_name, key_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
for key_id in key_ids:
|
keys = {server_name: keys}
|
||||||
if key_id in keys:
|
|
||||||
defer.returnValue(keys[key_id])
|
defer.returnValue(keys)
|
||||||
return
|
|
||||||
raise ValueError("No verification key found for given key ids")
|
results = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
get_key(server_name, key_ids)
|
||||||
|
for server_name, key_ids in server_name_and_key_ids
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
merged = {}
|
||||||
|
for result in results:
|
||||||
|
merged.update(result)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
server_name: keys
|
||||||
|
for server_name, keys in merged.items()
|
||||||
|
if keys
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
|
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
|
||||||
perspective_name,
|
perspective_name,
|
||||||
perspective_keys):
|
perspective_keys):
|
||||||
limiter = yield get_retry_limiter(
|
limiter = yield get_retry_limiter(
|
||||||
@ -204,6 +393,7 @@ class Keyring(object):
|
|||||||
u"minimum_valid_until_ts": 0
|
u"minimum_valid_until_ts": 0
|
||||||
} for key_id in key_ids
|
} for key_id in key_ids
|
||||||
}
|
}
|
||||||
|
for server_name, key_ids in server_names_and_key_ids
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -243,23 +433,29 @@ class Keyring(object):
|
|||||||
" server %r" % (perspective_name,)
|
" server %r" % (perspective_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
processed_response = yield self.process_v2_response(
|
||||||
server_name, perspective_name, response
|
perspective_name, response
|
||||||
)
|
)
|
||||||
|
|
||||||
keys.update(response_keys)
|
for server_name, response_keys in processed_response.items():
|
||||||
|
keys.setdefault(server_name, {}).update(response_keys)
|
||||||
|
|
||||||
yield self.store_keys(
|
yield defer.gatherResults(
|
||||||
server_name=server_name,
|
[
|
||||||
from_server=perspective_name,
|
self.store_keys(
|
||||||
verify_keys=keys,
|
server_name=server_name,
|
||||||
)
|
from_server=perspective_name,
|
||||||
|
verify_keys=response_keys,
|
||||||
|
)
|
||||||
|
for server_name, response_keys in keys.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||||
|
|
||||||
keys = {}
|
keys = {}
|
||||||
|
|
||||||
for requested_key_id in key_ids:
|
for requested_key_id in key_ids:
|
||||||
@ -295,25 +491,30 @@ class Keyring(object):
|
|||||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||||
|
|
||||||
response_keys = yield self.process_v2_response(
|
response_keys = yield self.process_v2_response(
|
||||||
server_name=server_name,
|
|
||||||
from_server=server_name,
|
from_server=server_name,
|
||||||
requested_id=requested_key_id,
|
requested_ids=[requested_key_id],
|
||||||
response_json=response,
|
response_json=response,
|
||||||
)
|
)
|
||||||
|
|
||||||
keys.update(response_keys)
|
keys.update(response_keys)
|
||||||
|
|
||||||
yield self.store_keys(
|
yield defer.gatherResults(
|
||||||
server_name=server_name,
|
[
|
||||||
from_server=server_name,
|
self.store_keys(
|
||||||
verify_keys=keys,
|
server_name=key_server_name,
|
||||||
)
|
from_server=server_name,
|
||||||
|
verify_keys=verify_keys,
|
||||||
|
)
|
||||||
|
for key_server_name, verify_keys in keys.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def process_v2_response(self, server_name, from_server, response_json,
|
def process_v2_response(self, from_server, response_json,
|
||||||
requested_id=None):
|
requested_ids=[]):
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
response_keys = {}
|
response_keys = {}
|
||||||
verify_keys = {}
|
verify_keys = {}
|
||||||
@ -335,6 +536,8 @@ class Keyring(object):
|
|||||||
verify_key.time_added = time_now_ms
|
verify_key.time_added = time_now_ms
|
||||||
old_verify_keys[key_id] = verify_key
|
old_verify_keys[key_id] = verify_key
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
server_name = response_json["server_name"]
|
||||||
for key_id in response_json["signatures"].get(server_name, {}):
|
for key_id in response_json["signatures"].get(server_name, {}):
|
||||||
if key_id not in response_json["verify_keys"]:
|
if key_id not in response_json["verify_keys"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -357,28 +560,31 @@ class Keyring(object):
|
|||||||
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
||||||
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
|
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
|
||||||
|
|
||||||
updated_key_ids = set()
|
updated_key_ids = set(requested_ids)
|
||||||
if requested_id is not None:
|
|
||||||
updated_key_ids.add(requested_id)
|
|
||||||
updated_key_ids.update(verify_keys)
|
updated_key_ids.update(verify_keys)
|
||||||
updated_key_ids.update(old_verify_keys)
|
updated_key_ids.update(old_verify_keys)
|
||||||
|
|
||||||
response_keys.update(verify_keys)
|
response_keys.update(verify_keys)
|
||||||
response_keys.update(old_verify_keys)
|
response_keys.update(old_verify_keys)
|
||||||
|
|
||||||
for key_id in updated_key_ids:
|
yield defer.gatherResults(
|
||||||
yield self.store.store_server_keys_json(
|
[
|
||||||
server_name=server_name,
|
self.store.store_server_keys_json(
|
||||||
key_id=key_id,
|
server_name=server_name,
|
||||||
from_server=server_name,
|
key_id=key_id,
|
||||||
ts_now_ms=time_now_ms,
|
from_server=server_name,
|
||||||
ts_expires_ms=ts_valid_until_ms,
|
ts_now_ms=time_now_ms,
|
||||||
key_json_bytes=signed_key_json_bytes,
|
ts_expires_ms=ts_valid_until_ms,
|
||||||
)
|
key_json_bytes=signed_key_json_bytes,
|
||||||
|
)
|
||||||
|
for key_id in updated_key_ids
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(response_keys)
|
results[server_name] = response_keys
|
||||||
|
|
||||||
raise ValueError("No verification key found for given key ids")
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_key_v1_direct(self, server_name, key_ids):
|
def get_server_verify_key_v1_direct(self, server_name, key_ids):
|
||||||
@ -462,8 +668,13 @@ class Keyring(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A deferred that completes when the keys are stored.
|
A deferred that completes when the keys are stored.
|
||||||
"""
|
"""
|
||||||
for key_id, key in verify_keys.items():
|
# TODO(markjh): Store whether the keys have expired.
|
||||||
# TODO(markjh): Store whether the keys have expired.
|
yield defer.gatherResults(
|
||||||
yield self.store.store_server_verify_key(
|
[
|
||||||
server_name, server_name, key.time_added, key
|
self.store.store_server_verify_key(
|
||||||
)
|
server_name, server_name, key.time_added, key
|
||||||
|
)
|
||||||
|
for key_id, key in verify_keys.items()
|
||||||
|
],
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
@ -90,7 +90,7 @@ class EventBase(object):
|
|||||||
d = dict(self._event_dict)
|
d = dict(self._event_dict)
|
||||||
d.update({
|
d.update({
|
||||||
"signatures": self.signatures,
|
"signatures": self.signatures,
|
||||||
"unsigned": self.unsigned,
|
"unsigned": dict(self.unsigned),
|
||||||
})
|
})
|
||||||
|
|
||||||
return d
|
return d
|
||||||
@ -109,6 +109,9 @@ class EventBase(object):
|
|||||||
pdu_json.setdefault("unsigned", {})["age"] = int(age)
|
pdu_json.setdefault("unsigned", {})["age"] = int(age)
|
||||||
del pdu_json["unsigned"]["age_ts"]
|
del pdu_json["unsigned"]["age_ts"]
|
||||||
|
|
||||||
|
# This may be a frozen event
|
||||||
|
pdu_json["unsigned"].pop("redacted_because", None)
|
||||||
|
|
||||||
return pdu_json
|
return pdu_json
|
||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
|
@ -74,6 +74,8 @@ def prune_event(event):
|
|||||||
)
|
)
|
||||||
elif event_type == EventTypes.Aliases:
|
elif event_type == EventTypes.Aliases:
|
||||||
add_fields("aliases")
|
add_fields("aliases")
|
||||||
|
elif event_type == EventTypes.RoomHistoryVisibility:
|
||||||
|
add_fields("history_visibility")
|
||||||
|
|
||||||
allowed_fields = {
|
allowed_fields = {
|
||||||
k: v
|
k: v
|
||||||
|
@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class FederationBase(object):
|
class FederationBase(object):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
|
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||||
|
include_none=False):
|
||||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||||
one. If a PDU fails its signature check then we check if we have it in
|
one. If a PDU fails its signature check then we check if we have it in
|
||||||
the database and if not then request if from the originating server of
|
the database and if not then request if from the originating server of
|
||||||
@ -50,84 +51,108 @@ class FederationBase(object):
|
|||||||
Returns:
|
Returns:
|
||||||
Deferred : A list of PDUs that have valid signatures and hashes.
|
Deferred : A list of PDUs that have valid signatures and hashes.
|
||||||
"""
|
"""
|
||||||
|
deferreds = self._check_sigs_and_hashes(pdus)
|
||||||
|
|
||||||
signed_pdus = []
|
def callback(pdu):
|
||||||
|
return pdu
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def errback(failure, pdu):
|
||||||
def do(pdu):
|
failure.trap(SynapseError)
|
||||||
try:
|
return None
|
||||||
new_pdu = yield self._check_sigs_and_hash(pdu)
|
|
||||||
signed_pdus.append(new_pdu)
|
|
||||||
except SynapseError:
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
|
||||||
|
|
||||||
|
def try_local_db(res, pdu):
|
||||||
|
if not res:
|
||||||
# Check local db.
|
# Check local db.
|
||||||
new_pdu = yield self.store.get_event(
|
return self.store.get_event(
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
allow_rejected=True,
|
allow_rejected=True,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
if new_pdu:
|
return res
|
||||||
signed_pdus.append(new_pdu)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check pdu.origin
|
def try_remote(res, pdu):
|
||||||
if pdu.origin != origin:
|
if not res and pdu.origin != origin:
|
||||||
try:
|
return self.get_pdu(
|
||||||
new_pdu = yield self.get_pdu(
|
destinations=[pdu.origin],
|
||||||
destinations=[pdu.origin],
|
event_id=pdu.event_id,
|
||||||
event_id=pdu.event_id,
|
outlier=outlier,
|
||||||
outlier=outlier,
|
timeout=10000,
|
||||||
timeout=10000,
|
).addErrback(lambda e: None)
|
||||||
)
|
return res
|
||||||
|
|
||||||
if new_pdu:
|
|
||||||
signed_pdus.append(new_pdu)
|
|
||||||
return
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
def warn(res, pdu):
|
||||||
|
if not res:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to find copy of %s with valid signature",
|
"Failed to find copy of %s with valid signature",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
yield defer.gatherResults(
|
for pdu, deferred in zip(pdus, deferreds):
|
||||||
[do(pdu) for pdu in pdus],
|
deferred.addCallbacks(
|
||||||
|
callback, errback, errbackArgs=[pdu]
|
||||||
|
).addCallback(
|
||||||
|
try_local_db, pdu
|
||||||
|
).addCallback(
|
||||||
|
try_remote, pdu
|
||||||
|
).addCallback(
|
||||||
|
warn, pdu
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_pdus = yield defer.gatherResults(
|
||||||
|
deferreds,
|
||||||
consumeErrors=True
|
consumeErrors=True
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
defer.returnValue(signed_pdus)
|
if include_none:
|
||||||
|
defer.returnValue(valid_pdus)
|
||||||
|
else:
|
||||||
|
defer.returnValue([p for p in valid_pdus if p])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def _check_sigs_and_hash(self, pdu):
|
def _check_sigs_and_hash(self, pdu):
|
||||||
"""Throws a SynapseError if the PDU does not have the correct
|
return self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
||||||
|
def _check_sigs_and_hashes(self, pdus):
|
||||||
|
"""Throws a SynapseError if a PDU does not have the correct
|
||||||
signatures.
|
signatures.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FrozenEvent: Either the given event or it redacted if it failed the
|
FrozenEvent: Either the given event or it redacted if it failed the
|
||||||
content hash check.
|
content hash check.
|
||||||
"""
|
"""
|
||||||
# Check signatures are correct.
|
|
||||||
redacted_event = prune_event(pdu)
|
|
||||||
redacted_pdu_json = redacted_event.get_pdu_json()
|
|
||||||
|
|
||||||
try:
|
redacted_pdus = [
|
||||||
yield self.keyring.verify_json_for_server(
|
prune_event(pdu)
|
||||||
pdu.origin, redacted_pdu_json
|
for pdu in pdus
|
||||||
)
|
]
|
||||||
except SynapseError:
|
|
||||||
|
deferreds = self.keyring.verify_json_objects_for_server([
|
||||||
|
(p.origin, p.get_pdu_json())
|
||||||
|
for p in redacted_pdus
|
||||||
|
])
|
||||||
|
|
||||||
|
def callback(_, pdu, redacted):
|
||||||
|
if not check_event_content_hash(pdu):
|
||||||
|
logger.warn(
|
||||||
|
"Event content has been tampered, redacting %s: %s",
|
||||||
|
pdu.event_id, pdu.get_pdu_json()
|
||||||
|
)
|
||||||
|
return redacted
|
||||||
|
return pdu
|
||||||
|
|
||||||
|
def errback(failure, pdu):
|
||||||
|
failure.trap(SynapseError)
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Signature check failed for %s",
|
"Signature check failed for %s",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
raise
|
return failure
|
||||||
|
|
||||||
if not check_event_content_hash(pdu):
|
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
|
||||||
logger.warn(
|
deferred.addCallbacks(
|
||||||
"Event content has been tampered, redacting.",
|
callback, errback,
|
||||||
pdu.event_id,
|
callbackArgs=[pdu, redacted],
|
||||||
|
errbackArgs=[pdu],
|
||||||
)
|
)
|
||||||
defer.returnValue(redacted_event)
|
|
||||||
|
|
||||||
defer.returnValue(pdu)
|
return deferreds
|
||||||
|
@ -23,13 +23,14 @@ from synapse.api.errors import (
|
|||||||
CodeMessageException, HttpResponseException, SynapseError,
|
CodeMessageException, HttpResponseException, SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import FrozenEvent
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
|
||||||
|
|
||||||
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
@ -133,6 +134,36 @@ class FederationClient(FederationBase):
|
|||||||
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
|
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def query_client_keys(self, destination, content):
|
||||||
|
"""Query device keys for a device hosted on a remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): Domain name of the remote homeserver
|
||||||
|
content (dict): The query content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Deferred which will eventually yield a JSON object from the
|
||||||
|
response
|
||||||
|
"""
|
||||||
|
sent_queries_counter.inc("client_device_keys")
|
||||||
|
return self.transport_layer.query_client_keys(destination, content)
|
||||||
|
|
||||||
|
@log_function
|
||||||
|
def claim_client_keys(self, destination, content):
|
||||||
|
"""Claims one-time keys for a device hosted on a remote server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (str): Domain name of the remote homeserver
|
||||||
|
content (dict): The query content.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a Deferred which will eventually yield a JSON object from the
|
||||||
|
response
|
||||||
|
"""
|
||||||
|
sent_queries_counter.inc("client_one_time_keys")
|
||||||
|
return self.transport_layer.claim_client_keys(destination, content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def backfill(self, dest, context, limit, extremities):
|
def backfill(self, dest, context, limit, extremities):
|
||||||
@ -167,7 +198,7 @@ class FederationClient(FederationBase):
|
|||||||
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
pdus[:] = yield defer.gatherResults(
|
pdus[:] = yield defer.gatherResults(
|
||||||
[self._check_sigs_and_hash(pdu) for pdu in pdus],
|
self._check_sigs_and_hashes(pdus),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
@ -230,7 +261,7 @@ class FederationClient(FederationBase):
|
|||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
pdu = yield self._check_sigs_and_hash(pdu)
|
pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -327,6 +358,9 @@ class FederationClient(FederationBase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def make_join(self, destinations, room_id, user_id):
|
def make_join(self, destinations, room_id, user_id):
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
if destination == self.server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = yield self.transport_layer.make_join(
|
ret = yield self.transport_layer.make_join(
|
||||||
destination, room_id, user_id
|
destination, room_id, user_id
|
||||||
@ -353,6 +387,9 @@ class FederationClient(FederationBase):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_join(self, destinations, pdu):
|
def send_join(self, destinations, pdu):
|
||||||
for destination in destinations:
|
for destination in destinations:
|
||||||
|
if destination == self.server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
time_now = self._clock.time_msec()
|
time_now = self._clock.time_msec()
|
||||||
_, content = yield self.transport_layer.send_join(
|
_, content = yield self.transport_layer.send_join(
|
||||||
@ -374,17 +411,39 @@ class FederationClient(FederationBase):
|
|||||||
for p in content.get("auth_chain", [])
|
for p in content.get("auth_chain", [])
|
||||||
]
|
]
|
||||||
|
|
||||||
signed_state, signed_auth = yield defer.gatherResults(
|
pdus = {
|
||||||
[
|
p.event_id: p
|
||||||
self._check_sigs_and_hash_and_fetch(
|
for p in itertools.chain(state, auth_chain)
|
||||||
destination, state, outlier=True
|
}
|
||||||
),
|
|
||||||
self._check_sigs_and_hash_and_fetch(
|
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
|
||||||
destination, auth_chain, outlier=True
|
destination, pdus.values(),
|
||||||
)
|
outlier=True,
|
||||||
],
|
)
|
||||||
consumeErrors=True
|
|
||||||
).addErrback(unwrapFirstError)
|
valid_pdus_map = {
|
||||||
|
p.event_id: p
|
||||||
|
for p in valid_pdus
|
||||||
|
}
|
||||||
|
|
||||||
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
|
# references being passed on, as that causes... issues.
|
||||||
|
signed_state = [
|
||||||
|
copy.copy(valid_pdus_map[p.event_id])
|
||||||
|
for p in state
|
||||||
|
if p.event_id in valid_pdus_map
|
||||||
|
]
|
||||||
|
|
||||||
|
signed_auth = [
|
||||||
|
valid_pdus_map[p.event_id]
|
||||||
|
for p in auth_chain
|
||||||
|
if p.event_id in valid_pdus_map
|
||||||
|
]
|
||||||
|
|
||||||
|
# NB: We *need* to copy to ensure that we don't have multiple
|
||||||
|
# references being passed on, as that causes... issues.
|
||||||
|
for s in signed_state:
|
||||||
|
s.internal_metadata = copy.deepcopy(s.internal_metadata)
|
||||||
|
|
||||||
auth_chain.sort(key=lambda e: e.depth)
|
auth_chain.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
@ -396,7 +455,7 @@ class FederationClient(FederationBase):
|
|||||||
except CodeMessageException:
|
except CodeMessageException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warn(
|
logger.exception(
|
||||||
"Failed to send_join via %s: %s",
|
"Failed to send_join via %s: %s",
|
||||||
destination, e.message
|
destination, e.message
|
||||||
)
|
)
|
||||||
|
@ -27,6 +27,7 @@ from synapse.api.errors import FederationError, SynapseError
|
|||||||
|
|
||||||
from synapse.crypto.event_signing import compute_event_signature
|
from synapse.crypto.event_signing import compute_event_signature
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
@ -312,6 +313,48 @@ class FederationServer(FederationBase):
|
|||||||
(200, send_content)
|
(200, send_content)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def on_query_client_keys(self, origin, content):
|
||||||
|
query = []
|
||||||
|
for user_id, device_ids in content.get("device_keys", {}).items():
|
||||||
|
if not device_ids:
|
||||||
|
query.append((user_id, None))
|
||||||
|
else:
|
||||||
|
for device_id in device_ids:
|
||||||
|
query.append((user_id, device_id))
|
||||||
|
|
||||||
|
results = yield self.store.get_e2e_device_keys(query)
|
||||||
|
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, json_bytes in device_keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||||
|
json_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({"device_keys": json_result})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def on_claim_client_keys(self, origin, content):
|
||||||
|
query = []
|
||||||
|
for user_id, device_keys in content.get("one_time_keys", {}).items():
|
||||||
|
for device_id, algorithm in device_keys.items():
|
||||||
|
query.append((user_id, device_id, algorithm))
|
||||||
|
|
||||||
|
results = yield self.store.claim_e2e_one_time_keys(query)
|
||||||
|
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, keys in device_keys.items():
|
||||||
|
for key_id, json_bytes in keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = {
|
||||||
|
key_id: json.loads(json_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue({"one_time_keys": json_result})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_get_missing_events(self, origin, room_id, earliest_events,
|
def on_get_missing_events(self, origin, room_id, earliest_events,
|
||||||
|
@ -222,6 +222,76 @@ class TransportLayerClient(object):
|
|||||||
|
|
||||||
defer.returnValue(content)
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def query_client_keys(self, destination, query_content):
|
||||||
|
"""Query the device keys for a list of user ids hosted on a remote
|
||||||
|
server.
|
||||||
|
|
||||||
|
Request:
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": ["<device_id>"]
|
||||||
|
} }
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {...}
|
||||||
|
} } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination(str): The server to query.
|
||||||
|
query_content(dict): The user ids to query.
|
||||||
|
Returns:
|
||||||
|
A dict containg the device keys.
|
||||||
|
"""
|
||||||
|
path = PREFIX + "/user/keys/query"
|
||||||
|
|
||||||
|
content = yield self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=query_content,
|
||||||
|
)
|
||||||
|
defer.returnValue(content)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
@log_function
|
||||||
|
def claim_client_keys(self, destination, query_content):
|
||||||
|
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||||
|
|
||||||
|
Request:
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": "<algorithm>"
|
||||||
|
} } }
|
||||||
|
|
||||||
|
Response:
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>"
|
||||||
|
} } } }
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination(str): The server to query.
|
||||||
|
query_content(dict): The user ids to query.
|
||||||
|
Returns:
|
||||||
|
A dict containg the one-time keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path = PREFIX + "/user/keys/claim"
|
||||||
|
|
||||||
|
content = yield self.client.post_json(
|
||||||
|
destination=destination,
|
||||||
|
path=path,
|
||||||
|
data=query_content,
|
||||||
|
)
|
||||||
|
defer.returnValue(content)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_missing_events(self, destination, room_id, earliest_events,
|
def get_missing_events(self, destination, room_id, earliest_events,
|
||||||
|
@ -325,6 +325,24 @@ class FederationInviteServlet(BaseFederationServlet):
|
|||||||
defer.returnValue((200, content))
|
defer.returnValue((200, content))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationClientKeysQueryServlet(BaseFederationServlet):
|
||||||
|
PATH = "/user/keys/query"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query):
|
||||||
|
response = yield self.handler.on_query_client_keys(origin, content)
|
||||||
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
|
||||||
|
class FederationClientKeysClaimServlet(BaseFederationServlet):
|
||||||
|
PATH = "/user/keys/claim"
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, origin, content, query):
|
||||||
|
response = yield self.handler.on_claim_client_keys(origin, content)
|
||||||
|
defer.returnValue((200, response))
|
||||||
|
|
||||||
|
|
||||||
class FederationQueryAuthServlet(BaseFederationServlet):
|
class FederationQueryAuthServlet(BaseFederationServlet):
|
||||||
PATH = "/query_auth/([^/]*)/([^/]*)"
|
PATH = "/query_auth/([^/]*)/([^/]*)"
|
||||||
|
|
||||||
@ -373,4 +391,6 @@ SERVLET_CLASSES = (
|
|||||||
FederationQueryAuthServlet,
|
FederationQueryAuthServlet,
|
||||||
FederationGetMissingEventsServlet,
|
FederationGetMissingEventsServlet,
|
||||||
FederationEventAuthServlet,
|
FederationEventAuthServlet,
|
||||||
|
FederationClientKeysQueryServlet,
|
||||||
|
FederationClientKeysClaimServlet,
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,6 @@ from .room import (
|
|||||||
from .message import MessageHandler
|
from .message import MessageHandler
|
||||||
from .events import EventStreamHandler, EventHandler
|
from .events import EventStreamHandler, EventHandler
|
||||||
from .federation import FederationHandler
|
from .federation import FederationHandler
|
||||||
from .login import LoginHandler
|
|
||||||
from .profile import ProfileHandler
|
from .profile import ProfileHandler
|
||||||
from .presence import PresenceHandler
|
from .presence import PresenceHandler
|
||||||
from .directory import DirectoryHandler
|
from .directory import DirectoryHandler
|
||||||
@ -32,6 +31,7 @@ from .appservice import ApplicationServicesHandler
|
|||||||
from .sync import SyncHandler
|
from .sync import SyncHandler
|
||||||
from .auth import AuthHandler
|
from .auth import AuthHandler
|
||||||
from .identity import IdentityHandler
|
from .identity import IdentityHandler
|
||||||
|
from .receipts import ReceiptsHandler
|
||||||
|
|
||||||
|
|
||||||
class Handlers(object):
|
class Handlers(object):
|
||||||
@ -53,10 +53,10 @@ class Handlers(object):
|
|||||||
self.profile_handler = ProfileHandler(hs)
|
self.profile_handler = ProfileHandler(hs)
|
||||||
self.presence_handler = PresenceHandler(hs)
|
self.presence_handler = PresenceHandler(hs)
|
||||||
self.room_list_handler = RoomListHandler(hs)
|
self.room_list_handler = RoomListHandler(hs)
|
||||||
self.login_handler = LoginHandler(hs)
|
|
||||||
self.directory_handler = DirectoryHandler(hs)
|
self.directory_handler = DirectoryHandler(hs)
|
||||||
self.typing_notification_handler = TypingNotificationHandler(hs)
|
self.typing_notification_handler = TypingNotificationHandler(hs)
|
||||||
self.admin_handler = AdminHandler(hs)
|
self.admin_handler = AdminHandler(hs)
|
||||||
|
self.receipts_handler = ReceiptsHandler(hs)
|
||||||
asapi = ApplicationServiceApi(hs)
|
asapi = ApplicationServiceApi(hs)
|
||||||
self.appservice_handler = ApplicationServicesHandler(
|
self.appservice_handler = ApplicationServicesHandler(
|
||||||
hs, asapi, AppServiceScheduler(
|
hs, asapi, AppServiceScheduler(
|
||||||
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
|||||||
from synapse.api.errors import LimitExceededError, SynapseError
|
from synapse.api.errors import LimitExceededError, SynapseError
|
||||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, RoomAlias
|
||||||
|
|
||||||
from synapse.util.logcontext import PreserveLoggingContext
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
@ -107,6 +107,22 @@ class BaseHandler(object):
|
|||||||
if not suppress_auth:
|
if not suppress_auth:
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
self.auth.check(event, auth_events=context.current_state)
|
||||||
|
|
||||||
|
if event.type == EventTypes.CanonicalAlias:
|
||||||
|
# Check the alias is acually valid (at this time at least)
|
||||||
|
room_alias_str = event.content.get("alias", None)
|
||||||
|
if room_alias_str:
|
||||||
|
room_alias = RoomAlias.from_string(room_alias_str)
|
||||||
|
directory_handler = self.hs.get_handlers().directory_handler
|
||||||
|
mapping = yield directory_handler.get_association(room_alias)
|
||||||
|
|
||||||
|
if mapping["room_id"] != event.room_id:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"Room alias %s does not point to the room" % (
|
||||||
|
room_alias_str,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
(event_stream_id, max_stream_id) = yield self.store.persist_event(
|
||||||
event, context=context
|
event, context=context
|
||||||
)
|
)
|
||||||
|
@ -47,17 +47,24 @@ class AuthHandler(BaseHandler):
|
|||||||
self.sessions = {}
|
self.sessions = {}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip=None):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
"""
|
"""
|
||||||
Takes a dictionary sent by the client in the login / registration
|
Takes a dictionary sent by the client in the login / registration
|
||||||
protocol and handles the login flow.
|
protocol and handles the login flow.
|
||||||
|
|
||||||
|
As a side effect, this function fills in the 'creds' key on the user's
|
||||||
|
session with a map, which maps each auth-type (str) to the relevant
|
||||||
|
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flows: list of list of stages
|
flows (list): A list of login flows. Each flow is an ordered list of
|
||||||
authdict: The dictionary from the client root level, not the
|
strings representing auth-types. At least one full
|
||||||
'auth' key: this method prompts for auth if none is sent.
|
flow must be completed in order for auth to be successful.
|
||||||
|
clientdict: The dictionary from the client root level, not the
|
||||||
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
|
clientip (str): The IP address of the client.
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of authed, dict, dict where authed is true if the client
|
A tuple of (authed, dict, dict) where authed is true if the client
|
||||||
has successfully completed an auth flow. If it is true, the first
|
has successfully completed an auth flow. If it is true, the first
|
||||||
dict contains the authenticated credentials of each stage.
|
dict contains the authenticated credentials of each stage.
|
||||||
|
|
||||||
@ -75,7 +82,7 @@ class AuthHandler(BaseHandler):
|
|||||||
del clientdict['auth']
|
del clientdict['auth']
|
||||||
if 'session' in authdict:
|
if 'session' in authdict:
|
||||||
sid = authdict['session']
|
sid = authdict['session']
|
||||||
sess = self._get_session_info(sid)
|
session = self._get_session_info(sid)
|
||||||
|
|
||||||
if len(clientdict) > 0:
|
if len(clientdict) > 0:
|
||||||
# This was designed to allow the client to omit the parameters
|
# This was designed to allow the client to omit the parameters
|
||||||
@ -85,20 +92,21 @@ class AuthHandler(BaseHandler):
|
|||||||
# email auth link on there). It's probably too open to abuse
|
# email auth link on there). It's probably too open to abuse
|
||||||
# because it lets unauthenticated clients store arbitrary objects
|
# because it lets unauthenticated clients store arbitrary objects
|
||||||
# on a home server.
|
# on a home server.
|
||||||
# sess['clientdict'] = clientdict
|
# Revisit: Assumimg the REST APIs do sensible validation, the data
|
||||||
# self._save_session(sess)
|
# isn't arbintrary.
|
||||||
pass
|
session['clientdict'] = clientdict
|
||||||
elif 'clientdict' in sess:
|
self._save_session(session)
|
||||||
clientdict = sess['clientdict']
|
elif 'clientdict' in session:
|
||||||
|
clientdict = session['clientdict']
|
||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
(False, self._auth_dict_for_flows(flows, sess), clientdict)
|
(False, self._auth_dict_for_flows(flows, session), clientdict)
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'creds' not in sess:
|
if 'creds' not in session:
|
||||||
sess['creds'] = {}
|
session['creds'] = {}
|
||||||
creds = sess['creds']
|
creds = session['creds']
|
||||||
|
|
||||||
# check auth type currently being presented
|
# check auth type currently being presented
|
||||||
if 'type' in authdict:
|
if 'type' in authdict:
|
||||||
@ -107,15 +115,15 @@ class AuthHandler(BaseHandler):
|
|||||||
result = yield self.checkers[authdict['type']](authdict, clientip)
|
result = yield self.checkers[authdict['type']](authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
creds[authdict['type']] = result
|
creds[authdict['type']] = result
|
||||||
self._save_session(sess)
|
self._save_session(session)
|
||||||
|
|
||||||
for f in flows:
|
for f in flows:
|
||||||
if len(set(f) - set(creds.keys())) == 0:
|
if len(set(f) - set(creds.keys())) == 0:
|
||||||
logger.info("Auth completed with creds: %r", creds)
|
logger.info("Auth completed with creds: %r", creds)
|
||||||
self._remove_session(sess)
|
self._remove_session(session)
|
||||||
defer.returnValue((True, creds, clientdict))
|
defer.returnValue((True, creds, clientdict))
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, sess)
|
ret = self._auth_dict_for_flows(flows, session)
|
||||||
ret['completed'] = creds.keys()
|
ret['completed'] = creds.keys()
|
||||||
defer.returnValue((False, ret, clientdict))
|
defer.returnValue((False, ret, clientdict))
|
||||||
|
|
||||||
@ -149,22 +157,14 @@ class AuthHandler(BaseHandler):
|
|||||||
if "user" not in authdict or "password" not in authdict:
|
if "user" not in authdict or "password" not in authdict:
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
user = authdict["user"]
|
user_id = authdict["user"]
|
||||||
password = authdict["password"]
|
password = authdict["password"]
|
||||||
if not user.startswith('@'):
|
if not user_id.startswith('@'):
|
||||||
user = UserID.create(user, self.hs.hostname).to_string()
|
user_id = UserID.create(user_id, self.hs.hostname).to_string()
|
||||||
|
|
||||||
user_info = yield self.store.get_user_by_id(user_id=user)
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
if not user_info:
|
self._check_password(user_id, password, password_hash)
|
||||||
logger.warn("Attempted to login as %s but they do not exist", user)
|
defer.returnValue(user_id)
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
stored_hash = user_info["password_hash"]
|
|
||||||
if bcrypt.checkpw(password, stored_hash):
|
|
||||||
defer.returnValue(user)
|
|
||||||
else:
|
|
||||||
logger.warn("Failed password login for user %s", user)
|
|
||||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_recaptcha(self, authdict, clientip):
|
def _check_recaptcha(self, authdict, clientip):
|
||||||
@ -268,6 +268,79 @@ class AuthHandler(BaseHandler):
|
|||||||
|
|
||||||
return self.sessions[session_id]
|
return self.sessions[session_id]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def login_with_password(self, user_id, password):
|
||||||
|
"""
|
||||||
|
Authenticates the user with their username and password.
|
||||||
|
|
||||||
|
Used only by the v1 login API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): User ID
|
||||||
|
password (str): Password
|
||||||
|
Returns:
|
||||||
|
The access token for the user's session.
|
||||||
|
Raises:
|
||||||
|
StoreError if there was a problem storing the token.
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
|
"""
|
||||||
|
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
|
||||||
|
self._check_password(user_id, password, password_hash)
|
||||||
|
|
||||||
|
reg_handler = self.hs.get_handlers().registration_handler
|
||||||
|
access_token = reg_handler.generate_token(user_id)
|
||||||
|
logger.info("Logging in user %s", user_id)
|
||||||
|
yield self.store.add_access_token_to_user(user_id, access_token)
|
||||||
|
defer.returnValue((user_id, access_token))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _find_user_id_and_pwd_hash(self, user_id):
|
||||||
|
"""Checks to see if a user with the given id exists. Will check case
|
||||||
|
insensitively, but will throw if there are multiple inexact matches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
|
||||||
|
"""
|
||||||
|
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
|
if not user_infos:
|
||||||
|
logger.warn("Attempted to login as %s but they do not exist", user_id)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
if len(user_infos) > 1:
|
||||||
|
if user_id not in user_infos:
|
||||||
|
logger.warn(
|
||||||
|
"Attempted to login as %s but it matches more than one user "
|
||||||
|
"inexactly: %r",
|
||||||
|
user_id, user_infos.keys()
|
||||||
|
)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
defer.returnValue((user_id, user_infos[user_id]))
|
||||||
|
else:
|
||||||
|
defer.returnValue(user_infos.popitem())
|
||||||
|
|
||||||
|
def _check_password(self, user_id, password, stored_hash):
|
||||||
|
"""Checks that user_id has passed password, raises LoginError if not."""
|
||||||
|
if not bcrypt.checkpw(password, stored_hash):
|
||||||
|
logger.warn("Failed password login for user %s", user_id)
|
||||||
|
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def set_password(self, user_id, newpassword):
|
||||||
|
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
||||||
|
|
||||||
|
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||||
|
yield self.store.user_delete_access_tokens(user_id)
|
||||||
|
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
|
||||||
|
yield self.store.flush_user(user_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def add_threepid(self, user_id, medium, address, validated_at):
|
||||||
|
yield self.store.user_add_threepid(
|
||||||
|
user_id, medium, address, validated_at,
|
||||||
|
self.hs.get_clock().time_msec()
|
||||||
|
)
|
||||||
|
|
||||||
def _save_session(self, session):
|
def _save_session(self, session):
|
||||||
# TODO: Persistent storage
|
# TODO: Persistent storage
|
||||||
logger.debug("Saving session %s", session)
|
logger.debug("Saving session %s", session)
|
||||||
|
@ -49,7 +49,12 @@ class EventStreamHandler(BaseHandler):
|
|||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_stream(self, auth_user_id, pagin_config, timeout=0,
|
def get_stream(self, auth_user_id, pagin_config, timeout=0,
|
||||||
as_client_event=True, affect_presence=True):
|
as_client_event=True, affect_presence=True,
|
||||||
|
only_room_events=False):
|
||||||
|
"""Fetches the events stream for a given user.
|
||||||
|
|
||||||
|
If `only_room_events` is `True` only room events will be returned.
|
||||||
|
"""
|
||||||
auth_user = UserID.from_string(auth_user_id)
|
auth_user = UserID.from_string(auth_user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -70,7 +75,15 @@ class EventStreamHandler(BaseHandler):
|
|||||||
self._streams_per_user[auth_user] += 1
|
self._streams_per_user[auth_user] += 1
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
rm_handler = self.hs.get_handlers().room_member_handler
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
|
|
||||||
|
app_service = yield self.store.get_app_service_by_user_id(
|
||||||
|
auth_user.to_string()
|
||||||
|
)
|
||||||
|
if app_service:
|
||||||
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
|
room_ids = set(r.room_id for r in rooms)
|
||||||
|
else:
|
||||||
|
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
|
||||||
|
|
||||||
if timeout:
|
if timeout:
|
||||||
# If they've set a timeout set a minimum limit.
|
# If they've set a timeout set a minimum limit.
|
||||||
@ -81,7 +94,8 @@ class EventStreamHandler(BaseHandler):
|
|||||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
||||||
|
|
||||||
events, tokens = yield self.notifier.get_events_for(
|
events, tokens = yield self.notifier.get_events_for(
|
||||||
auth_user, room_ids, pagin_config, timeout
|
auth_user, room_ids, pagin_config, timeout,
|
||||||
|
only_room_events=only_room_events
|
||||||
)
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
@ -31,6 +31,8 @@ from synapse.crypto.event_signing import (
|
|||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
from synapse.events.utils import prune_event
|
||||||
|
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
@ -138,26 +140,29 @@ class FederationHandler(BaseHandler):
|
|||||||
if state and auth_chain is not None:
|
if state and auth_chain is not None:
|
||||||
# If we have any state or auth_chain given to us by the replication
|
# If we have any state or auth_chain given to us by the replication
|
||||||
# layer, then we should handle them (if we haven't before.)
|
# layer, then we should handle them (if we haven't before.)
|
||||||
|
|
||||||
|
event_infos = []
|
||||||
|
|
||||||
for e in itertools.chain(auth_chain, state):
|
for e in itertools.chain(auth_chain, state):
|
||||||
if e.event_id in seen_ids:
|
if e.event_id in seen_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
auth = {
|
||||||
auth = {
|
(e.type, e.state_key): e for e in auth_chain
|
||||||
(e.type, e.state_key): e for e in auth_chain
|
if e.event_id in auth_ids
|
||||||
if e.event_id in auth_ids
|
}
|
||||||
}
|
event_infos.append({
|
||||||
yield self._handle_new_event(
|
"event": e,
|
||||||
origin, e, auth_events=auth
|
"auth_events": auth,
|
||||||
)
|
})
|
||||||
seen_ids.add(e.event_id)
|
seen_ids.add(e.event_id)
|
||||||
except:
|
|
||||||
logger.exception(
|
yield self._handle_new_events(
|
||||||
"Failed to handle state event %s",
|
origin,
|
||||||
e.event_id,
|
event_infos,
|
||||||
)
|
outliers=True
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_, event_stream_id, max_stream_id = yield self._handle_new_event(
|
_, event_stream_id, max_stream_id = yield self._handle_new_event(
|
||||||
@ -222,6 +227,55 @@ class FederationHandler(BaseHandler):
|
|||||||
"user_joined_room", user=user, room_id=event.room_id
|
"user_joined_room", user=user, room_id=event.room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _filter_events_for_server(self, server_name, room_id, events):
|
||||||
|
event_to_state = yield self.store.get_state_for_events(
|
||||||
|
room_id, frozenset(e.event_id for e in events),
|
||||||
|
types=(
|
||||||
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
(EventTypes.Member, None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def redact_disallowed(event, state):
|
||||||
|
if not state:
|
||||||
|
return event
|
||||||
|
|
||||||
|
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||||
|
if history:
|
||||||
|
visibility = history.content.get("history_visibility", "shared")
|
||||||
|
if visibility in ["invited", "joined"]:
|
||||||
|
# We now loop through all state events looking for
|
||||||
|
# membership states for the requesting server to determine
|
||||||
|
# if the server is either in the room or has been invited
|
||||||
|
# into the room.
|
||||||
|
for ev in state.values():
|
||||||
|
if ev.type != EventTypes.Member:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
domain = UserID.from_string(ev.state_key).domain
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if domain != server_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
memtype = ev.membership
|
||||||
|
if memtype == Membership.JOIN:
|
||||||
|
return event
|
||||||
|
elif memtype == Membership.INVITE:
|
||||||
|
if visibility == "invited":
|
||||||
|
return event
|
||||||
|
else:
|
||||||
|
return prune_event(event)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
defer.returnValue([
|
||||||
|
redact_disallowed(e, event_to_state[e.event_id])
|
||||||
|
for e in events
|
||||||
|
])
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def backfill(self, dest, room_id, limit, extremities=[]):
|
def backfill(self, dest, room_id, limit, extremities=[]):
|
||||||
@ -292,38 +346,29 @@ class FederationHandler(BaseHandler):
|
|||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
auth_events.update({a.event_id: a for a in results})
|
auth_events.update({a.event_id: a for a in results})
|
||||||
|
|
||||||
yield defer.gatherResults(
|
ev_infos = []
|
||||||
[
|
for a in auth_events.values():
|
||||||
self._handle_new_event(
|
if a.event_id in seen_events:
|
||||||
dest, a,
|
continue
|
||||||
auth_events={
|
ev_infos.append({
|
||||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
"event": a,
|
||||||
auth_events[a_id]
|
"auth_events": {
|
||||||
for a_id, _ in a.auth_events
|
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||||
},
|
auth_events[a_id]
|
||||||
)
|
for a_id, _ in a.auth_events
|
||||||
for a in auth_events.values()
|
}
|
||||||
if a.event_id not in seen_events
|
})
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
yield defer.gatherResults(
|
for e_id in events_to_state:
|
||||||
[
|
ev_infos.append({
|
||||||
self._handle_new_event(
|
"event": event_map[e_id],
|
||||||
dest, event_map[e_id],
|
"state": events_to_state[e_id],
|
||||||
state=events_to_state[e_id],
|
"auth_events": {
|
||||||
backfilled=True,
|
(auth_events[a_id].type, auth_events[a_id].state_key):
|
||||||
auth_events={
|
auth_events[a_id]
|
||||||
(auth_events[a_id].type, auth_events[a_id].state_key):
|
for a_id, _ in event_map[e_id].auth_events
|
||||||
auth_events[a_id]
|
}
|
||||||
for a_id, _ in event_map[e_id].auth_events
|
})
|
||||||
},
|
|
||||||
)
|
|
||||||
for e_id in events_to_state
|
|
||||||
],
|
|
||||||
consumeErrors=True
|
|
||||||
).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
events.sort(key=lambda e: e.depth)
|
events.sort(key=lambda e: e.depth)
|
||||||
|
|
||||||
@ -331,10 +376,14 @@ class FederationHandler(BaseHandler):
|
|||||||
if event in events_to_state:
|
if event in events_to_state:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield self._handle_new_event(
|
ev_infos.append({
|
||||||
dest, event,
|
"event": event,
|
||||||
backfilled=True,
|
})
|
||||||
)
|
|
||||||
|
yield self._handle_new_events(
|
||||||
|
dest, ev_infos,
|
||||||
|
backfilled=True,
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@ -453,7 +502,7 @@ class FederationHandler(BaseHandler):
|
|||||||
event_ids = list(extremities.keys())
|
event_ids = list(extremities.keys())
|
||||||
|
|
||||||
states = yield defer.gatherResults([
|
states = yield defer.gatherResults([
|
||||||
self.state_handler.resolve_state_groups([e])
|
self.state_handler.resolve_state_groups(room_id, [e])
|
||||||
for e in event_ids
|
for e in event_ids
|
||||||
])
|
])
|
||||||
states = dict(zip(event_ids, [s[1] for s in states]))
|
states = dict(zip(event_ids, [s[1] for s in states]))
|
||||||
@ -600,32 +649,22 @@ class FederationHandler(BaseHandler):
|
|||||||
# FIXME
|
# FIXME
|
||||||
pass
|
pass
|
||||||
|
|
||||||
yield self._handle_auth_events(
|
ev_infos = []
|
||||||
origin, [e for e in auth_chain if e.event_id != event.event_id]
|
for e in itertools.chain(state, auth_chain):
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def handle_state(e):
|
|
||||||
if e.event_id == event.event_id:
|
if e.event_id == event.event_id:
|
||||||
return
|
continue
|
||||||
|
|
||||||
e.internal_metadata.outlier = True
|
e.internal_metadata.outlier = True
|
||||||
try:
|
auth_ids = [e_id for e_id, _ in e.auth_events]
|
||||||
auth_ids = [e_id for e_id, _ in e.auth_events]
|
ev_infos.append({
|
||||||
auth = {
|
"event": e,
|
||||||
|
"auth_events": {
|
||||||
(e.type, e.state_key): e for e in auth_chain
|
(e.type, e.state_key): e for e in auth_chain
|
||||||
if e.event_id in auth_ids
|
if e.event_id in auth_ids
|
||||||
}
|
}
|
||||||
yield self._handle_new_event(
|
})
|
||||||
origin, e, auth_events=auth
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to handle state event %s",
|
|
||||||
e.event_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield defer.DeferredList([handle_state(e) for e in state])
|
yield self._handle_new_events(origin, ev_infos, outliers=True)
|
||||||
|
|
||||||
auth_ids = [e_id for e_id, _ in event.auth_events]
|
auth_ids = [e_id for e_id, _ in event.auth_events]
|
||||||
auth_events = {
|
auth_events = {
|
||||||
@ -835,7 +874,7 @@ class FederationHandler(BaseHandler):
|
|||||||
raise AuthError(403, "Host not in room.")
|
raise AuthError(403, "Host not in room.")
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
[event_id]
|
room_id, [event_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
if state_groups:
|
if state_groups:
|
||||||
@ -882,6 +921,8 @@ class FederationHandler(BaseHandler):
|
|||||||
limit
|
limit
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events = yield self._filter_events_for_server(origin, room_id, events)
|
||||||
|
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -940,11 +981,54 @@ class FederationHandler(BaseHandler):
|
|||||||
def _handle_new_event(self, origin, event, state=None, backfilled=False,
|
def _handle_new_event(self, origin, event, state=None, backfilled=False,
|
||||||
current_state=None, auth_events=None):
|
current_state=None, auth_events=None):
|
||||||
|
|
||||||
logger.debug(
|
outlier = event.internal_metadata.is_outlier()
|
||||||
"_handle_new_event: %s, sigs: %s",
|
|
||||||
event.event_id, event.signatures,
|
context = yield self._prep_event(
|
||||||
|
origin, event,
|
||||||
|
state=state,
|
||||||
|
backfilled=backfilled,
|
||||||
|
current_state=current_state,
|
||||||
|
auth_events=auth_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_stream_id, max_stream_id = yield self.store.persist_event(
|
||||||
|
event,
|
||||||
|
context=context,
|
||||||
|
backfilled=backfilled,
|
||||||
|
is_new_state=(not outlier and not backfilled),
|
||||||
|
current_state=current_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((context, event_stream_id, max_stream_id))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_new_events(self, origin, event_infos, backfilled=False,
|
||||||
|
outliers=False):
|
||||||
|
contexts = yield defer.gatherResults(
|
||||||
|
[
|
||||||
|
self._prep_event(
|
||||||
|
origin,
|
||||||
|
ev_info["event"],
|
||||||
|
state=ev_info.get("state"),
|
||||||
|
backfilled=backfilled,
|
||||||
|
auth_events=ev_info.get("auth_events"),
|
||||||
|
)
|
||||||
|
for ev_info in event_infos
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self.store.persist_events(
|
||||||
|
[
|
||||||
|
(ev_info["event"], context)
|
||||||
|
for ev_info, context in itertools.izip(event_infos, contexts)
|
||||||
|
],
|
||||||
|
backfilled=backfilled,
|
||||||
|
is_new_state=(not outliers and not backfilled),
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _prep_event(self, origin, event, state=None, backfilled=False,
|
||||||
|
current_state=None, auth_events=None):
|
||||||
outlier = event.internal_metadata.is_outlier()
|
outlier = event.internal_metadata.is_outlier()
|
||||||
|
|
||||||
context = yield self.state_handler.compute_event_context(
|
context = yield self.state_handler.compute_event_context(
|
||||||
@ -954,13 +1038,6 @@ class FederationHandler(BaseHandler):
|
|||||||
if not auth_events:
|
if not auth_events:
|
||||||
auth_events = context.current_state
|
auth_events = context.current_state
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"_handle_new_event: %s, auth_events: %s",
|
|
||||||
event.event_id, auth_events,
|
|
||||||
)
|
|
||||||
|
|
||||||
is_new_state = not outlier
|
|
||||||
|
|
||||||
# This is a hack to fix some old rooms where the initial join event
|
# This is a hack to fix some old rooms where the initial join event
|
||||||
# didn't reference the create event in its auth events.
|
# didn't reference the create event in its auth events.
|
||||||
if event.type == EventTypes.Member and not event.auth_events:
|
if event.type == EventTypes.Member and not event.auth_events:
|
||||||
@ -984,26 +1061,7 @@ class FederationHandler(BaseHandler):
|
|||||||
|
|
||||||
context.rejected = RejectedReason.AUTH_ERROR
|
context.rejected = RejectedReason.AUTH_ERROR
|
||||||
|
|
||||||
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
|
defer.returnValue(context)
|
||||||
# seen all the auth events.
|
|
||||||
yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
context=context,
|
|
||||||
backfilled=backfilled,
|
|
||||||
is_new_state=False,
|
|
||||||
current_state=current_state,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
event_stream_id, max_stream_id = yield self.store.persist_event(
|
|
||||||
event,
|
|
||||||
context=context,
|
|
||||||
backfilled=backfilled,
|
|
||||||
is_new_state=(is_new_state and not backfilled),
|
|
||||||
current_state=current_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue((context, event_stream_id, max_stream_id))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
|
||||||
@ -1066,14 +1124,24 @@ class FederationHandler(BaseHandler):
|
|||||||
@log_function
|
@log_function
|
||||||
def do_auth(self, origin, event, context, auth_events):
|
def do_auth(self, origin, event, context, auth_events):
|
||||||
# Check if we have all the auth events.
|
# Check if we have all the auth events.
|
||||||
have_events = yield self.store.have_events(
|
current_state = set(e.event_id for e in auth_events.values())
|
||||||
[e_id for e_id, _ in event.auth_events]
|
|
||||||
)
|
|
||||||
|
|
||||||
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
event_auth_events = set(e_id for e_id, _ in event.auth_events)
|
||||||
|
|
||||||
|
if event_auth_events - current_state:
|
||||||
|
have_events = yield self.store.have_events(
|
||||||
|
event_auth_events - current_state
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
have_events = {}
|
||||||
|
|
||||||
|
have_events.update({
|
||||||
|
e.event_id: ""
|
||||||
|
for e in auth_events.values()
|
||||||
|
})
|
||||||
|
|
||||||
seen_events = set(have_events.keys())
|
seen_events = set(have_events.keys())
|
||||||
|
|
||||||
missing_auth = event_auth_events - seen_events
|
missing_auth = event_auth_events - seen_events - current_state
|
||||||
|
|
||||||
if missing_auth:
|
if missing_auth:
|
||||||
logger.info("Missing auth: %s", missing_auth)
|
logger.info("Missing auth: %s", missing_auth)
|
||||||
|
@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
|
|||||||
http_client = SimpleHttpClient(self.hs)
|
http_client = SimpleHttpClient(self.hs)
|
||||||
# XXX: make this configurable!
|
# XXX: make this configurable!
|
||||||
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
||||||
trustedIdServers = ['matrix.org']
|
trustedIdServers = ['matrix.org', 'vector.im']
|
||||||
|
|
||||||
if 'id_server' in creds:
|
if 'id_server' in creds:
|
||||||
id_server = creds['id_server']
|
id_server = creds['id_server']
|
||||||
@ -117,3 +117,28 @@ class IdentityHandler(BaseHandler):
|
|||||||
except CodeMessageException as e:
|
except CodeMessageException as e:
|
||||||
data = json.loads(e.msg)
|
data = json.loads(e.msg)
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||||
|
yield run_on_reactor()
|
||||||
|
http_client = SimpleHttpClient(self.hs)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
'email': email,
|
||||||
|
'client_secret': client_secret,
|
||||||
|
'send_attempt': send_attempt,
|
||||||
|
}
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = yield http_client.post_urlencoded_get_json(
|
||||||
|
"https://%s%s" % (
|
||||||
|
id_server,
|
||||||
|
"/_matrix/identity/api/v1/validate/email/requestToken"
|
||||||
|
),
|
||||||
|
params
|
||||||
|
)
|
||||||
|
defer.returnValue(data)
|
||||||
|
except CodeMessageException as e:
|
||||||
|
logger.info("Proxied requestToken failed: %r", e)
|
||||||
|
raise e
|
||||||
|
@ -1,83 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014, 2015 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from ._base import BaseHandler
|
|
||||||
from synapse.api.errors import LoginError, Codes
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginHandler(BaseHandler):
|
|
||||||
|
|
||||||
def __init__(self, hs):
|
|
||||||
super(LoginHandler, self).__init__(hs)
|
|
||||||
self.hs = hs
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def login(self, user, password):
|
|
||||||
"""Login as the specified user with the specified password.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user (str): The user ID.
|
|
||||||
password (str): The password.
|
|
||||||
Returns:
|
|
||||||
The newly allocated access token.
|
|
||||||
Raises:
|
|
||||||
StoreError if there was a problem storing the token.
|
|
||||||
LoginError if there was an authentication problem.
|
|
||||||
"""
|
|
||||||
# TODO do this better, it can't go in __init__ else it cyclic loops
|
|
||||||
if not hasattr(self, "reg_handler"):
|
|
||||||
self.reg_handler = self.hs.get_handlers().registration_handler
|
|
||||||
|
|
||||||
# pull out the hash for this user if they exist
|
|
||||||
user_info = yield self.store.get_user_by_id(user_id=user)
|
|
||||||
if not user_info:
|
|
||||||
logger.warn("Attempted to login as %s but they do not exist", user)
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
stored_hash = user_info["password_hash"]
|
|
||||||
if bcrypt.checkpw(password, stored_hash):
|
|
||||||
# generate an access token and store it.
|
|
||||||
token = self.reg_handler._generate_token(user)
|
|
||||||
logger.info("Adding token %s for user %s", token, user)
|
|
||||||
yield self.store.add_access_token_to_user(user, token)
|
|
||||||
defer.returnValue(token)
|
|
||||||
else:
|
|
||||||
logger.warn("Failed password login for user %s", user)
|
|
||||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def set_password(self, user_id, newpassword, token_id=None):
|
|
||||||
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
|
||||||
|
|
||||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
|
||||||
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
|
|
||||||
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
|
|
||||||
user_id, token_id
|
|
||||||
)
|
|
||||||
yield self.store.flush_user(user_id)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def add_threepid(self, user_id, medium, address, validated_at):
|
|
||||||
yield self.store.user_add_threepid(
|
|
||||||
user_id, medium, address, validated_at,
|
|
||||||
self.hs.get_clock().time_msec()
|
|
||||||
)
|
|
@ -113,11 +113,21 @@ class MessageHandler(BaseHandler):
|
|||||||
"room_key", next_key
|
"room_key", next_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
defer.returnValue({
|
||||||
|
"chunk": [],
|
||||||
|
"start": pagin_config.from_token.to_string(),
|
||||||
|
"end": next_token.to_string(),
|
||||||
|
})
|
||||||
|
|
||||||
|
events = yield self._filter_events_for_client(user_id, room_id, events)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
chunk = {
|
chunk = {
|
||||||
"chunk": [
|
"chunk": [
|
||||||
serialize_event(e, time_now, as_client_event) for e in events
|
serialize_event(e, time_now, as_client_event)
|
||||||
|
for e in events
|
||||||
],
|
],
|
||||||
"start": pagin_config.from_token.to_string(),
|
"start": pagin_config.from_token.to_string(),
|
||||||
"end": next_token.to_string(),
|
"end": next_token.to_string(),
|
||||||
@ -125,6 +135,52 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
defer.returnValue(chunk)
|
defer.returnValue(chunk)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _filter_events_for_client(self, user_id, room_id, events):
|
||||||
|
event_id_to_state = yield self.store.get_state_for_events(
|
||||||
|
room_id, frozenset(e.event_id for e in events),
|
||||||
|
types=(
|
||||||
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
(EventTypes.Member, user_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def allowed(event, state):
|
||||||
|
if event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
return True
|
||||||
|
|
||||||
|
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||||
|
if membership_ev:
|
||||||
|
membership = membership_ev.membership
|
||||||
|
else:
|
||||||
|
membership = Membership.LEAVE
|
||||||
|
|
||||||
|
if membership == Membership.JOIN:
|
||||||
|
return True
|
||||||
|
|
||||||
|
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||||
|
if history:
|
||||||
|
visibility = history.content.get("history_visibility", "shared")
|
||||||
|
else:
|
||||||
|
visibility = "shared"
|
||||||
|
|
||||||
|
if visibility == "public":
|
||||||
|
return True
|
||||||
|
elif visibility == "shared":
|
||||||
|
return True
|
||||||
|
elif visibility == "joined":
|
||||||
|
return membership == Membership.JOIN
|
||||||
|
elif visibility == "invited":
|
||||||
|
return membership == Membership.INVITE
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
defer.returnValue([
|
||||||
|
event
|
||||||
|
for event in events
|
||||||
|
if allowed(event, event_id_to_state[event.event_id])
|
||||||
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_and_send_event(self, event_dict, ratelimit=True,
|
def create_and_send_event(self, event_dict, ratelimit=True,
|
||||||
client=None, txn_id=None):
|
client=None, txn_id=None):
|
||||||
@ -278,6 +334,11 @@ class MessageHandler(BaseHandler):
|
|||||||
user, pagination_config.get_source_config("presence"), None
|
user, pagination_config.get_source_config("presence"), None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
receipt_stream = self.hs.get_event_sources().sources["receipt"]
|
||||||
|
receipt, _ = yield receipt_stream.get_pagination_rows(
|
||||||
|
user, pagination_config.get_source_config("receipt"), None
|
||||||
|
)
|
||||||
|
|
||||||
public_room_ids = yield self.store.get_public_room_ids()
|
public_room_ids = yield self.store.get_public_room_ids()
|
||||||
|
|
||||||
limit = pagin_config.limit
|
limit = pagin_config.limit
|
||||||
@ -316,6 +377,10 @@ class MessageHandler(BaseHandler):
|
|||||||
]
|
]
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
messages = yield self._filter_events_for_client(
|
||||||
|
user_id, event.room_id, messages
|
||||||
|
)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
@ -336,15 +401,20 @@ class MessageHandler(BaseHandler):
|
|||||||
except:
|
except:
|
||||||
logger.exception("Failed to get snapshot")
|
logger.exception("Failed to get snapshot")
|
||||||
|
|
||||||
yield defer.gatherResults(
|
# Only do N rooms at once
|
||||||
[handle_room(e) for e in room_list],
|
n = 5
|
||||||
consumeErrors=True
|
d_list = [handle_room(e) for e in room_list]
|
||||||
).addErrback(unwrapFirstError)
|
for i in range(0, len(d_list), n):
|
||||||
|
yield defer.gatherResults(
|
||||||
|
d_list[i:i + n],
|
||||||
|
consumeErrors=True
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"rooms": rooms_ret,
|
"rooms": rooms_ret,
|
||||||
"presence": presence,
|
"presence": presence,
|
||||||
"end": now_token.to_string()
|
"receipts": receipt,
|
||||||
|
"end": now_token.to_string(),
|
||||||
}
|
}
|
||||||
|
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
@ -390,24 +460,21 @@ class MessageHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_presence():
|
def get_presence():
|
||||||
presence_defs = yield defer.DeferredList(
|
states = yield presence_handler.get_states(
|
||||||
[
|
target_users=[UserID.from_string(m.user_id) for m in room_members],
|
||||||
presence_handler.get_state(
|
auth_user=auth_user,
|
||||||
target_user=UserID.from_string(m.user_id),
|
as_event=True,
|
||||||
auth_user=auth_user,
|
check_auth=False,
|
||||||
as_event=True,
|
|
||||||
check_auth=False,
|
|
||||||
)
|
|
||||||
for m in room_members
|
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue([p for success, p in presence_defs if success])
|
defer.returnValue(states.values())
|
||||||
|
|
||||||
presence, (messages, token) = yield defer.gatherResults(
|
receipts_handler = self.hs.get_handlers().receipts_handler
|
||||||
|
|
||||||
|
presence, receipts, (messages, token) = yield defer.gatherResults(
|
||||||
[
|
[
|
||||||
get_presence(),
|
get_presence(),
|
||||||
|
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
|
||||||
self.store.get_recent_events_for_room(
|
self.store.get_recent_events_for_room(
|
||||||
room_id,
|
room_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@ -417,6 +484,10 @@ class MessageHandler(BaseHandler):
|
|||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError)
|
||||||
|
|
||||||
|
messages = yield self._filter_events_for_client(
|
||||||
|
user_id, room_id, messages
|
||||||
|
)
|
||||||
|
|
||||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||||
end_token = now_token.copy_and_replace("room_key", token[1])
|
end_token = now_token.copy_and_replace("room_key", token[1])
|
||||||
|
|
||||||
@ -431,5 +502,6 @@ class MessageHandler(BaseHandler):
|
|||||||
"end": end_token.to_string(),
|
"end": end_token.to_string(),
|
||||||
},
|
},
|
||||||
"state": state,
|
"state": state,
|
||||||
"presence": presence
|
"presence": presence,
|
||||||
|
"receipts": receipts,
|
||||||
})
|
})
|
||||||
|
@ -192,6 +192,20 @@ class PresenceHandler(BaseHandler):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
|
def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
|
||||||
|
"""Get the current presence state of the given user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_user (UserID): The user whose presence we want
|
||||||
|
auth_user (UserID): The user requesting the presence, used for
|
||||||
|
checking if said user is allowed to see the persence of the
|
||||||
|
`target_user`
|
||||||
|
as_event (bool): Format the return as an event or not?
|
||||||
|
check_auth (bool): Perform the auth checks or not?
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: The presence state of the `target_user`, whose format depends
|
||||||
|
on the `as_event` argument.
|
||||||
|
"""
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
if check_auth:
|
if check_auth:
|
||||||
visible = yield self.is_presence_visible(
|
visible = yield self.is_presence_visible(
|
||||||
@ -232,6 +246,81 @@ class PresenceHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
defer.returnValue(state)
|
defer.returnValue(state)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
|
||||||
|
"""A batched version of the `get_state` method that accepts a list of
|
||||||
|
`target_users`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_users (list): The list of UserID's whose presence we want
|
||||||
|
auth_user (UserID): The user requesting the presence, used for
|
||||||
|
checking if said user is allowed to see the persence of the
|
||||||
|
`target_users`
|
||||||
|
as_event (bool): Format the return as an event or not?
|
||||||
|
check_auth (bool): Perform the auth checks or not?
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A mapping from user -> presence_state
|
||||||
|
"""
|
||||||
|
local_users, remote_users = partitionbool(
|
||||||
|
target_users,
|
||||||
|
lambda u: self.hs.is_mine(u)
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_auth:
|
||||||
|
for user in local_users:
|
||||||
|
visible = yield self.is_presence_visible(
|
||||||
|
observer_user=auth_user,
|
||||||
|
observed_user=user
|
||||||
|
)
|
||||||
|
|
||||||
|
if not visible:
|
||||||
|
raise SynapseError(404, "Presence information not visible")
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
if local_users:
|
||||||
|
for user in local_users:
|
||||||
|
if user in self._user_cachemap:
|
||||||
|
results[user] = self._user_cachemap[user].get_state()
|
||||||
|
|
||||||
|
local_to_user = {u.localpart: u for u in local_users}
|
||||||
|
|
||||||
|
states = yield self.store.get_presence_states(
|
||||||
|
[u.localpart for u in local_users if u not in results]
|
||||||
|
)
|
||||||
|
|
||||||
|
for local_part, state in states.items():
|
||||||
|
if state is None:
|
||||||
|
continue
|
||||||
|
res = {"presence": state["state"]}
|
||||||
|
if "status_msg" in state and state["status_msg"]:
|
||||||
|
res["status_msg"] = state["status_msg"]
|
||||||
|
results[local_to_user[local_part]] = res
|
||||||
|
|
||||||
|
for user in remote_users:
|
||||||
|
# TODO(paul): Have remote server send us permissions set
|
||||||
|
results[user] = self._get_or_offline_usercache(user).get_state()
|
||||||
|
|
||||||
|
for state in results.values():
|
||||||
|
if "last_active" in state:
|
||||||
|
state["last_active_ago"] = int(
|
||||||
|
self.clock.time_msec() - state.pop("last_active")
|
||||||
|
)
|
||||||
|
|
||||||
|
if as_event:
|
||||||
|
for user, state in results.items():
|
||||||
|
content = state
|
||||||
|
content["user_id"] = user.to_string()
|
||||||
|
|
||||||
|
if "last_active" in content:
|
||||||
|
content["last_active_ago"] = int(
|
||||||
|
self._clock.time_msec() - content.pop("last_active")
|
||||||
|
)
|
||||||
|
|
||||||
|
results[user] = {"type": "m.presence", "content": content}
|
||||||
|
|
||||||
|
defer.returnValue(results)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def set_state(self, target_user, auth_user, state):
|
def set_state(self, target_user, auth_user, state):
|
||||||
@ -992,7 +1081,7 @@ class PresenceHandler(BaseHandler):
|
|||||||
room_ids([str]): List of room_ids to notify.
|
room_ids([str]): List of room_ids to notify.
|
||||||
"""
|
"""
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
self.notifier.on_new_user_event(
|
self.notifier.on_new_event(
|
||||||
"presence_key",
|
"presence_key",
|
||||||
self._user_cachemap_latest_serial,
|
self._user_cachemap_latest_serial,
|
||||||
users_to_push,
|
users_to_push,
|
||||||
|
210
synapse/handlers/receipts.py
Normal file
210
synapse/handlers/receipts.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptsHandler(BaseHandler):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReceiptsHandler, self).__init__(hs)
|
||||||
|
|
||||||
|
self.hs = hs
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self.federation.register_edu_handler(
|
||||||
|
"m.receipt", self._received_remote_receipt
|
||||||
|
)
|
||||||
|
self.clock = self.hs.get_clock()
|
||||||
|
|
||||||
|
self._receipt_cache = None
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||||
|
event_id):
|
||||||
|
"""Called when a client tells us a local user has read up to the given
|
||||||
|
event_id in the room.
|
||||||
|
"""
|
||||||
|
receipt = {
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_ids": [event_id],
|
||||||
|
"data": {
|
||||||
|
"ts": int(self.clock.time_msec()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
is_new = yield self._handle_new_receipts([receipt])
|
||||||
|
|
||||||
|
if is_new:
|
||||||
|
self._push_remotes([receipt])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _received_remote_receipt(self, origin, content):
|
||||||
|
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
||||||
|
"""
|
||||||
|
receipts = [
|
||||||
|
{
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_ids": user_values["event_ids"],
|
||||||
|
"data": user_values.get("data", {}),
|
||||||
|
}
|
||||||
|
for room_id, room_values in content.items()
|
||||||
|
for receipt_type, users in room_values.items()
|
||||||
|
for user_id, user_values in users.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
yield self._handle_new_receipts(receipts)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_new_receipts(self, receipts):
|
||||||
|
"""Takes a list of receipts, stores them and informs the notifier.
|
||||||
|
"""
|
||||||
|
for receipt in receipts:
|
||||||
|
room_id = receipt["room_id"]
|
||||||
|
receipt_type = receipt["receipt_type"]
|
||||||
|
user_id = receipt["user_id"]
|
||||||
|
event_ids = receipt["event_ids"]
|
||||||
|
data = receipt["data"]
|
||||||
|
|
||||||
|
res = yield self.store.insert_receipt(
|
||||||
|
room_id, receipt_type, user_id, event_ids, data
|
||||||
|
)
|
||||||
|
|
||||||
|
if not res:
|
||||||
|
# res will be None if this read receipt is 'old'
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
stream_id, max_persisted_id = res
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
self.notifier.on_new_event(
|
||||||
|
"receipt_key", max_persisted_id, rooms=[room_id]
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _push_remotes(self, receipts):
|
||||||
|
"""Given a list of receipts, works out which remote servers should be
|
||||||
|
poked and pokes them.
|
||||||
|
"""
|
||||||
|
# TODO: Some of this stuff should be coallesced.
|
||||||
|
for receipt in receipts:
|
||||||
|
room_id = receipt["room_id"]
|
||||||
|
receipt_type = receipt["receipt_type"]
|
||||||
|
user_id = receipt["user_id"]
|
||||||
|
event_ids = receipt["event_ids"]
|
||||||
|
data = receipt["data"]
|
||||||
|
|
||||||
|
remotedomains = set()
|
||||||
|
|
||||||
|
rm_handler = self.hs.get_handlers().room_member_handler
|
||||||
|
yield rm_handler.fetch_room_distributions_into(
|
||||||
|
room_id, localusers=None, remotedomains=remotedomains
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Sending receipt to: %r", remotedomains)
|
||||||
|
|
||||||
|
for domain in remotedomains:
|
||||||
|
self.federation.send_edu(
|
||||||
|
destination=domain,
|
||||||
|
edu_type="m.receipt",
|
||||||
|
content={
|
||||||
|
room_id: {
|
||||||
|
receipt_type: {
|
||||||
|
user_id: {
|
||||||
|
"event_ids": event_ids,
|
||||||
|
"data": data,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_receipts_for_room(self, room_id, to_key):
|
||||||
|
"""Gets all receipts for a room, upto the given key.
|
||||||
|
"""
|
||||||
|
result = yield self.store.get_linearized_receipts_for_room(
|
||||||
|
room_id,
|
||||||
|
to_key=to_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"type": "m.receipt",
|
||||||
|
"room_id": room_id,
|
||||||
|
"content": result,
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue([event])
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptEventSource(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_new_events_for_user(self, user, from_key, limit):
|
||||||
|
from_key = int(from_key)
|
||||||
|
to_key = yield self.get_current_key()
|
||||||
|
|
||||||
|
if from_key == to_key:
|
||||||
|
defer.returnValue(([], to_key))
|
||||||
|
|
||||||
|
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||||
|
rooms = [room.room_id for room in rooms]
|
||||||
|
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||||
|
rooms,
|
||||||
|
from_key=from_key,
|
||||||
|
to_key=to_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((events, to_key))
|
||||||
|
|
||||||
|
def get_current_key(self, direction='f'):
|
||||||
|
return self.store.get_max_receipt_stream_id()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_pagination_rows(self, user, config, key):
|
||||||
|
to_key = int(config.from_key)
|
||||||
|
|
||||||
|
if config.to_key:
|
||||||
|
from_key = int(config.to_key)
|
||||||
|
else:
|
||||||
|
from_key = None
|
||||||
|
|
||||||
|
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||||
|
rooms = [room.room_id for room in rooms]
|
||||||
|
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||||
|
rooms,
|
||||||
|
from_key=from_key,
|
||||||
|
to_key=to_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((events, to_key))
|
@ -57,8 +57,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
|
||||||
u = yield self.store.get_user_by_id(user_id)
|
users = yield self.store.get_users_by_id_case_insensitive(user_id)
|
||||||
if u:
|
if users:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"User ID already taken.",
|
"User ID already taken.",
|
||||||
@ -73,7 +73,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
localpart : The local part of the user ID to register. If None,
|
localpart : The local part of the user ID to register. If None,
|
||||||
one will be randomly generated.
|
one will be randomly generated.
|
||||||
password (str) : The password to assign to this user so they can
|
password (str) : The password to assign to this user so they can
|
||||||
login again.
|
login again. This can be None which means they cannot login again
|
||||||
|
via a password (e.g. the user is an application service user).
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (user_id, access_token).
|
A tuple of (user_id, access_token).
|
||||||
Raises:
|
Raises:
|
||||||
@ -90,7 +91,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
token = self._generate_token(user_id)
|
token = self.generate_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
@ -110,7 +111,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
|
||||||
token = self._generate_token(user_id)
|
token = self.generate_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
@ -160,7 +161,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
400, "Invalid user localpart for this application service.",
|
400, "Invalid user localpart for this application service.",
|
||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
token = self._generate_token(user_id)
|
token = self.generate_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
token=token,
|
token=token,
|
||||||
@ -192,6 +193,35 @@ class RegistrationHandler(BaseHandler):
|
|||||||
else:
|
else:
|
||||||
logger.info("Valid captcha entered from %s", ip)
|
logger.info("Valid captcha entered from %s", ip)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def register_saml2(self, localpart):
|
||||||
|
"""
|
||||||
|
Registers email_id as SAML2 Based Auth.
|
||||||
|
"""
|
||||||
|
if urllib.quote(localpart) != localpart:
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
"User ID must only contain characters which do not"
|
||||||
|
" require URL encoding."
|
||||||
|
)
|
||||||
|
user = UserID(localpart, self.hs.hostname)
|
||||||
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
token = self.generate_token(user_id)
|
||||||
|
try:
|
||||||
|
yield self.store.register(
|
||||||
|
user_id=user_id,
|
||||||
|
token=token,
|
||||||
|
password_hash=None
|
||||||
|
)
|
||||||
|
yield self.distributor.fire("registered_user", user)
|
||||||
|
except Exception, e:
|
||||||
|
yield self.store.add_access_token_to_user(user_id, token)
|
||||||
|
# Ignore Registration errors
|
||||||
|
logger.exception(e)
|
||||||
|
defer.returnValue((user_id, token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register_email(self, threepidCreds):
|
def register_email(self, threepidCreds):
|
||||||
"""
|
"""
|
||||||
@ -243,7 +273,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
errcode=Codes.EXCLUSIVE
|
errcode=Codes.EXCLUSIVE
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_token(self, user_id):
|
def generate_token(self, user_id):
|
||||||
# urlsafe variant uses _ and - so use . as the separator and replace
|
# urlsafe variant uses _ and - so use . as the separator and replace
|
||||||
# all =s with .s so http clients don't quote =s when it is used as
|
# all =s with .s so http clients don't quote =s when it is used as
|
||||||
# query params.
|
# query params.
|
||||||
|
@ -19,12 +19,15 @@ from twisted.internet import defer
|
|||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
from synapse.types import UserID, RoomAlias, RoomID
|
from synapse.types import UserID, RoomAlias, RoomID
|
||||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
from synapse.api.constants import (
|
||||||
|
EventTypes, Membership, JoinRules, RoomCreationPreset,
|
||||||
|
)
|
||||||
from synapse.api.errors import StoreError, SynapseError
|
from synapse.api.errors import StoreError, SynapseError
|
||||||
from synapse.util import stringutils, unwrapFirstError
|
from synapse.util import stringutils, unwrapFirstError
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
import logging
|
import logging
|
||||||
import string
|
import string
|
||||||
|
|
||||||
@ -33,6 +36,19 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class RoomCreationHandler(BaseHandler):
|
class RoomCreationHandler(BaseHandler):
|
||||||
|
|
||||||
|
PRESETS_DICT = {
|
||||||
|
RoomCreationPreset.PRIVATE_CHAT: {
|
||||||
|
"join_rules": JoinRules.INVITE,
|
||||||
|
"history_visibility": "invited",
|
||||||
|
"original_invitees_have_ops": False,
|
||||||
|
},
|
||||||
|
RoomCreationPreset.PUBLIC_CHAT: {
|
||||||
|
"join_rules": JoinRules.PUBLIC,
|
||||||
|
"history_visibility": "shared",
|
||||||
|
"original_invitees_have_ops": False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_room(self, user_id, room_id, config):
|
def create_room(self, user_id, room_id, config):
|
||||||
""" Creates a new room.
|
""" Creates a new room.
|
||||||
@ -121,9 +137,25 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
servers=[self.hs.hostname],
|
servers=[self.hs.hostname],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
preset_config = config.get(
|
||||||
|
"preset",
|
||||||
|
RoomCreationPreset.PUBLIC_CHAT
|
||||||
|
if is_public
|
||||||
|
else RoomCreationPreset.PRIVATE_CHAT
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_initial_state = config.get("initial_state", [])
|
||||||
|
|
||||||
|
initial_state = OrderedDict()
|
||||||
|
for val in raw_initial_state:
|
||||||
|
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
|
||||||
|
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
creation_events = self._create_events_for_new_room(
|
creation_events = self._create_events_for_new_room(
|
||||||
user, room_id, is_public=is_public
|
user, room_id,
|
||||||
|
preset_config=preset_config,
|
||||||
|
invite_list=invite_list,
|
||||||
|
initial_state=initial_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
msg_handler = self.hs.get_handlers().message_handler
|
msg_handler = self.hs.get_handlers().message_handler
|
||||||
@ -170,7 +202,10 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
|
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _create_events_for_new_room(self, creator, room_id, is_public=False):
|
def _create_events_for_new_room(self, creator, room_id, preset_config,
|
||||||
|
invite_list, initial_state):
|
||||||
|
config = RoomCreationHandler.PRESETS_DICT[preset_config]
|
||||||
|
|
||||||
creator_id = creator.to_string()
|
creator_id = creator.to_string()
|
||||||
|
|
||||||
event_keys = {
|
event_keys = {
|
||||||
@ -203,16 +238,20 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
power_levels_event = create(
|
returned_events = [creation_event, join_event]
|
||||||
etype=EventTypes.PowerLevels,
|
|
||||||
content={
|
if (EventTypes.PowerLevels, '') not in initial_state:
|
||||||
|
power_level_content = {
|
||||||
"users": {
|
"users": {
|
||||||
creator.to_string(): 100,
|
creator.to_string(): 100,
|
||||||
},
|
},
|
||||||
"users_default": 0,
|
"users_default": 0,
|
||||||
"events": {
|
"events": {
|
||||||
EventTypes.Name: 100,
|
EventTypes.Name: 50,
|
||||||
EventTypes.PowerLevels: 100,
|
EventTypes.PowerLevels: 100,
|
||||||
|
EventTypes.RoomHistoryVisibility: 100,
|
||||||
|
EventTypes.CanonicalAlias: 50,
|
||||||
|
EventTypes.RoomAvatar: 50,
|
||||||
},
|
},
|
||||||
"events_default": 0,
|
"events_default": 0,
|
||||||
"state_default": 50,
|
"state_default": 50,
|
||||||
@ -220,21 +259,43 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"kick": 50,
|
"kick": 50,
|
||||||
"redact": 50,
|
"redact": 50,
|
||||||
"invite": 0,
|
"invite": 0,
|
||||||
},
|
}
|
||||||
)
|
|
||||||
|
|
||||||
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
|
if config["original_invitees_have_ops"]:
|
||||||
join_rules_event = create(
|
for invitee in invite_list:
|
||||||
etype=EventTypes.JoinRules,
|
power_level_content["users"][invitee] = 100
|
||||||
content={"join_rule": join_rule},
|
|
||||||
)
|
|
||||||
|
|
||||||
return [
|
power_levels_event = create(
|
||||||
creation_event,
|
etype=EventTypes.PowerLevels,
|
||||||
join_event,
|
content=power_level_content,
|
||||||
power_levels_event,
|
)
|
||||||
join_rules_event,
|
|
||||||
]
|
returned_events.append(power_levels_event)
|
||||||
|
|
||||||
|
if (EventTypes.JoinRules, '') not in initial_state:
|
||||||
|
join_rules_event = create(
|
||||||
|
etype=EventTypes.JoinRules,
|
||||||
|
content={"join_rule": config["join_rules"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
returned_events.append(join_rules_event)
|
||||||
|
|
||||||
|
if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
|
||||||
|
history_event = create(
|
||||||
|
etype=EventTypes.RoomHistoryVisibility,
|
||||||
|
content={"history_visibility": config["history_visibility"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
returned_events.append(history_event)
|
||||||
|
|
||||||
|
for (etype, state_key), content in initial_state.items():
|
||||||
|
returned_events.append(create(
|
||||||
|
etype=etype,
|
||||||
|
state_key=state_key,
|
||||||
|
content=content,
|
||||||
|
))
|
||||||
|
|
||||||
|
return returned_events
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberHandler(BaseHandler):
|
class RoomMemberHandler(BaseHandler):
|
||||||
@ -498,15 +559,9 @@ class RoomMemberHandler(BaseHandler):
|
|||||||
"""Returns a list of roomids that the user has any of the given
|
"""Returns a list of roomids that the user has any of the given
|
||||||
membership states in."""
|
membership states in."""
|
||||||
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(
|
rooms = yield self.store.get_rooms_for_user(
|
||||||
user.to_string()
|
user.to_string(),
|
||||||
)
|
)
|
||||||
if app_service:
|
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
|
||||||
else:
|
|
||||||
rooms = yield self.store.get_rooms_for_user(
|
|
||||||
user.to_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# For some reason the list of events contains duplicates
|
# For some reason the list of events contains duplicates
|
||||||
# TODO(paul): work out why because I really don't think it should
|
# TODO(paul): work out why because I really don't think it should
|
||||||
|
@ -96,9 +96,18 @@ class SyncHandler(BaseHandler):
|
|||||||
return self.current_sync_for_user(sync_config, since_token)
|
return self.current_sync_for_user(sync_config, since_token)
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
rm_handler = self.hs.get_handlers().room_member_handler
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(
|
|
||||||
sync_config.user
|
app_service = yield self.store.get_app_service_by_user_id(
|
||||||
|
sync_config.user.to_string()
|
||||||
)
|
)
|
||||||
|
if app_service:
|
||||||
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
|
room_ids = set(r.room_id for r in rooms)
|
||||||
|
else:
|
||||||
|
room_ids = yield rm_handler.get_joined_rooms_for_user(
|
||||||
|
sync_config.user
|
||||||
|
)
|
||||||
|
|
||||||
result = yield self.notifier.wait_for_events(
|
result = yield self.notifier.wait_for_events(
|
||||||
sync_config.user, room_ids,
|
sync_config.user, room_ids,
|
||||||
sync_config.filter, timeout, current_sync_callback
|
sync_config.filter, timeout, current_sync_callback
|
||||||
@ -229,7 +238,16 @@ class SyncHandler(BaseHandler):
|
|||||||
logger.debug("Typing %r", typing_by_room)
|
logger.debug("Typing %r", typing_by_room)
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
rm_handler = self.hs.get_handlers().room_member_handler
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user)
|
app_service = yield self.store.get_app_service_by_user_id(
|
||||||
|
sync_config.user.to_string()
|
||||||
|
)
|
||||||
|
if app_service:
|
||||||
|
rooms = yield self.store.get_app_service_rooms(app_service)
|
||||||
|
room_ids = set(r.room_id for r in rooms)
|
||||||
|
else:
|
||||||
|
room_ids = yield rm_handler.get_joined_rooms_for_user(
|
||||||
|
sync_config.user
|
||||||
|
)
|
||||||
|
|
||||||
# TODO (mjark): Does public mean "published"?
|
# TODO (mjark): Does public mean "published"?
|
||||||
published_rooms = yield self.store.get_rooms(is_public=True)
|
published_rooms = yield self.store.get_rooms(is_public=True)
|
||||||
@ -292,6 +310,52 @@ class SyncHandler(BaseHandler):
|
|||||||
next_batch=now_token,
|
next_batch=now_token,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _filter_events_for_client(self, user_id, room_id, events):
|
||||||
|
event_id_to_state = yield self.store.get_state_for_events(
|
||||||
|
room_id, frozenset(e.event_id for e in events),
|
||||||
|
types=(
|
||||||
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
|
(EventTypes.Member, user_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def allowed(event, state):
|
||||||
|
if event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
return True
|
||||||
|
|
||||||
|
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||||
|
if membership_ev:
|
||||||
|
membership = membership_ev.membership
|
||||||
|
else:
|
||||||
|
membership = Membership.LEAVE
|
||||||
|
|
||||||
|
if membership == Membership.JOIN:
|
||||||
|
return True
|
||||||
|
|
||||||
|
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||||
|
if history:
|
||||||
|
visibility = history.content.get("history_visibility", "shared")
|
||||||
|
else:
|
||||||
|
visibility = "shared"
|
||||||
|
|
||||||
|
if visibility == "public":
|
||||||
|
return True
|
||||||
|
elif visibility == "shared":
|
||||||
|
return True
|
||||||
|
elif visibility == "joined":
|
||||||
|
return membership == Membership.JOIN
|
||||||
|
elif visibility == "invited":
|
||||||
|
return membership == Membership.INVITE
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
defer.returnValue([
|
||||||
|
event
|
||||||
|
for event in events
|
||||||
|
if allowed(event, event_id_to_state[event.event_id])
|
||||||
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def load_filtered_recents(self, room_id, sync_config, now_token,
|
def load_filtered_recents(self, room_id, sync_config, now_token,
|
||||||
since_token=None):
|
since_token=None):
|
||||||
@ -313,6 +377,9 @@ class SyncHandler(BaseHandler):
|
|||||||
(room_key, _) = keys
|
(room_key, _) = keys
|
||||||
end_key = "s" + room_key.split('-')[-1]
|
end_key = "s" + room_key.split('-')[-1]
|
||||||
loaded_recents = sync_config.filter.filter_room_events(events)
|
loaded_recents = sync_config.filter.filter_room_events(events)
|
||||||
|
loaded_recents = yield self._filter_events_for_client(
|
||||||
|
sync_config.user.to_string(), room_id, loaded_recents,
|
||||||
|
)
|
||||||
loaded_recents.extend(recents)
|
loaded_recents.extend(recents)
|
||||||
recents = loaded_recents
|
recents = loaded_recents
|
||||||
if len(events) <= load_limit:
|
if len(events) <= load_limit:
|
||||||
|
@ -204,21 +204,17 @@ class TypingNotificationHandler(BaseHandler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _push_update_local(self, room_id, user, typing):
|
def _push_update_local(self, room_id, user, typing):
|
||||||
if room_id not in self._room_serials:
|
room_set = self._room_typing.setdefault(room_id, set())
|
||||||
self._room_serials[room_id] = 0
|
|
||||||
self._room_typing[room_id] = set()
|
|
||||||
|
|
||||||
room_set = self._room_typing[room_id]
|
|
||||||
if typing:
|
if typing:
|
||||||
room_set.add(user)
|
room_set.add(user)
|
||||||
elif user in room_set:
|
else:
|
||||||
room_set.remove(user)
|
room_set.discard(user)
|
||||||
|
|
||||||
self._latest_room_serial += 1
|
self._latest_room_serial += 1
|
||||||
self._room_serials[room_id] = self._latest_room_serial
|
self._room_serials[room_id] = self._latest_room_serial
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
self.notifier.on_new_user_event(
|
self.notifier.on_new_event(
|
||||||
"typing_key", self._latest_room_serial, rooms=[room_id]
|
"typing_key", self._latest_room_serial, rooms=[room_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -260,8 +256,8 @@ class TypingNotificationEventSource(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
for room_id in handler._room_serials:
|
for room_id in joined_room_ids:
|
||||||
if room_id not in joined_room_ids:
|
if room_id not in handler._room_serials:
|
||||||
continue
|
continue
|
||||||
if handler._room_serials[room_id] <= from_key:
|
if handler._room_serials[room_id] <= from_key:
|
||||||
continue
|
continue
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
from twisted.internet import defer, reactor, protocol
|
from twisted.internet import defer, reactor, protocol
|
||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.web.client import readBody, _AgentBase, _URI, HTTPConnectionPool
|
from twisted.web.client import readBody, HTTPConnectionPool, Agent
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web._newclient import ResponseDone
|
from twisted.web._newclient import ResponseDone
|
||||||
|
|
||||||
@ -55,41 +55,17 @@ incoming_responses_counter = metrics.register_counter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationHttpAgent(_AgentBase):
|
class MatrixFederationEndpointFactory(object):
|
||||||
|
def __init__(self, hs):
|
||||||
|
self.tls_context_factory = hs.tls_context_factory
|
||||||
|
|
||||||
def __init__(self, reactor, pool=None):
|
def endpointForURI(self, uri):
|
||||||
_AgentBase.__init__(self, reactor, pool)
|
destination = uri.netloc
|
||||||
|
|
||||||
def request(self, destination, endpoint, method, path, params, query,
|
return matrix_federation_endpoint(
|
||||||
headers, body_producer):
|
reactor, destination, timeout=10,
|
||||||
|
ssl_context_factory=self.tls_context_factory
|
||||||
outgoing_requests_counter.inc(method)
|
)
|
||||||
|
|
||||||
host = b""
|
|
||||||
port = 0
|
|
||||||
fragment = b""
|
|
||||||
|
|
||||||
parsed_URI = _URI(b"http", destination, host, port, path, params,
|
|
||||||
query, fragment)
|
|
||||||
|
|
||||||
# Set the connection pool key to be the destination.
|
|
||||||
key = destination
|
|
||||||
|
|
||||||
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
|
|
||||||
headers, body_producer,
|
|
||||||
parsed_URI.originForm)
|
|
||||||
|
|
||||||
def _cb(response):
|
|
||||||
incoming_responses_counter.inc(method, response.code)
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _eb(failure):
|
|
||||||
incoming_responses_counter.inc(method, "ERR")
|
|
||||||
return failure
|
|
||||||
|
|
||||||
d.addCallbacks(_cb, _eb)
|
|
||||||
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
class MatrixFederationHttpClient(object):
|
class MatrixFederationHttpClient(object):
|
||||||
@ -107,12 +83,18 @@ class MatrixFederationHttpClient(object):
|
|||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
pool = HTTPConnectionPool(reactor)
|
pool = HTTPConnectionPool(reactor)
|
||||||
pool.maxPersistentPerHost = 10
|
pool.maxPersistentPerHost = 10
|
||||||
self.agent = MatrixFederationHttpAgent(reactor, pool=pool)
|
self.agent = Agent.usingEndpointFactory(
|
||||||
|
reactor, MatrixFederationEndpointFactory(hs), pool=pool
|
||||||
|
)
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
|
|
||||||
self._next_id = 1
|
self._next_id = 1
|
||||||
|
|
||||||
|
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
|
||||||
|
return urlparse.urlunparse(
|
||||||
|
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _create_request(self, destination, method, path_bytes,
|
def _create_request(self, destination, method, path_bytes,
|
||||||
body_callback, headers_dict={}, param_bytes=b"",
|
body_callback, headers_dict={}, param_bytes=b"",
|
||||||
@ -123,8 +105,8 @@ class MatrixFederationHttpClient(object):
|
|||||||
headers_dict[b"User-Agent"] = [self.version_string]
|
headers_dict[b"User-Agent"] = [self.version_string]
|
||||||
headers_dict[b"Host"] = [destination]
|
headers_dict[b"Host"] = [destination]
|
||||||
|
|
||||||
url_bytes = urlparse.urlunparse(
|
url_bytes = self._create_url(
|
||||||
("", "", path_bytes, param_bytes, query_bytes, "",)
|
destination, path_bytes, param_bytes, query_bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
txn_id = "%s-O-%s" % (method, self._next_id)
|
txn_id = "%s-O-%s" % (method, self._next_id)
|
||||||
@ -139,8 +121,8 @@ class MatrixFederationHttpClient(object):
|
|||||||
# (once we have reliable transactions in place)
|
# (once we have reliable transactions in place)
|
||||||
retries_left = 5
|
retries_left = 5
|
||||||
|
|
||||||
endpoint = preserve_context_over_fn(
|
http_url_bytes = urlparse.urlunparse(
|
||||||
self._getEndpoint, reactor, destination
|
("", "", path_bytes, param_bytes, query_bytes, "")
|
||||||
)
|
)
|
||||||
|
|
||||||
log_result = None
|
log_result = None
|
||||||
@ -148,17 +130,14 @@ class MatrixFederationHttpClient(object):
|
|||||||
while True:
|
while True:
|
||||||
producer = None
|
producer = None
|
||||||
if body_callback:
|
if body_callback:
|
||||||
producer = body_callback(method, url_bytes, headers_dict)
|
producer = body_callback(method, http_url_bytes, headers_dict)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
def send_request():
|
def send_request():
|
||||||
request_deferred = self.agent.request(
|
request_deferred = preserve_context_over_fn(
|
||||||
destination,
|
self.agent.request,
|
||||||
endpoint,
|
|
||||||
method,
|
method,
|
||||||
path_bytes,
|
url_bytes,
|
||||||
param_bytes,
|
|
||||||
query_bytes,
|
|
||||||
Headers(headers_dict),
|
Headers(headers_dict),
|
||||||
producer
|
producer
|
||||||
)
|
)
|
||||||
@ -452,12 +431,6 @@ class MatrixFederationHttpClient(object):
|
|||||||
|
|
||||||
defer.returnValue((length, headers))
|
defer.returnValue((length, headers))
|
||||||
|
|
||||||
def _getEndpoint(self, reactor, destination):
|
|
||||||
return matrix_federation_endpoint(
|
|
||||||
reactor, destination, timeout=10,
|
|
||||||
ssl_context_factory=self.hs.tls_context_factory
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _ReadBodyToFileProtocol(protocol.Protocol):
|
class _ReadBodyToFileProtocol(protocol.Protocol):
|
||||||
def __init__(self, stream, deferred, max_size):
|
def __init__(self, stream, deferred, max_size):
|
||||||
|
@ -207,7 +207,7 @@ class JsonResource(HttpServer, resource.Resource):
|
|||||||
incoming_requests_counter.inc(request.method, servlet_classname)
|
incoming_requests_counter.inc(request.method, servlet_classname)
|
||||||
|
|
||||||
args = [
|
args = [
|
||||||
urllib.unquote(u).decode("UTF-8") for u in m.groups()
|
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
|
||||||
]
|
]
|
||||||
|
|
||||||
callback_return = yield callback(request, *args)
|
callback_return = yield callback(request, *args)
|
||||||
|
@ -18,8 +18,12 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from resource import getrusage, getpagesize, RUSAGE_SELF
|
from resource import getrusage, getpagesize, RUSAGE_SELF
|
||||||
|
import functools
|
||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
|
import time
|
||||||
|
|
||||||
|
from twisted.internet import reactor
|
||||||
|
|
||||||
from .metric import (
|
from .metric import (
|
||||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
||||||
@ -144,3 +148,50 @@ def _process_fds():
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
|
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
|
||||||
|
|
||||||
|
reactor_metrics = get_metrics_for("reactor")
|
||||||
|
tick_time = reactor_metrics.register_distribution("tick_time")
|
||||||
|
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
|
||||||
|
|
||||||
|
|
||||||
|
def runUntilCurrentTimer(func):
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
def f(*args, **kwargs):
|
||||||
|
now = reactor.seconds()
|
||||||
|
num_pending = 0
|
||||||
|
|
||||||
|
# _newTimedCalls is one long list of *all* pending calls. Below loop
|
||||||
|
# is based off of impl of reactor.runUntilCurrent
|
||||||
|
for delayed_call in reactor._newTimedCalls:
|
||||||
|
if delayed_call.time > now:
|
||||||
|
break
|
||||||
|
|
||||||
|
if delayed_call.delayed_time > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_pending += 1
|
||||||
|
|
||||||
|
num_pending += len(reactor.threadCallQueue)
|
||||||
|
|
||||||
|
start = time.time() * 1000
|
||||||
|
ret = func(*args, **kwargs)
|
||||||
|
end = time.time() * 1000
|
||||||
|
tick_time.inc_by(end - start)
|
||||||
|
pending_calls_metric.inc_by(num_pending)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Ensure the reactor has all the attributes we expect
|
||||||
|
reactor.runUntilCurrent
|
||||||
|
reactor._newTimedCalls
|
||||||
|
reactor.threadCallQueue
|
||||||
|
|
||||||
|
# runUntilCurrent is called when we have pending calls. It is called once
|
||||||
|
# per iteratation after fd polling.
|
||||||
|
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
@ -221,16 +221,7 @@ class Notifier(object):
|
|||||||
event
|
event
|
||||||
)
|
)
|
||||||
|
|
||||||
room_id = event.room_id
|
app_streams = set()
|
||||||
|
|
||||||
room_user_streams = self.room_to_user_streams.get(room_id, set())
|
|
||||||
|
|
||||||
user_streams = room_user_streams.copy()
|
|
||||||
|
|
||||||
for user in extra_users:
|
|
||||||
user_stream = self.user_to_user_stream.get(str(user))
|
|
||||||
if user_stream is not None:
|
|
||||||
user_streams.add(user_stream)
|
|
||||||
|
|
||||||
for appservice in self.appservice_to_user_streams:
|
for appservice in self.appservice_to_user_streams:
|
||||||
# TODO (kegan): Redundant appservice listener checks?
|
# TODO (kegan): Redundant appservice listener checks?
|
||||||
@ -242,24 +233,20 @@ class Notifier(object):
|
|||||||
app_user_streams = self.appservice_to_user_streams.get(
|
app_user_streams = self.appservice_to_user_streams.get(
|
||||||
appservice, set()
|
appservice, set()
|
||||||
)
|
)
|
||||||
user_streams |= app_user_streams
|
app_streams |= app_user_streams
|
||||||
|
|
||||||
logger.debug("on_new_room_event listeners %s", user_streams)
|
self.on_new_event(
|
||||||
|
"room_key", room_stream_id,
|
||||||
time_now_ms = self.clock.time_msec()
|
users=extra_users,
|
||||||
for user_stream in user_streams:
|
rooms=[event.room_id],
|
||||||
try:
|
extra_streams=app_streams,
|
||||||
user_stream.notify(
|
)
|
||||||
"room_key", "s%d" % (room_stream_id,), time_now_ms
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
logger.exception("Failed to notify listener")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]):
|
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
|
||||||
""" Used to inform listeners that something has happend
|
extra_streams=set()):
|
||||||
presence/user event wise.
|
""" Used to inform listeners that something has happend event wise.
|
||||||
|
|
||||||
Will wake up all listeners for the given users and rooms.
|
Will wake up all listeners for the given users and rooms.
|
||||||
"""
|
"""
|
||||||
@ -283,7 +270,7 @@ class Notifier(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_events(self, user, rooms, timeout, callback,
|
def wait_for_events(self, user, rooms, timeout, callback,
|
||||||
from_token=StreamToken("s0", "0", "0")):
|
from_token=StreamToken("s0", "0", "0", "0")):
|
||||||
"""Wait until the callback returns a non empty response or the
|
"""Wait until the callback returns a non empty response or the
|
||||||
timeout fires.
|
timeout fires.
|
||||||
"""
|
"""
|
||||||
@ -341,10 +328,13 @@ class Notifier(object):
|
|||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_events_for(self, user, rooms, pagination_config, timeout):
|
def get_events_for(self, user, rooms, pagination_config, timeout,
|
||||||
|
only_room_events=False):
|
||||||
""" For the given user and rooms, return any new events for them. If
|
""" For the given user and rooms, return any new events for them. If
|
||||||
there are no new events wait for up to `timeout` milliseconds for any
|
there are no new events wait for up to `timeout` milliseconds for any
|
||||||
new events to happen before returning.
|
new events to happen before returning.
|
||||||
|
|
||||||
|
If `only_room_events` is `True` only room events will be returned.
|
||||||
"""
|
"""
|
||||||
from_token = pagination_config.from_token
|
from_token = pagination_config.from_token
|
||||||
if not from_token:
|
if not from_token:
|
||||||
@ -365,10 +355,12 @@ class Notifier(object):
|
|||||||
after_id = getattr(after_token, keyname)
|
after_id = getattr(after_token, keyname)
|
||||||
if before_id == after_id:
|
if before_id == after_id:
|
||||||
continue
|
continue
|
||||||
stuff, new_key = yield source.get_new_events_for_user(
|
if only_room_events and name != "room":
|
||||||
|
continue
|
||||||
|
new_events, new_key = yield source.get_new_events_for_user(
|
||||||
user, getattr(from_token, keyname), limit,
|
user, getattr(from_token, keyname), limit,
|
||||||
)
|
)
|
||||||
events.extend(stuff)
|
events.extend(new_events)
|
||||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||||
|
|
||||||
if events:
|
if events:
|
||||||
|
@ -249,7 +249,9 @@ class Pusher(object):
|
|||||||
# we fail to dispatch the push)
|
# we fail to dispatch the push)
|
||||||
config = PaginationConfig(from_token=None, limit='1')
|
config = PaginationConfig(from_token=None, limit='1')
|
||||||
chunk = yield self.evStreamHandler.get_stream(
|
chunk = yield self.evStreamHandler.get_stream(
|
||||||
self.user_name, config, timeout=0)
|
self.user_name, config, timeout=0, affect_presence=False,
|
||||||
|
only_room_events=True
|
||||||
|
)
|
||||||
self.last_token = chunk['end']
|
self.last_token = chunk['end']
|
||||||
self.store.update_pusher_last_token(
|
self.store.update_pusher_last_token(
|
||||||
self.app_id, self.pushkey, self.user_name, self.last_token
|
self.app_id, self.pushkey, self.user_name, self.last_token
|
||||||
@ -280,8 +282,8 @@ class Pusher(object):
|
|||||||
config = PaginationConfig(from_token=from_tok, limit='1')
|
config = PaginationConfig(from_token=from_tok, limit='1')
|
||||||
timeout = (300 + random.randint(-60, 60)) * 1000
|
timeout = (300 + random.randint(-60, 60)) * 1000
|
||||||
chunk = yield self.evStreamHandler.get_stream(
|
chunk = yield self.evStreamHandler.get_stream(
|
||||||
self.user_name, config,
|
self.user_name, config, timeout=timeout, affect_presence=False,
|
||||||
timeout=timeout, affect_presence=False
|
only_room_events=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# limiting to 1 may get 1 event plus 1 presence event, so
|
# limiting to 1 may get 1 event plus 1 presence event, so
|
||||||
@ -294,6 +296,12 @@ class Pusher(object):
|
|||||||
if not single_event:
|
if not single_event:
|
||||||
self.last_token = chunk['end']
|
self.last_token = chunk['end']
|
||||||
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
|
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
|
||||||
|
yield self.store.update_pusher_last_token(
|
||||||
|
self.app_id,
|
||||||
|
self.pushkey,
|
||||||
|
self.user_name,
|
||||||
|
self.last_token
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.alive:
|
if not self.alive:
|
||||||
@ -345,7 +353,7 @@ class Pusher(object):
|
|||||||
if processed:
|
if processed:
|
||||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||||
self.last_token = chunk['end']
|
self.last_token = chunk['end']
|
||||||
self.store.update_pusher_last_token_and_success(
|
yield self.store.update_pusher_last_token_and_success(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_name,
|
self.user_name,
|
||||||
@ -354,7 +362,7 @@ class Pusher(object):
|
|||||||
)
|
)
|
||||||
if self.failing_since:
|
if self.failing_since:
|
||||||
self.failing_since = None
|
self.failing_since = None
|
||||||
self.store.update_pusher_failing_since(
|
yield self.store.update_pusher_failing_since(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_name,
|
self.user_name,
|
||||||
@ -362,7 +370,7 @@ class Pusher(object):
|
|||||||
else:
|
else:
|
||||||
if not self.failing_since:
|
if not self.failing_since:
|
||||||
self.failing_since = self.clock.time_msec()
|
self.failing_since = self.clock.time_msec()
|
||||||
self.store.update_pusher_failing_since(
|
yield self.store.update_pusher_failing_since(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_name,
|
self.user_name,
|
||||||
@ -380,7 +388,7 @@ class Pusher(object):
|
|||||||
self.user_name, self.pushkey)
|
self.user_name, self.pushkey)
|
||||||
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
self.backoff_delay = Pusher.INITIAL_BACKOFF
|
||||||
self.last_token = chunk['end']
|
self.last_token = chunk['end']
|
||||||
self.store.update_pusher_last_token(
|
yield self.store.update_pusher_last_token(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_name,
|
self.user_name,
|
||||||
@ -388,7 +396,7 @@ class Pusher(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.failing_since = None
|
self.failing_since = None
|
||||||
self.store.update_pusher_failing_since(
|
yield self.store.update_pusher_failing_since(
|
||||||
self.app_id,
|
self.app_id,
|
||||||
self.pushkey,
|
self.pushkey,
|
||||||
self.user_name,
|
self.user_name,
|
||||||
|
@ -164,7 +164,7 @@ def make_base_append_underride_rules(user):
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'rule_id': 'global/override/.m.rule.contains_display_name',
|
'rule_id': 'global/underride/.m.rule.contains_display_name',
|
||||||
'conditions': [
|
'conditions': [
|
||||||
{
|
{
|
||||||
'kind': 'contains_display_name'
|
'kind': 'contains_display_name'
|
||||||
|
@ -94,17 +94,14 @@ class PusherPool:
|
|||||||
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
|
def remove_pushers_by_user(self, user_id):
|
||||||
all = yield self.store.get_all_pushers()
|
all = yield self.store.get_all_pushers()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing all pushers for user %s except access token %s",
|
"Removing all pushers for user %s",
|
||||||
user_id, not_access_token_id
|
user_id,
|
||||||
)
|
)
|
||||||
for p in all:
|
for p in all:
|
||||||
if (
|
if p['user_name'] == user_id:
|
||||||
p['user_name'] == user_id and
|
|
||||||
p['access_token'] != not_access_token_id
|
|
||||||
):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
p['app_id'], p['pushkey'], p['user_name']
|
p['app_id'], p['pushkey'], p['user_name']
|
||||||
|
@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
REQUIREMENTS = {
|
REQUIREMENTS = {
|
||||||
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
"syutil>=0.0.7": ["syutil>=0.0.7"],
|
||||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
"Twisted>=15.1.0": ["twisted>=15.1.0"],
|
||||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||||
"pyyaml": ["yaml"],
|
"pyyaml": ["yaml"],
|
||||||
@ -31,6 +31,8 @@ REQUIREMENTS = {
|
|||||||
"pillow": ["PIL"],
|
"pillow": ["PIL"],
|
||||||
"pydenticon": ["pydenticon"],
|
"pydenticon": ["pydenticon"],
|
||||||
"ujson": ["ujson"],
|
"ujson": ["ujson"],
|
||||||
|
"blist": ["blist"],
|
||||||
|
"pysaml2": ["saml2"],
|
||||||
}
|
}
|
||||||
CONDITIONAL_REQUIREMENTS = {
|
CONDITIONAL_REQUIREMENTS = {
|
||||||
"web_client": {
|
"web_client": {
|
||||||
@ -41,8 +43,8 @@ CONDITIONAL_REQUIREMENTS = {
|
|||||||
|
|
||||||
def requirements(config=None, include_conditional=False):
|
def requirements(config=None, include_conditional=False):
|
||||||
reqs = REQUIREMENTS.copy()
|
reqs = REQUIREMENTS.copy()
|
||||||
for key, req in CONDITIONAL_REQUIREMENTS.items():
|
if include_conditional:
|
||||||
if (config and getattr(config, key)) or include_conditional:
|
for _, req in CONDITIONAL_REQUIREMENTS.items():
|
||||||
reqs.update(req)
|
reqs.update(req)
|
||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
@ -50,18 +52,18 @@ def requirements(config=None, include_conditional=False):
|
|||||||
def github_link(project, version, egg):
|
def github_link(project, version, egg):
|
||||||
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
|
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
|
||||||
|
|
||||||
DEPENDENCY_LINKS = [
|
DEPENDENCY_LINKS = {
|
||||||
github_link(
|
"syutil": github_link(
|
||||||
project="matrix-org/syutil",
|
project="matrix-org/syutil",
|
||||||
version="v0.0.7",
|
version="v0.0.7",
|
||||||
egg="syutil-0.0.7",
|
egg="syutil-0.0.7",
|
||||||
),
|
),
|
||||||
github_link(
|
"matrix-angular-sdk": github_link(
|
||||||
project="matrix-org/matrix-angular-sdk",
|
project="matrix-org/matrix-angular-sdk",
|
||||||
version="v0.6.6",
|
version="v0.6.6",
|
||||||
egg="matrix_angular_sdk-0.6.6",
|
egg="matrix_angular_sdk-0.6.6",
|
||||||
),
|
),
|
||||||
]
|
}
|
||||||
|
|
||||||
|
|
||||||
class MissingRequirementError(Exception):
|
class MissingRequirementError(Exception):
|
||||||
@ -129,7 +131,7 @@ def check_requirements(config=None):
|
|||||||
def list_requirements():
|
def list_requirements():
|
||||||
result = []
|
result = []
|
||||||
linked = []
|
linked = []
|
||||||
for link in DEPENDENCY_LINKS:
|
for link in DEPENDENCY_LINKS.values():
|
||||||
egg = link.split("#egg=")[1]
|
egg = link.split("#egg=")[1]
|
||||||
linked.append(egg.split('-')[0])
|
linked.append(egg.split('-')[0])
|
||||||
result.append(link)
|
result.append(link)
|
||||||
|
@ -20,14 +20,32 @@ from synapse.types import UserID
|
|||||||
from base import ClientV1RestServlet, client_path_pattern
|
from base import ClientV1RestServlet, client_path_pattern
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
|
import urllib
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from saml2 import BINDING_HTTP_POST
|
||||||
|
from saml2 import config
|
||||||
|
from saml2.client import Saml2Client
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LoginRestServlet(ClientV1RestServlet):
|
class LoginRestServlet(ClientV1RestServlet):
|
||||||
PATTERN = client_path_pattern("/login$")
|
PATTERN = client_path_pattern("/login$")
|
||||||
PASS_TYPE = "m.login.password"
|
PASS_TYPE = "m.login.password"
|
||||||
|
SAML2_TYPE = "m.login.saml2"
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(LoginRestServlet, self).__init__(hs)
|
||||||
|
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
|
||||||
|
self.saml2_enabled = hs.config.saml2_enabled
|
||||||
|
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
|
flows = [{"type": LoginRestServlet.PASS_TYPE}]
|
||||||
|
if self.saml2_enabled:
|
||||||
|
flows.append({"type": LoginRestServlet.SAML2_TYPE})
|
||||||
|
return (200, {"flows": flows})
|
||||||
|
|
||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return (200, {})
|
return (200, {})
|
||||||
@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
|
||||||
result = yield self.do_password_login(login_submission)
|
result = yield self.do_password_login(login_submission)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
elif self.saml2_enabled and (login_submission["type"] ==
|
||||||
|
LoginRestServlet.SAML2_TYPE):
|
||||||
|
relay_state = ""
|
||||||
|
if "relay_state" in login_submission:
|
||||||
|
relay_state = "&RelayState="+urllib.quote(
|
||||||
|
login_submission["relay_state"])
|
||||||
|
result = {
|
||||||
|
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
|
||||||
|
}
|
||||||
|
defer.returnValue((200, result))
|
||||||
else:
|
else:
|
||||||
raise SynapseError(400, "Bad login type.")
|
raise SynapseError(400, "Bad login type.")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -46,17 +74,24 @@ class LoginRestServlet(ClientV1RestServlet):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_password_login(self, login_submission):
|
def do_password_login(self, login_submission):
|
||||||
if not login_submission["user"].startswith('@'):
|
if 'medium' in login_submission and 'address' in login_submission:
|
||||||
login_submission["user"] = UserID.create(
|
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
login_submission["user"], self.hs.hostname).to_string()
|
login_submission['medium'], login_submission['address']
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
user_id = login_submission['user']
|
||||||
|
|
||||||
handler = self.handlers.login_handler
|
if not user_id.startswith('@'):
|
||||||
token = yield handler.login(
|
user_id = UserID.create(
|
||||||
user=login_submission["user"],
|
user_id, self.hs.hostname
|
||||||
|
).to_string()
|
||||||
|
|
||||||
|
user_id, token = yield self.handlers.auth_handler.login_with_password(
|
||||||
|
user_id=user_id,
|
||||||
password=login_submission["password"])
|
password=login_submission["password"])
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"user_id": login_submission["user"], # may have changed
|
"user_id": user_id, # may have changed
|
||||||
"access_token": token,
|
"access_token": token,
|
||||||
"home_server": self.hs.hostname,
|
"home_server": self.hs.hostname,
|
||||||
}
|
}
|
||||||
@ -94,6 +129,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SAML2RestServlet(ClientV1RestServlet):
|
||||||
|
PATTERN = client_path_pattern("/login/saml2")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(SAML2RestServlet, self).__init__(hs)
|
||||||
|
self.sp_config = hs.config.saml2_config_path
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
saml2_auth = None
|
||||||
|
try:
|
||||||
|
conf = config.SPConfig()
|
||||||
|
conf.load_file(self.sp_config)
|
||||||
|
SP = Saml2Client(conf)
|
||||||
|
saml2_auth = SP.parse_authn_request_response(
|
||||||
|
request.args['SAMLResponse'][0], BINDING_HTTP_POST)
|
||||||
|
except Exception, e: # Not authenticated
|
||||||
|
logger.exception(e)
|
||||||
|
if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
|
||||||
|
username = saml2_auth.name_id.text
|
||||||
|
handler = self.handlers.registration_handler
|
||||||
|
(user_id, token) = yield handler.register_saml2(username)
|
||||||
|
# Forward to the RelayState callback along with ava
|
||||||
|
if 'RelayState' in request.args:
|
||||||
|
request.redirect(urllib.unquote(
|
||||||
|
request.args['RelayState'][0]) +
|
||||||
|
'?status=authenticated&access_token=' +
|
||||||
|
token + '&user_id=' + user_id + '&ava=' +
|
||||||
|
urllib.quote(json.dumps(saml2_auth.ava)))
|
||||||
|
request.finish()
|
||||||
|
defer.returnValue(None)
|
||||||
|
defer.returnValue((200, {"status": "authenticated",
|
||||||
|
"user_id": user_id, "token": token,
|
||||||
|
"ava": saml2_auth.ava}))
|
||||||
|
elif 'RelayState' in request.args:
|
||||||
|
request.redirect(urllib.unquote(
|
||||||
|
request.args['RelayState'][0]) +
|
||||||
|
'?status=not_authenticated')
|
||||||
|
request.finish()
|
||||||
|
defer.returnValue(None)
|
||||||
|
defer.returnValue((200, {"status": "not_authenticated"}))
|
||||||
|
|
||||||
|
|
||||||
def _parse_json(request):
|
def _parse_json(request):
|
||||||
try:
|
try:
|
||||||
content = json.loads(request.content.read())
|
content = json.loads(request.content.read())
|
||||||
@ -106,4 +184,6 @@ def _parse_json(request):
|
|||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
LoginRestServlet(hs).register(http_server)
|
LoginRestServlet(hs).register(http_server)
|
||||||
|
if hs.config.saml2_enabled:
|
||||||
|
SAML2RestServlet(hs).register(http_server)
|
||||||
# TODO PasswordResetRestServlet(hs).register(http_server)
|
# TODO PasswordResetRestServlet(hs).register(http_server)
|
||||||
|
@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||||||
if "user_id" not in content:
|
if "user_id" not in content:
|
||||||
raise SynapseError(400, "Missing user_id key.")
|
raise SynapseError(400, "Missing user_id key.")
|
||||||
state_key = content["user_id"]
|
state_key = content["user_id"]
|
||||||
|
# make sure it looks like a user ID; it'll throw if it's invalid.
|
||||||
|
UserID.from_string(state_key)
|
||||||
|
|
||||||
if membership_action == "kick":
|
if membership_action == "kick":
|
||||||
membership_action = "leave"
|
membership_action = "leave"
|
||||||
|
@ -18,7 +18,9 @@ from . import (
|
|||||||
filter,
|
filter,
|
||||||
account,
|
account,
|
||||||
register,
|
register,
|
||||||
auth
|
auth,
|
||||||
|
receipts,
|
||||||
|
keys,
|
||||||
)
|
)
|
||||||
|
|
||||||
from synapse.http.server import JsonResource
|
from synapse.http.server import JsonResource
|
||||||
@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
|
|||||||
account.register_servlets(hs, client_resource)
|
account.register_servlets(hs, client_resource)
|
||||||
register.register_servlets(hs, client_resource)
|
register.register_servlets(hs, client_resource)
|
||||||
auth.register_servlets(hs, client_resource)
|
auth.register_servlets(hs, client_resource)
|
||||||
|
receipts.register_servlets(hs, client_resource)
|
||||||
|
keys.register_servlets(hs, client_resource)
|
||||||
|
@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
|
|||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_handlers().auth_handler
|
||||||
self.login_handler = hs.get_handlers().login_handler
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
|
|||||||
authed, result, params = yield self.auth_handler.check_auth([
|
authed, result, params = yield self.auth_handler.check_auth([
|
||||||
[LoginType.PASSWORD],
|
[LoginType.PASSWORD],
|
||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
], body)
|
], body, self.hs.get_ip_from_request(request))
|
||||||
|
|
||||||
if not authed:
|
if not authed:
|
||||||
defer.returnValue((401, result))
|
defer.returnValue((401, result))
|
||||||
@ -79,8 +78,8 @@ class PasswordRestServlet(RestServlet):
|
|||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||||
new_password = params['new_password']
|
new_password = params['new_password']
|
||||||
|
|
||||||
yield self.login_handler.set_password(
|
yield self.auth_handler.set_password(
|
||||||
user_id, new_password, None
|
user_id, new_password
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
|
|||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(ThreepidRestServlet, self).__init__()
|
super(ThreepidRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.login_handler = hs.get_handlers().login_handler
|
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
|
|||||||
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
|
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
|
||||||
raise SynapseError(500, "Invalid response from ID Server")
|
raise SynapseError(500, "Invalid response from ID Server")
|
||||||
|
|
||||||
yield self.login_handler.add_threepid(
|
yield self.auth_handler.add_threepid(
|
||||||
auth_user.to_string(),
|
auth_user.to_string(),
|
||||||
threepid['medium'],
|
threepid['medium'],
|
||||||
threepid['address'],
|
threepid['address'],
|
||||||
|
316
synapse/rest/client/v2_alpha/keys.py
Normal file
316
synapse/rest/client/v2_alpha/keys.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import SynapseError
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
from synapse.types import UserID
|
||||||
|
from syutil.jsonutil import encode_canonical_json
|
||||||
|
|
||||||
|
from ._base import client_v2_pattern
|
||||||
|
|
||||||
|
import simplejson as json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class KeyUploadServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
POST /keys/upload/<device_id> HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"user_id": "<user_id>",
|
||||||
|
"device_id": "<device_id>",
|
||||||
|
"valid_until_ts": <millisecond_timestamp>,
|
||||||
|
"algorithms": [
|
||||||
|
"m.olm.curve25519-aes-sha256",
|
||||||
|
]
|
||||||
|
"keys": {
|
||||||
|
"<algorithm>:<device_id>": "<key_base64>",
|
||||||
|
},
|
||||||
|
"signatures:" {
|
||||||
|
"<user_id>" {
|
||||||
|
"<algorithm>:<device_id>": "<signature_base64>"
|
||||||
|
} } },
|
||||||
|
"one_time_keys": {
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(KeyUploadServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, device_id):
|
||||||
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = auth_user.to_string()
|
||||||
|
# TODO: Check that the device_id matches that in the authentication
|
||||||
|
# or derive the device_id from the authentication instead.
|
||||||
|
try:
|
||||||
|
body = json.loads(request.content.read())
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
# TODO: Validate the JSON to make sure it has the right keys.
|
||||||
|
device_keys = body.get("device_keys", None)
|
||||||
|
if device_keys:
|
||||||
|
logger.info(
|
||||||
|
"Updating device_keys for device %r for user %r at %d",
|
||||||
|
device_id, auth_user, time_now
|
||||||
|
)
|
||||||
|
# TODO: Sign the JSON with the server key
|
||||||
|
yield self.store.set_e2e_device_keys(
|
||||||
|
user_id, device_id, time_now,
|
||||||
|
encode_canonical_json(device_keys)
|
||||||
|
)
|
||||||
|
|
||||||
|
one_time_keys = body.get("one_time_keys", None)
|
||||||
|
if one_time_keys:
|
||||||
|
logger.info(
|
||||||
|
"Adding %d one_time_keys for device %r for user %r at %d",
|
||||||
|
len(one_time_keys), device_id, user_id, time_now
|
||||||
|
)
|
||||||
|
key_list = []
|
||||||
|
for key_id, key_json in one_time_keys.items():
|
||||||
|
algorithm, key_id = key_id.split(":")
|
||||||
|
key_list.append((
|
||||||
|
algorithm, key_id, encode_canonical_json(key_json)
|
||||||
|
))
|
||||||
|
|
||||||
|
yield self.store.add_e2e_one_time_keys(
|
||||||
|
user_id, device_id, time_now, key_list
|
||||||
|
)
|
||||||
|
|
||||||
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, device_id):
|
||||||
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
|
user_id = auth_user.to_string()
|
||||||
|
|
||||||
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
defer.returnValue((200, {"one_time_key_counts": result}))
|
||||||
|
|
||||||
|
|
||||||
|
class KeyQueryServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
GET /keys/query/<user_id> HTTP/1.1
|
||||||
|
|
||||||
|
GET /keys/query/<user_id>/<device_id> HTTP/1.1
|
||||||
|
|
||||||
|
POST /keys/query HTTP/1.1
|
||||||
|
Content-Type: application/json
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": ["<device_id>"]
|
||||||
|
} }
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
{
|
||||||
|
"device_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
"user_id": "<user_id>", // Duplicated to be signed
|
||||||
|
"device_id": "<device_id>", // Duplicated to be signed
|
||||||
|
"valid_until_ts": <millisecond_timestamp>,
|
||||||
|
"algorithms": [ // List of supported algorithms
|
||||||
|
"m.olm.curve25519-aes-sha256",
|
||||||
|
],
|
||||||
|
"keys": { // Must include a ed25519 signing key
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>",
|
||||||
|
},
|
||||||
|
"signatures:" {
|
||||||
|
// Must be signed with device's ed25519 key
|
||||||
|
"<user_id>/<device_id>": {
|
||||||
|
"<algorithm>:<key_id>": "<signature_base64>"
|
||||||
|
}
|
||||||
|
// Must be signed by this server.
|
||||||
|
"<server_name>": {
|
||||||
|
"<algorithm>:<key_id>": "<signature_base64>"
|
||||||
|
} } } } } }
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERN = client_v2_pattern(
|
||||||
|
"/keys/query(?:"
|
||||||
|
"/(?P<user_id>[^/]*)(?:"
|
||||||
|
"/(?P<device_id>[^/]*)"
|
||||||
|
")?"
|
||||||
|
")?"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(KeyQueryServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, user_id, device_id):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
try:
|
||||||
|
body = json.loads(request.content.read())
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
|
result = yield self.handle_request(body)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, user_id, device_id):
|
||||||
|
auth_user, client_info = yield self.auth.get_user_by_req(request)
|
||||||
|
auth_user_id = auth_user.to_string()
|
||||||
|
user_id = user_id if user_id else auth_user_id
|
||||||
|
device_ids = [device_id] if device_id else []
|
||||||
|
result = yield self.handle_request(
|
||||||
|
{"device_keys": {user_id: device_ids}}
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_request(self, body):
|
||||||
|
local_query = []
|
||||||
|
remote_queries = {}
|
||||||
|
for user_id, device_ids in body.get("device_keys", {}).items():
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
if self.is_mine(user):
|
||||||
|
if not device_ids:
|
||||||
|
local_query.append((user_id, None))
|
||||||
|
else:
|
||||||
|
for device_id in device_ids:
|
||||||
|
local_query.append((user_id, device_id))
|
||||||
|
else:
|
||||||
|
remote_queries.setdefault(user.domain, {})[user_id] = list(
|
||||||
|
device_ids
|
||||||
|
)
|
||||||
|
results = yield self.store.get_e2e_device_keys(local_query)
|
||||||
|
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, json_bytes in device_keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = json.loads(
|
||||||
|
json_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
for destination, device_keys in remote_queries.items():
|
||||||
|
remote_result = yield self.federation.query_client_keys(
|
||||||
|
destination, {"device_keys": device_keys}
|
||||||
|
)
|
||||||
|
for user_id, keys in remote_result["device_keys"].items():
|
||||||
|
if user_id in device_keys:
|
||||||
|
json_result[user_id] = keys
|
||||||
|
defer.returnValue((200, {"device_keys": json_result}))
|
||||||
|
|
||||||
|
|
||||||
|
class OneTimeKeyServlet(RestServlet):
|
||||||
|
"""
|
||||||
|
GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
|
||||||
|
|
||||||
|
POST /keys/claim HTTP/1.1
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": "<algorithm>"
|
||||||
|
} } }
|
||||||
|
|
||||||
|
HTTP/1.1 200 OK
|
||||||
|
{
|
||||||
|
"one_time_keys": {
|
||||||
|
"<user_id>": {
|
||||||
|
"<device_id>": {
|
||||||
|
"<algorithm>:<key_id>": "<key_base64>"
|
||||||
|
} } } }
|
||||||
|
|
||||||
|
"""
|
||||||
|
PATTERN = client_v2_pattern(
|
||||||
|
"/keys/claim(?:/?|(?:/"
|
||||||
|
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
|
||||||
|
")?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(OneTimeKeyServlet, self).__init__()
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.clock = hs.get_clock()
|
||||||
|
self.federation = hs.get_replication_layer()
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_GET(self, request, user_id, device_id, algorithm):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
result = yield self.handle_request(
|
||||||
|
{"one_time_keys": {user_id: {device_id: algorithm}}}
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, user_id, device_id, algorithm):
|
||||||
|
yield self.auth.get_user_by_req(request)
|
||||||
|
try:
|
||||||
|
body = json.loads(request.content.read())
|
||||||
|
except:
|
||||||
|
raise SynapseError(400, "Invalid key JSON")
|
||||||
|
result = yield self.handle_request(body)
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_request(self, body):
|
||||||
|
local_query = []
|
||||||
|
remote_queries = {}
|
||||||
|
for user_id, device_keys in body.get("one_time_keys", {}).items():
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
if self.is_mine(user):
|
||||||
|
for device_id, algorithm in device_keys.items():
|
||||||
|
local_query.append((user_id, device_id, algorithm))
|
||||||
|
else:
|
||||||
|
remote_queries.setdefault(user.domain, {})[user_id] = (
|
||||||
|
device_keys
|
||||||
|
)
|
||||||
|
results = yield self.store.claim_e2e_one_time_keys(local_query)
|
||||||
|
|
||||||
|
json_result = {}
|
||||||
|
for user_id, device_keys in results.items():
|
||||||
|
for device_id, keys in device_keys.items():
|
||||||
|
for key_id, json_bytes in keys.items():
|
||||||
|
json_result.setdefault(user_id, {})[device_id] = {
|
||||||
|
key_id: json.loads(json_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
for destination, device_keys in remote_queries.items():
|
||||||
|
remote_result = yield self.federation.claim_client_keys(
|
||||||
|
destination, {"one_time_keys": device_keys}
|
||||||
|
)
|
||||||
|
for user_id, keys in remote_result["one_time_keys"].items():
|
||||||
|
if user_id in device_keys:
|
||||||
|
json_result[user_id] = keys
|
||||||
|
|
||||||
|
defer.returnValue((200, {"one_time_keys": json_result}))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
KeyUploadServlet(hs).register(http_server)
|
||||||
|
KeyQueryServlet(hs).register(http_server)
|
||||||
|
OneTimeKeyServlet(hs).register(http_server)
|
55
synapse/rest/client/v2_alpha/receipts.py
Normal file
55
synapse/rest/client/v2_alpha/receipts.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.http.servlet import RestServlet
|
||||||
|
from ._base import client_v2_pattern
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptRestServlet(RestServlet):
|
||||||
|
PATTERN = client_v2_pattern(
|
||||||
|
"/rooms/(?P<room_id>[^/]*)"
|
||||||
|
"/receipt/(?P<receipt_type>[^/]*)"
|
||||||
|
"/(?P<event_id>[^/]*)$"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReceiptRestServlet, self).__init__()
|
||||||
|
self.hs = hs
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.receipts_handler = hs.get_handlers().receipts_handler
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||||
|
user, client = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
yield self.receipts_handler.received_client_receipt(
|
||||||
|
room_id,
|
||||||
|
receipt_type,
|
||||||
|
user_id=user.to_string(),
|
||||||
|
event_id=event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs, http_server):
|
||||||
|
ReceiptRestServlet(hs).register(http_server)
|
@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
|
|||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
from ._base import client_v2_pattern, parse_request_allow_empty
|
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import hmac
|
import hmac
|
||||||
@ -50,26 +50,64 @@ class RegisterRestServlet(RestServlet):
|
|||||||
self.auth_handler = hs.get_handlers().auth_handler
|
self.auth_handler = hs.get_handlers().auth_handler
|
||||||
self.registration_handler = hs.get_handlers().registration_handler
|
self.registration_handler = hs.get_handlers().registration_handler
|
||||||
self.identity_handler = hs.get_handlers().identity_handler
|
self.identity_handler = hs.get_handlers().identity_handler
|
||||||
self.login_handler = hs.get_handlers().login_handler
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
body = parse_request_allow_empty(request)
|
if '/register/email/requestToken' in request.path:
|
||||||
if 'password' not in body:
|
ret = yield self.onEmailTokenRequest(request)
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
defer.returnValue(ret)
|
||||||
|
|
||||||
|
body = parse_json_dict_from_request(request)
|
||||||
|
|
||||||
|
# we do basic sanity checks here because the auth layer will store these
|
||||||
|
# in sessions. Pull out the username/password provided to us.
|
||||||
|
desired_password = None
|
||||||
|
if 'password' in body:
|
||||||
|
if (not isinstance(body['password'], basestring) or
|
||||||
|
len(body['password']) > 512):
|
||||||
|
raise SynapseError(400, "Invalid password")
|
||||||
|
desired_password = body["password"]
|
||||||
|
|
||||||
|
desired_username = None
|
||||||
if 'username' in body:
|
if 'username' in body:
|
||||||
|
if (not isinstance(body['username'], basestring) or
|
||||||
|
len(body['username']) > 512):
|
||||||
|
raise SynapseError(400, "Invalid username")
|
||||||
desired_username = body['username']
|
desired_username = body['username']
|
||||||
yield self.registration_handler.check_username(desired_username)
|
|
||||||
|
|
||||||
is_using_shared_secret = False
|
appservice = None
|
||||||
is_application_server = False
|
|
||||||
|
|
||||||
service = None
|
|
||||||
if 'access_token' in request.args:
|
if 'access_token' in request.args:
|
||||||
service = yield self.auth.get_appservice_by_req(request)
|
appservice = yield self.auth.get_appservice_by_req(request)
|
||||||
|
|
||||||
|
# fork off as soon as possible for ASes and shared secret auth which
|
||||||
|
# have completely different registration flows to normal users
|
||||||
|
|
||||||
|
# == Application Service Registration ==
|
||||||
|
if appservice:
|
||||||
|
result = yield self._do_appservice_registration(
|
||||||
|
desired_username, request.args["access_token"][0]
|
||||||
|
)
|
||||||
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
|
return
|
||||||
|
|
||||||
|
# == Shared Secret Registration == (e.g. create new user scripts)
|
||||||
|
if 'mac' in body:
|
||||||
|
# FIXME: Should we really be determining if this is shared secret
|
||||||
|
# auth based purely on the 'mac' key?
|
||||||
|
result = yield self._do_shared_secret_registration(
|
||||||
|
desired_username, desired_password, body["mac"]
|
||||||
|
)
|
||||||
|
defer.returnValue((200, result)) # we throw for non 200 responses
|
||||||
|
return
|
||||||
|
|
||||||
|
# == Normal User Registration == (everyone else)
|
||||||
|
if self.hs.config.disable_registration:
|
||||||
|
raise SynapseError(403, "Registration has been disabled")
|
||||||
|
|
||||||
|
if desired_username is not None:
|
||||||
|
yield self.registration_handler.check_username(desired_username)
|
||||||
|
|
||||||
if self.hs.config.enable_registration_captcha:
|
if self.hs.config.enable_registration_captcha:
|
||||||
flows = [
|
flows = [
|
||||||
@ -82,39 +120,20 @@ class RegisterRestServlet(RestServlet):
|
|||||||
[LoginType.EMAIL_IDENTITY]
|
[LoginType.EMAIL_IDENTITY]
|
||||||
]
|
]
|
||||||
|
|
||||||
result = None
|
authed, result, params = yield self.auth_handler.check_auth(
|
||||||
if service:
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
is_application_server = True
|
|
||||||
params = body
|
|
||||||
elif 'mac' in body:
|
|
||||||
# Check registration-specific shared secret auth
|
|
||||||
if 'username' not in body:
|
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
|
||||||
self._check_shared_secret_auth(
|
|
||||||
body['username'], body['mac']
|
|
||||||
)
|
|
||||||
is_using_shared_secret = True
|
|
||||||
params = body
|
|
||||||
else:
|
|
||||||
authed, result, params = yield self.auth_handler.check_auth(
|
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
can_register = (
|
|
||||||
not self.hs.config.disable_registration
|
|
||||||
or is_application_server
|
|
||||||
or is_using_shared_secret
|
|
||||||
)
|
)
|
||||||
if not can_register:
|
|
||||||
raise SynapseError(403, "Registration has been disabled")
|
|
||||||
|
|
||||||
|
if not authed:
|
||||||
|
defer.returnValue((401, result))
|
||||||
|
return
|
||||||
|
|
||||||
|
# NB: This may be from the auth handler and NOT from the POST
|
||||||
if 'password' not in params:
|
if 'password' not in params:
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
|
||||||
desired_username = params['username'] if 'username' in params else None
|
|
||||||
new_password = params['password']
|
desired_username = params.get("username", None)
|
||||||
|
new_password = params.get("password", None)
|
||||||
|
|
||||||
(user_id, token) = yield self.registration_handler.register(
|
(user_id, token) = yield self.registration_handler.register(
|
||||||
localpart=desired_username,
|
localpart=desired_username,
|
||||||
@ -128,7 +147,7 @@ class RegisterRestServlet(RestServlet):
|
|||||||
if reqd not in threepid:
|
if reqd not in threepid:
|
||||||
logger.info("Can't add incomplete 3pid")
|
logger.info("Can't add incomplete 3pid")
|
||||||
else:
|
else:
|
||||||
yield self.login_handler.add_threepid(
|
yield self.auth_handler.add_threepid(
|
||||||
user_id,
|
user_id,
|
||||||
threepid['medium'],
|
threepid['medium'],
|
||||||
threepid['address'],
|
threepid['address'],
|
||||||
@ -147,18 +166,21 @@ class RegisterRestServlet(RestServlet):
|
|||||||
else:
|
else:
|
||||||
logger.info("bind_email not specified: not binding email")
|
logger.info("bind_email not specified: not binding email")
|
||||||
|
|
||||||
result = {
|
result = self._create_registration_details(user_id, token)
|
||||||
"user_id": user_id,
|
|
||||||
"access_token": token,
|
|
||||||
"home_server": self.hs.hostname,
|
|
||||||
}
|
|
||||||
|
|
||||||
defer.returnValue((200, result))
|
defer.returnValue((200, result))
|
||||||
|
|
||||||
def on_OPTIONS(self, _):
|
def on_OPTIONS(self, _):
|
||||||
return 200, {}
|
return 200, {}
|
||||||
|
|
||||||
def _check_shared_secret_auth(self, username, mac):
|
@defer.inlineCallbacks
|
||||||
|
def _do_appservice_registration(self, username, as_token):
|
||||||
|
(user_id, token) = yield self.registration_handler.appservice_register(
|
||||||
|
username, as_token
|
||||||
|
)
|
||||||
|
defer.returnValue(self._create_registration_details(user_id, token))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_shared_secret_registration(self, username, password, mac):
|
||||||
if not self.hs.config.registration_shared_secret:
|
if not self.hs.config.registration_shared_secret:
|
||||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||||
|
|
||||||
@ -174,13 +196,46 @@ class RegisterRestServlet(RestServlet):
|
|||||||
digestmod=sha1,
|
digestmod=sha1,
|
||||||
).hexdigest()
|
).hexdigest()
|
||||||
|
|
||||||
if compare_digest(want_mac, got_mac):
|
if not compare_digest(want_mac, got_mac):
|
||||||
return True
|
|
||||||
else:
|
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403, "HMAC incorrect",
|
403, "HMAC incorrect",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
(user_id, token) = yield self.registration_handler.register(
|
||||||
|
localpart=username, password=password
|
||||||
|
)
|
||||||
|
defer.returnValue(self._create_registration_details(user_id, token))
|
||||||
|
|
||||||
|
def _create_registration_details(self, user_id, token):
|
||||||
|
return {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def onEmailTokenRequest(self, request):
|
||||||
|
body = parse_json_dict_from_request(request)
|
||||||
|
|
||||||
|
required = ['id_server', 'client_secret', 'email', 'send_attempt']
|
||||||
|
absent = []
|
||||||
|
for k in required:
|
||||||
|
if k not in body:
|
||||||
|
absent.append(k)
|
||||||
|
|
||||||
|
if len(absent) > 0:
|
||||||
|
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||||
|
'email', body['email']
|
||||||
|
)
|
||||||
|
|
||||||
|
if existingUid is not None:
|
||||||
|
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||||
|
|
||||||
|
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||||
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
@ -27,18 +27,30 @@ from twisted.web.resource import Resource
|
|||||||
from twisted.protocols.basic import FileSender
|
from twisted.protocols.basic import FileSender
|
||||||
|
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
|
from synapse.util.stringutils import is_ascii
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import cgi
|
||||||
import logging
|
import logging
|
||||||
|
import urllib
|
||||||
|
import urlparse
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def parse_media_id(request):
|
def parse_media_id(request):
|
||||||
try:
|
try:
|
||||||
server_name, media_id = request.postpath
|
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||||
return (server_name, media_id)
|
# clients that parse the URL to see content type.
|
||||||
|
server_name, media_id = request.postpath[:2]
|
||||||
|
file_name = None
|
||||||
|
if len(request.postpath) > 2:
|
||||||
|
try:
|
||||||
|
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
pass
|
||||||
|
return server_name, media_id, file_name
|
||||||
except:
|
except:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
404,
|
404,
|
||||||
@ -62,6 +74,8 @@ class BaseMediaResource(Resource):
|
|||||||
self.filepaths = filepaths
|
self.filepaths = filepaths
|
||||||
self.version_string = hs.version_string
|
self.version_string = hs.version_string
|
||||||
self.downloads = {}
|
self.downloads = {}
|
||||||
|
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||||
|
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||||
|
|
||||||
def _respond_404(self, request):
|
def _respond_404(self, request):
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
@ -128,12 +142,38 @@ class BaseMediaResource(Resource):
|
|||||||
media_type = headers["Content-Type"][0]
|
media_type = headers["Content-Type"][0]
|
||||||
time_now_ms = self.clock.time_msec()
|
time_now_ms = self.clock.time_msec()
|
||||||
|
|
||||||
|
content_disposition = headers.get("Content-Disposition", None)
|
||||||
|
if content_disposition:
|
||||||
|
_, params = cgi.parse_header(content_disposition[0],)
|
||||||
|
upload_name = None
|
||||||
|
|
||||||
|
# First check if there is a valid UTF-8 filename
|
||||||
|
upload_name_utf8 = params.get("filename*", None)
|
||||||
|
if upload_name_utf8:
|
||||||
|
if upload_name_utf8.lower().startswith("utf-8''"):
|
||||||
|
upload_name = upload_name_utf8[7:]
|
||||||
|
|
||||||
|
# If there isn't check for an ascii name.
|
||||||
|
if not upload_name:
|
||||||
|
upload_name_ascii = params.get("filename", None)
|
||||||
|
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||||
|
upload_name = upload_name_ascii
|
||||||
|
|
||||||
|
if upload_name:
|
||||||
|
upload_name = urlparse.unquote(upload_name)
|
||||||
|
try:
|
||||||
|
upload_name = upload_name.decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
upload_name = None
|
||||||
|
else:
|
||||||
|
upload_name = None
|
||||||
|
|
||||||
yield self.store.store_cached_remote_media(
|
yield self.store.store_cached_remote_media(
|
||||||
origin=server_name,
|
origin=server_name,
|
||||||
media_id=media_id,
|
media_id=media_id,
|
||||||
media_type=media_type,
|
media_type=media_type,
|
||||||
time_now_ms=self.clock.time_msec(),
|
time_now_ms=self.clock.time_msec(),
|
||||||
upload_name=None,
|
upload_name=upload_name,
|
||||||
media_length=length,
|
media_length=length,
|
||||||
filesystem_id=file_id,
|
filesystem_id=file_id,
|
||||||
)
|
)
|
||||||
@ -144,7 +184,7 @@ class BaseMediaResource(Resource):
|
|||||||
media_info = {
|
media_info = {
|
||||||
"media_type": media_type,
|
"media_type": media_type,
|
||||||
"media_length": length,
|
"media_length": length,
|
||||||
"upload_name": None,
|
"upload_name": upload_name,
|
||||||
"created_ts": time_now_ms,
|
"created_ts": time_now_ms,
|
||||||
"filesystem_id": file_id,
|
"filesystem_id": file_id,
|
||||||
}
|
}
|
||||||
@ -157,11 +197,26 @@ class BaseMediaResource(Resource):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_with_file(self, request, media_type, file_path,
|
def _respond_with_file(self, request, media_type, file_path,
|
||||||
file_size=None):
|
file_size=None, upload_name=None):
|
||||||
logger.debug("Responding with %r", file_path)
|
logger.debug("Responding with %r", file_path)
|
||||||
|
|
||||||
if os.path.isfile(file_path):
|
if os.path.isfile(file_path):
|
||||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||||
|
if upload_name:
|
||||||
|
if is_ascii(upload_name):
|
||||||
|
request.setHeader(
|
||||||
|
b"Content-Disposition",
|
||||||
|
b"inline; filename=%s" % (
|
||||||
|
urllib.quote(upload_name.encode("utf-8")),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
request.setHeader(
|
||||||
|
b"Content-Disposition",
|
||||||
|
b"inline; filename*=utf-8''%s" % (
|
||||||
|
urllib.quote(upload_name.encode("utf-8")),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# cache for at least a day.
|
# cache for at least a day.
|
||||||
# XXX: we might want to turn this off for data we don't want to
|
# XXX: we might want to turn this off for data we don't want to
|
||||||
@ -187,22 +242,74 @@ class BaseMediaResource(Resource):
|
|||||||
self._respond_404(request)
|
self._respond_404(request)
|
||||||
|
|
||||||
def _get_thumbnail_requirements(self, media_type):
|
def _get_thumbnail_requirements(self, media_type):
|
||||||
if media_type == "image/jpeg":
|
return self.thumbnail_requirements.get(media_type, ())
|
||||||
return (
|
|
||||||
(32, 32, "crop", "image/jpeg"),
|
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
|
||||||
(96, 96, "crop", "image/jpeg"),
|
t_method, t_type):
|
||||||
(320, 240, "scale", "image/jpeg"),
|
thumbnailer = Thumbnailer(input_path)
|
||||||
(640, 480, "scale", "image/jpeg"),
|
m_width = thumbnailer.width
|
||||||
)
|
m_height = thumbnailer.height
|
||||||
elif (media_type == "image/png") or (media_type == "image/gif"):
|
|
||||||
return (
|
if m_width * m_height >= self.max_image_pixels:
|
||||||
(32, 32, "crop", "image/png"),
|
logger.info(
|
||||||
(96, 96, "crop", "image/png"),
|
"Image too large to thumbnail %r x %r > %r",
|
||||||
(320, 240, "scale", "image/png"),
|
m_width, m_height, self.max_image_pixels
|
||||||
(640, 480, "scale", "image/png"),
|
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if t_method == "crop":
|
||||||
|
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||||
|
elif t_method == "scale":
|
||||||
|
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||||
else:
|
else:
|
||||||
return ()
|
t_len = None
|
||||||
|
|
||||||
|
return t_len
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||||
|
t_method, t_type):
|
||||||
|
input_path = self.filepaths.local_media_filepath(media_id)
|
||||||
|
|
||||||
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
self._makedirs(t_path)
|
||||||
|
|
||||||
|
t_len = yield threads.deferToThread(
|
||||||
|
self._generate_thumbnail,
|
||||||
|
input_path, t_path, t_width, t_height, t_method, t_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if t_len:
|
||||||
|
yield self.store.store_local_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(t_path)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||||
|
t_width, t_height, t_method, t_type):
|
||||||
|
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||||
|
|
||||||
|
t_path = self.filepaths.remote_media_thumbnail(
|
||||||
|
server_name, file_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
self._makedirs(t_path)
|
||||||
|
|
||||||
|
t_len = yield threads.deferToThread(
|
||||||
|
self._generate_thumbnail,
|
||||||
|
input_path, t_path, t_width, t_height, t_method, t_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if t_len:
|
||||||
|
yield self.store.store_remote_media_thumbnail(
|
||||||
|
server_name, media_id, file_id,
|
||||||
|
t_width, t_height, t_type, t_method, t_len
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue(t_path)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _generate_local_thumbnails(self, media_id, media_info):
|
def _generate_local_thumbnails(self, media_id, media_info):
|
||||||
@ -223,43 +330,52 @@ class BaseMediaResource(Resource):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
scales = set()
|
local_thumbnails = []
|
||||||
crops = set()
|
|
||||||
for r_width, r_height, r_method, r_type in requirements:
|
def generate_thumbnails():
|
||||||
if r_method == "scale":
|
scales = set()
|
||||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
crops = set()
|
||||||
scales.add((
|
for r_width, r_height, r_method, r_type in requirements:
|
||||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
if r_method == "scale":
|
||||||
|
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||||
|
scales.add((
|
||||||
|
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||||
|
))
|
||||||
|
elif r_method == "crop":
|
||||||
|
crops.add((r_width, r_height, r_type))
|
||||||
|
|
||||||
|
for t_width, t_height, t_type in scales:
|
||||||
|
t_method = "scale"
|
||||||
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
|
media_id, t_width, t_height, t_type, t_method
|
||||||
|
)
|
||||||
|
self._makedirs(t_path)
|
||||||
|
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||||
|
|
||||||
|
local_thumbnails.append((
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
))
|
))
|
||||||
elif r_method == "crop":
|
|
||||||
crops.add((r_width, r_height, r_type))
|
|
||||||
|
|
||||||
for t_width, t_height, t_type in scales:
|
for t_width, t_height, t_type in crops:
|
||||||
t_method = "scale"
|
if (t_width, t_height, t_type) in scales:
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||||
media_id, t_width, t_height, t_type, t_method
|
# scaled one then there is no point in calculating a separate
|
||||||
)
|
# thumbnail.
|
||||||
self._makedirs(t_path)
|
continue
|
||||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
t_method = "crop"
|
||||||
yield self.store.store_local_thumbnail(
|
t_path = self.filepaths.local_media_thumbnail(
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
media_id, t_width, t_height, t_type, t_method
|
||||||
)
|
)
|
||||||
|
self._makedirs(t_path)
|
||||||
|
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||||
|
local_thumbnails.append((
|
||||||
|
media_id, t_width, t_height, t_type, t_method, t_len
|
||||||
|
))
|
||||||
|
|
||||||
for t_width, t_height, t_type in crops:
|
yield threads.deferToThread(generate_thumbnails)
|
||||||
if (t_width, t_height, t_type) in scales:
|
|
||||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
for l in local_thumbnails:
|
||||||
# scaled one then there is no point in calculating a separate
|
yield self.store.store_local_thumbnail(*l)
|
||||||
# thumbnail.
|
|
||||||
continue
|
|
||||||
t_method = "crop"
|
|
||||||
t_path = self.filepaths.local_media_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method
|
|
||||||
)
|
|
||||||
self._makedirs(t_path)
|
|
||||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
|
||||||
yield self.store.store_local_thumbnail(
|
|
||||||
media_id, t_width, t_height, t_type, t_method, t_len
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"width": m_width,
|
"width": m_width,
|
||||||
|
@ -32,14 +32,16 @@ class DownloadResource(BaseMediaResource):
|
|||||||
@request_handler
|
@request_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
server_name, media_id = parse_media_id(request)
|
server_name, media_id, name = parse_media_id(request)
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
yield self._respond_local_file(request, media_id)
|
yield self._respond_local_file(request, media_id, name)
|
||||||
else:
|
else:
|
||||||
yield self._respond_remote_file(request, server_name, media_id)
|
yield self._respond_remote_file(
|
||||||
|
request, server_name, media_id, name
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_local_file(self, request, media_id):
|
def _respond_local_file(self, request, media_id, name):
|
||||||
media_info = yield self.store.get_local_media(media_id)
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
if not media_info:
|
if not media_info:
|
||||||
self._respond_404(request)
|
self._respond_404(request)
|
||||||
@ -47,24 +49,28 @@ class DownloadResource(BaseMediaResource):
|
|||||||
|
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info["media_type"]
|
||||||
media_length = media_info["media_length"]
|
media_length = media_info["media_length"]
|
||||||
|
upload_name = name if name else media_info["upload_name"]
|
||||||
file_path = self.filepaths.local_media_filepath(media_id)
|
file_path = self.filepaths.local_media_filepath(media_id)
|
||||||
|
|
||||||
yield self._respond_with_file(
|
yield self._respond_with_file(
|
||||||
request, media_type, file_path, media_length
|
request, media_type, file_path, media_length,
|
||||||
|
upload_name=upload_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_remote_file(self, request, server_name, media_id):
|
def _respond_remote_file(self, request, server_name, media_id, name):
|
||||||
media_info = yield self._get_remote_media(server_name, media_id)
|
media_info = yield self._get_remote_media(server_name, media_id)
|
||||||
|
|
||||||
media_type = media_info["media_type"]
|
media_type = media_info["media_type"]
|
||||||
media_length = media_info["media_length"]
|
media_length = media_info["media_length"]
|
||||||
filesystem_id = media_info["filesystem_id"]
|
filesystem_id = media_info["filesystem_id"]
|
||||||
|
upload_name = name if name else media_info["upload_name"]
|
||||||
|
|
||||||
file_path = self.filepaths.remote_media_filepath(
|
file_path = self.filepaths.remote_media_filepath(
|
||||||
server_name, filesystem_id
|
server_name, filesystem_id
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self._respond_with_file(
|
yield self._respond_with_file(
|
||||||
request, media_type, file_path, media_length
|
request, media_type, file_path, media_length,
|
||||||
|
upload_name=upload_name,
|
||||||
)
|
)
|
||||||
|
@ -36,21 +36,32 @@ class ThumbnailResource(BaseMediaResource):
|
|||||||
@request_handler
|
@request_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_GET(self, request):
|
def _async_render_GET(self, request):
|
||||||
server_name, media_id = parse_media_id(request)
|
server_name, media_id, _ = parse_media_id(request)
|
||||||
width = parse_integer(request, "width")
|
width = parse_integer(request, "width")
|
||||||
height = parse_integer(request, "height")
|
height = parse_integer(request, "height")
|
||||||
method = parse_string(request, "method", "scale")
|
method = parse_string(request, "method", "scale")
|
||||||
m_type = parse_string(request, "type", "image/png")
|
m_type = parse_string(request, "type", "image/png")
|
||||||
|
|
||||||
if server_name == self.server_name:
|
if server_name == self.server_name:
|
||||||
yield self._respond_local_thumbnail(
|
if self.dynamic_thumbnails:
|
||||||
request, media_id, width, height, method, m_type
|
yield self._select_or_generate_local_thumbnail(
|
||||||
)
|
request, media_id, width, height, method, m_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield self._respond_local_thumbnail(
|
||||||
|
request, media_id, width, height, method, m_type
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield self._respond_remote_thumbnail(
|
if self.dynamic_thumbnails:
|
||||||
request, server_name, media_id,
|
yield self._select_or_generate_remote_thumbnail(
|
||||||
width, height, method, m_type
|
request, server_name, media_id,
|
||||||
)
|
width, height, method, m_type
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield self._respond_remote_thumbnail(
|
||||||
|
request, server_name, media_id,
|
||||||
|
width, height, method, m_type
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||||
@ -82,6 +93,87 @@ class ThumbnailResource(BaseMediaResource):
|
|||||||
request, media_info, width, height, method, m_type,
|
request, media_info, width, height, method, m_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
|
||||||
|
desired_height, desired_method,
|
||||||
|
desired_type):
|
||||||
|
media_info = yield self.store.get_local_media(media_id)
|
||||||
|
|
||||||
|
if not media_info:
|
||||||
|
self._respond_404(request)
|
||||||
|
return
|
||||||
|
|
||||||
|
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
t_w = info["thumbnail_width"] == desired_width
|
||||||
|
t_h = info["thumbnail_height"] == desired_height
|
||||||
|
t_method = info["thumbnail_method"] == desired_method
|
||||||
|
t_type = info["thumbnail_type"] == desired_type
|
||||||
|
|
||||||
|
if t_w and t_h and t_method and t_type:
|
||||||
|
file_path = self.filepaths.local_media_thumbnail(
|
||||||
|
media_id, desired_width, desired_height, desired_type, desired_method,
|
||||||
|
)
|
||||||
|
yield self._respond_with_file(request, desired_type, file_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("We don't have a local thumbnail of that size. Generating")
|
||||||
|
|
||||||
|
# Okay, so we generate one.
|
||||||
|
file_path = yield self._generate_local_exact_thumbnail(
|
||||||
|
media_id, desired_width, desired_height, desired_method, desired_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
yield self._respond_with_file(request, desired_type, file_path)
|
||||||
|
else:
|
||||||
|
yield self._respond_default_thumbnail(
|
||||||
|
request, media_info, desired_width, desired_height,
|
||||||
|
desired_method, desired_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
|
||||||
|
desired_width, desired_height,
|
||||||
|
desired_method, desired_type):
|
||||||
|
media_info = yield self._get_remote_media(server_name, media_id)
|
||||||
|
|
||||||
|
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||||
|
server_name, media_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id = media_info["filesystem_id"]
|
||||||
|
|
||||||
|
for info in thumbnail_infos:
|
||||||
|
t_w = info["thumbnail_width"] == desired_width
|
||||||
|
t_h = info["thumbnail_height"] == desired_height
|
||||||
|
t_method = info["thumbnail_method"] == desired_method
|
||||||
|
t_type = info["thumbnail_type"] == desired_type
|
||||||
|
|
||||||
|
if t_w and t_h and t_method and t_type:
|
||||||
|
file_path = self.filepaths.remote_media_thumbnail(
|
||||||
|
server_name, file_id, desired_width, desired_height,
|
||||||
|
desired_type, desired_method,
|
||||||
|
)
|
||||||
|
yield self._respond_with_file(request, desired_type, file_path)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug("We don't have a local thumbnail of that size. Generating")
|
||||||
|
|
||||||
|
# Okay, so we generate one.
|
||||||
|
file_path = yield self._generate_remote_exact_thumbnail(
|
||||||
|
server_name, file_id, media_id, desired_width,
|
||||||
|
desired_height, desired_method, desired_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if file_path:
|
||||||
|
yield self._respond_with_file(request, desired_type, file_path)
|
||||||
|
else:
|
||||||
|
yield self._respond_default_thumbnail(
|
||||||
|
request, media_info, desired_width, desired_height,
|
||||||
|
desired_method, desired_type,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
||||||
height, method, m_type):
|
height, method, m_type):
|
||||||
@ -162,11 +254,12 @@ class ThumbnailResource(BaseMediaResource):
|
|||||||
t_method = info["thumbnail_method"]
|
t_method = info["thumbnail_method"]
|
||||||
if t_method == "scale" or t_method == "crop":
|
if t_method == "scale" or t_method == "crop":
|
||||||
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
aspect_quality = abs(d_w * t_h - d_h * t_w)
|
||||||
|
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
|
||||||
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
size_quality = abs((d_w - t_w) * (d_h - t_h))
|
||||||
type_quality = desired_type != info["thumbnail_type"]
|
type_quality = desired_type != info["thumbnail_type"]
|
||||||
length_quality = info["thumbnail_length"]
|
length_quality = info["thumbnail_length"]
|
||||||
info_list.append((
|
info_list.append((
|
||||||
aspect_quality, size_quality, type_quality,
|
aspect_quality, min_quality, size_quality, type_quality,
|
||||||
length_quality, info
|
length_quality, info
|
||||||
))
|
))
|
||||||
if info_list:
|
if info_list:
|
||||||
|
@ -82,7 +82,7 @@ class Thumbnailer(object):
|
|||||||
|
|
||||||
def save_image(self, output_image, output_type, output_path):
|
def save_image(self, output_image, output_type, output_path):
|
||||||
output_bytes_io = BytesIO()
|
output_bytes_io = BytesIO()
|
||||||
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70)
|
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
|
||||||
output_bytes = output_bytes_io.getvalue()
|
output_bytes = output_bytes_io.getvalue()
|
||||||
with open(output_path, "wb") as output_file:
|
with open(output_path, "wb") as output_file:
|
||||||
output_file.write(output_bytes)
|
output_file.write(output_bytes)
|
||||||
|
@ -84,6 +84,16 @@ class UploadResource(BaseMediaResource):
|
|||||||
code=413,
|
code=413,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upload_name = request.args.get("filename", None)
|
||||||
|
if upload_name:
|
||||||
|
try:
|
||||||
|
upload_name = upload_name[0].decode('UTF-8')
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise SynapseError(
|
||||||
|
msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
|
||||||
|
code=400,
|
||||||
|
)
|
||||||
|
|
||||||
headers = request.requestHeaders
|
headers = request.requestHeaders
|
||||||
|
|
||||||
if headers.hasHeader("Content-Type"):
|
if headers.hasHeader("Content-Type"):
|
||||||
@ -99,7 +109,7 @@ class UploadResource(BaseMediaResource):
|
|||||||
# TODO(markjh): parse content-dispostion
|
# TODO(markjh): parse content-dispostion
|
||||||
|
|
||||||
content_uri = yield self.create_content(
|
content_uri = yield self.create_content(
|
||||||
media_type, None, request.content.read(),
|
media_type, upload_name, request.content.read(),
|
||||||
content_length, auth_user
|
content_length, auth_user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.api.errors import AuthError
|
from synapse.api.errors import AuthError
|
||||||
from synapse.api.auth import AuthEventTypes
|
from synapse.api.auth import AuthEventTypes
|
||||||
@ -96,7 +96,7 @@ class StateHandler(object):
|
|||||||
cache.ts = self.clock.time_msec()
|
cache.ts = self.clock.time_msec()
|
||||||
state = cache.state
|
state = cache.state
|
||||||
else:
|
else:
|
||||||
res = yield self.resolve_state_groups(event_ids)
|
res = yield self.resolve_state_groups(room_id, event_ids)
|
||||||
state = res[1]
|
state = res[1]
|
||||||
|
|
||||||
if event_type:
|
if event_type:
|
||||||
@ -155,13 +155,13 @@ class StateHandler(object):
|
|||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
ret = yield self.resolve_state_groups(
|
ret = yield self.resolve_state_groups(
|
||||||
[e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
event_type=event.type,
|
event_type=event.type,
|
||||||
state_key=event.state_key,
|
state_key=event.state_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret = yield self.resolve_state_groups(
|
ret = yield self.resolve_state_groups(
|
||||||
[e for e, _ in event.prev_events],
|
event.room_id, [e for e, _ in event.prev_events],
|
||||||
)
|
)
|
||||||
|
|
||||||
group, curr_state, prev_state = ret
|
group, curr_state, prev_state = ret
|
||||||
@ -180,7 +180,7 @@ class StateHandler(object):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def resolve_state_groups(self, event_ids, event_type=None, state_key=""):
|
def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
|
||||||
""" Given a list of event_ids this method fetches the state at each
|
""" Given a list of event_ids this method fetches the state at each
|
||||||
event, resolves conflicts between them and returns them.
|
event, resolves conflicts between them and returns them.
|
||||||
|
|
||||||
@ -205,7 +205,7 @@ class StateHandler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
state_groups = yield self.store.get_state_groups(
|
state_groups = yield self.store.get_state_groups(
|
||||||
event_ids
|
room_id, event_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -37,6 +37,9 @@ from .rejections import RejectionsStore
|
|||||||
from .state import StateStore
|
from .state import StateStore
|
||||||
from .signatures import SignatureStore
|
from .signatures import SignatureStore
|
||||||
from .filtering import FilteringStore
|
from .filtering import FilteringStore
|
||||||
|
from .end_to_end_keys import EndToEndKeyStore
|
||||||
|
|
||||||
|
from .receipts import ReceiptsStore
|
||||||
|
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
@ -51,7 +54,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# Remember to update this number every time a change is made to database
|
# Remember to update this number every time a change is made to database
|
||||||
# schema files, so the users will be informed on server restarts.
|
# schema files, so the users will be informed on server restarts.
|
||||||
SCHEMA_VERSION = 20
|
SCHEMA_VERSION = 22
|
||||||
|
|
||||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||||
|
|
||||||
@ -74,6 +77,8 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
PushRuleStore,
|
PushRuleStore,
|
||||||
ApplicationServiceTransactionStore,
|
ApplicationServiceTransactionStore,
|
||||||
EventsStore,
|
EventsStore,
|
||||||
|
ReceiptsStore,
|
||||||
|
EndToEndKeyStore,
|
||||||
):
|
):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
@ -94,7 +99,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
key = (user.to_string(), access_token, device_id, ip)
|
key = (user.to_string(), access_token, device_id, ip)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
last_seen = self.client_ip_last_seen.get(*key)
|
last_seen = self.client_ip_last_seen.get(key)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
last_seen = None
|
last_seen = None
|
||||||
|
|
||||||
@ -102,7 +107,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
|||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
self.client_ip_last_seen.prefill(*key + (now,))
|
self.client_ip_last_seen.prefill(key, now)
|
||||||
|
|
||||||
# It's safe not to lock here: a) no unique constraint,
|
# It's safe not to lock here: a) no unique constraint,
|
||||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
||||||
@ -349,6 +354,11 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
|||||||
)
|
)
|
||||||
logger.debug("Running script %s", relative_path)
|
logger.debug("Running script %s", relative_path)
|
||||||
module.run_upgrade(cur, database_engine)
|
module.run_upgrade(cur, database_engine)
|
||||||
|
elif ext == ".pyc":
|
||||||
|
# Sometimes .pyc files turn up anyway even though we've
|
||||||
|
# disabled their generation; e.g. from distribution package
|
||||||
|
# installers. Silently skip it
|
||||||
|
pass
|
||||||
elif ext == ".sql":
|
elif ext == ".sql":
|
||||||
# A plain old .sql file, just read and execute it
|
# A plain old .sql file, just read and execute it
|
||||||
logger.debug("Applying schema %s", relative_path)
|
logger.debug("Applying schema %s", relative_path)
|
||||||
|
@ -17,21 +17,20 @@ import logging
|
|||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
|
||||||
from synapse.util.lrucache import LruCache
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
|
from synapse.util.caches.descriptors import Cache
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from collections import namedtuple, OrderedDict
|
from collections import namedtuple
|
||||||
|
|
||||||
import functools
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
DEBUG_CACHES = False
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -47,159 +46,6 @@ sql_scheduling_timer = metrics.register_distribution("schedule_time")
|
|||||||
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
|
||||||
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
|
||||||
|
|
||||||
caches_by_name = {}
|
|
||||||
cache_counter = metrics.register_cache(
|
|
||||||
"cache",
|
|
||||||
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
|
|
||||||
labels=["name"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Cache(object):
|
|
||||||
|
|
||||||
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
|
||||||
if lru:
|
|
||||||
self.cache = LruCache(max_size=max_entries)
|
|
||||||
self.max_entries = None
|
|
||||||
else:
|
|
||||||
self.cache = OrderedDict()
|
|
||||||
self.max_entries = max_entries
|
|
||||||
|
|
||||||
self.name = name
|
|
||||||
self.keylen = keylen
|
|
||||||
self.sequence = 0
|
|
||||||
self.thread = None
|
|
||||||
caches_by_name[name] = self.cache
|
|
||||||
|
|
||||||
def check_thread(self):
|
|
||||||
expected_thread = self.thread
|
|
||||||
if expected_thread is None:
|
|
||||||
self.thread = threading.current_thread()
|
|
||||||
else:
|
|
||||||
if expected_thread is not threading.current_thread():
|
|
||||||
raise ValueError(
|
|
||||||
"Cache objects can only be accessed from the main thread"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get(self, *keyargs):
|
|
||||||
if len(keyargs) != self.keylen:
|
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
|
||||||
|
|
||||||
if keyargs in self.cache:
|
|
||||||
cache_counter.inc_hits(self.name)
|
|
||||||
return self.cache[keyargs]
|
|
||||||
|
|
||||||
cache_counter.inc_misses(self.name)
|
|
||||||
raise KeyError()
|
|
||||||
|
|
||||||
def update(self, sequence, *args):
|
|
||||||
self.check_thread()
|
|
||||||
if self.sequence == sequence:
|
|
||||||
# Only update the cache if the caches sequence number matches the
|
|
||||||
# number that the cache had before the SELECT was started (SYN-369)
|
|
||||||
self.prefill(*args)
|
|
||||||
|
|
||||||
def prefill(self, *args): # because I can't *keyargs, value
|
|
||||||
keyargs = args[:-1]
|
|
||||||
value = args[-1]
|
|
||||||
|
|
||||||
if len(keyargs) != self.keylen:
|
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
|
||||||
|
|
||||||
if self.max_entries is not None:
|
|
||||||
while len(self.cache) >= self.max_entries:
|
|
||||||
self.cache.popitem(last=False)
|
|
||||||
|
|
||||||
self.cache[keyargs] = value
|
|
||||||
|
|
||||||
def invalidate(self, *keyargs):
|
|
||||||
self.check_thread()
|
|
||||||
if len(keyargs) != self.keylen:
|
|
||||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
|
||||||
# Increment the sequence number so that any SELECT statements that
|
|
||||||
# raced with the INSERT don't update the cache (SYN-369)
|
|
||||||
self.sequence += 1
|
|
||||||
self.cache.pop(keyargs, None)
|
|
||||||
|
|
||||||
def invalidate_all(self):
|
|
||||||
self.check_thread()
|
|
||||||
self.sequence += 1
|
|
||||||
self.cache.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class CacheDescriptor(object):
|
|
||||||
""" A method decorator that applies a memoizing cache around the function.
|
|
||||||
|
|
||||||
The function is presumed to take zero or more arguments, which are used in
|
|
||||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
|
||||||
misses use the function body to generate the value.
|
|
||||||
|
|
||||||
The wrapped function has an additional member, a callable called
|
|
||||||
"invalidate". This can be used to remove individual entries from the cache.
|
|
||||||
|
|
||||||
The wrapped function has another additional callable, called "prefill",
|
|
||||||
which can be used to insert values into the cache specifically, without
|
|
||||||
calling the calculation function.
|
|
||||||
"""
|
|
||||||
def __init__(self, orig, max_entries=1000, num_args=1, lru=False):
|
|
||||||
self.orig = orig
|
|
||||||
|
|
||||||
self.max_entries = max_entries
|
|
||||||
self.num_args = num_args
|
|
||||||
self.lru = lru
|
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
|
||||||
cache = Cache(
|
|
||||||
name=self.orig.__name__,
|
|
||||||
max_entries=self.max_entries,
|
|
||||||
keylen=self.num_args,
|
|
||||||
lru=self.lru,
|
|
||||||
)
|
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def wrapped(*keyargs):
|
|
||||||
try:
|
|
||||||
cached_result = cache.get(*keyargs[:self.num_args])
|
|
||||||
if DEBUG_CACHES:
|
|
||||||
actual_result = yield self.orig(obj, *keyargs)
|
|
||||||
if actual_result != cached_result:
|
|
||||||
logger.error(
|
|
||||||
"Stale cache entry %s%r: cached: %r, actual %r",
|
|
||||||
self.orig.__name__, keyargs,
|
|
||||||
cached_result, actual_result,
|
|
||||||
)
|
|
||||||
raise ValueError("Stale cache entry")
|
|
||||||
defer.returnValue(cached_result)
|
|
||||||
except KeyError:
|
|
||||||
# Get the sequence number of the cache before reading from the
|
|
||||||
# database so that we can tell if the cache is invalidated
|
|
||||||
# while the SELECT is executing (SYN-369)
|
|
||||||
sequence = cache.sequence
|
|
||||||
|
|
||||||
ret = yield self.orig(obj, *keyargs)
|
|
||||||
|
|
||||||
cache.update(sequence, *keyargs[:self.num_args] + (ret,))
|
|
||||||
|
|
||||||
defer.returnValue(ret)
|
|
||||||
|
|
||||||
wrapped.invalidate = cache.invalidate
|
|
||||||
wrapped.invalidate_all = cache.invalidate_all
|
|
||||||
wrapped.prefill = cache.prefill
|
|
||||||
|
|
||||||
obj.__dict__[self.orig.__name__] = wrapped
|
|
||||||
|
|
||||||
return wrapped
|
|
||||||
|
|
||||||
|
|
||||||
def cached(max_entries=1000, num_args=1, lru=False):
|
|
||||||
return lambda orig: CacheDescriptor(
|
|
||||||
orig,
|
|
||||||
max_entries=max_entries,
|
|
||||||
num_args=num_args,
|
|
||||||
lru=lru
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingTransaction(object):
|
class LoggingTransaction(object):
|
||||||
"""An object that almost-transparently proxies for the 'txn' object
|
"""An object that almost-transparently proxies for the 'txn' object
|
||||||
@ -321,6 +167,8 @@ class SQLBaseStore(object):
|
|||||||
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
||||||
max_entries=hs.config.event_cache_size)
|
max_entries=hs.config.event_cache_size)
|
||||||
|
|
||||||
|
self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000)
|
||||||
|
|
||||||
self._event_fetch_lock = threading.Condition()
|
self._event_fetch_lock = threading.Condition()
|
||||||
self._event_fetch_list = []
|
self._event_fetch_list = []
|
||||||
self._event_fetch_ongoing = 0
|
self._event_fetch_ongoing = 0
|
||||||
@ -329,13 +177,14 @@ class SQLBaseStore(object):
|
|||||||
|
|
||||||
self.database_engine = hs.database_engine
|
self.database_engine = hs.database_engine
|
||||||
|
|
||||||
self._stream_id_gen = StreamIdGenerator()
|
self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
|
||||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||||
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
|
||||||
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
|
||||||
|
self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
|
||||||
|
|
||||||
def start_profiling(self):
|
def start_profiling(self):
|
||||||
self._previous_loop_ts = self._clock.time_msec()
|
self._previous_loop_ts = self._clock.time_msec()
|
||||||
|
@ -13,7 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
|
|
||||||
@ -104,7 +105,7 @@ class DirectoryStore(SQLBaseStore):
|
|||||||
},
|
},
|
||||||
desc="create_room_alias_association",
|
desc="create_room_alias_association",
|
||||||
)
|
)
|
||||||
self.get_aliases_for_room.invalidate(room_id)
|
self.get_aliases_for_room.invalidate((room_id,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_room_alias(self, room_alias):
|
def delete_room_alias(self, room_alias):
|
||||||
@ -114,7 +115,7 @@ class DirectoryStore(SQLBaseStore):
|
|||||||
room_alias,
|
room_alias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_aliases_for_room.invalidate(room_id)
|
self.get_aliases_for_room.invalidate((room_id,))
|
||||||
defer.returnValue(room_id)
|
defer.returnValue(room_id)
|
||||||
|
|
||||||
def _delete_room_alias_txn(self, txn, room_alias):
|
def _delete_room_alias_txn(self, txn, room_alias):
|
||||||
|
125
synapse/storage/end_to_end_keys.py
Normal file
125
synapse/storage/end_to_end_keys.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from _base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
class EndToEndKeyStore(SQLBaseStore):
|
||||||
|
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
|
||||||
|
return self._simple_upsert(
|
||||||
|
table="e2e_device_keys_json",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"ts_added_ms": time_now,
|
||||||
|
"key_json": json_bytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_e2e_device_keys(self, query_list):
|
||||||
|
"""Fetch a list of device keys.
|
||||||
|
Args:
|
||||||
|
query_list(list): List of pairs of user_ids and device_ids.
|
||||||
|
Returns:
|
||||||
|
Dict mapping from user-id to dict mapping from device_id to
|
||||||
|
key json byte strings.
|
||||||
|
"""
|
||||||
|
def _get_e2e_device_keys(txn):
|
||||||
|
result = {}
|
||||||
|
for user_id, device_id in query_list:
|
||||||
|
user_result = result.setdefault(user_id, {})
|
||||||
|
keyvalues = {"user_id": user_id}
|
||||||
|
if device_id:
|
||||||
|
keyvalues["device_id"] = device_id
|
||||||
|
rows = self._simple_select_list_txn(
|
||||||
|
txn, table="e2e_device_keys_json",
|
||||||
|
keyvalues=keyvalues,
|
||||||
|
retcols=["device_id", "key_json"]
|
||||||
|
)
|
||||||
|
for row in rows:
|
||||||
|
user_result[row["device_id"]] = row["key_json"]
|
||||||
|
return result
|
||||||
|
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
|
||||||
|
|
||||||
|
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
||||||
|
def _add_e2e_one_time_keys(txn):
|
||||||
|
for (algorithm, key_id, json_bytes) in key_list:
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn, table="e2e_one_time_keys_json",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"key_id": key_id,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"ts_added_ms": time_now,
|
||||||
|
"key_json": json_bytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"add_e2e_one_time_keys", _add_e2e_one_time_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
def count_e2e_one_time_keys(self, user_id, device_id):
|
||||||
|
""" Count the number of one time keys the server has for a device
|
||||||
|
Returns:
|
||||||
|
Dict mapping from algorithm to number of keys for that algorithm.
|
||||||
|
"""
|
||||||
|
def _count_e2e_one_time_keys(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ?"
|
||||||
|
" GROUP BY algorithm"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id, device_id))
|
||||||
|
result = {}
|
||||||
|
for algorithm, key_count in txn.fetchall():
|
||||||
|
result[algorithm] = key_count
|
||||||
|
return result
|
||||||
|
return self.runInteraction(
|
||||||
|
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
def claim_e2e_one_time_keys(self, query_list):
|
||||||
|
"""Take a list of one time keys out of the database"""
|
||||||
|
def _claim_e2e_one_time_keys(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
|
" LIMIT 1"
|
||||||
|
)
|
||||||
|
result = {}
|
||||||
|
delete = []
|
||||||
|
for user_id, device_id, algorithm in query_list:
|
||||||
|
user_result = result.setdefault(user_id, {})
|
||||||
|
device_result = user_result.setdefault(device_id, {})
|
||||||
|
txn.execute(sql, (user_id, device_id, algorithm))
|
||||||
|
for key_id, key_json in txn.fetchall():
|
||||||
|
device_result[algorithm + ":" + key_id] = key_json
|
||||||
|
delete.append((user_id, device_id, algorithm, key_id))
|
||||||
|
sql = (
|
||||||
|
"DELETE FROM e2e_one_time_keys_json"
|
||||||
|
" WHERE user_id = ? AND device_id = ? AND algorithm = ?"
|
||||||
|
" AND key_id = ?"
|
||||||
|
)
|
||||||
|
for user_id, device_id, algorithm, key_id in delete:
|
||||||
|
txn.execute(sql, (user_id, device_id, algorithm, key_id))
|
||||||
|
return result
|
||||||
|
return self.runInteraction(
|
||||||
|
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
|
||||||
|
)
|
@ -15,7 +15,8 @@
|
|||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
from syutil.base64util import encode_base64
|
from syutil.base64util import encode_base64
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -49,14 +50,22 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
results = set()
|
results = set()
|
||||||
|
|
||||||
base_sql = (
|
base_sql = (
|
||||||
"SELECT auth_id FROM event_auth WHERE event_id = ?"
|
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
front = set(event_ids)
|
front = set(event_ids)
|
||||||
while front:
|
while front:
|
||||||
new_front = set()
|
new_front = set()
|
||||||
for f in front:
|
front_list = list(front)
|
||||||
txn.execute(base_sql, (f,))
|
chunks = [
|
||||||
|
front_list[x:x+100]
|
||||||
|
for x in xrange(0, len(front), 100)
|
||||||
|
]
|
||||||
|
for chunk in chunks:
|
||||||
|
txn.execute(
|
||||||
|
base_sql % (",".join(["?"] * len(chunk)),),
|
||||||
|
chunk
|
||||||
|
)
|
||||||
new_front.update([r[0] for r in txn.fetchall()])
|
new_front.update([r[0] for r in txn.fetchall()])
|
||||||
|
|
||||||
new_front -= results
|
new_front -= results
|
||||||
@ -274,8 +283,7 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
|
def _handle_mult_prev_events(self, txn, events):
|
||||||
room_id):
|
|
||||||
"""
|
"""
|
||||||
For the given event, update the event edges table and forward and
|
For the given event, update the event edges table and forward and
|
||||||
backward extremities tables.
|
backward extremities tables.
|
||||||
@ -285,70 +293,83 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
table="event_edges",
|
table="event_edges",
|
||||||
values=[
|
values=[
|
||||||
{
|
{
|
||||||
"event_id": event_id,
|
"event_id": ev.event_id,
|
||||||
"prev_event_id": e_id,
|
"prev_event_id": e_id,
|
||||||
"room_id": room_id,
|
"room_id": ev.room_id,
|
||||||
"is_state": False,
|
"is_state": False,
|
||||||
}
|
}
|
||||||
for e_id, _ in prev_events
|
for ev in events
|
||||||
|
for e_id, _ in ev.prev_events
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update the extremities table if this is not an outlier.
|
events_by_room = {}
|
||||||
if not outlier:
|
for ev in events:
|
||||||
for e_id, _ in prev_events:
|
events_by_room.setdefault(ev.room_id, []).append(ev)
|
||||||
# TODO (erikj): This could be done as a bulk insert
|
|
||||||
self._simple_delete_txn(
|
for room_id, room_events in events_by_room.items():
|
||||||
txn,
|
prevs = [
|
||||||
table="event_forward_extremities",
|
e_id for ev in room_events for e_id, _ in ev.prev_events
|
||||||
keyvalues={
|
if not ev.internal_metadata.is_outlier()
|
||||||
"event_id": e_id,
|
]
|
||||||
"room_id": room_id,
|
if prevs:
|
||||||
}
|
txn.execute(
|
||||||
|
"DELETE FROM event_forward_extremities"
|
||||||
|
" WHERE room_id = ?"
|
||||||
|
" AND event_id in (%s)" % (
|
||||||
|
",".join(["?"] * len(prevs)),
|
||||||
|
),
|
||||||
|
[room_id] + prevs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We only insert as a forward extremity the new event if there are
|
query = (
|
||||||
# no other events that reference it as a prev event
|
"INSERT INTO event_forward_extremities (event_id, room_id)"
|
||||||
query = (
|
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||||
"SELECT 1 FROM event_edges WHERE prev_event_id = ?"
|
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
|
||||||
)
|
" )"
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(query, (event_id,))
|
txn.executemany(
|
||||||
|
query,
|
||||||
|
[
|
||||||
|
(ev.event_id, ev.room_id, ev.event_id) for ev in events
|
||||||
|
if not ev.internal_metadata.is_outlier()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
if not txn.fetchone():
|
query = (
|
||||||
query = (
|
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
||||||
"INSERT INTO event_forward_extremities"
|
" SELECT ?, ? WHERE NOT EXISTS ("
|
||||||
" (event_id, room_id)"
|
" SELECT 1 FROM event_backward_extremities"
|
||||||
" VALUES (?, ?)"
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
)
|
" )"
|
||||||
|
" AND NOT EXISTS ("
|
||||||
|
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
|
||||||
|
" AND outlier = ?"
|
||||||
|
" )"
|
||||||
|
)
|
||||||
|
|
||||||
txn.execute(query, (event_id, room_id))
|
txn.executemany(query, [
|
||||||
|
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
|
||||||
|
for ev in events for e_id, _ in ev.prev_events
|
||||||
|
if not ev.internal_metadata.is_outlier()
|
||||||
|
])
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
"INSERT INTO event_backward_extremities (event_id, room_id)"
|
"DELETE FROM event_backward_extremities"
|
||||||
" SELECT ?, ? WHERE NOT EXISTS ("
|
" WHERE event_id = ? AND room_id = ?"
|
||||||
" SELECT 1 FROM event_backward_extremities"
|
)
|
||||||
" WHERE event_id = ? AND room_id = ?"
|
txn.executemany(
|
||||||
" )"
|
query,
|
||||||
" AND NOT EXISTS ("
|
[
|
||||||
" SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
|
(ev.event_id, ev.room_id) for ev in events
|
||||||
" AND outlier = ?"
|
if not ev.internal_metadata.is_outlier()
|
||||||
" )"
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.executemany(query, [
|
|
||||||
(e_id, room_id, e_id, room_id, e_id, room_id, False)
|
|
||||||
for e_id, _ in prev_events
|
|
||||||
])
|
|
||||||
|
|
||||||
query = (
|
|
||||||
"DELETE FROM event_backward_extremities"
|
|
||||||
" WHERE event_id = ? AND room_id = ?"
|
|
||||||
)
|
|
||||||
txn.execute(query, (event_id, room_id))
|
|
||||||
|
|
||||||
|
for room_id in events_by_room:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_latest_event_ids_in_room.invalidate, room_id
|
self.get_latest_event_ids_in_room.invalidate, (room_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_backfill_events(self, room_id, event_list, limit):
|
def get_backfill_events(self, room_id, event_list, limit):
|
||||||
@ -400,10 +421,12 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
keyvalues={
|
keyvalues={
|
||||||
"event_id": event_id,
|
"event_id": event_id,
|
||||||
},
|
},
|
||||||
retcol="depth"
|
retcol="depth",
|
||||||
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
queue.put((-depth, event_id))
|
if depth:
|
||||||
|
queue.put((-depth, event_id))
|
||||||
|
|
||||||
while not queue.empty() and len(event_results) < limit:
|
while not queue.empty() and len(event_results) < limit:
|
||||||
try:
|
try:
|
||||||
@ -489,4 +512,4 @@ class EventFederationStore(SQLBaseStore):
|
|||||||
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
|
||||||
|
|
||||||
txn.execute(query, (room_id,))
|
txn.execute(query, (room_id,))
|
||||||
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)
|
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
|
||||||
|
@ -23,9 +23,7 @@ from synapse.events.utils import prune_event
|
|||||||
from synapse.util.logcontext import preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_context_over_deferred
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
|
||||||
|
|
||||||
from syutil.base64util import decode_base64
|
|
||||||
from syutil.jsonutil import encode_json
|
from syutil.jsonutil import encode_json
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
@ -46,6 +44,48 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
|||||||
|
|
||||||
|
|
||||||
class EventsStore(SQLBaseStore):
|
class EventsStore(SQLBaseStore):
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def persist_events(self, events_and_contexts, backfilled=False,
|
||||||
|
is_new_state=True):
|
||||||
|
if not events_and_contexts:
|
||||||
|
return
|
||||||
|
|
||||||
|
if backfilled:
|
||||||
|
if not self.min_token_deferred.called:
|
||||||
|
yield self.min_token_deferred
|
||||||
|
start = self.min_token - 1
|
||||||
|
self.min_token -= len(events_and_contexts) + 1
|
||||||
|
stream_orderings = range(start, self.min_token, -1)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def stream_ordering_manager():
|
||||||
|
yield stream_orderings
|
||||||
|
stream_ordering_manager = stream_ordering_manager()
|
||||||
|
else:
|
||||||
|
stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
|
||||||
|
self, len(events_and_contexts)
|
||||||
|
)
|
||||||
|
|
||||||
|
with stream_ordering_manager as stream_orderings:
|
||||||
|
for (event, _), stream in zip(events_and_contexts, stream_orderings):
|
||||||
|
event.internal_metadata.stream_ordering = stream
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
events_and_contexts[x:x+100]
|
||||||
|
for x in xrange(0, len(events_and_contexts), 100)
|
||||||
|
]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
# We can't easily parallelize these since different chunks
|
||||||
|
# might contain the same event. :(
|
||||||
|
yield self.runInteraction(
|
||||||
|
"persist_events",
|
||||||
|
self._persist_events_txn,
|
||||||
|
events_and_contexts=chunk,
|
||||||
|
backfilled=backfilled,
|
||||||
|
is_new_state=is_new_state,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def persist_event(self, event, context, backfilled=False,
|
def persist_event(self, event, context, backfilled=False,
|
||||||
@ -67,13 +107,13 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with stream_ordering_manager as stream_ordering:
|
with stream_ordering_manager as stream_ordering:
|
||||||
|
event.internal_metadata.stream_ordering = stream_ordering
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"persist_event",
|
"persist_event",
|
||||||
self._persist_event_txn,
|
self._persist_event_txn,
|
||||||
event=event,
|
event=event,
|
||||||
context=context,
|
context=context,
|
||||||
backfilled=backfilled,
|
backfilled=backfilled,
|
||||||
stream_ordering=stream_ordering,
|
|
||||||
is_new_state=is_new_state,
|
is_new_state=is_new_state,
|
||||||
current_state=current_state,
|
current_state=current_state,
|
||||||
)
|
)
|
||||||
@ -116,19 +156,14 @@ class EventsStore(SQLBaseStore):
|
|||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||||
stream_ordering=None, is_new_state=True,
|
is_new_state=True, current_state=None):
|
||||||
current_state=None):
|
|
||||||
|
|
||||||
# Remove the any existing cache entries for the event_id
|
|
||||||
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
|
||||||
|
|
||||||
# We purposefully do this first since if we include a `current_state`
|
# We purposefully do this first since if we include a `current_state`
|
||||||
# key, we *want* to update the `current_state_events` table
|
# key, we *want* to update the `current_state_events` table
|
||||||
if current_state:
|
if current_state:
|
||||||
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
txn.call_after(self.get_current_state_for_key.invalidate_all)
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
txn.call_after(self.get_rooms_for_user.invalidate_all)
|
||||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
txn.call_after(self.get_room_name_and_aliases, event.room_id)
|
||||||
|
|
||||||
self._simple_delete_txn(
|
self._simple_delete_txn(
|
||||||
@ -149,37 +184,78 @@ class EventsStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
outlier = event.internal_metadata.is_outlier()
|
return self._persist_events_txn(
|
||||||
|
|
||||||
if not outlier:
|
|
||||||
self._update_min_depth_for_room_txn(
|
|
||||||
txn,
|
|
||||||
event.room_id,
|
|
||||||
event.depth
|
|
||||||
)
|
|
||||||
|
|
||||||
have_persisted = self._simple_select_one_txn(
|
|
||||||
txn,
|
txn,
|
||||||
table="events",
|
[(event, context)],
|
||||||
keyvalues={"event_id": event.event_id},
|
backfilled=backfilled,
|
||||||
retcols=["event_id", "outlier"],
|
is_new_state=is_new_state,
|
||||||
allow_none=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_json = encode_json(
|
@log_function
|
||||||
event.internal_metadata.get_dict(),
|
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||||
using_frozen_dicts=USE_FROZEN_DICTS
|
is_new_state=True):
|
||||||
).decode("UTF-8")
|
|
||||||
|
|
||||||
# If we have already persisted this event, we don't need to do any
|
# Remove the any existing cache entries for the event_ids
|
||||||
# more processing.
|
for event, _ in events_and_contexts:
|
||||||
# The processing above must be done on every call to persist event,
|
txn.call_after(self._invalidate_get_event_cache, event.event_id)
|
||||||
# since they might not have happened on previous calls. For example,
|
|
||||||
# if we are persisting an event that we had persisted as an outlier,
|
depth_updates = {}
|
||||||
# but is no longer one.
|
for event, _ in events_and_contexts:
|
||||||
if have_persisted:
|
if event.internal_metadata.is_outlier():
|
||||||
if not outlier and have_persisted["outlier"]:
|
continue
|
||||||
self._store_state_groups_txn(txn, event, context)
|
depth_updates[event.room_id] = max(
|
||||||
|
event.depth, depth_updates.get(event.room_id, event.depth)
|
||||||
|
)
|
||||||
|
|
||||||
|
for room_id, depth in depth_updates.items():
|
||||||
|
self._update_min_depth_for_room_txn(txn, room_id, depth)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
"SELECT event_id, outlier FROM events WHERE event_id in (%s)" % (
|
||||||
|
",".join(["?"] * len(events_and_contexts)),
|
||||||
|
),
|
||||||
|
[event.event_id for event, _ in events_and_contexts]
|
||||||
|
)
|
||||||
|
have_persisted = {
|
||||||
|
event_id: outlier
|
||||||
|
for event_id, outlier in txn.fetchall()
|
||||||
|
}
|
||||||
|
|
||||||
|
event_map = {}
|
||||||
|
to_remove = set()
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
# Handle the case of the list including the same event multiple
|
||||||
|
# times. The tricky thing here is when they differ by whether
|
||||||
|
# they are an outlier.
|
||||||
|
if event.event_id in event_map:
|
||||||
|
other = event_map[event.event_id]
|
||||||
|
|
||||||
|
if not other.internal_metadata.is_outlier():
|
||||||
|
to_remove.add(event)
|
||||||
|
continue
|
||||||
|
elif not event.internal_metadata.is_outlier():
|
||||||
|
to_remove.add(event)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
to_remove.add(other)
|
||||||
|
|
||||||
|
event_map[event.event_id] = event
|
||||||
|
|
||||||
|
if event.event_id not in have_persisted:
|
||||||
|
continue
|
||||||
|
|
||||||
|
to_remove.add(event)
|
||||||
|
|
||||||
|
outlier_persisted = have_persisted[event.event_id]
|
||||||
|
if not event.internal_metadata.is_outlier() and outlier_persisted:
|
||||||
|
self._store_state_groups_txn(
|
||||||
|
txn, event, context,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_json = encode_json(
|
||||||
|
event.internal_metadata.get_dict(),
|
||||||
|
using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
|
).decode("UTF-8")
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"UPDATE event_json SET internal_metadata = ?"
|
"UPDATE event_json SET internal_metadata = ?"
|
||||||
@ -198,94 +274,91 @@ class EventsStore(SQLBaseStore):
|
|||||||
sql,
|
sql,
|
||||||
(False, event.event_id,)
|
(False, event.event_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
events_and_contexts = filter(
|
||||||
|
lambda ec: ec[0] not in to_remove,
|
||||||
|
events_and_contexts
|
||||||
|
)
|
||||||
|
|
||||||
|
if not events_and_contexts:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not outlier:
|
self._store_mult_state_groups_txn(txn, [
|
||||||
self._store_state_groups_txn(txn, event, context)
|
(event, context)
|
||||||
|
for event, context in events_and_contexts
|
||||||
|
if not event.internal_metadata.is_outlier()
|
||||||
|
])
|
||||||
|
|
||||||
self._handle_prev_events(
|
self._handle_mult_prev_events(
|
||||||
txn,
|
txn,
|
||||||
outlier=outlier,
|
events=[event for event, _ in events_and_contexts],
|
||||||
event_id=event.event_id,
|
|
||||||
prev_events=event.prev_events,
|
|
||||||
room_id=event.room_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
for event, _ in events_and_contexts:
|
||||||
self._store_room_member_txn(txn, event)
|
if event.type == EventTypes.Name:
|
||||||
elif event.type == EventTypes.Name:
|
self._store_room_name_txn(txn, event)
|
||||||
self._store_room_name_txn(txn, event)
|
elif event.type == EventTypes.Topic:
|
||||||
elif event.type == EventTypes.Topic:
|
self._store_room_topic_txn(txn, event)
|
||||||
self._store_room_topic_txn(txn, event)
|
elif event.type == EventTypes.Redaction:
|
||||||
elif event.type == EventTypes.Redaction:
|
self._store_redaction(txn, event)
|
||||||
self._store_redaction(txn, event)
|
|
||||||
|
|
||||||
event_dict = {
|
self._store_room_members_txn(
|
||||||
k: v
|
txn,
|
||||||
for k, v in event.get_dict().items()
|
[
|
||||||
if k not in [
|
event
|
||||||
"redacted",
|
for event, _ in events_and_contexts
|
||||||
"redacted_because",
|
if event.type == EventTypes.Member
|
||||||
]
|
]
|
||||||
}
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
def event_dict(event):
|
||||||
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in event.get_dict().items()
|
||||||
|
if k not in [
|
||||||
|
"redacted",
|
||||||
|
"redacted_because",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_json",
|
table="event_json",
|
||||||
values={
|
values=[
|
||||||
"event_id": event.event_id,
|
{
|
||||||
"room_id": event.room_id,
|
"event_id": event.event_id,
|
||||||
"internal_metadata": metadata_json,
|
"room_id": event.room_id,
|
||||||
"json": encode_json(
|
"internal_metadata": encode_json(
|
||||||
event_dict, using_frozen_dicts=USE_FROZEN_DICTS
|
event.internal_metadata.get_dict(),
|
||||||
).decode("UTF-8"),
|
using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
},
|
).decode("UTF-8"),
|
||||||
|
"json": encode_json(
|
||||||
|
event_dict(event), using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
|
).decode("UTF-8"),
|
||||||
|
}
|
||||||
|
for event, _ in events_and_contexts
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
content = encode_json(
|
self._simple_insert_many_txn(
|
||||||
event.content, using_frozen_dicts=USE_FROZEN_DICTS
|
txn,
|
||||||
).decode("UTF-8")
|
table="events",
|
||||||
|
values=[
|
||||||
vals = {
|
{
|
||||||
"topological_ordering": event.depth,
|
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||||
"event_id": event.event_id,
|
"topological_ordering": event.depth,
|
||||||
"type": event.type,
|
"depth": event.depth,
|
||||||
"room_id": event.room_id,
|
"event_id": event.event_id,
|
||||||
"content": content,
|
"room_id": event.room_id,
|
||||||
"processed": True,
|
"type": event.type,
|
||||||
"outlier": outlier,
|
"processed": True,
|
||||||
"depth": event.depth,
|
"outlier": event.internal_metadata.is_outlier(),
|
||||||
}
|
"content": encode_json(
|
||||||
|
event.content, using_frozen_dicts=USE_FROZEN_DICTS
|
||||||
unrec = {
|
).decode("UTF-8"),
|
||||||
k: v
|
}
|
||||||
for k, v in event.get_dict().items()
|
for event, _ in events_and_contexts
|
||||||
if k not in vals.keys() and k not in [
|
],
|
||||||
"redacted",
|
|
||||||
"redacted_because",
|
|
||||||
"signatures",
|
|
||||||
"hashes",
|
|
||||||
"prev_events",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
vals["unrecognized_keys"] = encode_json(
|
|
||||||
unrec, using_frozen_dicts=USE_FROZEN_DICTS
|
|
||||||
).decode("UTF-8")
|
|
||||||
|
|
||||||
sql = (
|
|
||||||
"INSERT INTO events"
|
|
||||||
" (stream_ordering, topological_ordering, event_id, type,"
|
|
||||||
" room_id, content, processed, outlier, depth)"
|
|
||||||
" VALUES (?,?,?,?,?,?,?,?,?)"
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.execute(
|
|
||||||
sql,
|
|
||||||
(
|
|
||||||
stream_ordering, event.depth, event.event_id, event.type,
|
|
||||||
event.room_id, content, True, outlier, event.depth
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.rejected:
|
if context.rejected:
|
||||||
@ -293,20 +366,6 @@ class EventsStore(SQLBaseStore):
|
|||||||
txn, event.event_id, context.rejected
|
txn, event.event_id, context.rejected
|
||||||
)
|
)
|
||||||
|
|
||||||
for hash_alg, hash_base64 in event.hashes.items():
|
|
||||||
hash_bytes = decode_base64(hash_base64)
|
|
||||||
self._store_event_content_hash_txn(
|
|
||||||
txn, event.event_id, hash_alg, hash_bytes,
|
|
||||||
)
|
|
||||||
|
|
||||||
for prev_event_id, prev_hashes in event.prev_events:
|
|
||||||
for alg, hash_base64 in prev_hashes.items():
|
|
||||||
hash_bytes = decode_base64(hash_base64)
|
|
||||||
self._store_prev_event_hash_txn(
|
|
||||||
txn, event.event_id, prev_event_id, alg,
|
|
||||||
hash_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_auth",
|
table="event_auth",
|
||||||
@ -316,16 +375,22 @@ class EventsStore(SQLBaseStore):
|
|||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
"auth_id": auth_id,
|
"auth_id": auth_id,
|
||||||
}
|
}
|
||||||
|
for event, _ in events_and_contexts
|
||||||
for auth_id, _ in event.auth_events
|
for auth_id, _ in event.auth_events
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
self._store_event_reference_hashes_txn(
|
||||||
self._store_event_reference_hash_txn(
|
txn, [event for event, _ in events_and_contexts]
|
||||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.is_state():
|
state_events_and_contexts = filter(
|
||||||
|
lambda i: i[0].is_state(),
|
||||||
|
events_and_contexts,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_values = []
|
||||||
|
for event, context in state_events_and_contexts:
|
||||||
vals = {
|
vals = {
|
||||||
"event_id": event.event_id,
|
"event_id": event.event_id,
|
||||||
"room_id": event.room_id,
|
"room_id": event.room_id,
|
||||||
@ -337,51 +402,55 @@ class EventsStore(SQLBaseStore):
|
|||||||
if hasattr(event, "replaces_state"):
|
if hasattr(event, "replaces_state"):
|
||||||
vals["prev_state"] = event.replaces_state
|
vals["prev_state"] = event.replaces_state
|
||||||
|
|
||||||
self._simple_insert_txn(
|
state_values.append(vals)
|
||||||
txn,
|
|
||||||
"state_events",
|
|
||||||
vals,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_edges",
|
table="state_events",
|
||||||
values=[
|
values=state_values,
|
||||||
{
|
)
|
||||||
"event_id": event.event_id,
|
|
||||||
"prev_event_id": e_id,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"is_state": True,
|
|
||||||
}
|
|
||||||
for e_id, h in event.prev_state
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_new_state and not context.rejected:
|
self._simple_insert_many_txn(
|
||||||
txn.call_after(
|
txn,
|
||||||
self.get_current_state_for_key.invalidate,
|
table="event_edges",
|
||||||
event.room_id, event.type, event.state_key
|
values=[
|
||||||
)
|
{
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"prev_event_id": prev_id,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"is_state": True,
|
||||||
|
}
|
||||||
|
for event, _ in state_events_and_contexts
|
||||||
|
for prev_id, _ in event.prev_state
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
if (event.type == EventTypes.Name
|
if is_new_state:
|
||||||
or event.type == EventTypes.Aliases):
|
for event, _ in state_events_and_contexts:
|
||||||
|
if not context.rejected:
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_room_name_and_aliases.invalidate,
|
self.get_current_state_for_key.invalidate,
|
||||||
event.room_id
|
(event.room_id, event.type, event.state_key,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_upsert_txn(
|
if event.type in [EventTypes.Name, EventTypes.Aliases]:
|
||||||
txn,
|
txn.call_after(
|
||||||
"current_state_events",
|
self.get_room_name_and_aliases.invalidate,
|
||||||
keyvalues={
|
(event.room_id,)
|
||||||
"room_id": event.room_id,
|
)
|
||||||
"type": event.type,
|
|
||||||
"state_key": event.state_key,
|
self._simple_upsert_txn(
|
||||||
},
|
txn,
|
||||||
values={
|
"current_state_events",
|
||||||
"event_id": event.event_id,
|
keyvalues={
|
||||||
}
|
"room_id": event.room_id,
|
||||||
)
|
"type": event.type,
|
||||||
|
"state_key": event.state_key,
|
||||||
|
},
|
||||||
|
values={
|
||||||
|
"event_id": event.event_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -498,8 +567,9 @@ class EventsStore(SQLBaseStore):
|
|||||||
def _invalidate_get_event_cache(self, event_id):
|
def _invalidate_get_event_cache(self, event_id):
|
||||||
for check_redacted in (False, True):
|
for check_redacted in (False, True):
|
||||||
for get_prev_content in (False, True):
|
for get_prev_content in (False, True):
|
||||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
self._get_event_cache.invalidate(
|
||||||
get_prev_content)
|
(event_id, check_redacted, get_prev_content)
|
||||||
|
)
|
||||||
|
|
||||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||||
get_prev_content=False, allow_rejected=False):
|
get_prev_content=False, allow_rejected=False):
|
||||||
@ -520,7 +590,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
for event_id in events:
|
for event_id in events:
|
||||||
try:
|
try:
|
||||||
ret = self._get_event_cache.get(
|
ret = self._get_event_cache.get(
|
||||||
event_id, check_redacted, get_prev_content
|
(event_id, check_redacted, get_prev_content,)
|
||||||
)
|
)
|
||||||
|
|
||||||
if allow_rejected or not ret.rejected_reason:
|
if allow_rejected or not ret.rejected_reason:
|
||||||
@ -741,6 +811,8 @@ class EventsStore(SQLBaseStore):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if because:
|
if because:
|
||||||
|
# It's fine to do add the event directly, since get_pdu_json
|
||||||
|
# will serialise this field correctly
|
||||||
ev.unsigned["redacted_because"] = because
|
ev.unsigned["redacted_because"] = because
|
||||||
|
|
||||||
if get_prev_content and "replaces_state" in ev.unsigned:
|
if get_prev_content and "replaces_state" in ev.unsigned:
|
||||||
@ -753,7 +825,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
self._get_event_cache.prefill(
|
self._get_event_cache.prefill(
|
||||||
ev.event_id, check_redacted, get_prev_content, ev
|
(ev.event_id, check_redacted, get_prev_content), ev
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue(ev)
|
defer.returnValue(ev)
|
||||||
@ -810,7 +882,7 @@ class EventsStore(SQLBaseStore):
|
|||||||
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
ev.unsigned["prev_content"] = prev.get_dict()["content"]
|
||||||
|
|
||||||
self._get_event_cache.prefill(
|
self._get_event_cache.prefill(
|
||||||
ev.event_id, check_redacted, get_prev_content, ev
|
(ev.event_id, check_redacted, get_prev_content), ev
|
||||||
)
|
)
|
||||||
|
|
||||||
return ev
|
return ev
|
||||||
|
@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from _base import SQLBaseStore
|
from _base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -71,6 +72,24 @@ class KeyStore(SQLBaseStore):
|
|||||||
desc="store_server_certificate",
|
desc="store_server_certificate",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks()
|
||||||
|
def get_all_server_verify_keys(self, server_name):
|
||||||
|
rows = yield self._simple_select_list(
|
||||||
|
table="server_signature_keys",
|
||||||
|
keyvalues={
|
||||||
|
"server_name": server_name,
|
||||||
|
},
|
||||||
|
retcols=["key_id", "verify_key"],
|
||||||
|
desc="get_all_server_verify_keys",
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue({
|
||||||
|
row["key_id"]: decode_verify_key_bytes(
|
||||||
|
row["key_id"], str(row["verify_key"])
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_server_verify_keys(self, server_name, key_ids):
|
def get_server_verify_keys(self, server_name, key_ids):
|
||||||
"""Retrieve the NACL verification key for a given server for the given
|
"""Retrieve the NACL verification key for a given server for the given
|
||||||
@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
|
|||||||
Returns:
|
Returns:
|
||||||
(list of VerifyKey): The verification keys.
|
(list of VerifyKey): The verification keys.
|
||||||
"""
|
"""
|
||||||
sql = (
|
keys = yield self.get_all_server_verify_keys(server_name)
|
||||||
"SELECT key_id, verify_key FROM server_signature_keys"
|
defer.returnValue({
|
||||||
" WHERE server_name = ?"
|
k: keys[k]
|
||||||
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
|
for k in key_ids
|
||||||
)
|
if k in keys and keys[k]
|
||||||
|
})
|
||||||
rows = yield self._execute_and_decode(
|
|
||||||
"get_server_verify_keys", sql, server_name, *key_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
keys = []
|
|
||||||
for row in rows:
|
|
||||||
key_id = row["key_id"]
|
|
||||||
key_bytes = row["verify_key"]
|
|
||||||
key = decode_verify_key_bytes(key_id, str(key_bytes))
|
|
||||||
keys.append(key)
|
|
||||||
defer.returnValue(keys)
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||||
verify_key):
|
verify_key):
|
||||||
"""Stores a NACL verification key for the given server.
|
"""Stores a NACL verification key for the given server.
|
||||||
@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
|
|||||||
ts_now_ms (int): The time now in milliseconds
|
ts_now_ms (int): The time now in milliseconds
|
||||||
verification_key (VerifyKey): The NACL verify key.
|
verification_key (VerifyKey): The NACL verify key.
|
||||||
"""
|
"""
|
||||||
return self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"server_name": server_name,
|
"server_name": server_name,
|
||||||
@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
|
|||||||
desc="store_server_verify_key",
|
desc="store_server_verify_key",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.get_all_server_verify_keys.invalidate((server_name,))
|
||||||
|
|
||||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||||
"""Stores the JSON bytes for a set of keys from a server
|
"""Stores the JSON bytes for a set of keys from a server
|
||||||
@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
|
|||||||
"ts_valid_until_ms": ts_expires_ms,
|
"ts_valid_until_ms": ts_expires_ms,
|
||||||
"key_json": buffer(key_json_bytes),
|
"key_json": buffer(key_json_bytes),
|
||||||
},
|
},
|
||||||
|
desc="store_server_keys_json",
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_server_keys_json(self, server_keys):
|
def get_server_keys_json(self, server_keys):
|
||||||
|
@ -13,19 +13,23 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
class PresenceStore(SQLBaseStore):
|
class PresenceStore(SQLBaseStore):
|
||||||
def create_presence(self, user_localpart):
|
def create_presence(self, user_localpart):
|
||||||
return self._simple_insert(
|
res = self._simple_insert(
|
||||||
table="presence",
|
table="presence",
|
||||||
values={"user_id": user_localpart},
|
values={"user_id": user_localpart},
|
||||||
desc="create_presence",
|
desc="create_presence",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.get_presence_state.invalidate((user_localpart,))
|
||||||
|
return res
|
||||||
|
|
||||||
def has_presence_state(self, user_localpart):
|
def has_presence_state(self, user_localpart):
|
||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
table="presence",
|
table="presence",
|
||||||
@ -35,6 +39,7 @@ class PresenceStore(SQLBaseStore):
|
|||||||
desc="has_presence_state",
|
desc="has_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cached(max_entries=2000)
|
||||||
def get_presence_state(self, user_localpart):
|
def get_presence_state(self, user_localpart):
|
||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
table="presence",
|
table="presence",
|
||||||
@ -43,8 +48,27 @@ class PresenceStore(SQLBaseStore):
|
|||||||
desc="get_presence_state",
|
desc="get_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@cachedList(get_presence_state.cache, list_name="user_localparts")
|
||||||
|
def get_presence_states(self, user_localparts):
|
||||||
|
def f(txn):
|
||||||
|
results = {}
|
||||||
|
for user_localpart in user_localparts:
|
||||||
|
res = self._simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table="presence",
|
||||||
|
keyvalues={"user_id": user_localpart},
|
||||||
|
retcols=["state", "status_msg", "mtime"],
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if res:
|
||||||
|
results[user_localpart] = res
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
return self.runInteraction("get_presence_states", f)
|
||||||
|
|
||||||
def set_presence_state(self, user_localpart, new_state):
|
def set_presence_state(self, user_localpart, new_state):
|
||||||
return self._simple_update_one(
|
res = self._simple_update_one(
|
||||||
table="presence",
|
table="presence",
|
||||||
keyvalues={"user_id": user_localpart},
|
keyvalues={"user_id": user_localpart},
|
||||||
updatevalues={"state": new_state["state"],
|
updatevalues={"state": new_state["state"],
|
||||||
@ -53,6 +77,9 @@ class PresenceStore(SQLBaseStore):
|
|||||||
desc="set_presence_state",
|
desc="set_presence_state",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.get_presence_state.invalidate((user_localpart,))
|
||||||
|
return res
|
||||||
|
|
||||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||||
return self._simple_insert(
|
return self._simple_insert(
|
||||||
table="presence_allow_inbound",
|
table="presence_allow_inbound",
|
||||||
@ -98,7 +125,7 @@ class PresenceStore(SQLBaseStore):
|
|||||||
updatevalues={"accepted": True},
|
updatevalues={"accepted": True},
|
||||||
desc="set_presence_list_accepted",
|
desc="set_presence_list_accepted",
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def get_presence_list(self, observer_localpart, accepted=None):
|
def get_presence_list(self, observer_localpart, accepted=None):
|
||||||
@ -133,4 +160,4 @@ class PresenceStore(SQLBaseStore):
|
|||||||
"observed_user_id": observed_userid},
|
"observed_user_id": observed_userid},
|
||||||
desc="del_presence_list",
|
desc="del_presence_list",
|
||||||
)
|
)
|
||||||
self.get_presence_list_accepted.invalidate(observer_localpart)
|
self.get_presence_list_accepted.invalidate((observer_localpart,))
|
||||||
|
@ -13,7 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@ -23,8 +24,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class PushRuleStore(SQLBaseStore):
|
class PushRuleStore(SQLBaseStore):
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_push_rules_for_user(self, user_name):
|
def get_push_rules_for_user(self, user_name):
|
||||||
rows = yield self._simple_select_list(
|
rows = yield self._simple_select_list(
|
||||||
table=PushRuleTable.table_name,
|
table=PushRuleTable.table_name,
|
||||||
@ -41,8 +41,7 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue(rows)
|
defer.returnValue(rows)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_push_rules_enabled_for_user(self, user_name):
|
def get_push_rules_enabled_for_user(self, user_name):
|
||||||
results = yield self._simple_select_list(
|
results = yield self._simple_select_list(
|
||||||
table=PushRuleEnableTable.table_name,
|
table=PushRuleEnableTable.table_name,
|
||||||
@ -153,11 +152,11 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
txn.execute(sql, (user_name, priority_class, new_rule_priority))
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
@ -189,10 +188,10 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
new_rule['priority'] = new_prio
|
new_rule['priority'] = new_prio
|
||||||
|
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
@ -218,8 +217,8 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
desc="delete_push_rule",
|
desc="delete_push_rule",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.get_push_rules_for_user.invalidate(user_name)
|
self.get_push_rules_for_user.invalidate((user_name,))
|
||||||
self.get_push_rules_enabled_for_user.invalidate(user_name)
|
self.get_push_rules_enabled_for_user.invalidate((user_name,))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
def set_push_rule_enabled(self, user_name, rule_id, enabled):
|
||||||
@ -240,10 +239,10 @@ class PushRuleStore(SQLBaseStore):
|
|||||||
{'id': new_id},
|
{'id': new_id},
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_for_user.invalidate, user_name
|
self.get_push_rules_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
txn.call_after(
|
txn.call_after(
|
||||||
self.get_push_rules_enabled_for_user.invalidate, user_name
|
self.get_push_rules_enabled_for_user.invalidate, (user_name,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
406
synapse/storage/receipts.py
Normal file
406
synapse/storage/receipts.py
Normal file
@ -0,0 +1,406 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2014, 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
|
||||||
|
from synapse.util.caches import cache_counter, caches_by_name
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from blist import sorteddict
|
||||||
|
import logging
|
||||||
|
import ujson as json
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ReceiptsStore(SQLBaseStore):
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(ReceiptsStore, self).__init__(hs)
|
||||||
|
|
||||||
|
self._receipts_stream_cache = _RoomStreamChangeCache()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||||
|
"""Get receipts for multiple rooms for sending to clients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_ids (list): List of room_ids.
|
||||||
|
to_key (int): Max stream id to fetch receipts upto.
|
||||||
|
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||||
|
from the start.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of receipts.
|
||||||
|
"""
|
||||||
|
room_ids = set(room_ids)
|
||||||
|
|
||||||
|
if from_key:
|
||||||
|
room_ids = yield self._receipts_stream_cache.get_rooms_changed(
|
||||||
|
self, room_ids, from_key
|
||||||
|
)
|
||||||
|
|
||||||
|
results = yield self._get_linearized_receipts_for_rooms(
|
||||||
|
room_ids, to_key, from_key=from_key
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue([ev for res in results.values() for ev in res])
|
||||||
|
|
||||||
|
@cachedInlineCallbacks(num_args=3, max_entries=5000)
|
||||||
|
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
|
||||||
|
"""Get receipts for a single room for sending to clients.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_ids (str): The room id.
|
||||||
|
to_key (int): Max stream id to fetch receipts upto.
|
||||||
|
from_key (int): Min stream id to fetch receipts from. None fetches
|
||||||
|
from the start.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of receipts.
|
||||||
|
"""
|
||||||
|
def f(txn):
|
||||||
|
if from_key:
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM receipts_linearized WHERE"
|
||||||
|
" room_id = ? AND stream_id > ? AND stream_id <= ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(room_id, from_key, to_key)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM receipts_linearized WHERE"
|
||||||
|
" room_id = ? AND stream_id <= ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(
|
||||||
|
sql,
|
||||||
|
(room_id, to_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
return rows
|
||||||
|
|
||||||
|
rows = yield self.runInteraction(
|
||||||
|
"get_linearized_receipts_for_room", f
|
||||||
|
)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
defer.returnValue([])
|
||||||
|
|
||||||
|
content = {}
|
||||||
|
for row in rows:
|
||||||
|
content.setdefault(
|
||||||
|
row["event_id"], {}
|
||||||
|
).setdefault(
|
||||||
|
row["receipt_type"], {}
|
||||||
|
)[row["user_id"]] = json.loads(row["data"])
|
||||||
|
|
||||||
|
defer.returnValue([{
|
||||||
|
"type": "m.receipt",
|
||||||
|
"room_id": room_id,
|
||||||
|
"content": content,
|
||||||
|
}])
|
||||||
|
|
||||||
|
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
|
||||||
|
num_args=3, inlineCallbacks=True)
|
||||||
|
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||||
|
if not room_ids:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
|
def f(txn):
|
||||||
|
if from_key:
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM receipts_linearized WHERE"
|
||||||
|
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
|
||||||
|
) % (
|
||||||
|
",".join(["?"] * len(room_ids))
|
||||||
|
)
|
||||||
|
args = list(room_ids)
|
||||||
|
args.extend([from_key, to_key])
|
||||||
|
|
||||||
|
txn.execute(sql, args)
|
||||||
|
else:
|
||||||
|
sql = (
|
||||||
|
"SELECT * FROM receipts_linearized WHERE"
|
||||||
|
" room_id IN (%s) AND stream_id <= ?"
|
||||||
|
) % (
|
||||||
|
",".join(["?"] * len(room_ids))
|
||||||
|
)
|
||||||
|
|
||||||
|
args = list(room_ids)
|
||||||
|
args.append(to_key)
|
||||||
|
|
||||||
|
txn.execute(sql, args)
|
||||||
|
|
||||||
|
return self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
txn_results = yield self.runInteraction(
|
||||||
|
"_get_linearized_receipts_for_rooms", f
|
||||||
|
)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for row in txn_results:
|
||||||
|
# We want a single event per room, since we want to batch the
|
||||||
|
# receipts by room, event and type.
|
||||||
|
room_event = results.setdefault(row["room_id"], {
|
||||||
|
"type": "m.receipt",
|
||||||
|
"room_id": row["room_id"],
|
||||||
|
"content": {},
|
||||||
|
})
|
||||||
|
|
||||||
|
# The content is of the form:
|
||||||
|
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||||
|
event_entry = room_event["content"].setdefault(row["event_id"], {})
|
||||||
|
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
||||||
|
|
||||||
|
receipt_type[row["user_id"]] = json.loads(row["data"])
|
||||||
|
|
||||||
|
results = {
|
||||||
|
room_id: [results[room_id]] if room_id in results else []
|
||||||
|
for room_id in room_ids
|
||||||
|
}
|
||||||
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
def get_max_receipt_stream_id(self):
|
||||||
|
return self._receipts_id_gen.get_max_token(self)
|
||||||
|
|
||||||
|
@cachedInlineCallbacks()
|
||||||
|
def get_graph_receipts_for_room(self, room_id):
|
||||||
|
"""Get receipts for sending to remote servers.
|
||||||
|
"""
|
||||||
|
rows = yield self._simple_select_list(
|
||||||
|
table="receipts_graph",
|
||||||
|
keyvalues={"room_id": room_id},
|
||||||
|
retcols=["receipt_type", "user_id", "event_id"],
|
||||||
|
desc="get_linearized_receipts_for_room",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for row in rows:
|
||||||
|
result.setdefault(
|
||||||
|
row["user_id"], {}
|
||||||
|
).setdefault(
|
||||||
|
row["receipt_type"], []
|
||||||
|
).append(row["event_id"])
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
||||||
|
user_id, event_id, data, stream_id):
|
||||||
|
|
||||||
|
# We don't want to clobber receipts for more recent events, so we
|
||||||
|
# have to compare orderings of existing receipts
|
||||||
|
sql = (
|
||||||
|
"SELECT topological_ordering, stream_ordering, event_id FROM events"
|
||||||
|
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
|
||||||
|
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||||
|
results = txn.fetchall()
|
||||||
|
|
||||||
|
if results:
|
||||||
|
res = self._simple_select_one_txn(
|
||||||
|
txn,
|
||||||
|
table="events",
|
||||||
|
retcols=["topological_ordering", "stream_ordering"],
|
||||||
|
keyvalues={"event_id": event_id},
|
||||||
|
)
|
||||||
|
topological_ordering = int(res["topological_ordering"])
|
||||||
|
stream_ordering = int(res["stream_ordering"])
|
||||||
|
|
||||||
|
for to, so, _ in results:
|
||||||
|
if int(to) > topological_ordering:
|
||||||
|
return False
|
||||||
|
elif int(to) == topological_ordering and int(so) >= stream_ordering:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="receipts_linearized",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="receipts_linearized",
|
||||||
|
values={
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_id": event_id,
|
||||||
|
"data": json.dumps(data),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||||
|
"""Insert a receipt, either from local client or remote server.
|
||||||
|
|
||||||
|
Automatically does conversion between linearized and graph
|
||||||
|
representations.
|
||||||
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(event_ids) == 1:
|
||||||
|
linearized_event_id = event_ids[0]
|
||||||
|
else:
|
||||||
|
# we need to points in graph -> linearized form.
|
||||||
|
# TODO: Make this better.
|
||||||
|
def graph_to_linear(txn):
|
||||||
|
query = (
|
||||||
|
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
|
||||||
|
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
|
||||||
|
")"
|
||||||
|
) % (",".join(["?"] * len(event_ids)))
|
||||||
|
|
||||||
|
txn.execute(query, [room_id] + event_ids)
|
||||||
|
rows = txn.fetchall()
|
||||||
|
if rows:
|
||||||
|
return rows[0][0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
|
||||||
|
|
||||||
|
linearized_event_id = yield self.runInteraction(
|
||||||
|
"insert_receipt_conv", graph_to_linear
|
||||||
|
)
|
||||||
|
|
||||||
|
stream_id_manager = yield self._receipts_id_gen.get_next(self)
|
||||||
|
with stream_id_manager as stream_id:
|
||||||
|
yield self._receipts_stream_cache.room_has_changed(
|
||||||
|
self, room_id, stream_id
|
||||||
|
)
|
||||||
|
have_persisted = yield self.runInteraction(
|
||||||
|
"insert_linearized_receipt",
|
||||||
|
self.insert_linearized_receipt_txn,
|
||||||
|
room_id, receipt_type, user_id, linearized_event_id,
|
||||||
|
data,
|
||||||
|
stream_id=stream_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not have_persisted:
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
yield self.insert_graph_receipt(
|
||||||
|
room_id, receipt_type, user_id, event_ids, data
|
||||||
|
)
|
||||||
|
|
||||||
|
max_persisted_id = yield self._stream_id_gen.get_max_token(self)
|
||||||
|
defer.returnValue((stream_id, max_persisted_id))
|
||||||
|
|
||||||
|
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
|
||||||
|
data):
|
||||||
|
return self.runInteraction(
|
||||||
|
"insert_graph_receipt",
|
||||||
|
self.insert_graph_receipt_txn,
|
||||||
|
room_id, receipt_type, user_id, event_ids, data
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
|
||||||
|
user_id, event_ids, data):
|
||||||
|
self._simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="receipts_graph",
|
||||||
|
keyvalues={
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="receipts_graph",
|
||||||
|
values={
|
||||||
|
"room_id": room_id,
|
||||||
|
"receipt_type": receipt_type,
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_ids": json.dumps(event_ids),
|
||||||
|
"data": json.dumps(data),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RoomStreamChangeCache(object):
|
||||||
|
"""Keeps track of the stream_id of the latest change in rooms.
|
||||||
|
|
||||||
|
Given a list of rooms and stream key, it will give a subset of rooms that
|
||||||
|
may have changed since that key. If the key is too old then the cache
|
||||||
|
will simply return all rooms.
|
||||||
|
"""
|
||||||
|
def __init__(self, size_of_cache=10000):
|
||||||
|
self._size_of_cache = size_of_cache
|
||||||
|
self._room_to_key = {}
|
||||||
|
self._cache = sorteddict()
|
||||||
|
self._earliest_key = None
|
||||||
|
self.name = "ReceiptsRoomChangeCache"
|
||||||
|
caches_by_name[self.name] = self._cache
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_rooms_changed(self, store, room_ids, key):
|
||||||
|
"""Returns subset of room ids that have had new receipts since the
|
||||||
|
given key. If the key is too old it will just return the given list.
|
||||||
|
"""
|
||||||
|
if key > (yield self._get_earliest_key(store)):
|
||||||
|
keys = self._cache.keys()
|
||||||
|
i = keys.bisect_right(key)
|
||||||
|
|
||||||
|
result = set(
|
||||||
|
self._cache[k] for k in keys[i:]
|
||||||
|
).intersection(room_ids)
|
||||||
|
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
else:
|
||||||
|
result = room_ids
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
|
||||||
|
defer.returnValue(result)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def room_has_changed(self, store, room_id, key):
|
||||||
|
"""Informs the cache that the room has been changed at the given key.
|
||||||
|
"""
|
||||||
|
if key > (yield self._get_earliest_key(store)):
|
||||||
|
old_key = self._room_to_key.get(room_id, None)
|
||||||
|
if old_key:
|
||||||
|
key = max(key, old_key)
|
||||||
|
self._cache.pop(old_key, None)
|
||||||
|
self._cache[key] = room_id
|
||||||
|
|
||||||
|
while len(self._cache) > self._size_of_cache:
|
||||||
|
k, r = self._cache.popitem()
|
||||||
|
self._earliest_key = max(k, self._earliest_key)
|
||||||
|
self._room_to_key.pop(r, None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_earliest_key(self, store):
|
||||||
|
if self._earliest_key is None:
|
||||||
|
self._earliest_key = yield store.get_max_receipt_stream_id()
|
||||||
|
self._earliest_key = int(self._earliest_key)
|
||||||
|
|
||||||
|
defer.returnValue(self._earliest_key)
|
@ -17,7 +17,8 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.api.errors import StoreError, Codes
|
from synapse.api.errors import StoreError, Codes
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
|
||||||
class RegistrationStore(SQLBaseStore):
|
class RegistrationStore(SQLBaseStore):
|
||||||
@ -97,6 +98,20 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_users_by_id_case_insensitive(self, user_id):
|
||||||
|
"""Gets users that match user_id case insensitively.
|
||||||
|
Returns a mapping of user_id -> password_hash.
|
||||||
|
"""
|
||||||
|
def f(txn):
|
||||||
|
sql = (
|
||||||
|
"SELECT name, password_hash FROM users"
|
||||||
|
" WHERE lower(name) = lower(?)"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (user_id,))
|
||||||
|
return dict(txn.fetchall())
|
||||||
|
|
||||||
|
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_set_password_hash(self, user_id, password_hash):
|
def user_set_password_hash(self, user_id, password_hash):
|
||||||
"""
|
"""
|
||||||
@ -111,16 +126,16 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
})
|
})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_delete_access_tokens_apart_from(self, user_id, token_id):
|
def user_delete_access_tokens(self, user_id):
|
||||||
yield self.runInteraction(
|
yield self.runInteraction(
|
||||||
"user_delete_access_tokens_apart_from",
|
"user_delete_access_tokens",
|
||||||
self._user_delete_access_tokens_apart_from, user_id, token_id
|
self._user_delete_access_tokens, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
def _user_delete_access_tokens_apart_from(self, txn, user_id, token_id):
|
def _user_delete_access_tokens(self, txn, user_id):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"DELETE FROM access_tokens WHERE user_id = ? AND id != ?",
|
"DELETE FROM access_tokens WHERE user_id = ?",
|
||||||
(user_id, token_id)
|
(user_id, )
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -131,7 +146,7 @@ class RegistrationStore(SQLBaseStore):
|
|||||||
user_id
|
user_id
|
||||||
)
|
)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
self.get_user_by_token.invalidate(r)
|
self.get_user_by_token.invalidate((r,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_token(self, token):
|
def get_user_by_token(self, token):
|
||||||
|
@ -17,7 +17,8 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
@ -186,8 +187,7 @@ class RoomStore(SQLBaseStore):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cachedInlineCallbacks()
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_room_name_and_aliases(self, room_id):
|
def get_room_name_and_aliases(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
|
@ -17,7 +17,8 @@ from twisted.internet import defer
|
|||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
@ -35,38 +36,28 @@ RoomsForUser = namedtuple(
|
|||||||
|
|
||||||
class RoomMemberStore(SQLBaseStore):
|
class RoomMemberStore(SQLBaseStore):
|
||||||
|
|
||||||
def _store_room_member_txn(self, txn, event):
|
def _store_room_members_txn(self, txn, events):
|
||||||
"""Store a room member in the database.
|
"""Store a room member in the database.
|
||||||
"""
|
"""
|
||||||
try:
|
self._simple_insert_many_txn(
|
||||||
target_user_id = event.state_key
|
|
||||||
except:
|
|
||||||
logger.exception(
|
|
||||||
"Failed to parse target_user_id=%s", target_user_id
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
"_store_room_member_txn: target_user_id=%s, membership=%s",
|
|
||||||
target_user_id,
|
|
||||||
event.membership,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
txn,
|
||||||
"room_memberships",
|
table="room_memberships",
|
||||||
{
|
values=[
|
||||||
"event_id": event.event_id,
|
{
|
||||||
"user_id": target_user_id,
|
"event_id": event.event_id,
|
||||||
"sender": event.user_id,
|
"user_id": event.state_key,
|
||||||
"room_id": event.room_id,
|
"sender": event.user_id,
|
||||||
"membership": event.membership,
|
"room_id": event.room_id,
|
||||||
}
|
"membership": event.membership,
|
||||||
|
}
|
||||||
|
for event in events
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
|
for event in events:
|
||||||
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
|
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
|
||||||
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
|
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
|
||||||
|
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
|
||||||
|
|
||||||
def get_room_member(self, user_id, room_id):
|
def get_room_member(self, user_id, room_id):
|
||||||
"""Retrieve the current state of a room member.
|
"""Retrieve the current state of a room member.
|
||||||
@ -88,7 +79,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
lambda events: events[0] if events else None
|
lambda events: events[0] if events else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached()
|
@cached(max_entries=5000)
|
||||||
def get_users_in_room(self, room_id):
|
def get_users_in_room(self, room_id):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
|
|
||||||
@ -164,7 +155,7 @@ class RoomMemberStore(SQLBaseStore):
|
|||||||
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
|
||||||
]
|
]
|
||||||
|
|
||||||
@cached()
|
@cached(max_entries=5000)
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
def get_joined_hosts_for_room(self, room_id):
|
||||||
return self.runInteraction(
|
return self.runInteraction(
|
||||||
"get_joined_hosts_for_room",
|
"get_joined_hosts_for_room",
|
||||||
|
34
synapse/storage/schema/delta/21/end_to_end_keys.sql
Normal file
34
synapse/storage/schema/delta/21/end_to_end_keys.sql
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS e2e_device_keys_json (
|
||||||
|
user_id TEXT NOT NULL, -- The user these keys are for.
|
||||||
|
device_id TEXT NOT NULL, -- Which of the user's devices these keys are for.
|
||||||
|
ts_added_ms BIGINT NOT NULL, -- When the keys were uploaded.
|
||||||
|
key_json TEXT NOT NULL, -- The keys for the device as a JSON blob.
|
||||||
|
CONSTRAINT e2e_device_keys_json_uniqueness UNIQUE (user_id, device_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
|
||||||
|
user_id TEXT NOT NULL, -- The user this one-time key is for.
|
||||||
|
device_id TEXT NOT NULL, -- The device this one-time key is for.
|
||||||
|
algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
|
||||||
|
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
|
||||||
|
ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
|
||||||
|
key_json TEXT NOT NULL, -- The key as a JSON blob.
|
||||||
|
CONSTRAINT e2e_one_time_keys_json_uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
|
||||||
|
);
|
38
synapse/storage/schema/delta/21/receipts.sql
Normal file
38
synapse/storage/schema/delta/21/receipts.sql
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS receipts_graph(
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
receipt_type TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
event_ids TEXT NOT NULL,
|
||||||
|
data TEXT NOT NULL,
|
||||||
|
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS receipts_linearized (
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
receipt_type TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
event_id TEXT NOT NULL,
|
||||||
|
data TEXT NOT NULL,
|
||||||
|
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX receipts_linearized_id ON receipts_linearized(
|
||||||
|
stream_id
|
||||||
|
);
|
18
synapse/storage/schema/delta/22/receipts_index.sql
Normal file
18
synapse/storage/schema/delta/22/receipts_index.sql
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized(
|
||||||
|
room_id, stream_id
|
||||||
|
);
|
19
synapse/storage/schema/delta/22/user_threepids_unique.sql
Normal file
19
synapse/storage/schema/delta/22/user_threepids_unique.sql
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS user_threepids2 (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
medium TEXT NOT NULL,
|
||||||
|
address TEXT NOT NULL,
|
||||||
|
validated_at BIGINT NOT NULL,
|
||||||
|
added_at BIGINT NOT NULL,
|
||||||
|
CONSTRAINT medium_address UNIQUE (medium, address)
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO user_threepids2
|
||||||
|
SELECT * FROM user_threepids WHERE added_at IN (
|
||||||
|
SELECT max(added_at) FROM user_threepids GROUP BY medium, address
|
||||||
|
)
|
||||||
|
;
|
||||||
|
|
||||||
|
DROP TABLE user_threepids;
|
||||||
|
ALTER TABLE user_threepids2 RENAME TO user_threepids;
|
||||||
|
|
||||||
|
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
|
@ -18,6 +18,7 @@ from twisted.internet import defer
|
|||||||
from _base import SQLBaseStore
|
from _base import SQLBaseStore
|
||||||
|
|
||||||
from syutil.base64util import encode_base64
|
from syutil.base64util import encode_base64
|
||||||
|
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||||
|
|
||||||
|
|
||||||
class SignatureStore(SQLBaseStore):
|
class SignatureStore(SQLBaseStore):
|
||||||
@ -101,23 +102,26 @@ class SignatureStore(SQLBaseStore):
|
|||||||
txn.execute(query, (event_id, ))
|
txn.execute(query, (event_id, ))
|
||||||
return {k: v for k, v in txn.fetchall()}
|
return {k: v for k, v in txn.fetchall()}
|
||||||
|
|
||||||
def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
|
def _store_event_reference_hashes_txn(self, txn, events):
|
||||||
hash_bytes):
|
|
||||||
"""Store a hash for a PDU
|
"""Store a hash for a PDU
|
||||||
Args:
|
Args:
|
||||||
txn (cursor):
|
txn (cursor):
|
||||||
event_id (str): Id for the Event.
|
events (list): list of Events.
|
||||||
algorithm (str): Hashing algorithm.
|
|
||||||
hash_bytes (bytes): Hash function output bytes.
|
|
||||||
"""
|
"""
|
||||||
self._simple_insert_txn(
|
|
||||||
|
vals = []
|
||||||
|
for event in events:
|
||||||
|
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
|
||||||
|
vals.append({
|
||||||
|
"event_id": event.event_id,
|
||||||
|
"algorithm": ref_alg,
|
||||||
|
"hash": buffer(ref_hash_bytes),
|
||||||
|
})
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
"event_reference_hashes",
|
table="event_reference_hashes",
|
||||||
{
|
values=vals,
|
||||||
"event_id": event_id,
|
|
||||||
"algorithm": algorithm,
|
|
||||||
"hash": buffer(hash_bytes),
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_event_signatures_txn(self, txn, event_id):
|
def _get_event_signatures_txn(self, txn, event_id):
|
||||||
|
@ -13,7 +13,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import (
|
||||||
|
cached, cachedInlineCallbacks, cachedList
|
||||||
|
)
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
@ -44,72 +47,44 @@ class StateStore(SQLBaseStore):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_groups(self, event_ids):
|
def get_state_groups(self, room_id, event_ids):
|
||||||
""" Get the state groups for the given list of event_ids
|
""" Get the state groups for the given list of event_ids
|
||||||
|
|
||||||
The return value is a dict mapping group names to lists of events.
|
The return value is a dict mapping group names to lists of events.
|
||||||
"""
|
"""
|
||||||
|
if not event_ids:
|
||||||
|
defer.returnValue({})
|
||||||
|
|
||||||
def f(txn):
|
event_to_groups = yield self._get_state_group_for_events(
|
||||||
groups = set()
|
room_id, event_ids,
|
||||||
for event_id in event_ids:
|
|
||||||
group = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="event_to_state_groups",
|
|
||||||
keyvalues={"event_id": event_id},
|
|
||||||
retcol="state_group",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if group:
|
|
||||||
groups.add(group)
|
|
||||||
|
|
||||||
res = {}
|
|
||||||
for group in groups:
|
|
||||||
state_ids = self._simple_select_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups_state",
|
|
||||||
keyvalues={"state_group": group},
|
|
||||||
retcol="event_id",
|
|
||||||
)
|
|
||||||
|
|
||||||
res[group] = state_ids
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
states = yield self.runInteraction(
|
|
||||||
"get_state_groups",
|
|
||||||
f,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
state_list = yield defer.gatherResults(
|
groups = set(event_to_groups.values())
|
||||||
[
|
group_to_state = yield self._get_state_for_groups(groups)
|
||||||
self._fetch_events_for_group(group, vals)
|
|
||||||
for group, vals in states.items()
|
|
||||||
],
|
|
||||||
consumeErrors=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue(dict(state_list))
|
defer.returnValue({
|
||||||
|
group: state_map.values()
|
||||||
@cached(num_args=1)
|
for group, state_map in group_to_state.items()
|
||||||
def _fetch_events_for_group(self, state_group, events):
|
})
|
||||||
return self._get_events(
|
|
||||||
events, get_prev_content=False
|
|
||||||
).addCallback(
|
|
||||||
lambda evs: (state_group, evs)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _store_state_groups_txn(self, txn, event, context):
|
def _store_state_groups_txn(self, txn, event, context):
|
||||||
if context.current_state is None:
|
return self._store_mult_state_groups_txn(txn, [(event, context)])
|
||||||
return
|
|
||||||
|
|
||||||
state_events = dict(context.current_state)
|
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||||
|
state_groups = {}
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
if context.current_state is None:
|
||||||
|
continue
|
||||||
|
|
||||||
if event.is_state():
|
if context.state_group is not None:
|
||||||
state_events[(event.type, event.state_key)] = event
|
state_groups[event.event_id] = context.state_group
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_events = dict(context.current_state)
|
||||||
|
|
||||||
|
if event.is_state():
|
||||||
|
state_events[(event.type, event.state_key)] = event
|
||||||
|
|
||||||
state_group = context.state_group
|
|
||||||
if not state_group:
|
|
||||||
state_group = self._state_groups_id_gen.get_next_txn(txn)
|
state_group = self._state_groups_id_gen.get_next_txn(txn)
|
||||||
self._simple_insert_txn(
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
@ -135,14 +110,19 @@ class StateStore(SQLBaseStore):
|
|||||||
for state in state_events.values()
|
for state in state_events.values()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
state_groups[event.event_id] = state_group
|
||||||
|
|
||||||
self._simple_insert_txn(
|
self._simple_insert_many_txn(
|
||||||
txn,
|
txn,
|
||||||
table="event_to_state_groups",
|
table="event_to_state_groups",
|
||||||
values={
|
values=[
|
||||||
"state_group": state_group,
|
{
|
||||||
"event_id": event.event_id,
|
"state_group": state_groups[event.event_id],
|
||||||
},
|
"event_id": event.event_id,
|
||||||
|
}
|
||||||
|
for event, context in events_and_contexts
|
||||||
|
if context.current_state is not None
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@ -177,8 +157,7 @@ class StateStore(SQLBaseStore):
|
|||||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
@cached(num_args=3)
|
@cachedInlineCallbacks(num_args=3)
|
||||||
@defer.inlineCallbacks
|
|
||||||
def get_current_state_for_key(self, room_id, event_type, state_key):
|
def get_current_state_for_key(self, room_id, event_type, state_key):
|
||||||
def f(txn):
|
def f(txn):
|
||||||
sql = (
|
sql = (
|
||||||
@ -194,6 +173,262 @@ class StateStore(SQLBaseStore):
|
|||||||
events = yield self._get_events(event_ids, get_prev_content=False)
|
events = yield self._get_events(event_ids, get_prev_content=False)
|
||||||
defer.returnValue(events)
|
defer.returnValue(events)
|
||||||
|
|
||||||
|
def _get_state_groups_from_groups(self, groups_and_types):
|
||||||
|
"""Returns dictionary state_group -> state event ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
groups_and_types (list): list of 2-tuple (`group`, `types`)
|
||||||
|
"""
|
||||||
|
def f(txn):
|
||||||
|
results = {}
|
||||||
|
for group, types in groups_and_types:
|
||||||
|
if types is not None:
|
||||||
|
where_clause = "AND (%s)" % (
|
||||||
|
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
where_clause = ""
|
||||||
|
|
||||||
|
sql = (
|
||||||
|
"SELECT event_id FROM state_groups_state WHERE"
|
||||||
|
" state_group = ? %s"
|
||||||
|
) % (where_clause,)
|
||||||
|
|
||||||
|
args = [group]
|
||||||
|
if types is not None:
|
||||||
|
args.extend([i for typ in types for i in typ])
|
||||||
|
|
||||||
|
txn.execute(sql, args)
|
||||||
|
|
||||||
|
results[group] = [r[0] for r in txn.fetchall()]
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"_get_state_groups_from_groups",
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_for_events(self, room_id, event_ids, types):
|
||||||
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
|
dicts for each event. The state dicts will only have the type/state_keys
|
||||||
|
that are in the `types` list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id (str)
|
||||||
|
event_ids (list)
|
||||||
|
types (list): List of (type, state_key) tuples which are used to
|
||||||
|
filter the state fetched. `state_key` may be None, which matches
|
||||||
|
any `state_key`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
deferred: A list of dicts corresponding to the event_ids given.
|
||||||
|
The dicts are mappings from (type, state_key) -> state_events
|
||||||
|
"""
|
||||||
|
event_to_groups = yield self._get_state_group_for_events(
|
||||||
|
room_id, event_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
groups = set(event_to_groups.values())
|
||||||
|
group_to_state = yield self._get_state_for_groups(groups, types)
|
||||||
|
|
||||||
|
event_to_state = {
|
||||||
|
event_id: group_to_state[group]
|
||||||
|
for event_id, group in event_to_groups.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue({event: event_to_state[event] for event in event_ids})
|
||||||
|
|
||||||
|
@cached(num_args=2, lru=True, max_entries=10000)
|
||||||
|
def _get_state_group_for_event(self, room_id, event_id):
|
||||||
|
return self._simple_select_one_onecol(
|
||||||
|
table="event_to_state_groups",
|
||||||
|
keyvalues={
|
||||||
|
"event_id": event_id,
|
||||||
|
},
|
||||||
|
retcol="state_group",
|
||||||
|
allow_none=True,
|
||||||
|
desc="_get_state_group_for_event",
|
||||||
|
)
|
||||||
|
|
||||||
|
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
|
||||||
|
num_args=2)
|
||||||
|
def _get_state_group_for_events(self, room_id, event_ids):
|
||||||
|
"""Returns mapping event_id -> state_group
|
||||||
|
"""
|
||||||
|
def f(txn):
|
||||||
|
results = {}
|
||||||
|
for event_id in event_ids:
|
||||||
|
results[event_id] = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="event_to_state_groups",
|
||||||
|
keyvalues={
|
||||||
|
"event_id": event_id,
|
||||||
|
},
|
||||||
|
retcol="state_group",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
return self.runInteraction("_get_state_group_for_events", f)
|
||||||
|
|
||||||
|
def _get_some_state_from_cache(self, group, types):
|
||||||
|
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||||
|
|
||||||
|
Returns 3-tuple (`state_dict`, `missing_types`, `got_all`).
|
||||||
|
`missing_types` is the list of types that aren't in the cache for that
|
||||||
|
group. `got_all` is a bool indicating if we successfully retrieved all
|
||||||
|
requests state from the cache, if False we need to query the DB for the
|
||||||
|
missing state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: The state group to lookup
|
||||||
|
types (list): List of 2-tuples of the form (`type`, `state_key`),
|
||||||
|
where a `state_key` of `None` matches all state_keys for the
|
||||||
|
`type`.
|
||||||
|
"""
|
||||||
|
is_all, state_dict = self._state_group_cache.get(group)
|
||||||
|
|
||||||
|
type_to_key = {}
|
||||||
|
missing_types = set()
|
||||||
|
for typ, state_key in types:
|
||||||
|
if state_key is None:
|
||||||
|
type_to_key[typ] = None
|
||||||
|
missing_types.add((typ, state_key))
|
||||||
|
else:
|
||||||
|
if type_to_key.get(typ, object()) is not None:
|
||||||
|
type_to_key.setdefault(typ, set()).add(state_key)
|
||||||
|
|
||||||
|
if (typ, state_key) not in state_dict:
|
||||||
|
missing_types.add((typ, state_key))
|
||||||
|
|
||||||
|
sentinel = object()
|
||||||
|
|
||||||
|
def include(typ, state_key):
|
||||||
|
valid_state_keys = type_to_key.get(typ, sentinel)
|
||||||
|
if valid_state_keys is sentinel:
|
||||||
|
return False
|
||||||
|
if valid_state_keys is None:
|
||||||
|
return True
|
||||||
|
if state_key in valid_state_keys:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
got_all = not (missing_types or types is None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
k: v for k, v in state_dict.items()
|
||||||
|
if include(k[0], k[1])
|
||||||
|
}, missing_types, got_all
|
||||||
|
|
||||||
|
def _get_all_state_from_cache(self, group):
|
||||||
|
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||||
|
|
||||||
|
Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
|
||||||
|
indicating if we successfully retrieved all requests state from the
|
||||||
|
cache, if False we need to query the DB for the missing state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: The state group to lookup
|
||||||
|
"""
|
||||||
|
is_all, state_dict = self._state_group_cache.get(group)
|
||||||
|
return state_dict, is_all
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _get_state_for_groups(self, groups, types=None):
|
||||||
|
"""Given list of groups returns dict of group -> list of state events
|
||||||
|
with matching types. `types` is a list of `(type, state_key)`, where
|
||||||
|
a `state_key` of None matches all state_keys. If `types` is None then
|
||||||
|
all events are returned.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
missing_groups_and_types = []
|
||||||
|
if types is not None:
|
||||||
|
for group in set(groups):
|
||||||
|
state_dict, missing_types, got_all = self._get_some_state_from_cache(
|
||||||
|
group, types
|
||||||
|
)
|
||||||
|
results[group] = state_dict
|
||||||
|
|
||||||
|
if not got_all:
|
||||||
|
missing_groups_and_types.append((group, missing_types))
|
||||||
|
else:
|
||||||
|
for group in set(groups):
|
||||||
|
state_dict, got_all = self._get_all_state_from_cache(
|
||||||
|
group
|
||||||
|
)
|
||||||
|
results[group] = state_dict
|
||||||
|
|
||||||
|
if not got_all:
|
||||||
|
missing_groups_and_types.append((group, None))
|
||||||
|
|
||||||
|
if not missing_groups_and_types:
|
||||||
|
defer.returnValue({
|
||||||
|
group: {
|
||||||
|
type_tuple: event
|
||||||
|
for type_tuple, event in state.items()
|
||||||
|
if event
|
||||||
|
}
|
||||||
|
for group, state in results.items()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Okay, so we have some missing_types, lets fetch them.
|
||||||
|
cache_seq_num = self._state_group_cache.sequence
|
||||||
|
|
||||||
|
group_state_dict = yield self._get_state_groups_from_groups(
|
||||||
|
missing_groups_and_types
|
||||||
|
)
|
||||||
|
|
||||||
|
state_events = yield self._get_events(
|
||||||
|
[e_id for l in group_state_dict.values() for e_id in l],
|
||||||
|
get_prev_content=False
|
||||||
|
)
|
||||||
|
|
||||||
|
state_events = {e.event_id: e for e in state_events}
|
||||||
|
|
||||||
|
# Now we want to update the cache with all the things we fetched
|
||||||
|
# from the database.
|
||||||
|
for group, state_ids in group_state_dict.items():
|
||||||
|
if types:
|
||||||
|
# We delibrately put key -> None mappings into the cache to
|
||||||
|
# cache absence of the key, on the assumption that if we've
|
||||||
|
# explicitly asked for some types then we will probably ask
|
||||||
|
# for them again.
|
||||||
|
state_dict = {key: None for key in types}
|
||||||
|
state_dict.update(results[group])
|
||||||
|
results[group] = state_dict
|
||||||
|
else:
|
||||||
|
state_dict = results[group]
|
||||||
|
|
||||||
|
for event_id in state_ids:
|
||||||
|
try:
|
||||||
|
state_event = state_events[event_id]
|
||||||
|
state_dict[(state_event.type, state_event.state_key)] = state_event
|
||||||
|
except KeyError:
|
||||||
|
# Hmm. So we do don't have that state event? Interesting.
|
||||||
|
logger.warn(
|
||||||
|
"Can't find state event %r for state group %r",
|
||||||
|
event_id, group,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._state_group_cache.update(
|
||||||
|
cache_seq_num,
|
||||||
|
key=group,
|
||||||
|
value=state_dict,
|
||||||
|
full=(types is None),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove all the entries with None values. The None values were just
|
||||||
|
# used for bookkeeping in the cache.
|
||||||
|
for group, state_dict in results.items():
|
||||||
|
results[group] = {
|
||||||
|
key: event for key, event in state_dict.items() if event
|
||||||
|
}
|
||||||
|
|
||||||
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
|
||||||
def _make_group_id(clock):
|
def _make_group_id(clock):
|
||||||
return str(int(clock.time_msec())) + random_string(5)
|
return str(int(clock.time_msec())) + random_string(5)
|
||||||
|
@ -36,6 +36,7 @@ what sort order was used:
|
|||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
@ -299,9 +300,8 @@ class StreamStore(SQLBaseStore):
|
|||||||
|
|
||||||
defer.returnValue((events, token))
|
defer.returnValue((events, token))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@cachedInlineCallbacks(num_args=4)
|
||||||
def get_recent_events_for_room(self, room_id, limit, end_token,
|
def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
|
||||||
with_feedback=False, from_token=None):
|
|
||||||
# TODO (erikj): Handle compressed feedback
|
# TODO (erikj): Handle compressed feedback
|
||||||
|
|
||||||
end_token = RoomStreamToken.parse_stream_token(end_token)
|
end_token = RoomStreamToken.parse_stream_token(end_token)
|
||||||
|
@ -13,7 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore, cached
|
from ._base import SQLBaseStore
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
@ -72,7 +72,10 @@ class StreamIdGenerator(object):
|
|||||||
with stream_id_gen.get_next_txn(txn) as stream_id:
|
with stream_id_gen.get_next_txn(txn) as stream_id:
|
||||||
# ... persist event ...
|
# ... persist event ...
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self, table, column):
|
||||||
|
self.table = table
|
||||||
|
self.column = column
|
||||||
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
self._current_max = None
|
self._current_max = None
|
||||||
@ -107,6 +110,37 @@ class StreamIdGenerator(object):
|
|||||||
|
|
||||||
defer.returnValue(manager())
|
defer.returnValue(manager())
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_next_mult(self, store, n):
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
with yield stream_id_gen.get_next(store, n) as stream_ids:
|
||||||
|
# ... persist events ...
|
||||||
|
"""
|
||||||
|
if not self._current_max:
|
||||||
|
yield store.runInteraction(
|
||||||
|
"_compute_current_max",
|
||||||
|
self._get_or_compute_current_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
next_ids = range(self._current_max + 1, self._current_max + n + 1)
|
||||||
|
self._current_max += n
|
||||||
|
|
||||||
|
for next_id in next_ids:
|
||||||
|
self._unfinished_ids.append(next_id)
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def manager():
|
||||||
|
try:
|
||||||
|
yield next_ids
|
||||||
|
finally:
|
||||||
|
with self._lock:
|
||||||
|
for next_id in next_ids:
|
||||||
|
self._unfinished_ids.remove(next_id)
|
||||||
|
|
||||||
|
defer.returnValue(manager())
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_max_token(self, store):
|
def get_max_token(self, store):
|
||||||
"""Returns the maximum stream id such that all stream ids less than or
|
"""Returns the maximum stream id such that all stream ids less than or
|
||||||
@ -126,7 +160,7 @@ class StreamIdGenerator(object):
|
|||||||
|
|
||||||
def _get_or_compute_current_max(self, txn):
|
def _get_or_compute_current_max(self, txn):
|
||||||
with self._lock:
|
with self._lock:
|
||||||
txn.execute("SELECT MAX(stream_ordering) FROM events")
|
txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table))
|
||||||
rows = txn.fetchall()
|
rows = txn.fetchall()
|
||||||
val, = rows[0]
|
val, = rows[0]
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from synapse.types import StreamToken
|
|||||||
from synapse.handlers.presence import PresenceEventSource
|
from synapse.handlers.presence import PresenceEventSource
|
||||||
from synapse.handlers.room import RoomEventSource
|
from synapse.handlers.room import RoomEventSource
|
||||||
from synapse.handlers.typing import TypingNotificationEventSource
|
from synapse.handlers.typing import TypingNotificationEventSource
|
||||||
|
from synapse.handlers.receipts import ReceiptEventSource
|
||||||
|
|
||||||
|
|
||||||
class NullSource(object):
|
class NullSource(object):
|
||||||
@ -43,6 +44,7 @@ class EventSources(object):
|
|||||||
"room": RoomEventSource,
|
"room": RoomEventSource,
|
||||||
"presence": PresenceEventSource,
|
"presence": PresenceEventSource,
|
||||||
"typing": TypingNotificationEventSource,
|
"typing": TypingNotificationEventSource,
|
||||||
|
"receipt": ReceiptEventSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
@ -62,7 +64,10 @@ class EventSources(object):
|
|||||||
),
|
),
|
||||||
typing_key=(
|
typing_key=(
|
||||||
yield self.sources["typing"].get_current_key()
|
yield self.sources["typing"].get_current_key()
|
||||||
)
|
),
|
||||||
|
receipt_key=(
|
||||||
|
yield self.sources["receipt"].get_current_key()
|
||||||
|
),
|
||||||
)
|
)
|
||||||
defer.returnValue(token)
|
defer.returnValue(token)
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ class EventID(DomainSpecificString):
|
|||||||
class StreamToken(
|
class StreamToken(
|
||||||
namedtuple(
|
namedtuple(
|
||||||
"Token",
|
"Token",
|
||||||
("room_key", "presence_key", "typing_key")
|
("room_key", "presence_key", "typing_key", "receipt_key")
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
_SEPARATOR = "_"
|
_SEPARATOR = "_"
|
||||||
@ -109,6 +109,9 @@ class StreamToken(
|
|||||||
def from_string(cls, string):
|
def from_string(cls, string):
|
||||||
try:
|
try:
|
||||||
keys = string.split(cls._SEPARATOR)
|
keys = string.split(cls._SEPARATOR)
|
||||||
|
if len(keys) == len(cls._fields) - 1:
|
||||||
|
# i.e. old token from before receipt_key
|
||||||
|
keys.append("0")
|
||||||
return cls(*keys)
|
return cls(*keys)
|
||||||
except:
|
except:
|
||||||
raise SynapseError(400, "Invalid Token")
|
raise SynapseError(400, "Invalid Token")
|
||||||
@ -131,6 +134,7 @@ class StreamToken(
|
|||||||
(other_token.room_stream_id < self.room_stream_id)
|
(other_token.room_stream_id < self.room_stream_id)
|
||||||
or (int(other_token.presence_key) < int(self.presence_key))
|
or (int(other_token.presence_key) < int(self.presence_key))
|
||||||
or (int(other_token.typing_key) < int(self.typing_key))
|
or (int(other_token.typing_key) < int(self.typing_key))
|
||||||
|
or (int(other_token.receipt_key) < int(self.receipt_key))
|
||||||
)
|
)
|
||||||
|
|
||||||
def copy_and_advance(self, key, new_value):
|
def copy_and_advance(self, key, new_value):
|
||||||
@ -174,7 +178,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||||||
|
|
||||||
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
Live tokens start with an "s" followed by the "stream_ordering" id of the
|
||||||
event it comes after. Historic tokens start with a "t" followed by the
|
event it comes after. Historic tokens start with a "t" followed by the
|
||||||
"topological_ordering" id of the event it comes after, follewed by "-",
|
"topological_ordering" id of the event it comes after, followed by "-",
|
||||||
followed by the "stream_ordering" id of the event it comes after.
|
followed by the "stream_ordering" id of the event it comes after.
|
||||||
"""
|
"""
|
||||||
__slots__ = []
|
__slots__ = []
|
||||||
@ -207,4 +211,5 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
|
|||||||
return "s%d" % (self.stream,)
|
return "s%d" % (self.stream,)
|
||||||
|
|
||||||
|
|
||||||
|
# token_id is the primary key ID of the access token, not the access token itself.
|
||||||
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|
||||||
|
@ -51,7 +51,7 @@ class ObservableDeferred(object):
|
|||||||
object.__setattr__(self, "_observers", set())
|
object.__setattr__(self, "_observers", set())
|
||||||
|
|
||||||
def callback(r):
|
def callback(r):
|
||||||
self._result = (True, r)
|
object.__setattr__(self, "_result", (True, r))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
self._observers.pop().callback(r)
|
self._observers.pop().callback(r)
|
||||||
@ -60,7 +60,7 @@ class ObservableDeferred(object):
|
|||||||
return r
|
return r
|
||||||
|
|
||||||
def errback(f):
|
def errback(f):
|
||||||
self._result = (False, f)
|
object.__setattr__(self, "_result", (False, f))
|
||||||
while self._observers:
|
while self._observers:
|
||||||
try:
|
try:
|
||||||
self._observers.pop().errback(f)
|
self._observers.pop().errback(f)
|
||||||
@ -97,3 +97,8 @@ class ObservableDeferred(object):
|
|||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
setattr(self._deferred, name, value)
|
setattr(self._deferred, name, value)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
|
||||||
|
id(self), self._result, self._deferred,
|
||||||
|
)
|
||||||
|
27
synapse/util/caches/__init__.py
Normal file
27
synapse/util/caches/__init__.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse.metrics
|
||||||
|
|
||||||
|
DEBUG_CACHES = False
|
||||||
|
|
||||||
|
metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
|
||||||
|
|
||||||
|
caches_by_name = {}
|
||||||
|
cache_counter = metrics.register_cache(
|
||||||
|
"cache",
|
||||||
|
lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
|
||||||
|
labels=["name"],
|
||||||
|
)
|
377
synapse/util/caches/descriptors.py
Normal file
377
synapse/util/caches/descriptors.py
Normal file
@ -0,0 +1,377 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from synapse.util.async import ObservableDeferred
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
|
from . import caches_by_name, DEBUG_CACHES, cache_counter
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import inspect
|
||||||
|
import threading
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_CacheSentinel = object()
|
||||||
|
|
||||||
|
|
||||||
|
class Cache(object):
|
||||||
|
|
||||||
|
def __init__(self, name, max_entries=1000, keylen=1, lru=True):
|
||||||
|
if lru:
|
||||||
|
self.cache = LruCache(max_size=max_entries)
|
||||||
|
self.max_entries = None
|
||||||
|
else:
|
||||||
|
self.cache = OrderedDict()
|
||||||
|
self.max_entries = max_entries
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.keylen = keylen
|
||||||
|
self.sequence = 0
|
||||||
|
self.thread = None
|
||||||
|
caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
def check_thread(self):
|
||||||
|
expected_thread = self.thread
|
||||||
|
if expected_thread is None:
|
||||||
|
self.thread = threading.current_thread()
|
||||||
|
else:
|
||||||
|
if expected_thread is not threading.current_thread():
|
||||||
|
raise ValueError(
|
||||||
|
"Cache objects can only be accessed from the main thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, key, default=_CacheSentinel):
|
||||||
|
val = self.cache.get(key, _CacheSentinel)
|
||||||
|
if val is not _CacheSentinel:
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
return val
|
||||||
|
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
|
||||||
|
if default is _CacheSentinel:
|
||||||
|
raise KeyError()
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def update(self, sequence, key, value):
|
||||||
|
self.check_thread()
|
||||||
|
if self.sequence == sequence:
|
||||||
|
# Only update the cache if the caches sequence number matches the
|
||||||
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
|
self.prefill(key, value)
|
||||||
|
|
||||||
|
def prefill(self, key, value):
|
||||||
|
if self.max_entries is not None:
|
||||||
|
while len(self.cache) >= self.max_entries:
|
||||||
|
self.cache.popitem(last=False)
|
||||||
|
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def invalidate(self, key):
|
||||||
|
self.check_thread()
|
||||||
|
if not isinstance(key, tuple):
|
||||||
|
raise TypeError(
|
||||||
|
"The cache key must be a tuple not %r" % (type(key),)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increment the sequence number so that any SELECT statements that
|
||||||
|
# raced with the INSERT don't update the cache (SYN-369)
|
||||||
|
self.sequence += 1
|
||||||
|
self.cache.pop(key, None)
|
||||||
|
|
||||||
|
def invalidate_all(self):
|
||||||
|
self.check_thread()
|
||||||
|
self.sequence += 1
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDescriptor(object):
|
||||||
|
""" A method decorator that applies a memoizing cache around the function.
|
||||||
|
|
||||||
|
This caches deferreds, rather than the results themselves. Deferreds that
|
||||||
|
fail are removed from the cache.
|
||||||
|
|
||||||
|
The function is presumed to take zero or more arguments, which are used in
|
||||||
|
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||||
|
misses use the function body to generate the value.
|
||||||
|
|
||||||
|
The wrapped function has an additional member, a callable called
|
||||||
|
"invalidate". This can be used to remove individual entries from the cache.
|
||||||
|
|
||||||
|
The wrapped function has another additional callable, called "prefill",
|
||||||
|
which can be used to insert values into the cache specifically, without
|
||||||
|
calling the calculation function.
|
||||||
|
"""
|
||||||
|
def __init__(self, orig, max_entries=1000, num_args=1, lru=True,
|
||||||
|
inlineCallbacks=False):
|
||||||
|
self.orig = orig
|
||||||
|
|
||||||
|
if inlineCallbacks:
|
||||||
|
self.function_to_call = defer.inlineCallbacks(orig)
|
||||||
|
else:
|
||||||
|
self.function_to_call = orig
|
||||||
|
|
||||||
|
self.max_entries = max_entries
|
||||||
|
self.num_args = num_args
|
||||||
|
self.lru = lru
|
||||||
|
|
||||||
|
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
||||||
|
|
||||||
|
if len(self.arg_names) < self.num_args:
|
||||||
|
raise Exception(
|
||||||
|
"Not enough explicit positional arguments to key off of for %r."
|
||||||
|
" (@cached cannot key off of *args or **kwars)"
|
||||||
|
% (orig.__name__,)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.cache = Cache(
|
||||||
|
name=self.orig.__name__,
|
||||||
|
max_entries=self.max_entries,
|
||||||
|
keylen=self.num_args,
|
||||||
|
lru=self.lru,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __get__(self, obj, objtype=None):
|
||||||
|
|
||||||
|
@functools.wraps(self.orig)
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
|
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
|
||||||
|
try:
|
||||||
|
cached_result_d = self.cache.get(cache_key)
|
||||||
|
|
||||||
|
observer = cached_result_d.observe()
|
||||||
|
if DEBUG_CACHES:
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def check_result(cached_result):
|
||||||
|
actual_result = yield self.function_to_call(obj, *args, **kwargs)
|
||||||
|
if actual_result != cached_result:
|
||||||
|
logger.error(
|
||||||
|
"Stale cache entry %s%r: cached: %r, actual %r",
|
||||||
|
self.orig.__name__, cache_key,
|
||||||
|
cached_result, actual_result,
|
||||||
|
)
|
||||||
|
raise ValueError("Stale cache entry")
|
||||||
|
defer.returnValue(cached_result)
|
||||||
|
observer.addCallback(check_result)
|
||||||
|
|
||||||
|
return observer
|
||||||
|
except KeyError:
|
||||||
|
# Get the sequence number of the cache before reading from the
|
||||||
|
# database so that we can tell if the cache is invalidated
|
||||||
|
# while the SELECT is executing (SYN-369)
|
||||||
|
sequence = self.cache.sequence
|
||||||
|
|
||||||
|
ret = defer.maybeDeferred(
|
||||||
|
self.function_to_call,
|
||||||
|
obj, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def onErr(f):
|
||||||
|
self.cache.invalidate(cache_key)
|
||||||
|
return f
|
||||||
|
|
||||||
|
ret.addErrback(onErr)
|
||||||
|
|
||||||
|
ret = ObservableDeferred(ret, consumeErrors=True)
|
||||||
|
self.cache.update(sequence, cache_key, ret)
|
||||||
|
|
||||||
|
return ret.observe()
|
||||||
|
|
||||||
|
wrapped.invalidate = self.cache.invalidate
|
||||||
|
wrapped.invalidate_all = self.cache.invalidate_all
|
||||||
|
wrapped.prefill = self.cache.prefill
|
||||||
|
|
||||||
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
class CacheListDescriptor(object):
|
||||||
|
"""Wraps an existing cache to support bulk fetching of keys.
|
||||||
|
|
||||||
|
Given a list of keys it looks in the cache to find any hits, then passes
|
||||||
|
the list of missing keys to the wrapped fucntion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orig, cache, list_name, num_args=1, inlineCallbacks=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
orig (function)
|
||||||
|
cache (Cache)
|
||||||
|
list_name (str): Name of the argument which is the bulk lookup list
|
||||||
|
num_args (int)
|
||||||
|
inlineCallbacks (bool): Whether orig is a generator that should
|
||||||
|
be wrapped by defer.inlineCallbacks
|
||||||
|
"""
|
||||||
|
self.orig = orig
|
||||||
|
|
||||||
|
if inlineCallbacks:
|
||||||
|
self.function_to_call = defer.inlineCallbacks(orig)
|
||||||
|
else:
|
||||||
|
self.function_to_call = orig
|
||||||
|
|
||||||
|
self.num_args = num_args
|
||||||
|
self.list_name = list_name
|
||||||
|
|
||||||
|
self.arg_names = inspect.getargspec(orig).args[1:num_args+1]
|
||||||
|
self.list_pos = self.arg_names.index(self.list_name)
|
||||||
|
|
||||||
|
self.cache = cache
|
||||||
|
|
||||||
|
self.sentinel = object()
|
||||||
|
|
||||||
|
if len(self.arg_names) < self.num_args:
|
||||||
|
raise Exception(
|
||||||
|
"Not enough explicit positional arguments to key off of for %r."
|
||||||
|
" (@cached cannot key off of *args or **kwars)"
|
||||||
|
% (orig.__name__,)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.list_name not in self.arg_names:
|
||||||
|
raise Exception(
|
||||||
|
"Couldn't see arguments %r for %r."
|
||||||
|
% (self.list_name, cache.name,)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __get__(self, obj, objtype=None):
|
||||||
|
|
||||||
|
@functools.wraps(self.orig)
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
|
||||||
|
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
|
||||||
|
list_args = arg_dict[self.list_name]
|
||||||
|
|
||||||
|
# cached is a dict arg -> deferred, where deferred results in a
|
||||||
|
# 2-tuple (`arg`, `result`)
|
||||||
|
cached = {}
|
||||||
|
missing = []
|
||||||
|
for arg in list_args:
|
||||||
|
key = list(keyargs)
|
||||||
|
key[self.list_pos] = arg
|
||||||
|
|
||||||
|
try:
|
||||||
|
res = self.cache.get(tuple(key)).observe()
|
||||||
|
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||||
|
cached[arg] = res
|
||||||
|
except KeyError:
|
||||||
|
missing.append(arg)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
sequence = self.cache.sequence
|
||||||
|
args_to_call = dict(arg_dict)
|
||||||
|
args_to_call[self.list_name] = missing
|
||||||
|
|
||||||
|
ret_d = defer.maybeDeferred(
|
||||||
|
self.function_to_call,
|
||||||
|
**args_to_call
|
||||||
|
)
|
||||||
|
|
||||||
|
ret_d = ObservableDeferred(ret_d)
|
||||||
|
|
||||||
|
# We need to create deferreds for each arg in the list so that
|
||||||
|
# we can insert the new deferred into the cache.
|
||||||
|
for arg in missing:
|
||||||
|
observer = ret_d.observe()
|
||||||
|
observer.addCallback(lambda r, arg: r.get(arg, None), arg)
|
||||||
|
|
||||||
|
observer = ObservableDeferred(observer)
|
||||||
|
|
||||||
|
key = list(keyargs)
|
||||||
|
key[self.list_pos] = arg
|
||||||
|
self.cache.update(sequence, tuple(key), observer)
|
||||||
|
|
||||||
|
def invalidate(f, key):
|
||||||
|
self.cache.invalidate(key)
|
||||||
|
return f
|
||||||
|
observer.addErrback(invalidate, tuple(key))
|
||||||
|
|
||||||
|
res = observer.observe()
|
||||||
|
res.addCallback(lambda r, arg: (arg, r), arg)
|
||||||
|
|
||||||
|
cached[arg] = res
|
||||||
|
|
||||||
|
return defer.gatherResults(
|
||||||
|
cached.values(),
|
||||||
|
consumeErrors=True,
|
||||||
|
).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
|
||||||
|
|
||||||
|
obj.__dict__[self.orig.__name__] = wrapped
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def cached(max_entries=1000, num_args=1, lru=True):
|
||||||
|
return lambda orig: CacheDescriptor(
|
||||||
|
orig,
|
||||||
|
max_entries=max_entries,
|
||||||
|
num_args=num_args,
|
||||||
|
lru=lru
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cachedInlineCallbacks(max_entries=1000, num_args=1, lru=False):
|
||||||
|
return lambda orig: CacheDescriptor(
|
||||||
|
orig,
|
||||||
|
max_entries=max_entries,
|
||||||
|
num_args=num_args,
|
||||||
|
lru=lru,
|
||||||
|
inlineCallbacks=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cachedList(cache, list_name, num_args=1, inlineCallbacks=False):
|
||||||
|
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||||
|
|
||||||
|
Used to do batch lookups for an already created cache. A single argument
|
||||||
|
is specified as a list that is iterated through to lookup keys in the
|
||||||
|
original cache. A new list consisting of the keys that weren't in the cache
|
||||||
|
get passed to the original function, the result of which is stored in the
|
||||||
|
cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cache (Cache): The underlying cache to use.
|
||||||
|
list_name (str): The name of the argument that is the list to use to
|
||||||
|
do batch lookups in the cache.
|
||||||
|
num_args (int): Number of arguments to use as the key in the cache.
|
||||||
|
inlineCallbacks (bool): Should the function be wrapped in an
|
||||||
|
`defer.inlineCallbacks`?
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
class Example(object):
|
||||||
|
@cached(num_args=2)
|
||||||
|
def do_something(self, first_arg):
|
||||||
|
...
|
||||||
|
|
||||||
|
@cachedList(do_something.cache, list_name="second_args", num_args=2)
|
||||||
|
def batch_do_something(self, first_arg, second_args):
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
return lambda orig: CacheListDescriptor(
|
||||||
|
orig,
|
||||||
|
cache=cache,
|
||||||
|
list_name=list_name,
|
||||||
|
num_args=num_args,
|
||||||
|
inlineCallbacks=inlineCallbacks,
|
||||||
|
)
|
103
synapse/util/caches/dictionary_cache.py
Normal file
103
synapse/util/caches/dictionary_cache.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2015 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
from collections import namedtuple
|
||||||
|
from . import caches_by_name, cache_counter
|
||||||
|
import threading
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
|
||||||
|
|
||||||
|
|
||||||
|
class DictionaryCache(object):
|
||||||
|
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
||||||
|
fetching a subset of dictionary keys for a particular key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name, max_entries=1000):
|
||||||
|
self.cache = LruCache(max_size=max_entries)
|
||||||
|
|
||||||
|
self.name = name
|
||||||
|
self.sequence = 0
|
||||||
|
self.thread = None
|
||||||
|
# caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
class Sentinel(object):
|
||||||
|
__slots__ = []
|
||||||
|
|
||||||
|
self.sentinel = Sentinel()
|
||||||
|
caches_by_name[name] = self.cache
|
||||||
|
|
||||||
|
def check_thread(self):
|
||||||
|
expected_thread = self.thread
|
||||||
|
if expected_thread is None:
|
||||||
|
self.thread = threading.current_thread()
|
||||||
|
else:
|
||||||
|
if expected_thread is not threading.current_thread():
|
||||||
|
raise ValueError(
|
||||||
|
"Cache objects can only be accessed from the main thread"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get(self, key, dict_keys=None):
|
||||||
|
entry = self.cache.get(key, self.sentinel)
|
||||||
|
if entry is not self.sentinel:
|
||||||
|
cache_counter.inc_hits(self.name)
|
||||||
|
|
||||||
|
if dict_keys is None:
|
||||||
|
return DictionaryEntry(entry.full, dict(entry.value))
|
||||||
|
else:
|
||||||
|
return DictionaryEntry(entry.full, {
|
||||||
|
k: entry.value[k]
|
||||||
|
for k in dict_keys
|
||||||
|
if k in entry.value
|
||||||
|
})
|
||||||
|
|
||||||
|
cache_counter.inc_misses(self.name)
|
||||||
|
return DictionaryEntry(False, {})
|
||||||
|
|
||||||
|
def invalidate(self, key):
|
||||||
|
self.check_thread()
|
||||||
|
|
||||||
|
# Increment the sequence number so that any SELECT statements that
|
||||||
|
# raced with the INSERT don't update the cache (SYN-369)
|
||||||
|
self.sequence += 1
|
||||||
|
self.cache.pop(key, None)
|
||||||
|
|
||||||
|
def invalidate_all(self):
|
||||||
|
self.check_thread()
|
||||||
|
self.sequence += 1
|
||||||
|
self.cache.clear()
|
||||||
|
|
||||||
|
def update(self, sequence, key, value, full=False):
|
||||||
|
self.check_thread()
|
||||||
|
if self.sequence == sequence:
|
||||||
|
# Only update the cache if the caches sequence number matches the
|
||||||
|
# number that the cache had before the SELECT was started (SYN-369)
|
||||||
|
if full:
|
||||||
|
self._insert(key, value)
|
||||||
|
else:
|
||||||
|
self._update_or_insert(key, value)
|
||||||
|
|
||||||
|
def _update_or_insert(self, key, value):
|
||||||
|
entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
|
||||||
|
entry.value.update(value)
|
||||||
|
|
||||||
|
def _insert(self, key, value):
|
||||||
|
self.cache[key] = DictionaryEntry(True, value)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user