Merge branch 'develop' of github.com:matrix-org/synapse into erikj-perf

This commit is contained in:
Erik Johnston 2015-03-09 13:29:41 +00:00
commit f31e65ca8b
207 changed files with 13953 additions and 3607 deletions

18
.gitignore vendored
View File

@ -26,17 +26,19 @@ htmlcov
demo/*.db demo/*.db
demo/*.log demo/*.log
demo/*.log.*
demo/*.pid demo/*.pid
demo/media_store.*
demo/etc demo/etc
graph/*.svg
graph/*.png
graph/*.dot
**/webclient/config.js
**/webclient/test/coverage/
**/webclient/test/environment-protractor.js
uploads uploads
.idea/ .idea/
media_store/
*.tac
build/
localhost-800*/
static/client/register/register_config.js

View File

@ -1,54 +1,135 @@
Changes in synapse v0.8.0 (2015-03-06)
======================================
General:
* Add support for registration fallback. This is a page hosted on the server
which allows a user to register for an account, regardless of what client
they are using (e.g. mobile devices).
* Added new default push rules and made them configurable by clients:
* Suppress all notice messages.
* Notify when invited to a new room.
* Notify for messages that don't match any rule.
* Notify on incoming call.
Federation:
* Added per host server side rate-limiting of incoming federation requests.
* Added a ``/get_missing_events/`` API to federation to reduce number of
``/events/`` requests.
Configuration:
* Added configuration option to disable registration:
``disable_registration``.
* Added configuration option to change soft limit of number of open file
descriptors: ``soft_file_limit``.
* Make ``tls_private_key_path`` optional when running with ``no_tls``.
Application services:
* Application services can now poll on the CS API ``/events`` for their events,
by providing their application service ``access_token``.
* Added exclusive namespace support to application services API.
Changes in synapse v0.7.1 (2015-02-19)
======================================
* Initial alpha implementation of parts of the Application Services API.
Including:
- AS Registration / Unregistration
- User Query API
- Room Alias Query API
- Push transport for receiving events.
- User/Alias namespace admin control
* Add cache when fetching events from remote servers to stop repeatedly
fetching events with bad signatures.
* Respect the per remote server retry scheme when fetching both events and
server keys to reduce the number of times we send requests to dead servers.
* Inform remote servers when the local server fails to handle a received event.
* Turn off python bytecode generation due to problems experienced when
upgrading from previous versions.
Changes in synapse v0.7.0 (2015-02-12)
======================================
* Add initial implementation of the query auth federation API, allowing
servers to agree on whether an event should be allowed or rejected.
* Persist events we have rejected from federation, fixing the bug where
servers would keep requesting the same events.
* Various federation performance improvements, including:
- Add in memory caches on queries such as:
* Computing the state of a room at a point in time, used for
authorization on federation requests.
* Fetching events from the database.
* User's room membership, used for authorizing presence updates.
- Upgraded JSON library to improve parsing and serialisation speeds.
* Add default avatars to new user accounts using pydenticon library.
* Correctly time out federation requests.
* Retry federation requests against different servers.
* Add support for push and push rules.
* Add alpha versions of proposed new CSv2 APIs, including ``/sync`` API.
Changes in synapse 0.6.1 (2015-01-07) Changes in synapse 0.6.1 (2015-01-07)
===================================== =====================================
* Major optimizations to improve performance of initial sync and event sending * Major optimizations to improve performance of initial sync and event sending
in large rooms (by up to 10x) in large rooms (by up to 10x)
* Media repository now includes a Content-Length header on media downloads. * Media repository now includes a Content-Length header on media downloads.
* Improve quality of thumbnails by changing resizing algorithm. * Improve quality of thumbnails by changing resizing algorithm.
Changes in synapse 0.6.0 (2014-12-16) Changes in synapse 0.6.0 (2014-12-16)
===================================== =====================================
* Add new API for media upload and download that supports thumbnailing. * Add new API for media upload and download that supports thumbnailing.
* Replicate media uploads over multiple homeservers so media is always served * Replicate media uploads over multiple homeservers so media is always served
to clients from their local homeserver. This obsoletes the to clients from their local homeserver. This obsoletes the
--content-addr parameter and confusion over accessing content directly --content-addr parameter and confusion over accessing content directly
from remote homeservers. from remote homeservers.
* Implement exponential backoff when retrying federation requests when * Implement exponential backoff when retrying federation requests when
sending to remote homeservers which are offline. sending to remote homeservers which are offline.
* Implement typing notifications. * Implement typing notifications.
* Fix bugs where we sent events with invalid signatures due to bugs where * Fix bugs where we sent events with invalid signatures due to bugs where
we incorrectly persisted events. we incorrectly persisted events.
* Improve performance of database queries involving retrieving events. * Improve performance of database queries involving retrieving events.
Changes in synapse 0.5.4a (2014-12-13) Changes in synapse 0.5.4a (2014-12-13)
====================================== ======================================
* Fix bug while generating the error message when a file path specified in * Fix bug while generating the error message when a file path specified in
the config doesn't exist. the config doesn't exist.
Changes in synapse 0.5.4 (2014-12-03) Changes in synapse 0.5.4 (2014-12-03)
===================================== =====================================
* Fix presence bug where some rooms did not display presence updates for * Fix presence bug where some rooms did not display presence updates for
remote users. remote users.
* Do not log SQL timing log lines when started with "-v" * Do not log SQL timing log lines when started with "-v"
* Fix potential memory leak. * Fix potential memory leak.
Changes in synapse 0.5.3c (2014-12-02) Changes in synapse 0.5.3c (2014-12-02)
====================================== ======================================
* Change the default value for the `content_addr` option to use the HTTP * Change the default value for the `content_addr` option to use the HTTP
listener, as by default the HTTPS listener will be using a self-signed listener, as by default the HTTPS listener will be using a self-signed
certificate. certificate.
Changes in synapse 0.5.3 (2014-11-27) Changes in synapse 0.5.3 (2014-11-27)
===================================== =====================================
* Fix bug that caused joining a remote room to fail if a single event was not * Fix bug that caused joining a remote room to fail if a single event was not
signed correctly. signed correctly.
* Fix bug which caused servers to continuously try and fetch events from other * Fix bug which caused servers to continuously try and fetch events from other
servers. servers.
Changes in synapse 0.5.2 (2014-11-26) Changes in synapse 0.5.2 (2014-11-26)
===================================== =====================================

View File

@ -1,4 +1,14 @@
recursive-include docs * include synctl
recursive-include tests *.py include LICENSE
include VERSION
include *.rst
include demo/README
recursive-include synapse/storage/schema *.sql recursive-include synapse/storage/schema *.sql
recursive-include syweb/webclient *
recursive-include demo *.dh
recursive-include demo *.py
recursive-include demo *.sh
recursive-include docs *
recursive-include scripts *
recursive-include tests *.py

View File

@ -6,7 +6,7 @@ VoIP. The basics you need to know to get up and running are:
- Everything in Matrix happens in a room. Rooms are distributed and do not - Everything in Matrix happens in a room. Rooms are distributed and do not
exist on any single server. Rooms can be located using convenience aliases exist on any single server. Rooms can be located using convenience aliases
like ``#matrix:matrix.org`` or ``#test:localhost:8008``. like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future - Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
you will normally refer to yourself and others using a 3PID: email you will normally refer to yourself and others using a 3PID: email
@ -95,27 +95,36 @@ Installing prerequisites on Ubuntu or Debian::
$ sudo apt-get install build-essential python2.7-dev libffi-dev \ $ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \ python-pip python-setuptools sqlite3 \
libssl-dev libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux::
$ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
$ xcode-select --install $ xcode-select --install
$ sudo pip install virtualenv
To install the synapse homeserver run:: To install the synapse homeserver run::
$ pip install --user --process-dependency-links https://github.com/matrix-org/synapse/tarball/master $ virtualenv ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into This installs synapse, along with the libraries it uses, into a virtual
``$HOME/.local/lib/`` on Linux or ``$HOME/Library/Python/2.7/lib/`` on OSX. environment under ``~/.synapse``.
Your python may not give priority to locally installed libraries over system To set up your homeserver, run (in your virtualenv, as before)::
libraries, in which case you must add your local packages to your python path::
$ # on Linux: $ cd ~/.synapse
$ export PYTHONPATH=$HOME/.local/lib/python2.7/site-packages:$PYTHONPATH $ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
$ # on OSX: Substituting your host and domain name as appropriate.
$ export PYTHONPATH=$HOME/Library/Python/2.7/lib/python/site-packages:$PYTHONPATH
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details. a TURN server. See docs/turn-howto.rst for details.
@ -128,23 +137,57 @@ you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it:: may need to manually upgrade it::
$ sudo pip install --upgrade pip $ sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it refuse to run until you remove the temporary installation directory it
created. To reset the installation:: created. To reset the installation::
$ rm -rf /tmp/pip_install_matrix $ rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are happens, you will have to individually install the dependencies which are
failing, e.g.:: failing, e.g.::
$ pip install --user twisted $ pip install twisted
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments. will need to export CFLAGS=-Qunused-arguments.
ArchLinux
---------
Installation on ArchLinux may encounter a few hiccups as Arch defaults to
python 3, but synapse currently assumes python 2.7 by default.
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
$ sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install
request::
$ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv)::
$ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt
During setup of homeserver you need to call python2.7 directly again::
$ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
...substituting your host and domain name as appropriate.
Windows Install Windows Install
--------------- ---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages: Synapse can be installed on Cygwin. It requires the following Cygwin packages:
@ -155,7 +198,7 @@ Synapse can be installed on Cygwin. It requires the following Cygwin packages:
- openssl (and openssl-devel, python-openssl) - openssl (and openssl-devel, python-openssl)
- python - python
- python-setuptools - python-setuptools
The content repository requires additional packages and will be unable to process The content repository requires additional packages and will be unable to process
uploads without them: uploads without them:
- libjpeg8 - libjpeg8
@ -182,23 +225,13 @@ Running Your Homeserver
To actually run your new homeserver, pick a working directory for Synapse to run To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and:: (e.g. ``~/.synapse``), and::
$ mkdir ~/.synapse
$ cd ~/.synapse $ cd ~/.synapse
$ source ./bin/activate
$ # on Linux $ synctl start
$ ~/.local/bin/synctl start
$ # on OSX
$ ~/Library/Python/2.7/bin/synctl start
Troubleshooting Running Troubleshooting Running
----------------------- -----------------------
If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may
need a newer version of setuptools than that provided by your OS.::
$ sudo pip install setuptools --upgrade
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
encryption and digital signatures. encryption and digital signatures.
@ -214,6 +247,14 @@ fix try re-installing from PyPI or directly from
$ # Install from github $ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master $ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux
---------
If running `$ synctl start` fails wit 'returned non-zero exit status 1', you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
...or by editing synctl with the correct python executable.
Homeserver Development Homeserver Development
====================== ======================
@ -225,13 +266,15 @@ directory of your choice::
$ cd synapse $ cd synapse
The homeserver has a number of external dependencies, that are easiest The homeserver has a number of external dependencies, that are easiest
to install by making setup.py do so, in --user mode:: to install using pip and a virtualenv::
$ python setup.py develop --user $ virtualenv env
$ source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock
This will run a process of downloading and installing into your This will run a process of downloading and installing all the needed
user's .local/lib directory all of the required dependencies that are dependencies into a virtual env.
missing.
Once this is done, you may wish to run the homeserver's unit tests, to Once this is done, you may wish to run the homeserver's unit tests, to
check that everything is installed as it should be:: check that everything is installed as it should be::
@ -252,7 +295,7 @@ IMPORTANT: Before upgrading an existing homeserver to a new version, please
refer to UPGRADE.rst for any additional instructions. refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g. Otherwise, simply re-install the new codebase over the current one - e.g.
by ``pip install --user --process-dependency-links by ``pip install --process-dependency-links
https://github.com/matrix-org/synapse/tarball/master`` https://github.com/matrix-org/synapse/tarball/master``
if using pip, or by ``git pull`` if running off a git working copy. if using pip, or by ``git pull`` if running off a git working copy.
@ -279,9 +322,9 @@ For the first form, simply pass the required hostname (of the machine) as the
$ python -m synapse.app.homeserver \ $ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.config \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.config $ python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process. Alternatively, you can run ``synctl start`` to guide you through the process.
@ -301,9 +344,9 @@ SRV record, as that is the name other machines will expect it to have::
$ python -m synapse.app.homeserver \ $ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \ --server-name YOURDOMAIN \
--bind-port 8448 \ --bind-port 8448 \
--config-path homeserver.config \ --config-path homeserver.yaml \
--generate-config --generate-config
$ python -m synapse.app.homeserver --config-path homeserver.config $ python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to

View File

@ -1,3 +1,32 @@
Upgrading to v0.8.0
===================
Servers which use captchas will need to add their public key to::
static/client/register/register_config.js
window.matrixRegistrationConfig = {
recaptcha_public_key: "YOUR_PUBLIC_KEY"
};
This is required in order to support registration fallback (typically used on
mobile devices).
Upgrading to v0.7.0
===================
New dependencies are:
- pydenticon
- simplejson
- syutil
- matrix-angular-sdk
To pull in these dependencies in a virtual env, run::
python synapse/python_dependencies.py | xargs -n 1 pip install
Upgrading to v0.6.0 Upgrading to v0.6.0
=================== ===================

View File

@ -1 +0,0 @@
0.6.1b

View File

@ -21,6 +21,7 @@ import datetime
import argparse import argparse
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
def make_graph(db_name, room_id, file_prefix, limit): def make_graph(db_name, room_id, file_prefix, limit):
@ -70,7 +71,7 @@ def make_graph(db_name, room_id, file_prefix, limit):
float(event.origin_server_ts) / 1000 float(event.origin_server_ts) / 1000
).strftime('%Y-%m-%d %H:%M:%S,%f') ).strftime('%Y-%m-%d %H:%M:%S,%f')
content = json.dumps(event.get_dict()["content"]) content = json.dumps(unfreeze(event.get_dict()["content"]))
label = ( label = (
"<" "<"

View File

@ -39,43 +39,43 @@ ROOMDOMAIN="meet.jit.si"
#ROOMDOMAIN="conference.jitsi.vuc.me" #ROOMDOMAIN="conference.jitsi.vuc.me"
class TrivialMatrixClient: class TrivialMatrixClient:
def __init__(self, access_token): def __init__(self, access_token):
self.token = None self.token = None
self.access_token = access_token self.access_token = access_token
def getEvent(self): def getEvent(self):
while True: while True:
url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000"
if self.token: if self.token:
url += "&from="+self.token url += "&from="+self.token
req = grequests.get(url) req = grequests.get(url)
resps = grequests.map([req]) resps = grequests.map([req])
obj = json.loads(resps[0].content) obj = json.loads(resps[0].content)
print "incoming from matrix",obj print "incoming from matrix",obj
if 'end' not in obj: if 'end' not in obj:
continue continue
self.token = obj['end'] self.token = obj['end']
if len(obj['chunk']): if len(obj['chunk']):
return obj['chunk'][0] return obj['chunk'][0]
def joinRoom(self, roomId): def joinRoom(self, roomId):
url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token
print url print url
headers={ 'Content-Type': 'application/json' } headers={ 'Content-Type': 'application/json' }
req = grequests.post(url, headers=headers, data='{}') req = grequests.post(url, headers=headers, data='{}')
resps = grequests.map([req]) resps = grequests.map([req])
obj = json.loads(resps[0].content) obj = json.loads(resps[0].content)
print "response: ",obj print "response: ",obj
def sendEvent(self, roomId, evType, event): def sendEvent(self, roomId, evType, event):
url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token
print url print url
print json.dumps(event) print json.dumps(event)
headers={ 'Content-Type': 'application/json' } headers={ 'Content-Type': 'application/json' }
req = grequests.post(url, headers=headers, data=json.dumps(event)) req = grequests.post(url, headers=headers, data=json.dumps(event))
resps = grequests.map([req]) resps = grequests.map([req])
obj = json.loads(resps[0].content) obj = json.loads(resps[0].content)
print "response: ",obj print "response: ",obj
@ -83,178 +83,178 @@ xmppClients = {}
def matrixLoop(): def matrixLoop():
while True: while True:
ev = matrixCli.getEvent() ev = matrixCli.getEvent()
print ev print ev
if ev['type'] == 'm.room.member': if ev['type'] == 'm.room.member':
print 'membership event' print 'membership event'
if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME:
roomId = ev['room_id'] roomId = ev['room_id']
print "joining room %s" % (roomId) print "joining room %s" % (roomId)
matrixCli.joinRoom(roomId) matrixCli.joinRoom(roomId)
elif ev['type'] == 'm.room.message': elif ev['type'] == 'm.room.message':
if ev['room_id'] in xmppClients: if ev['room_id'] in xmppClients:
print "already have a bridge for that user, ignoring" print "already have a bridge for that user, ignoring"
continue continue
print "got message, connecting" print "got message, connecting"
xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
gevent.spawn(xmppClients[ev['room_id']].xmppLoop) gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
elif ev['type'] == 'm.call.invite': elif ev['type'] == 'm.call.invite':
print "Incoming call" print "Incoming call"
#sdp = ev['content']['offer']['sdp'] #sdp = ev['content']['offer']['sdp']
#print "sdp: %s" % (sdp) #print "sdp: %s" % (sdp)
#xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id'])
#gevent.spawn(xmppClients[ev['room_id']].xmppLoop) #gevent.spawn(xmppClients[ev['room_id']].xmppLoop)
elif ev['type'] == 'm.call.answer': elif ev['type'] == 'm.call.answer':
print "Call answered" print "Call answered"
sdp = ev['content']['answer']['sdp'] sdp = ev['content']['answer']['sdp']
if ev['room_id'] not in xmppClients: if ev['room_id'] not in xmppClients:
print "We didn't have a call for that room" print "We didn't have a call for that room"
continue continue
# should probably check call ID too # should probably check call ID too
xmppCli = xmppClients[ev['room_id']] xmppCli = xmppClients[ev['room_id']]
xmppCli.sendAnswer(sdp) xmppCli.sendAnswer(sdp)
elif ev['type'] == 'm.call.hangup': elif ev['type'] == 'm.call.hangup':
if ev['room_id'] in xmppClients: if ev['room_id'] in xmppClients:
xmppClients[ev['room_id']].stop() xmppClients[ev['room_id']].stop()
del xmppClients[ev['room_id']] del xmppClients[ev['room_id']]
class TrivialXmppClient: class TrivialXmppClient:
def __init__(self, matrixRoom, userId): def __init__(self, matrixRoom, userId):
self.rid = 0 self.rid = 0
self.matrixRoom = matrixRoom self.matrixRoom = matrixRoom
self.userId = userId self.userId = userId
self.running = True self.running = True
def stop(self): def stop(self):
self.running = False self.running = False
def nextRid(self): def nextRid(self):
self.rid += 1 self.rid += 1
return '%d' % (self.rid) return '%d' % (self.rid)
def sendIq(self, xml): def sendIq(self, xml):
fullXml = "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>" % (self.nextRid(), self.sid, xml) fullXml = "<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s'>%s</body>" % (self.nextRid(), self.sid, xml)
#print "\t>>>%s" % (fullXml) #print "\t>>>%s" % (fullXml)
return self.xmppPoke(fullXml) return self.xmppPoke(fullXml)
def xmppPoke(self, xml):
headers = {'Content-Type': 'application/xml'}
req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml)
resps = grequests.map([req])
obj = BeautifulSoup(resps[0].content)
return obj
def sendAnswer(self, answer): def xmppPoke(self, xml):
print "sdp from matrix client",answer headers = {'Content-Type': 'application/xml'}
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml)
jingle, out_err = p.communicate(answer) resps = grequests.map([req])
jingle = jingle % { obj = BeautifulSoup(resps[0].content)
'tojid': self.callfrom, return obj
'action': 'session-accept',
'initiator': self.callfrom,
'responder': self.jid,
'sid': self.callsid
}
print "answer jingle from sdp",jingle
res = self.sendIq(jingle)
print "reply from answer: ",res
self.ssrcs = {}
jingleSoup = BeautifulSoup(jingle)
for cont in jingleSoup.iq.jingle.findAll('content'):
if cont.description:
self.ssrcs[cont['name']] = cont.description['ssrc']
print "my ssrcs:",self.ssrcs
gevent.joinall([ def sendAnswer(self, answer):
gevent.spawn(self.advertiseSsrcs) print "sdp from matrix client",answer
]) p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
jingle, out_err = p.communicate(answer)
def advertiseSsrcs(self): jingle = jingle % {
'tojid': self.callfrom,
'action': 'session-accept',
'initiator': self.callfrom,
'responder': self.jid,
'sid': self.callsid
}
print "answer jingle from sdp",jingle
res = self.sendIq(jingle)
print "reply from answer: ",res
self.ssrcs = {}
jingleSoup = BeautifulSoup(jingle)
for cont in jingleSoup.iq.jingle.findAll('content'):
if cont.description:
self.ssrcs[cont['name']] = cont.description['ssrc']
print "my ssrcs:",self.ssrcs
gevent.joinall([
gevent.spawn(self.advertiseSsrcs)
])
def advertiseSsrcs(self):
time.sleep(7) time.sleep(7)
print "SSRC spammer started" print "SSRC spammer started"
while self.running: while self.running:
ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] }
res = self.sendIq(ssrcMsg) res = self.sendIq(ssrcMsg)
print "reply from ssrc announce: ",res print "reply from ssrc announce: ",res
time.sleep(10) time.sleep(10)
def xmppLoop(self):
self.matrixCallId = time.time()
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), HOST))
print res
self.sid = res.body['sid']
print "sid %s" % (self.sid)
res = self.sendIq("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>") def xmppLoop(self):
self.matrixCallId = time.time()
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' to='%s' xml:lang='en' wait='60' hold='1' content='text/xml; charset=utf-8' ver='1.6' xmpp:version='1.0' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), HOST))
res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), self.sid, HOST)) print res
self.sid = res.body['sid']
res = self.sendIq("<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>") print "sid %s" % (self.sid)
print res
self.jid = res.body.iq.bind.jid.string res = self.sendIq("<auth xmlns='urn:ietf:params:xml:ns:xmpp-sasl' mechanism='ANONYMOUS'/>")
print "jid: %s" % (self.jid)
self.shortJid = self.jid.split('-')[0]
res = self.sendIq("<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>") res = self.xmppPoke("<body rid='%s' xmlns='http://jabber.org/protocol/httpbind' sid='%s' to='%s' xml:lang='en' xmpp:restart='true' xmlns:xmpp='urn:xmpp:xbosh'/>" % (self.nextRid(), self.sid, HOST))
#randomthing = res.body.iq['to'] res = self.sendIq("<iq type='set' id='_bind_auth_2' xmlns='jabber:client'><bind xmlns='urn:ietf:params:xml:ns:xmpp-bind'/></iq>")
#whatsitpart = randomthing.split('-')[0] print res
#print "other random bind thing: %s" % (randomthing) self.jid = res.body.iq.bind.jid.string
print "jid: %s" % (self.jid)
self.shortJid = self.jid.split('-')[0]
# advertise preence to the jitsi room, with our nick res = self.sendIq("<iq type='set' id='_session_auth_2' xmlns='jabber:client'><session xmlns='urn:ietf:params:xml:ns:xmpp-session'/></iq>")
res = self.sendIq("<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId))
self.muc = {'users': []}
for p in res.body.findAll('presence'):
u = {}
u['shortJid'] = p['from'].split('/')[1]
if p.c and p.c.nick:
u['nick'] = p.c.nick.string
self.muc['users'].append(u)
print "muc: ",self.muc
# wait for stuff #randomthing = res.body.iq['to']
while True: #whatsitpart = randomthing.split('-')[0]
print "waiting..."
res = self.sendIq("") #print "other random bind thing: %s" % (randomthing)
print "got from stream: ",res
if res.body.iq: # advertise preence to the jitsi room, with our nick
jingles = res.body.iq.findAll('jingle') res = self.sendIq("<iq type='get' to='%s' xmlns='jabber:client' id='1:sendIQ'><services xmlns='urn:xmpp:extdisco:1'><service host='%s'/></services></iq><presence to='%s@%s/d98f6c40' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%s</nick></presence>" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId))
if len(jingles): self.muc = {'users': []}
self.callfrom = res.body.iq['from'] for p in res.body.findAll('presence'):
self.handleInvite(jingles[0]) u = {}
elif 'type' in res.body and res.body['type'] == 'terminate': u['shortJid'] = p['from'].split('/')[1]
self.running = False if p.c and p.c.nick:
del xmppClients[self.matrixRoom] u['nick'] = p.c.nick.string
return self.muc['users'].append(u)
print "muc: ",self.muc
# wait for stuff
while True:
print "waiting..."
res = self.sendIq("")
print "got from stream: ",res
if res.body.iq:
jingles = res.body.iq.findAll('jingle')
if len(jingles):
self.callfrom = res.body.iq['from']
self.handleInvite(jingles[0])
elif 'type' in res.body and res.body['type'] == 'terminate':
self.running = False
del xmppClients[self.matrixRoom]
return
def handleInvite(self, jingle):
self.initiator = jingle['initiator']
self.callsid = jingle['sid']
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
print "raw jingle invite",str(jingle)
sdp, out_err = p.communicate(str(jingle))
print "transformed remote offer sdp",sdp
inviteEvent = {
'offer': {
'type': 'offer',
'sdp': sdp
},
'call_id': self.matrixCallId,
'version': 0,
'lifetime': 30000
}
matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent)
def handleInvite(self, jingle):
self.initiator = jingle['initiator']
self.callsid = jingle['sid']
p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
print "raw jingle invite",str(jingle)
sdp, out_err = p.communicate(str(jingle))
print "transformed remote offer sdp",sdp
inviteEvent = {
'offer': {
'type': 'offer',
'sdp': sdp
},
'call_id': self.matrixCallId,
'version': 0,
'lifetime': 30000
}
matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent)
matrixCli = TrivialMatrixClient(ACCESS_TOKEN) matrixCli = TrivialMatrixClient(ACCESS_TOKEN)
gevent.joinall([ gevent.joinall([
gevent.spawn(matrixLoop) gevent.spawn(matrixLoop)
]) ])

489
contrib/vertobot/bridge.pl Executable file
View File

@ -0,0 +1,489 @@
#!/usr/bin/env perl
use strict;
use warnings;
use 5.010; # //
use IO::Socket::SSL qw(SSL_VERIFY_NONE);
use IO::Async::Loop;
use Net::Async::WebSocket::Client;
use Net::Async::HTTP;
use Net::Async::HTTP::Server;
use JSON;
use YAML;
use Data::UUID;
use Getopt::Long;
use Data::Dumper;
use URI::Encode qw(uri_encode uri_decode);
binmode STDOUT, ":encoding(UTF-8)";
binmode STDERR, ":encoding(UTF-8)";
my $msisdn_to_matrix = {
'447417892400' => '@matthew:matrix.org',
};
my $matrix_to_msisdn = {};
foreach (keys %$msisdn_to_matrix) {
$matrix_to_msisdn->{$msisdn_to_matrix->{$_}} = $_;
}
my $loop = IO::Async::Loop->new;
# Net::Async::HTTP + SSL + IO::Poll doesn't play well. See
# https://rt.cpan.org/Ticket/Display.html?id=93107
# ref $loop eq "IO::Async::Loop::Poll" and
# warn "Using SSL with IO::Poll causes known memory-leaks!!\n";
GetOptions(
'C|config=s' => \my $CONFIG,
'eval-from=s' => \my $EVAL_FROM,
) or exit 1;
if( defined $EVAL_FROM ) {
# An emergency 'eval() this file' hack
$SIG{HUP} = sub {
my $code = do {
open my $fh, "<", $EVAL_FROM or warn( "Cannot read - $!" ), return;
local $/; <$fh>
};
eval $code or warn "Cannot eval() - $@";
};
}
defined $CONFIG or die "Must supply --config\n";
my %CONFIG = %{ YAML::LoadFile( $CONFIG ) };
my %MATRIX_CONFIG = %{ $CONFIG{matrix} };
# No harm in always applying this
$MATRIX_CONFIG{SSL_verify_mode} = SSL_VERIFY_NONE;
my $bridgestate = {};
my $roomid_by_callid = {};
my $sessid = lc new Data::UUID->create_str();
my $as_token = $CONFIG{"matrix-bot"}->{as_token};
my $hs_domain = $CONFIG{"matrix-bot"}->{domain};
my $http = Net::Async::HTTP->new();
$loop->add( $http );
sub create_virtual_user
{
my ($localpart) = @_;
my ( $response ) = $http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/register?".
"access_token=$as_token&user_id=$localpart"
),
content_type => "application/json",
content => <<EOT
{
"type": "m.login.application_service",
"user": "$localpart"
}
EOT
)->get;
warn $response->as_string if ($response->code != 200);
}
my $http_server = Net::Async::HTTP::Server->new(
on_request => sub {
my $self = shift;
my ( $req ) = @_;
my $response;
my $path = uri_decode($req->path);
warn("request: $path");
if ($path =~ m#/users/\@(\+.*)#) {
# when queried about virtual users, auto-create them in the HS
my $localpart = $1;
create_virtual_user($localpart);
$response = HTTP::Response->new( 200 );
$response->add_content('{}');
$response->content_type( "application/json" );
}
elsif ($path =~ m#/transactions/(.*)#) {
my $event = JSON->new->decode($req->body);
print Dumper($event);
my $room_id = $event->{room_id};
my %dp = %{$CONFIG{'verto-dialog-params'}};
$dp{callID} = $bridgestate->{$room_id}->{callid};
if ($event->{type} eq 'm.room.membership') {
my $membership = $event->{content}->{membership};
my $state_key = $event->{state_key};
my $room_id = $event->{state_id};
if ($membership eq 'invite') {
# autojoin invites
my ( $response ) = $http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/rooms/$room_id/join?".
"access_token=$as_token&user_id=$state_key"
),
content_type => "application/json",
content => "{}",
)->get;
warn $response->as_string if ($response->code != 200);
}
}
elsif ($event->{type} eq 'm.call.invite') {
my $room_id = $event->{room_id};
$bridgestate->{$room_id}->{matrix_callid} = $event->{content}->{call_id};
$bridgestate->{$room_id}->{callid} = lc new Data::UUID->create_str();
$bridgestate->{$room_id}->{sessid} = $sessid;
# $bridgestate->{$room_id}->{offer} = $event->{content}->{offer}->{sdp};
my $offer = $event->{content}->{offer}->{sdp};
# $bridgestate->{$room_id}->{gathered_candidates} = 0;
$roomid_by_callid->{ $bridgestate->{$room_id}->{callid} } = $room_id;
# no trickle ICE in verto apparently
my $f = send_verto_json_request("verto.invite", {
"sdp" => $offer,
"dialogParams" => \%dp,
"sessid" => $bridgestate->{$room_id}->{sessid},
});
$self->adopt_future($f);
}
# elsif ($event->{type} eq 'm.call.candidates') {
# # XXX: this could fire for both matrix->verto and verto->matrix calls
# # and races as it collects candidates. much better to just turn off
# # candidate gathering in the webclient entirely for now
#
# my $room_id = $event->{room_id};
# # XXX: compare call IDs
# if (!$bridgestate->{$room_id}->{gathered_candidates}) {
# $bridgestate->{$room_id}->{gathered_candidates} = 1;
# my $offer = $bridgestate->{$room_id}->{offer};
# my $candidate_block = "";
# foreach (@{$event->{content}->{candidates}}) {
# $candidate_block .= "a=" . $_->{candidate} . "\r\n";
# }
# # XXX: collate using the right m= line - for now assume audio call
# $offer =~ s/(a=rtcp.*[\r\n]+)/$1$candidate_block/;
#
# my $f = send_verto_json_request("verto.invite", {
# "sdp" => $offer,
# "dialogParams" => \%dp,
# "sessid" => $bridgestate->{$room_id}->{sessid},
# });
# $self->adopt_future($f);
# }
# else {
# # ignore them, as no trickle ICE, although we might as well
# # batch them up
# # foreach (@{$event->{content}->{candidates}}) {
# # push @{$bridgestate->{$room_id}->{candidates}}, $_;
# # }
# }
# }
elsif ($event->{type} eq 'm.call.answer') {
# grab the answer and relay it to verto as a verto.answer
my $room_id = $event->{room_id};
my $answer = $event->{content}->{answer}->{sdp};
my $f = send_verto_json_request("verto.answer", {
"sdp" => $answer,
"dialogParams" => \%dp,
"sessid" => $bridgestate->{$room_id}->{sessid},
});
$self->adopt_future($f);
}
elsif ($event->{type} eq 'm.call.hangup') {
my $room_id = $event->{room_id};
if ($bridgestate->{$room_id}->{matrix_callid} eq $event->{content}->{call_id}) {
my $f = send_verto_json_request("verto.bye", {
"dialogParams" => \%dp,
"sessid" => $bridgestate->{$room_id}->{sessid},
});
$self->adopt_future($f);
}
else {
warn "Ignoring unrecognised callid: ".$event->{content}->{call_id};
}
}
else {
warn "Unhandled event: $event->{type}";
}
$response = HTTP::Response->new( 200 );
$response->add_content('{}');
$response->content_type( "application/json" );
}
else {
warn "Unhandled path: $path";
$response = HTTP::Response->new( 404 );
}
$req->respond( $response );
},
);
$loop->add( $http_server );
$http_server->listen(
addr => { family => "inet", socktype => "stream", port => 8009 },
on_listen_error => sub { die "Cannot listen - $_[-1]\n" },
);
my $bot_verto = Net::Async::WebSocket::Client->new(
on_frame => sub {
my ( $self, $frame ) = @_;
warn "[Verto] receiving $frame";
on_verto_json($frame);
},
);
$loop->add( $bot_verto );
my $verto_connecting = $loop->new_future;
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
on_connected => sub {
warn("[Verto] connected to websocket");
if (not $verto_connecting->is_done) {
$verto_connecting->done($bot_verto);
send_verto_json_request("login", {
'login' => $CONFIG{'verto-dialog-params'}{'login'},
'passwd' => $CONFIG{'verto-config'}{'passwd'},
'sessid' => $sessid,
});
}
},
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
);
# die Dumper($verto_connecting);
my $as_url = $CONFIG{"matrix-bot"}->{as_url};
Future->needs_all(
$http->do_request(
method => "POST",
uri => URI->new( $CONFIG{"matrix"}->{server}."/_matrix/appservice/v1/register" ),
content_type => "application/json",
content => <<EOT
{
"as_token": "$as_token",
"url": "$as_url",
"namespaces": { "users": ["\@\\\\+.*"] }
}
EOT
),
$verto_connecting,
)->get;
$loop->attach_signal(
PIPE => sub { warn "pipe\n" }
);
$loop->attach_signal(
INT => sub { $loop->stop },
);
$loop->attach_signal(
TERM => sub { $loop->stop },
);
eval {
$loop->run;
} or my $e = $@;
die $e if $e;
exit 0;
{
my $json_id;
my $requests;
sub send_verto_json_request
{
$json_id ||= 1;
my ($method, $params) = @_;
my $json = {
jsonrpc => "2.0",
method => $method,
params => $params,
id => $json_id,
};
my $text = JSON->new->encode( $json );
warn "[Verto] sending $text";
$bot_verto->send_frame ( $text );
my $request = $loop->new_future;
$requests->{$json_id} = $request;
$json_id++;
return $request;
}
sub send_verto_json_response
{
my ($result, $id) = @_;
my $json = {
jsonrpc => "2.0",
result => $result,
id => $id,
};
my $text = JSON->new->encode( $json );
warn "[Verto] sending $text";
$bot_verto->send_frame ( $text );
}
sub on_verto_json
{
my $json = JSON->new->decode( $_[0] );
if ($json->{method}) {
if (($json->{method} eq 'verto.answer' && $json->{params}->{sdp}) ||
$json->{method} eq 'verto.media') {
my $caller = $json->{dialogParams}->{caller_id_number};
my $callee = $json->{dialogParams}->{destination_number};
my $caller_user = '@+' . $caller . ':' . $hs_domain;
my $callee_user = $msisdn_to_matrix->{$callee} || warn "unrecogised callee: $callee";
my $room_id = $roomid_by_callid->{$json->{params}->{callID}};
if ($json->{params}->{sdp}) {
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/send/m.call.answer?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
call_id => $bridgestate->{$room_id}->{matrix_callid},
version => 0,
answer => {
sdp => $json->{params}->{sdp},
type => "answer",
},
}),
)->then( sub {
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
})->get;
}
}
elsif ($json->{method} eq 'verto.invite') {
my $caller = $json->{dialogParams}->{caller_id_number};
my $callee = $json->{dialogParams}->{destination_number};
my $caller_user = '@+' . $caller . ':' . $hs_domain;
my $callee_user = $msisdn_to_matrix->{$callee} || warn "unrecogised callee: $callee";
my $alias = ($caller lt $callee) ? ($caller.'-'.$callee) : ($callee.'-'.$caller);
my $room_id;
# create a virtual user for the caller if needed.
create_virtual_user($caller);
# create a room of form #peer-peer and invite the callee
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/createRoom?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
room_alias_name => $alias,
invite => [ $callee_user ],
}),
)->then( sub {
my ( $response ) = @_;
my $resp = JSON->new->decode($response->content);
$room_id = $resp->{room_id};
$roomid_by_callid->{$json->{params}->{callID}} = $room_id;
})->get;
# join it
my ($response) = $http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/join/$room_id?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => '{}',
)->get;
$bridgestate->{$room_id}->{matrix_callid} = lc new Data::UUID->create_str();
$bridgestate->{$room_id}->{callid} = $json->{dialogParams}->{callID};
$bridgestate->{$room_id}->{sessid} = $sessid;
# put the m.call.invite in there
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/send/m.call.invite?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
call_id => $bridgestate->{$room_id}->{matrix_callid},
version => 0,
answer => {
sdp => $json->{params}->{sdp},
type => "offer",
},
}),
)->then( sub {
# acknowledge the verto
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
})->get;
}
elsif ($json->{method} eq 'verto.bye') {
my $caller = $json->{dialogParams}->{caller_id_number};
my $callee = $json->{dialogParams}->{destination_number};
my $caller_user = '@+' . $caller . ':' . $hs_domain;
my $callee_user = $msisdn_to_matrix->{$callee} || warn "unrecogised callee: $callee";
my $room_id = $roomid_by_callid->{$json->{params}->{callID}};
# put the m.call.hangup into the room
$http->do_request(
method => "POST",
uri => URI->new(
$CONFIG{"matrix"}->{server}.
"/_matrix/client/api/v1/send/m.call.hangup?".
"access_token=$as_token&user_id=$caller_user"
),
content_type => "application/json",
content => JSON->new->encode({
call_id => $bridgestate->{$room_id}->{matrix_callid},
version => 0,
}),
)->then( sub {
# acknowledge the verto
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
})->get;
}
else {
warn ("[Verto] unhandled method: " . $json->{method});
send_verto_json_response( {
method => $json->{method},
}, $json->{id});
}
}
elsif ($json->{result}) {
$requests->{$json->{id}}->done($json->{result});
}
elsif ($json->{error}) {
$requests->{$json->{id}}->fail($json->{error}->{message}, $json->{error});
}
}
}

View File

@ -32,7 +32,8 @@ for port in 8080 8081 8082; do
-D --pid-file "$DIR/$port.pid" \ -D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \ --manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \ --tls-dh-params-path "demo/demo.tls.dh" \
$PARAMS $SYNAPSE_PARAMS --media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--config-path "demo/etc/$port.config" \ --config-path "demo/etc/$port.config" \

View File

@ -81,7 +81,7 @@ Your home server configuration file needs the following extra keys:
As an example, here is the relevant section of the config file for As an example, here is the relevant section of the config file for
matrix.org:: matrix.org::
turn_uris: turn:turn.matrix.org:3478?transport=udp,turn:turn.matrix.org:3478?transport=tcp turn_uris: [ "turn:turn.matrix.org:3478?transport=udp", "turn:turn.matrix.org:3478?transport=tcp" ]
turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons turn_shared_secret: n0t4ctuAllymatr1Xd0TorgSshar3d5ecret4obvIousreAsons
turn_user_lifetime: 86400000 turn_user_lifetime: 86400000

65
scripts/check_auth.py Normal file
View File

@ -0,0 +1,65 @@
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from mock import Mock
import argparse
import itertools
import json
import sys
def check_auth(auth, auth_chain, events):
auth_chain.sort(key=lambda e: e.depth)
auth_map = {
e.event_id: e
for e in auth_chain
}
create_events = {}
for e in auth_chain:
if e.type == "m.room.create":
create_events[e.room_id] = e
for e in itertools.chain(auth_chain, events):
auth_events_list = [auth_map[i] for i, _ in e.auth_events]
auth_events = {
(e.type, e.state_key): e
for e in auth_events_list
}
auth_events[("m.room.create", "")] = create_events[e.room_id]
try:
auth.check(e, auth_events=auth_events)
except Exception as ex:
print "Failed:", e.event_id, e.type, e.state_key
print "Auth_events:", auth_events
print ex
print json.dumps(e.get_dict(), sort_keys=True, indent=4)
# raise
print "Success:", e.event_id, e.type, e.state_key
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'json',
nargs='?',
type=argparse.FileType('r'),
default=sys.stdin,
)
args = parser.parse_args()
js = json.load(args.json)
auth = Auth(Mock())
check_auth(
auth,
[FrozenEvent(d) for d in js["auth_chain"]],
[FrozenEvent(d) for d in js["pdus"]],
)

View File

@ -97,8 +97,11 @@ def lookup(destination, path):
if ":" in destination: if ":" in destination:
return "https://%s%s" % (destination, path) return "https://%s%s" % (destination, path)
else: else:
srv = srvlookup.lookup("matrix", "tcp", destination)[0] try:
return "https://%s:%d%s" % (srv.host, srv.port, path) srv = srvlookup.lookup("matrix", "tcp", destination)[0]
return "https://%s:%d%s" % (srv.host, srv.port, path)
except:
return "https://%s:%d%s" % (destination, 8448, path)
def get_json(origin_name, origin_key, destination, path): def get_json(origin_name, origin_key, destination, path):
request_json = { request_json = {

39
scripts/make_identicons.pl Executable file
View File

@ -0,0 +1,39 @@
#!/usr/bin/env perl
use strict;
use warnings;
use DBI;
use DBD::SQLite;
use JSON;
use Getopt::Long;
my $db; # = "homeserver.db";
my $server = "http://localhost:8008";
my $size = 320;
GetOptions("db|d=s", \$db,
"server|s=s", \$server,
"width|w=i", \$size) or usage();
usage() unless $db;
my $dbh = DBI->connect("dbi:SQLite:dbname=$db","","") || die $DBI::errstr;
my $res = $dbh->selectall_arrayref("select token, name from access_tokens, users where access_tokens.user_id = users.id group by user_id") || die $DBI::errstr;
foreach (@$res) {
my ($token, $mxid) = ($_->[0], $_->[1]);
my ($user_id) = ($mxid =~ m/@(.*):/);
my ($url) = $dbh->selectrow_array("select avatar_url from profiles where user_id=?", undef, $user_id);
if (!$url || $url =~ /#auto$/) {
`curl -s -o tmp.png "$server/_matrix/media/v1/identicon?name=${mxid}&width=$size&height=$size"`;
my $json = `curl -s -X POST -H "Content-Type: image/png" -T "tmp.png" $server/_matrix/media/v1/upload?access_token=$token`;
my $content_uri = from_json($json)->{content_uri};
`curl -X PUT -H "Content-Type: application/json" --data '{ "avatar_url": "${content_uri}#auto"}' $server/_matrix/client/api/v1/profile/${mxid}/avatar_url?access_token=$token`;
}
}
sub usage {
die "usage: ./make-identicons.pl\n\t-d database [e.g. homeserver.db]\n\t-s homeserver (default: http://localhost:8008)\n\t-w identicon size in pixels (default 320)";
}

View File

@ -8,3 +8,11 @@ test = trial
[trial] [trial]
test_suite = tests test_suite = tests
[check-manifest]
ignore =
contrib
contrib/*
docs/*
pylint.cfg
tox.ini

View File

@ -18,49 +18,42 @@ import os
from setuptools import setup, find_packages from setuptools import setup, find_packages
# Utility function to read the README file. here = os.path.abspath(os.path.dirname(__file__))
# Used for the long_description. It's nice, because now 1) we have a top level
# README file and 2) it's easier to type in the README file than to put a raw
# string in below ... def read_file(path_segments):
def read(fname): """Read a file from the package. Takes a list of strings to join to
return open(os.path.join(os.path.dirname(__file__), fname)).read() make the path"""
file_path = os.path.join(here, *path_segments)
with open(file_path) as f:
return f.read()
def exec_file(path_segments):
"""Execute a single python file to get the variables defined in it"""
result = {}
code = read_file(path_segments)
exec(code, result)
return result
version = exec_file(("synapse", "__init__.py"))["__version__"]
dependencies = exec_file(("synapse", "python_dependencies.py"))
long_description = read_file(("README.rst",))
setup( setup(
name="matrix-synapse", name="matrix-synapse",
version=read("VERSION").strip(), version=version,
packages=find_packages(exclude=["tests", "tests.*"]), packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server", description="Reference Synapse Home Server",
install_requires=[ install_requires=dependencies["REQUIREMENTS"].keys(),
"syutil==0.0.2",
"matrix_angular_sdk==0.6.0",
"Twisted>=14.0.0",
"service_identity>=1.0.0",
"pyopenssl>=0.14",
"pyyaml",
"pyasn1",
"pynacl",
"daemonize",
"py-bcrypt",
"frozendict>=0.4",
"pillow",
],
dependency_links=[
"https://github.com/matrix-org/syutil/tarball/v0.0.2#egg=syutil-0.0.2",
"https://github.com/pyca/pynacl/tarball/d4d3175589b892f6ea7c22f466e0e223853516fa#egg=pynacl-0.3.0",
"https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0",
],
setup_requires=[ setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial", "setuptools_trial",
"setuptools>=1.0.0", # Needs setuptools that supports git+ssh.
# TODO: Do we need this now? we don't use git+ssh.
"mock" "mock"
], ],
dependency_links=dependencies["DEPENDENCY_LINKS"],
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
long_description=read("README.rst"), long_description=long_description,
entry_points=""" scripts=["synctl"],
[console_scripts]
synctl=synapse.app.synctl:main
synapse-homeserver=synapse.app.homeserver:main
"""
) )

View File

@ -0,0 +1,32 @@
<html>
<head>
<title> Registration </title>
<meta name='viewport' content='width=device-width, initial-scale=1, user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
<link rel="stylesheet" href="style.css">
<script src="js/jquery-2.1.3.min.js"></script>
<script src="js/recaptcha_ajax.js"></script>
<script src="register_config.js"></script>
<script src="js/register.js"></script>
</head>
<body onload="matrixRegistration.onLoad()">
<form id="registrationForm" onsubmit="matrixRegistration.signUp(); return false;">
<div>
Create account:<br/>
<div style="text-align: center">
<input id="desired_user_id" size="32" type="text" placeholder="Matrix ID (e.g. bob)" autocapitalize="off" autocorrect="off" />
<br/>
<input id="pwd1" size="32" type="password" placeholder="Type a password"/>
<br/>
<input id="pwd2" size="32" type="password" placeholder="Confirm your password"/>
<br/>
<span id="feedback" style="color: #f00"></span>
<br/>
<div id="regcaptcha"></div>
<button type="submit" style="margin: 10px">Sign up</button>
</div>
</div>
</form>
</body>
</html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,117 @@
window.matrixRegistration = {
endpoint: location.origin + "/_matrix/client/api/v1/register"
};
var setupCaptcha = function() {
if (!window.matrixRegistrationConfig) {
return;
}
$.get(matrixRegistration.endpoint, function(response) {
var serverExpectsCaptcha = false;
for (var i=0; i<response.flows.length; i++) {
var flow = response.flows[i];
if ("m.login.recaptcha" === flow.type) {
serverExpectsCaptcha = true;
break;
}
}
if (!serverExpectsCaptcha) {
console.log("This server does not require a captcha.");
return;
}
console.log("Setting up ReCaptcha for "+matrixRegistration.endpoint);
var public_key = window.matrixRegistrationConfig.recaptcha_public_key;
if (public_key === undefined) {
console.error("No public key defined for captcha!");
setFeedbackString("Misconfigured captcha for server. Contact server admin.");
return;
}
Recaptcha.create(public_key,
"regcaptcha",
{
theme: "red",
callback: Recaptcha.focus_response_field
});
window.matrixRegistration.isUsingRecaptcha = true;
}).error(errorFunc);
};
var submitCaptcha = function(user, pwd) {
var challengeToken = Recaptcha.get_challenge();
var captchaEntry = Recaptcha.get_response();
var data = {
type: "m.login.recaptcha",
challenge: challengeToken,
response: captchaEntry
};
console.log("Submitting captcha");
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
console.log("Success -> "+JSON.stringify(response));
submitPassword(user, pwd, response.session);
}).error(function(err) {
Recaptcha.reload();
errorFunc(err);
});
};
var submitPassword = function(user, pwd, session) {
console.log("Registering...");
var data = {
type: "m.login.password",
user: user,
password: pwd,
session: session
};
$.post(matrixRegistration.endpoint, JSON.stringify(data), function(response) {
matrixRegistration.onRegistered(
response.home_server, response.user_id, response.access_token
);
}).error(errorFunc);
};
var errorFunc = function(err) {
if (err.responseJSON && err.responseJSON.error) {
setFeedbackString(err.responseJSON.error + " (" + err.responseJSON.errcode + ")");
}
else {
setFeedbackString("Request failed: " + err.status);
}
};
var setFeedbackString = function(text) {
$("#feedback").text(text);
};
matrixRegistration.onLoad = function() {
setupCaptcha();
};
matrixRegistration.signUp = function() {
var user = $("#desired_user_id").val();
if (user.length == 0) {
setFeedbackString("Must specify a username.");
return;
}
var pwd1 = $("#pwd1").val();
var pwd2 = $("#pwd2").val();
if (pwd1.length < 6) {
setFeedbackString("Password: min. 6 characters.");
return;
}
if (pwd1 != pwd2) {
setFeedbackString("Passwords do not match.");
return;
}
if (window.matrixRegistration.isUsingRecaptcha) {
submitCaptcha(user, pwd1);
}
else {
submitPassword(user, pwd1);
}
};
matrixRegistration.onRegistered = function(hs_url, user_id, access_token) {
// clobber this function
console.log("onRegistered - This function should be replaced to proceed.");
};

View File

@ -0,0 +1,3 @@
window.matrixRegistrationConfig = {
recaptcha_public_key: "YOUR_PUBLIC_KEY"
};

View File

@ -0,0 +1,56 @@
html {
height: 100%;
}
body {
height: 100%;
font-family: "Myriad Pro", "Myriad", Helvetica, Arial, sans-serif;
font-size: 12pt;
margin: 0px;
}
h1 {
font-size: 20pt;
}
a:link { color: #666; }
a:visited { color: #666; }
a:hover { color: #000; }
a:active { color: #000; }
input {
width: 100%
}
textarea, input {
font-family: inherit;
font-size: inherit;
}
.smallPrint {
color: #888;
font-size: 9pt ! important;
font-style: italic ! important;
}
#recaptcha_area {
margin: auto
}
#registrationForm {
text-align: left;
padding: 1em;
margin-bottom: 40px;
display: inline-block;
-webkit-border-radius: 10px;
-moz-border-radius: 10px;
border-radius: 10px;
-webkit-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
-moz-box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
box-shadow: 0px 0px 20px 0px rgba(0,0,0,0.15);
background-color: #f8f8f8;
border: 1px #ccc solid;
}

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.6.1b" __version__ = "0.8.0"

View File

@ -21,6 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo
import logging import logging
@ -88,12 +89,19 @@ class Auth(object):
raise raise
@defer.inlineCallbacks @defer.inlineCallbacks
def check_joined_room(self, room_id, user_id): def check_joined_room(self, room_id, user_id, current_state=None):
member = yield self.state.get_current_state( if current_state:
room_id=room_id, member = current_state.get(
event_type=EventTypes.Member, (EventTypes.Member, user_id),
state_key=user_id None
) )
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
self._check_joined_room(member, user_id, room_id) self._check_joined_room(member, user_id, room_id)
defer.returnValue(member) defer.returnValue(member)
@ -101,10 +109,10 @@ class Auth(object):
def check_host_in_room(self, room_id, host): def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id) curr_state = yield self.state.get_current_state(room_id)
for event in curr_state: for event in curr_state.values():
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
if self.hs.parse_userid(event.state_key).domain != host: if UserID.from_string(event.state_key).domain != host:
continue continue
except: except:
logger.warn("state_key not user_id: %s", event.state_key) logger.warn("state_key not user_id: %s", event.state_key)
@ -289,15 +297,47 @@ class Auth(object):
Args: Args:
request - An HTTP request with an access_token query parameter. request - An HTTP request with an access_token query parameter.
Returns: Returns:
UserID : User ID object of the user making the request tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
Raises: Raises:
AuthError if no user by that token exists or the token is invalid. AuthError if no user by that token exists or the token is invalid.
""" """
# Can optionally look elsewhere in the request (e.g. headers) # Can optionally look elsewhere in the request (e.g. headers)
try: try:
access_token = request.args["access_token"][0] access_token = request.args["access_token"][0]
# Check for application service tokens with a user_id override
try:
app_service = yield self.store.get_app_service_by_token(
access_token
)
if not app_service:
raise KeyError
user_id = app_service.sender
if "user_id" in request.args:
user_id = request.args["user_id"][0]
if not app_service.is_interested_in_user(user_id):
raise AuthError(
403,
"Application service cannot masquerade as this user."
)
if not user_id:
raise KeyError
defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", ""))
)
return
except KeyError:
pass # normal users won't have this query parameter set
user_info = yield self.get_user_by_token(access_token) user_info = yield self.get_user_by_token(access_token)
user = user_info["user"] user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request) ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders( user_agent = request.requestHeaders.getRawHeaders(
@ -313,7 +353,7 @@ class Auth(object):
user_agent=user_agent user_agent=user_agent
) )
defer.returnValue(user) defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError: except KeyError:
raise AuthError(403, "Missing access token.") raise AuthError(403, "Missing access token.")
@ -332,12 +372,12 @@ class Auth(object):
try: try:
ret = yield self.store.get_user_by_token(token=token) ret = yield self.store.get_user_by_token(token=token)
if not ret: if not ret:
raise StoreError() raise StoreError(400, "Unknown token")
user_info = { user_info = {
"admin": bool(ret.get("admin", False)), "admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"), "device_id": ret.get("device_id"),
"user": self.hs.parse_userid(ret.get("name")), "user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
} }
defer.returnValue(user_info) defer.returnValue(user_info)
@ -345,6 +385,18 @@ class Auth(object):
raise AuthError(403, "Unrecognised access token.", raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN) errcode=Codes.UNKNOWN_TOKEN)
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
try:
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
defer.returnValue(service)
except KeyError:
raise AuthError(403, "Missing access token.")
def is_server_admin(self, user): def is_server_admin(self, user):
return self.store.is_server_admin(user) return self.store.is_server_admin(user)
@ -352,26 +404,40 @@ class Auth(object):
def add_auth_events(self, builder, context): def add_auth_events(self, builder, context):
yield run_on_reactor() yield run_on_reactor()
if builder.type == EventTypes.Create: auth_ids = self.compute_auth_events(builder, context.current_state)
builder.auth_events = []
return auth_events_entries = yield self.store.add_event_hashes(
auth_ids
)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []
auth_ids = [] auth_ids = []
key = (EventTypes.PowerLevels, "", ) key = (EventTypes.PowerLevels, "", )
power_level_event = context.current_state.get(key) power_level_event = current_state.get(key)
if power_level_event: if power_level_event:
auth_ids.append(power_level_event.event_id) auth_ids.append(power_level_event.event_id)
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key) join_rule_event = current_state.get(key)
key = (EventTypes.Member, builder.user_id, ) key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key) member_event = current_state.get(key)
key = (EventTypes.Create, "", ) key = (EventTypes.Create, "", )
create_event = context.current_state.get(key) create_event = current_state.get(key)
if create_event: if create_event:
auth_ids.append(create_event.event_id) auth_ids.append(create_event.event_id)
@ -381,8 +447,8 @@ class Auth(object):
else: else:
is_public = False is_public = False
if builder.type == EventTypes.Member: if event.type == EventTypes.Member:
e_type = builder.content["membership"] e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]: if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event: if join_rule_event:
auth_ids.append(join_rule_event.event_id) auth_ids.append(join_rule_event.event_id)
@ -397,17 +463,7 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN: if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id) auth_ids.append(member_event.event_id)
auth_events_entries = yield self.store.add_event_hashes( return auth_ids
auth_ids
)
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
@log_function @log_function
def _can_send_event(self, event, auth_events): def _can_send_event(self, event, auth_events):
@ -461,7 +517,7 @@ class Auth(object):
"You are not allowed to set others state" "You are not allowed to set others state"
) )
else: else:
sender_domain = self.hs.parse_userid( sender_domain = UserID.from_string(
event.user_id event.user_id
).domain ).domain
@ -496,7 +552,7 @@ class Auth(object):
# Validate users # Validate users
for k, v in user_list.items(): for k, v in user_list.items():
try: try:
self.hs.parse_userid(k) UserID.from_string(k)
except: except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,)) raise SynapseError(400, "Not a valid user_id: %s" % (k,))

View File

@ -59,6 +59,7 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url" EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service"
class EventTypes(object): class EventTypes(object):
@ -74,3 +75,9 @@ class EventTypes(object):
Message = "m.room.message" Message = "m.room.message"
Topic = "m.room.topic" Topic = "m.room.topic"
Name = "m.room.name" Name = "m.room.name"
class RejectedReason(object):
AUTH_ERROR = "auth_error"
REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor"

View File

@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
class Codes(object): class Codes(object):
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED" UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN" FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON" BAD_JSON = "M_BAD_JSON"
@ -34,10 +35,12 @@ class Codes(object):
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID" CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
TOO_LARGE = "M_TOO_LARGE" MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE",
EXCLUSIVE = "M_EXCLUSIVE"
class CodeMessageException(Exception): class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.""" """An exception with integer code and message string attributes."""
def __init__(self, code, msg): def __init__(self, code, msg):
@ -81,6 +84,35 @@ class RegistrationError(SynapseError):
pass pass
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
message = None
if len(args) == 0:
message = "Unrecognized request"
else:
message = args[0]
super(UnrecognizedRequestError, self).__init__(
400,
message,
**kwargs
)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.NOT_FOUND
super(NotFoundError, self).__init__(
404,
"Not found",
**kwargs
)
class AuthError(SynapseError): class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event.""" """An error raised when there was a problem authorising an event."""
@ -196,3 +228,9 @@ class FederationError(RuntimeError):
"affected": self.affected, "affected": self.affected,
"source": self.source if self.source else self.affected, "source": self.source if self.source else self.affected,
} }
class HttpResponseException(CodeMessageException):
def __init__(self, code, msg, response):
self.response = response
super(HttpResponseException, self).__init__(code, msg)

229
synapse/api/filtering.py Normal file
View File

@ -0,0 +1,229 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.errors import SynapseError
from synapse.types import UserID, RoomID
class Filtering(object):
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
def get_user_filter(self, user_localpart, filter_id):
result = self.store.get_user_filter(user_localpart, filter_id)
result.addCallback(Filter)
return result
def add_user_filter(self, user_localpart, user_filter):
self._check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however
def _check_valid_filter(self, user_filter_json):
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
user_filter_json(dict): The filter
Raises:
SynapseError: If the filter is not valid.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
top_level_definitions = [
"public_user_data", "private_user_data", "server_data"
]
room_level_definitions = [
"state", "events", "ephemeral"
]
for key in top_level_definitions:
if key in user_filter_json:
self._check_definition(user_filter_json[key])
if "room" in user_filter_json:
for key in room_level_definitions:
if key in user_filter_json["room"]:
self._check_definition(user_filter_json["room"][key])
def _check_definition(self, definition):
"""Check if the provided definition is valid.
This inspects not only the types but also the values to make sure they
make sense.
Args:
definition(dict): The filter definition
Raises:
SynapseError: If there was a problem with this definition.
"""
# NB: Filters are the complete json blobs. "Definitions" are an
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
if type(definition) != dict:
raise SynapseError(
400, "Expected JSON object, not %s" % (definition,)
)
# check rooms are valid room IDs
room_id_keys = ["rooms", "not_rooms"]
for key in room_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for room_id in definition[key]:
RoomID.from_string(room_id)
# check senders are valid user IDs
user_id_keys = ["senders", "not_senders"]
for key in user_id_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for user_id in definition[key]:
UserID.from_string(user_id)
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
event_keys = ["types", "not_types"]
for key in event_keys:
if key in definition:
if type(definition[key]) != list:
raise SynapseError(400, "Expected %s to be a list." % key)
for event_type in definition[key]:
if not isinstance(event_type, basestring):
raise SynapseError(400, "Event type should be a string")
if "format" in definition:
event_format = definition["format"]
if event_format not in ["federation", "events"]:
raise SynapseError(400, "Invalid format: %s" % (event_format,))
if "select" in definition:
event_select_list = definition["select"]
for select_key in event_select_list:
if select_key not in ["event_id", "origin_server_ts",
"thread_id", "content", "content.body"]:
raise SynapseError(400, "Bad select: %s" % (select_key,))
if ("bundle_updates" in definition and
type(definition["bundle_updates"]) != bool):
raise SynapseError(400, "Bad bundle_updates: expected bool.")
class Filter(object):
def __init__(self, filter_json):
self.filter_json = filter_json
def filter_public_user_data(self, events):
return self._filter_on_key(events, ["public_user_data"])
def filter_private_user_data(self, events):
return self._filter_on_key(events, ["private_user_data"])
def filter_room_state(self, events):
return self._filter_on_key(events, ["room", "state"])
def filter_room_events(self, events):
return self._filter_on_key(events, ["room", "events"])
def filter_room_ephemeral(self, events):
return self._filter_on_key(events, ["room", "ephemeral"])
def _filter_on_key(self, events, keys):
filter_json = self.filter_json
if not filter_json:
return events
try:
# extract the right definition from the filter
definition = filter_json
for key in keys:
definition = definition[key]
return self._filter_with_definition(events, definition)
except KeyError:
# return all events if definition isn't specified.
return events
def _filter_with_definition(self, events, definition):
return [e for e in events if self._passes_definition(definition, e)]
def _passes_definition(self, definition, event):
"""Check if the event passes through the given definition.
Args:
definition(dict): The definition to check against.
event(Event): The event to check.
Returns:
True if the event passes through the filter.
"""
# Algorithm notes:
# For each key in the definition, check the event meets the criteria:
# * For types: Literal match or prefix match (if ends with wildcard)
# * For senders/rooms: Literal match only
# * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
# and 'not_types' then it is treated as only being in 'not_types')
# room checks
if hasattr(event, "room_id"):
room_id = event.room_id
allow_rooms = definition.get("rooms", None)
reject_rooms = definition.get("not_rooms", None)
if reject_rooms and room_id in reject_rooms:
return False
if allow_rooms and room_id not in allow_rooms:
return False
# sender checks
if hasattr(event, "sender"):
# Should we be including event.state_key for some event types?
sender = event.sender
allow_senders = definition.get("senders", None)
reject_senders = definition.get("not_senders", None)
if reject_senders and sender in reject_senders:
return False
if allow_senders and sender not in allow_senders:
return False
# type checks
if "not_types" in definition:
for def_type in definition["not_types"]:
if self._event_matches_type(event, def_type):
return False
if "types" in definition:
included = False
for def_type in definition["types"]:
if self._event_matches_type(event, def_type):
included = True
break
if not included:
return False
return True
def _event_matches_type(self, event, def_type):
if def_type.endswith("*"):
type_prefix = def_type[:-1]
return event.type.startswith(type_prefix)
else:
return event.type == def_type

View File

@ -16,8 +16,11 @@
"""Contains the URL paths to prefix various aspects of the server with. """ """Contains the URL paths to prefix various aspects of the server with. """
CLIENT_PREFIX = "/_matrix/client/api/v1" CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
FEDERATION_PREFIX = "/_matrix/federation/v1" FEDERATION_PREFIX = "/_matrix/federation/v1"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content" CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1" SERVER_KEY_PREFIX = "/_matrix/key/v1"
MEDIA_PREFIX = "/_matrix/media/v1" MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@ -14,7 +14,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.storage import prepare_database, UpgradeDatabaseException import sys
sys.dont_write_bytecode = True
from synapse.storage import (
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException,
)
from synapse.server import HomeServer from synapse.server import HomeServer
@ -27,17 +32,21 @@ from twisted.web.resource import Resource
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.media.v0.content_repository import ContentRepoResource from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.media.v1.media_repository import MediaRepositoryResource from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.http.server_key_resource import LocalKey from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
STATIC_PREFIX
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from daemonize import Daemonize from daemonize import Daemonize
import twisted.manhole.telnet import twisted.manhole.telnet
@ -47,7 +56,8 @@ import synapse
import logging import logging
import os import os
import re import re
import sys import resource
import subprocess
import sqlite3 import sqlite3
import syweb import syweb
@ -60,16 +70,25 @@ class SynapseHomeServer(HomeServer):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
return JsonResource() return ClientV1RestResource(self)
def build_resource_for_client_v2_alpha(self):
return ClientV2AlphaRestResource(self)
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource() return JsonResource(self)
def build_resource_for_app_services(self):
return AppServiceRestResource(self)
def build_resource_for_web_client(self): def build_resource_for_web_client(self):
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
return File("static")
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):
return ContentRepoResource( return ContentRepoResource(
self, self.upload_dir, self.auth, self.content_addr self, self.upload_dir, self.auth, self.content_addr
@ -86,7 +105,9 @@ class SynapseHomeServer(HomeServer):
"sqlite3", self.get_db_name(), "sqlite3", self.get_db_name(),
check_same_thread=False, check_same_thread=False,
cp_min=1, cp_min=1,
cp_max=1 cp_max=1,
cp_openfun=prepare_database, # Prepare the database for each conn
# so that :memory: sqlite works
) )
def create_resource_tree(self, web_client, redirect_root_to_web_client): def create_resource_tree(self, web_client, redirect_root_to_web_client):
@ -105,11 +126,15 @@ class SynapseHomeServer(HomeServer):
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ] # [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [ desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()), (CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()), (FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()), (MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
] ]
if web_client: if web_client:
logger.info("Adding the web client.") logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX, desired_tree.append((WEB_CLIENT_PREFIX,
@ -125,11 +150,11 @@ class SynapseHomeServer(HomeServer):
# instead, we'll store a copy of this mapping so we can actually add # instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key. # extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {} resource_mappings = {}
for (full_path, resource) in desired_tree: for full_path, res in desired_tree:
logger.info("Attaching %s to path %s", resource, full_path) logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]: for path_seg in full_path.split('/')[1:-1]:
if not path_seg in last_resource.listNames(): if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource" # resource doesn't exist, so make a "dummy resource"
child_resource = Resource() child_resource = Resource()
last_resource.putChild(path_seg, child_resource) last_resource.putChild(path_seg, child_resource)
@ -157,12 +182,12 @@ class SynapseHomeServer(HomeServer):
child_name) child_name)
child_resource = resource_mappings[child_res_id] child_resource = resource_mappings[child_res_id]
# steal the children # steal the children
resource.putChild(child_name, child_resource) res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place # finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, resource) last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg) res_id = self._resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = resource resource_mappings[res_id] = res
return self.root_resource return self.root_resource
@ -194,6 +219,83 @@ class SynapseHomeServer(HomeServer):
logger.info("Synapse now listening on port %d", unsecure_port) logger.info("Synapse now listening on port %d", unsecure_port)
def get_version_string():
try:
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(__file__))
try:
git_branch = subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip()
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
s for s in
(git_branch, git_tag, git_commit, git_dirty,)
if s
)
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.warn("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file limit: %s", e)
def setup(config_options, should_run=True): def setup(config_options, should_run=True):
config = HomeServerConfig.load_config( config = HomeServerConfig.load_config(
"Synapse Homeserver", "Synapse Homeserver",
@ -205,8 +307,10 @@ def setup(config_options, should_run=True):
check_requirements() check_requirements()
version_string = get_version_string()
logger.info("Server hostname: %s", config.server_name) logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", synapse.__version__) logger.info("Server version: %s", version_string)
if re.search(":[0-9]+$", config.server_name): if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name domain_with_port = config.server_name
@ -223,10 +327,9 @@ def setup(config_options, should_run=True):
tls_context_factory=tls_context_factory, tls_context_factory=tls_context_factory,
config=config, config=config,
content_addr=config.content_addr, content_addr=config.content_addr,
version_string=version_string,
) )
hs.register_servlets()
hs.create_resource_tree( hs.create_resource_tree(
web_client=config.webclient, web_client=config.webclient,
redirect_root_to_web_client=True, redirect_root_to_web_client=True,
@ -238,6 +341,7 @@ def setup(config_options, should_run=True):
try: try:
with sqlite3.connect(db_name) as db_conn: with sqlite3.connect(db_name) as db_conn:
prepare_sqlite3_database(db_conn)
prepare_database(db_conn) prepare_database(db_conn)
except UpgradeDatabaseException: except UpgradeDatabaseException:
sys.stderr.write( sys.stderr.write(
@ -249,14 +353,6 @@ def setup(config_options, should_run=True):
logger.info("Database prepared in %s.", db_name) logger.info("Database prepared in %s.", db_name)
db_pool = hs.get_db_pool()
if db_name == ":memory:":
# Memory databases will need to be setup each time they are opened.
reactor.callWhenRunning(
db_pool.runWithConnection, prepare_database
)
if config.manhole: if config.manhole:
f = twisted.manhole.telnet.ShellFactory() f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix" f.username = "matrix"
@ -267,17 +363,24 @@ def setup(config_options, should_run=True):
bind_port = config.bind_port bind_port = config.bind_port
if config.no_tls: if config.no_tls:
bind_port = None bind_port = None
hs.start_listening(bind_port, config.unsecure_port) hs.start_listening(bind_port, config.unsecure_port)
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_replication_layer().start_get_pdu_cache()
if not should_run: if not should_run:
return return
if config.daemonize: if config.daemonize:
print config.pid_file print config.pid_file
daemon = Daemonize( daemon = Daemonize(
app="synapse-homeserver", app="synapse-homeserver",
pid=config.pid_file, pid=config.pid_file,
action=run, action=lambda: run(config),
auto_close_fds=False, auto_close_fds=False,
verbose=True, verbose=True,
logger=logger, logger=logger,
@ -285,7 +388,7 @@ def setup(config_options, should_run=True):
daemon.start() daemon.start()
else: else:
reactor.run() run(config)
class SynapseService(service.Service): class SynapseService(service.Service):
@ -299,8 +402,10 @@ class SynapseService(service.Service):
return self._port.stopListening() return self._port.stopListening()
def run(): def run(config):
with LoggingContext("run"): with LoggingContext("run"):
change_resource_limit(config.soft_file_limit)
reactor.run() reactor.run()

View File

@ -19,7 +19,7 @@ import os
import subprocess import subprocess
import signal import signal
SYNAPSE = ["python", "-m", "synapse.app.homeserver"] SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml" CONFIGFILE = "homeserver.yaml"
PIDFILE = "homeserver.pid" PIDFILE = "homeserver.pid"

View File

@ -0,0 +1,176 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.api.constants import EventTypes
import logging
import re
logger = logging.getLogger(__name__)
class ApplicationService(object):
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
Provides methods to check if this service is "interested" in events.
"""
NS_USERS = "users"
NS_ALIASES = "aliases"
NS_ROOMS = "rooms"
# The ordering here is important as it is used to map database values (which
# are stored as ints representing the position in this list) to namespace
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, txn_id=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.txn_id = txn_id
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
# {
# users: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# aliases: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
if not namespaces:
return None
for ns in ApplicationService.NS_LIST:
if ns not in namespaces:
namespaces[ns] = []
continue
if type(namespaces[ns]) != list:
raise ValueError("Bad namespace value for '%s'" % ns)
for regex_obj in namespaces[ns]:
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Expected bool for 'exclusive' in ns '%s'" % ns
)
if not isinstance(regex_obj.get("regex"), basestring):
raise ValueError(
"Expected string for 'regex' in ns '%s'" % ns
)
return namespaces
def _matches_regex(self, test_string, namespace_key, return_obj=False):
if not isinstance(test_string, basestring):
logger.error(
"Expected a string to test regex against, but got %s",
test_string
)
return False
for regex_obj in self.namespaces[namespace_key]:
if re.match(regex_obj["regex"], test_string):
if return_obj:
return regex_obj
return True
return False
def _is_exclusive(self, ns_key, test_string):
regex_obj = self._matches_regex(test_string, ns_key, return_obj=True)
if regex_obj:
return regex_obj["exclusive"]
return False
def _matches_user(self, event, member_list):
if (hasattr(event, "sender") and
self.is_interested_in_user(event.sender)):
return True
# also check m.room.member state key
if (hasattr(event, "type") and event.type == EventTypes.Member
and hasattr(event, "state_key")
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for member in member_list:
if self.is_interested_in_user(member.state_key):
return True
return False
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
def _matches_aliases(self, event, alias_list):
for alias in alias_list:
if self.is_interested_in_alias(alias):
return True
return False
def is_interested(self, event, restrict_to=None, aliases_for_event=None,
member_list=None):
"""Check if this service is interested in this event.
Args:
event(Event): The event to check.
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined room members in this room.
Returns:
bool: True if this service would like to know about this event.
"""
if aliases_for_event is None:
aliases_for_event = []
if member_list is None:
member_list = []
if restrict_to and restrict_to not in ApplicationService.NS_LIST:
# this is a programming error, so fail early and raise a general
# exception
raise Exception("Unexpected restrict_to value: %s". restrict_to)
if not restrict_to:
return (self._matches_user(event, member_list)
or self._matches_aliases(event, aliases_for_event)
or self._matches_room_id(event))
elif restrict_to == ApplicationService.NS_ALIASES:
return self._matches_aliases(event, aliases_for_event)
elif restrict_to == ApplicationService.NS_ROOMS:
return self._matches_room_id(event)
elif restrict_to == ApplicationService.NS_USERS:
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id):
return self._matches_regex(user_id, ApplicationService.NS_USERS)
def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
def is_interested_in_room(self, room_id):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id):
return self._is_exclusive(ApplicationService.NS_USERS, user_id)
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
def __str__(self):
return "ApplicationService: %s" % (self.__dict__,)

108
synapse/appservice/api.py Normal file
View File

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
import logging
import urllib
logger = logging.getLogger(__name__)
class ApplicationServiceApi(SimpleHttpClient):
"""This class manages HS -> AS communications, including querying and
pushing.
"""
def __init__(self, hs):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def query_user(self, service, user_id):
uri = service.url + ("/users/%s" % urllib.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {
"access_token": service.hs_token
})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
if e.code == 404:
defer.returnValue(False)
return
logger.warning("query_user to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("query_user to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def query_alias(self, service, alias):
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
response = None
try:
response = yield self.get_json(uri, {
"access_token": service.hs_token
})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
logger.warning("query_alias to %s received %s", uri, e.code)
if e.code == 404:
defer.returnValue(False)
return
except Exception as ex:
logger.warning("query_alias to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def push_bulk(self, service, events):
events = self._serialize(events)
uri = service.url + ("/transactions/%s" %
urllib.quote(str(0))) # TODO txn_ids
response = None
try:
response = yield self.put_json(
uri=uri,
json_body={
"events": events
},
args={
"access_token": service.hs_token
})
if response: # just an empty json object
# TODO: Mark txn as sent successfully
defer.returnValue(True)
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_bulk to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event):
response = yield self.push_bulk(service, [event])
defer.returnValue(response)
def _serialize(self, events):
time_now = self.clock.time_msec()
return [
serialize_event(e, time_now, as_client_event=True) for e in events
]

View File

@ -27,6 +27,16 @@ class Config(object):
def __init__(self, args): def __init__(self, args):
pass pass
@staticmethod
def parse_size(string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = string[-1]
if suffix in sizes:
string = string[:-1]
size = sizes[suffix]
return int(string) * size
@staticmethod @staticmethod
def abspath(file_path): def abspath(file_path):
return os.path.abspath(file_path) if file_path else file_path return os.path.abspath(file_path) if file_path else file_path
@ -50,8 +60,9 @@ class Config(object):
) )
return cls.abspath(file_path) return cls.abspath(file_path)
@staticmethod @classmethod
def ensure_directory(dir_path): def ensure_directory(cls, dir_path):
dir_path = cls.abspath(dir_path)
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
os.makedirs(dir_path) os.makedirs(dir_path)
if not os.path.isdir(dir_path): if not os.path.isdir(dir_path):

View File

@ -24,6 +24,7 @@ class DatabaseConfig(Config):
self.database_path = ":memory:" self.database_path = ":memory:"
else: else:
self.database_path = self.abspath(args.database_path) self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size)
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
@ -33,6 +34,10 @@ class DatabaseConfig(Config):
"-d", "--database-path", default="homeserver.db", "-d", "--database-path", default="homeserver.db",
help="The database name." help="The database name."
) )
db_group.add_argument(
"--event-cache-size", default="100K",
help="Number of events to cache in memory."
)
@classmethod @classmethod
def generate_config(cls, args, config_dir_path): def generate_config(cls, args, config_dir_path):

View File

@ -22,11 +22,12 @@ from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig from .captcha import CaptchaConfig
from .email import EmailConfig from .email import EmailConfig
from .voip import VoipConfig from .voip import VoipConfig
from .registration import RegistrationConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig, RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
EmailConfig, VoipConfig): EmailConfig, VoipConfig, RegistrationConfig,):
pass pass

View File

@ -18,6 +18,7 @@ from synapse.util.logcontext import LoggingContextFilter
from twisted.python.log import PythonLoggingObserver from twisted.python.log import PythonLoggingObserver
import logging import logging
import logging.config import logging.config
import yaml
class LoggingConfig(Config): class LoggingConfig(Config):
@ -79,7 +80,8 @@ class LoggingConfig(Config):
logger.addHandler(handler) logger.addHandler(handler)
logger.info("Test") logger.info("Test")
else: else:
logging.config.fileConfig(self.log_config) with open(self.log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
observer = PythonLoggingObserver() observer = PythonLoggingObserver()
observer.start() observer.start()

View File

@ -22,6 +22,12 @@ class RatelimitConfig(Config):
self.rc_messages_per_second = args.rc_messages_per_second self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
@classmethod @classmethod
def add_arguments(cls, parser): def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser) super(RatelimitConfig, cls).add_arguments(parser)
@ -34,3 +40,33 @@ class RatelimitConfig(Config):
"--rc-message-burst-count", type=float, default=10, "--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled" help="number of message a client can send before being throttled"
) )
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@ -0,0 +1,33 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class RegistrationConfig(Config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
self.disable_registration = args.disable_registration
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--disable-registration",
action='store_true',
help="Disable registration of new users."
)

View File

@ -30,7 +30,7 @@ class ServerConfig(Config):
self.pid_file = self.abspath(args.pid_file) self.pid_file = self.abspath(args.pid_file)
self.webclient = True self.webclient = True
self.manhole = args.manhole self.manhole = args.manhole
self.no_tls = args.no_tls self.soft_file_limit = args.soft_file_limit
if not args.content_addr: if not args.content_addr:
host = args.server_name host = args.server_name
@ -75,8 +75,12 @@ class ServerConfig(Config):
server_group.add_argument("--content-addr", default=None, server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the " help="The host and scheme to use for the "
"content repository") "content repository")
server_group.add_argument("--no-tls", action='store_true', server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Don't bind to the https port.") help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")
def read_signing_key(self, signing_key_path): def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key") signing_keys = self.read_file(signing_key_path, "signing_key")

View File

@ -28,9 +28,16 @@ class TlsConfig(Config):
self.tls_certificate = self.read_tls_certificate( self.tls_certificate = self.read_tls_certificate(
args.tls_certificate_path args.tls_certificate_path
) )
self.tls_private_key = self.read_tls_private_key(
args.tls_private_key_path self.no_tls = args.no_tls
)
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
args.tls_private_key_path
)
self.tls_dh_params_path = self.check_file( self.tls_dh_params_path = self.check_file(
args.tls_dh_params_path, "tls_dh_params" args.tls_dh_params_path, "tls_dh_params"
) )
@ -45,6 +52,8 @@ class TlsConfig(Config):
help="PEM encoded private key for TLS") help="PEM encoded private key for TLS")
tls_group.add_argument("--tls-dh-params-path", tls_group.add_argument("--tls-dh-params-path",
help="PEM dh parameters for ephemeral keys") help="PEM dh parameters for ephemeral keys")
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_tls_certificate(self, cert_path): def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate") cert_pem = self.read_file(cert_path, "tls_certificate")

View File

@ -28,7 +28,7 @@ class VoipConfig(Config):
super(VoipConfig, cls).add_arguments(parser) super(VoipConfig, cls).add_arguments(parser)
group = parser.add_argument_group("voip") group = parser.add_argument_group("voip")
group.add_argument( group.add_argument(
"--turn-uris", type=str, default=None, "--turn-uris", type=str, default=None, action='append',
help="The public URIs of the TURN server to give to clients" help="The public URIs of the TURN server to give to clients"
) )
group.add_argument( group.add_argument(

View File

@ -38,7 +38,10 @@ class ServerContextFactory(ssl.ContextFactory):
logger.exception("Failed to enable eliptic curve for TLS") logger.exception("Failed to enable eliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3) context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate(config.tls_certificate) context.use_certificate(config.tls_certificate)
context.use_privatekey(config.tls_private_key)
if not config.no_tls:
context.use_privatekey(config.tls_private_key)
context.load_tmp_dh(config.tls_dh_params_path) context.load_tmp_dh(config.tls_dh_params_path)
context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH") context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")

View File

@ -19,7 +19,7 @@ from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
import json import simplejson as json
import logging import logging
@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self): def __init__(self):
self.remote_key = defer.Deferred() self.remote_key = defer.Deferred()
self.host = None
def connectionMade(self): def connectionMade(self):
logger.debug("Connected to %s", self.transport.getHost()) self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", b"/_matrix/key/v1/") self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders() self.endHeaders()
self.timer = reactor.callLater( self.timer = reactor.callLater(
@ -73,7 +75,7 @@ class SynapseKeyClientProtocol(HTTPClient):
def handleStatus(self, version, status, message): def handleStatus(self, version, status, message):
if status != b"200": if status != b"200":
#logger.info("Non-200 response from %s: %s %s", # logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message) # self.transport.getHost(), status, message)
self.transport.abortConnection() self.transport.abortConnection()
@ -81,7 +83,7 @@ class SynapseKeyClientProtocol(HTTPClient):
try: try:
json_response = json.loads(response_body_bytes) json_response = json.loads(response_body_bytes)
except ValueError: except ValueError:
#logger.info("Invalid JSON response from %s", # logger.info("Invalid JSON response from %s",
# self.transport.getHost()) # self.transport.getHost())
self.transport.abortConnection() self.transport.abortConnection()
return return
@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel() self.timer.cancel()
def on_timeout(self): def on_timeout(self):
logger.debug("Timeout waiting for response from %s", logger.debug("Timeout waiting for response from %s", self.host)
self.transport.getHost())
self.remote_key.errback(IOError("Timeout waiting for response")) self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection() self.transport.abortConnection()

View File

@ -22,6 +22,8 @@ from syutil.crypto.signing_key import (
from syutil.base64util import decode_base64, encode_base64 from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from OpenSSL import crypto from OpenSSL import crypto
import logging import logging
@ -48,18 +50,27 @@ class Keyring(object):
) )
try: try:
verify_key = yield self.get_server_verify_key(server_name, key_ids) verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError: except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError( raise SynapseError(
502, 502,
"Error downloading keys for %s" % (server_name,), "Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
except: except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError( raise SynapseError(
401, 401,
"No key for %s with id %s" % (server_name, key_ids), "No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
try: try:
verify_signed_json(json_object, server_name, verify_key) verify_signed_json(json_object, server_name, verify_key)
except: except:
@ -87,12 +98,18 @@ class Keyring(object):
return return
# Try to fetch the key from the remote server. # Try to fetch the key from the remote server.
# TODO(markjh): Ratelimit requests to a given server.
(response, tls_certificate) = yield fetch_server_key( limiter = yield get_retry_limiter(
server_name, self.hs.tls_context_factory server_name,
self.clock,
self.store,
) )
with limiter:
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
)
# Check the response. # Check the response.
x509_certificate_bytes = crypto.dump_certificate( x509_certificate_bytes = crypto.dump_certificate(

View File

@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.util.frozenutils import freeze, unfreeze from synapse.util.frozenutils import freeze
class _EventInternalMetadata(object): class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict): def __init__(self, internal_metadata_dict):
self.__dict__ = internal_metadata_dict self.__dict__ = dict(internal_metadata_dict)
def get_dict(self): def get_dict(self):
return dict(self.__dict__) return dict(self.__dict__)
@ -77,7 +77,7 @@ class EventBase(object):
return self.content["membership"] return self.content["membership"]
def is_state(self): def is_state(self):
return hasattr(self, "state_key") return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self): def get_dict(self):
d = dict(self._event_dict) d = dict(self._event_dict)
@ -140,10 +140,6 @@ class FrozenEvent(EventBase):
return e return e
def get_dict(self):
# We need to unfreeze what we return
return unfreeze(super(FrozenEvent, self).get_dict())
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()

View File

@ -23,14 +23,15 @@ import copy
class EventBuilder(EventBase): class EventBuilder(EventBase):
def __init__(self, key_values={}): def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {})) signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {})) unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__( super(EventBuilder, self).__init__(
key_values, key_values,
signatures=signatures, signatures=signatures,
unsigned=unsigned unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
) )
def build(self): def build(self):

View File

@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state self.current_state = current_state
self.auth_events = auth_events self.auth_events = auth_events
self.state_group = None self.state_group = None
self.rejected = False

View File

@ -45,12 +45,14 @@ def prune_event(event):
"membership", "membership",
] ]
event_dict = event.get_dict()
new_content = {} new_content = {}
def add_fields(*fields): def add_fields(*fields):
for field in fields: for field in fields:
if field in event.content: if field in event.content:
new_content[field] = event.content[field] new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
add_fields("membership") add_fields("membership")
@ -75,7 +77,7 @@ def prune_event(event):
allowed_fields = { allowed_fields = {
k: v k: v
for k, v in event.get_dict().items() for k, v in event_dict.items()
if k in allowed_keys if k in allowed_keys
} }
@ -86,56 +88,78 @@ def prune_event(event):
if "age_ts" in event.unsigned: if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
return type(event)(allowed_fields) return type(event)(
allowed_fields,
internal_metadata_dict=event.internal_metadata.get_dict()
)
def serialize_event(hs, e, client_event=True): def format_event_raw(d):
return d
def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None)
move_keys = ("age", "redacted_because", "replaces_state", "prev_content")
for key in move_keys:
if key in d["unsigned"]:
d[key] = d["unsigned"][key]
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"unsigned", "origin", "prev_state"
)
for key in drop_keys:
d.pop(key, None)
return d
def format_event_for_client_v2(d):
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"origin", "prev_state",
)
for key in drop_keys:
d.pop(key, None)
return d
def format_event_for_client_v2_without_event_id(d):
d = format_event_for_client_v2(d)
d.pop("room_id", None)
d.pop("event_id", None)
return d
def serialize_event(e, time_now_ms, as_client_event=True,
event_format=format_event_for_client_v1,
token_id=None):
# FIXME(erikj): To handle the case of presence events and the like # FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase): if not isinstance(e, EventBase):
return e return e
time_now_ms = int(time_now_ms)
# Should this strip out None's? # Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()} d = {k: v for k, v in e.get_dict().items()}
if not client_event:
# set the age and keep all other keys
if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec())
d["unsigned"]["age"] = now - d["unsigned"]["age_ts"]
return d
if "age_ts" in d["unsigned"]: if "age_ts" in d["unsigned"]:
now = int(hs.get_clock().time_msec()) d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
d["age"] = now - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"]
d["user_id"] = d.pop("sender", None)
if "redacted_because" in e.unsigned: if "redacted_because" in e.unsigned:
d["redacted_because"] = serialize_event( d["unsigned"]["redacted_because"] = serialize_event(
hs, e.unsigned["redacted_because"] e.unsigned["redacted_because"], time_now_ms
) )
del d["unsigned"]["redacted_because"] if token_id is not None:
if token_id == getattr(e.internal_metadata, "token_id", None):
txn_id = getattr(e.internal_metadata, "txn_id", None)
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
if "redacted_by" in e.unsigned: if as_client_event:
d["redacted_by"] = e.unsigned["redacted_by"] return event_format(d)
del d["unsigned"]["redacted_by"] else:
return d
if "replaces_state" in e.unsigned:
d["replaces_state"] = e.unsigned["replaces_state"]
del d["unsigned"]["replaces_state"]
if "prev_content" in e.unsigned:
d["prev_content"] = e.unsigned["prev_content"]
del d["unsigned"]["prev_content"]
del d["auth_events"]
del d["prev_events"]
del d["hashes"]
del d["signatures"]
d.pop("depth", None)
d.pop("unsigned", None)
d.pop("origin", None)
return d

View File

@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
If a PDU fails its content hash check then it is redacted.
The given list of PDUs are not modified, instead the function returns
a new list.
Args:
pdu (list)
outlier (bool)
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
signed_pdus = []
@defer.inlineCallbacks
def do(pdu):
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
# Check local db.
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
# Check pdu.origin
if pdu.origin != origin:
try:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
yield defer.gatherResults(
[do(pdu) for pdu in pdus],
consumeErrors=True
)
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
defer.returnValue(pdu)

View File

@ -0,0 +1,563 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import itertools
import logging
import random
logger = logging.getLogger(__name__)
class FederationClient(FederationBase):
def __init__(self):
self._get_pdu_cache = None
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
clock=self._clock,
max_len=1000,
expiry_ms=120*1000,
reset_expiry_on_get=False,
)
self._get_pdu_cache.start()
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
pdus = [
self.event_from_pdu_json(p, outlier=False)
for p in transaction_data["pdus"]
]
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
servers.
Will attempt to get the PDU from each destination in the list until
one succeeds.
This will persist the PDU locally upon receipt.
Args:
destinations (list): Which home servers to query
pdu_origin (str): The home server that originally sent the pdu.
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
# TODO: Rate limit the number of times we try and get the same event.
if self._get_pdu_cache:
e = self._get_pdu_cache.get(event_id)
if e:
defer.returnValue(e)
pdu = None
for destination in destinations:
try:
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
with limiter:
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
logger.debug("transaction_data %r", transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction_data["pdus"]
]
if pdu_list:
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
break
except SynapseError:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
continue
if self._get_pdu_cache is not None:
self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the `current` state PDUs for a given room from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus, outlier=True
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
defer.returnValue(signed_auth)
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
)
break
except CodeMessageException:
raise
except Exception as e:
logger.warn(
"Failed to make_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
signed_state = yield self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
})
except CodeMessageException:
raise
except Exception as e:
logger.warn(
"Failed to send_join via %s: %s",
destination, e.message
)
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
pdu = self.event_from_pdu_json(pdu_dict)
# Check signatures are correct.
pdu = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdu)
@defer.inlineCallbacks
def query_auth(self, destination, room_id, event_id, local_auth):
"""
Params:
destination (str)
event_it (str)
local_auth (list)
"""
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
}
code, content = yield self.transport_layer.send_query_auth(
destination=destination,
room_id=room_id,
event_id=event_id,
content=send_content,
)
auth_chain = [
self.event_from_pdu_json(e)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
signed_auth.sort(key=lambda e: e.depth)
ret = {
"auth_chain": signed_auth,
"rejects": content.get("rejects", []),
"missing": content.get("missing", []),
}
defer.returnValue(ret)
@defer.inlineCallbacks
def get_missing_events(self, destination, room_id, earliest_events_ids,
latest_events, limit, min_depth):
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
Args:
destination (str)
room_id (str)
earliest_events_ids (list): List of event ids. Effectively the
events we expected to receive, but haven't. `get_missing_events`
should only return events that didn't happen before these.
latest_events (list): List of events we have received that we don't
have all previous events for.
limit (int): Maximum number of events to return.
min_depth (int): Minimum depth of events tor return.
"""
try:
content = yield self.transport_layer.get_missing_events(
destination=destination,
room_id=room_id,
earliest_events=earliest_events_ids,
latest_events=[e.event_id for e in latest_events],
limit=limit,
min_depth=min_depth,
)
events = [
self.event_from_pdu_json(e)
for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=True
)
have_gotten_all_from_destination = True
except HttpResponseException as e:
if not e.code == 400:
raise
# We are probably hitting an old server that doesn't support
# get_missing_events
signed_events = []
have_gotten_all_from_destination = False
if len(signed_events) >= limit:
defer.returnValue(signed_events)
servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = set(servers)
servers.discard(self.server_name)
failed_to_fetch = set()
while len(signed_events) < limit:
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
if e.depth > min_depth:
missing_events.update({
e_id: e.depth for e_id, _ in e.prev_events
if e_id not in seen_events
and e_id not in failed_to_fetch
})
if not missing_events:
break
have_seen = yield self.store.have_events(missing_events)
for k in have_seen:
missing_events.pop(k, None)
if not missing_events:
break
# Okay, we haven't gotten everything yet. Lets get them.
ordered_missing = sorted(missing_events.items(), key=lambda x: x[0])
if have_gotten_all_from_destination:
servers.discard(destination)
def random_server_list():
srvs = list(servers)
random.shuffle(srvs)
return srvs
deferreds = [
self.get_pdu(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)
defer.returnValue(signed_events)
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@ -0,0 +1,474 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import logging
logger = logging.getLogger(__name__)
class FederationServer(FederationBase):
def set_handler(self, handler):
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler):
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug(
"[%s] We've already responed to this request",
transaction.transaction_id
)
defer.returnValue(response)
return
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
results = []
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)
logger.debug("Returning: %s", str(results))
response = {
"pdus": dict(zip(
(p.event_id for p in pdu_list), results
)),
}
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
for event in auth_chain:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}))
@defer.inlineCallbacks
def on_query_auth_request(self, origin, content, event_id):
"""
Content is a dict with keys::
auth_chain (list): A list of events that give the auth chain.
missing (list): A list of event_ids indicating what the other
side (`origin`) think we're missing.
rejects (dict): A mapping from event_id to a 2-tuple of reason
string and a proof (or None) of why the event was rejected.
The keys of this dict give the list of events the `origin` has
rejected.
Args:
origin (str)
content (dict)
event_id (str)
Returns:
Deferred: Results in `dict` with the same format as `content`
"""
auth_chain = [
self.event_from_pdu_json(e)
for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
origin, auth_chain, outlier=True
)
ret = yield self.handler.on_query_auth(
origin,
event_id,
signed_auth,
content.get("rejects", []),
content.get("missing", []),
)
time_now = self._clock.time_msec()
send_content = {
"auth_chain": [
e.get_pdu_json(time_now)
for e in ret["auth_chain"]
],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
defer.returnValue(
(200, send_content)
)
@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
time_now = self._clock.time_msec()
defer.returnValue({
"events": [ev.get_pdu_json(time_now) for ev in missing_events],
})
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, get_missing=True):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
return
# Check signature.
try:
pdu = yield self._check_sigs_and_hash(pdu)
except SynapseError as e:
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=pdu.event_id,
)
state = None
auth_chain = []
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
fetch_state = False
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
latest_tuples = yield self.store.get_latest_events_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(e_id for e_id, _, _ in latest_tuples)
latest |= seen
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
earliest_events_ids=list(latest),
latest_events=[pdu],
limit=10,
min_depth=min_depth,
)
# We want to sort these by depth so we process them and
# tell clients about them in order.
missing_events.sort(key=lambda x: x.depth)
for e in missing_events:
yield self._handle_new_pdu(
origin,
e,
get_missing=False
)
have_seen = yield self.store.have_events(
[ev for ev, _ in pdu.prev_events]
)
prevs = {e_id for e_id, _ in pdu.prev_events}
seen = set(have_seen.keys())
if prevs - seen:
fetch_state = True
if fetch_state:
# We need to get the state at this event, since we haven't
# processed all the prev events.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
try:
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
except:
logger.warn("Failed to get state for event: %s", pdu.event_id)
yield self.handler.on_receive_pdu(
origin,
pdu,
backfilled=False,
state=state,
auth_chain=auth_chain,
)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event

View File

@ -23,7 +23,8 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import json from syutil.jsonutil import encode_canonical_json
import logging import logging
@ -70,7 +71,7 @@ class TransactionActions(object):
transaction.transaction_id, transaction.transaction_id,
transaction.origin, transaction.origin,
code, code,
json.dumps(response) encode_canonical_json(response)
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -100,5 +101,5 @@ class TransactionActions(object):
transaction.transaction_id, transaction.transaction_id,
transaction.destination, transaction.destination,
response_code, response_code,
json.dumps(response_dict) encode_canonical_json(response_dict)
) )

View File

@ -17,23 +17,20 @@
a given transport. a given transport.
""" """
from twisted.internet import defer from .federation_client import FederationClient
from .federation_server import FederationServer
from .units import Transaction, Edu from .transaction_queue import TransactionQueue
from .persistence import TransactionActions from .persistence import TransactionActions
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ReplicationLayer(object): class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over """This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to the given transport. I.e., does the sending and receiving of PDUs to
remote home servers. remote home servers.
@ -54,898 +51,28 @@ class ReplicationLayer(object):
def __init__(self, hs, transport_layer): def __init__(self, hs, transport_layer):
self.server_name = hs.hostname self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.transport_layer = transport_layer self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self) self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self) self.transport_layer.register_request_handler(self)
self.store = hs.get_datastore() self.federation_client = self
# self.pdu_actions = PduActions(self.store)
self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = _TransactionQueue( self.store = hs.get_datastore()
hs, self.transaction_actions, transport_layer
)
self.handler = None self.handler = None
self.edu_handlers = {} self.edu_handlers = {}
self.query_handlers = {} self.query_handlers = {}
self._order = 0
self._clock = hs.get_clock() self._clock = hs.get_clock()
self.event_builder_factory = hs.get_event_builder_factory() self.transaction_actions = TransactionActions(self.store)
self._transaction_queue = TransactionQueue(hs, transport_layer)
def set_handler(self, handler): self._order = 0
"""Sets the handler that the replication layer will use to communicate
receipt of new PDUs from other home servers. The required methods are
documented on :py:class:`.ReplicationHandler`.
"""
self.handler = handler
def register_edu_handler(self, edu_type, handler): self.hs = hs
if edu_type in self.edu_handlers:
raise KeyError("Already have an EDU handler for %s" % (edu_type,))
self.edu_handlers[edu_type] = handler
def register_query_handler(self, query_type, handler):
"""Sets the handler callable that will be used to handle an incoming
federation Query of the given type.
Args:
query_type (str): Category name of the query, which should match
the string used by make_query.
handler (callable): Invoked to handle incoming queries of this type
handler is invoked as:
result = handler(args)
where 'args' is a dict mapping strings to strings of the query
arguments. It should return a Deferred that will eventually yield an
object to encode as JSON.
"""
if query_type in self.query_handlers:
raise KeyError(
"Already have a Query handler for %s" % (query_type,)
)
self.query_handlers[query_type] = handler
@log_function
def send_pdu(self, pdu, destinations):
"""Informs the replication layer about a new PDU generated within the
home server that should be transmitted to others.
TODO: Figure out when we should actually resolve the deferred.
Args:
pdu (Pdu): The new Pdu.
Returns:
Deferred: Completes when we have successfully processed the PDU
and replicated it to any interested remote home servers.
"""
order = self._order
self._order += 1
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
self._transaction_queue.enqueue_pdu(pdu, destinations, order)
logger.debug(
"[%s] transaction_layer.enqueue_pdu... done",
pdu.event_id
)
@log_function
def send_edu(self, destination, edu_type, content):
edu = Edu(
origin=self.server_name,
destination=destination,
edu_type=edu_type,
content=content,
)
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@log_function
def send_failure(self, failure, destination):
self._transaction_queue.enqueue_failure(failure, destination)
return defer.succeed(None)
@log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=True):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
destination (str): Domain name of the remote homeserver
query_type (str): Category of the query type; should match the
handler name used in register_query_handler().
args (dict): Mapping of strings to strings containing the details
of the query request.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
"""Requests some more historic PDUs for the given context from the
given destination server.
Args:
dest (str): The remote home server to ask.
context (str): The context to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
we have seen from the context
Returns:
Deferred: Results in the received PDUs.
"""
logger.debug("backfill extrem=%s", extremities)
# If there are no extremeties then we've (probably) reached the start.
if not extremities:
return
transaction_data = yield self.transport_layer.backfill(
dest, context, extremities, limit)
logger.debug("backfill transaction_data=%s", repr(transaction_data))
transaction = Transaction(**transaction_data)
pdus = [
self.event_from_pdu_json(p, outlier=False)
for p in transaction.pdus
]
for pdu in pdus:
yield self._handle_new_pdu(dest, pdu, backfilled=True)
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destination, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
server.
This will persist the PDU locally upon receipt.
Args:
destination (str): Which home server to query
pdu_origin (str): The home server that originally sent the pdu.
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
Returns:
Deferred: Results in the requested PDU.
"""
transaction_data = yield self.transport_layer.get_event(
destination, event_id
)
transaction = Transaction(**transaction_data)
pdu_list = [
self.event_from_pdu_json(p, outlier=outlier)
for p in transaction.pdus
]
pdu = None
if pdu_list:
pdu = pdu_list[0]
yield self._handle_new_pdu(destination, pdu)
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the `current` state PDUs for a given room from
a remote home server.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
Returns:
Deferred: Results in a list of PDUs.
"""
result = yield self.transport_layer.get_room_state(
destination, room_id, event_id=event_id,
)
pdus = [
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in result.get("auth_chain", [])
]
defer.returnValue((pdus, auth_chain))
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
res = yield self.transport_layer.get_event_auth(
destination, room_id, event_id,
)
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in res["auth_chain"]
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue(auth_chain)
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, versions, limit):
pdus = yield self.handler.on_backfill_request(
origin, room_id, versions, limit
)
defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
if "age" in unsigned:
p["age"] = unsigned["age"]
if "age" in p:
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
del p["age"]
pdu_list = [
self.event_from_pdu_json(p) for p in transaction.pdus
]
logger.debug("[%s] Got transaction", transaction.transaction_id)
response = yield self.transaction_actions.have_responded(transaction)
if response:
logger.debug("[%s] We've already responed to this request",
transaction.transaction_id)
defer.returnValue(response)
return
logger.debug("[%s] Transaction is new", transaction.transaction_id)
with PreserveLoggingContext():
dl = []
for pdu in pdu_list:
dl.append(self._handle_new_pdu(transaction.origin, pdu))
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
results = yield defer.DeferredList(dl)
ret = []
for r in results:
if r[0]:
ret.append({})
else:
logger.exception(r[1])
ret.append({"error": str(r[1])})
logger.debug("Returning: %s", str(ret))
yield self.transaction_actions.set_response(
transaction,
200, response
)
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
logger.warn("Received EDU of type %s with no handler", edu_type)
@defer.inlineCallbacks
@log_function
def on_context_state_request(self, origin, room_id, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
origin, room_id, event_id,
)
auth_chain = yield self.store.get_auth_chain(
[pdu.event_id for pdu in pdus]
)
else:
raise NotImplementedError("Specify an event")
defer.returnValue((200, {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
}))
@defer.inlineCallbacks
@log_function
def on_pdu_request(self, origin, event_id):
pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
(200, self._transaction_from_pdus([pdu]).get_dict())
)
else:
defer.returnValue((404, ""))
@defer.inlineCallbacks
@log_function
def on_pull_request(self, origin, versions):
raise NotImplementedError("Pull transactions not implemented")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
else:
defer.returnValue(
(404, "No handler for Query type '%s'" % (query_type,))
)
@defer.inlineCallbacks
def on_make_join_request(self, room_id, user_id):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@defer.inlineCallbacks
def on_invite_request(self, origin, content):
pdu = self.event_from_pdu_json(content)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@defer.inlineCallbacks
def on_send_join_request(self, origin, content):
logger.debug("on_send_join_request: content: %s", content)
pdu = self.event_from_pdu_json(content)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
defer.returnValue((200, {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
}))
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
defer.returnValue((200, {
"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
}))
@defer.inlineCallbacks
def make_join(self, destination, room_id, user_id):
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
)
pdu_dict = ret["event"]
logger.debug("Got response to make_join: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@defer.inlineCallbacks
def send_join(self, destination, pdu):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
state = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
self.event_from_pdu_json(p, outlier=True)
for p in content.get("auth_chain", [])
]
auth_chain.sort(key=lambda e: e.depth)
defer.returnValue({
"state": state,
"auth_chain": auth_chain,
})
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
time_now = self._clock.time_msec()
code, content = yield self.transport_layer.send_invite(
destination=destination,
room_id=room_id,
event_id=event_id,
content=pdu.get_pdu_json(time_now),
)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
defer.returnValue(self.event_from_pdu_json(pdu_dict))
@log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
return self.handler.get_persisted_pdu(
origin, event_id, do_auth=do_auth
)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
transmission.
"""
time_now = self._clock.time_msec()
pdus = [p.get_pdu_json(time_now) for p in pdu_list]
return Transaction(
origin=self.server_name,
pdus=pdus,
origin_server_ts=int(time_now),
destination=None,
)
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, backfilled=False):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
)
already_seen = (
existing and (
not existing.internal_metadata.is_outlier()
or pdu.internal_metadata.is_outlier()
)
)
if already_seen:
logger.debug("Already seen pdu %s", pdu.event_id)
defer.returnValue({})
return
state = None
auth_chain = []
# We need to make sure we have all the auth events.
# for e_id, _ in pdu.auth_events:
# exists = yield self._get_persisted_pdu(
# origin,
# e_id,
# do_auth=False
# )
#
# if not exists:
# try:
# logger.debug(
# "_handle_new_pdu fetch missing auth event %s from %s",
# e_id,
# origin,
# )
#
# yield self.get_pdu(
# origin,
# event_id=e_id,
# outlier=True,
# )
#
# logger.debug("Processed pdu %s", e_id)
# except:
# logger.warn(
# "Failed to get auth event %s from %s",
# e_id,
# origin
# )
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.handler.get_min_depth_for_context(
pdu.room_id
)
logger.debug(
"_handle_new_pdu min_depth for %s: %d",
pdu.room_id, min_depth
)
if min_depth and pdu.depth > min_depth:
for event_id, hashes in pdu.prev_events:
exists = yield self._get_persisted_pdu(
origin,
event_id,
do_auth=False
)
if not exists:
logger.debug(
"_handle_new_pdu requesting pdu %s",
event_id
)
try:
yield self.get_pdu(
origin,
event_id=event_id,
)
logger.debug("Processed pdu %s", event_id)
except:
# TODO(erikj): Do some more intelligent retries.
logger.exception("Failed to get PDU")
else:
# We need to get the state at this event, since we have reached
# a backward extremity edge.
logger.debug(
"_handle_new_pdu getting state for %s",
pdu.room_id
)
state, auth_chain = yield self.get_state_for_room(
origin, pdu.room_id, pdu.event_id,
)
if not backfilled:
ret = yield self.handler.on_receive_pdu(
origin,
pdu,
backfilled=backfilled,
state=state,
auth_chain=auth_chain,
)
else:
ret = None
# yield self.pdu_actions.mark_as_processed(pdu)
defer.returnValue(ret)
def __str__(self): def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name return "<ReplicationLayer(%s)>" % self.server_name
def event_from_pdu_json(self, pdu_json, outlier=False):
event = FrozenEvent(
pdu_json
)
event.internal_metadata.outlier = outlier
return event
class _TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transaction_actions, transport_layer):
self.server_name = hs.hostname
self.transaction_actions = transaction_actions
self.transport_layer = transport_layer
self._clock = hs.get_clock()
self.store = hs.get_datastore()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set(destinations)
destinations.discard(self.server_name)
destinations.discard("localhost")
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send pdu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
deferreds.append(deferred)
yield defer.DeferredList(deferreds)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if destination == self.server_name:
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def eb(failure):
if not deferred.called:
deferred.errback(failure)
else:
logger.warn("Failed to send edu", failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(eb)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
deferred = defer.Deferred()
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
(retry_last_ts, retry_interval) = (0, 0)
retry_timings = yield self.store.get_destination_retry_timings(
destination
)
if retry_timings:
(retry_last_ts, retry_interval) = (
retry_timings.retry_last_ts, retry_timings.retry_interval
)
if retry_last_ts + retry_interval > int(self._clock.time_msec()):
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
return
else:
logger.info("TX [%s] is ready for retry", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
return
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
return
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
code, response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
if retry_last_ts:
# this host is alive! reset retry schedule
yield self.store.set_destination_retry_timings(
destination, 0, 0
)
deferred.callback(None)
else:
self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
self.set_retrying(destination, retry_interval)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
@defer.inlineCallbacks
def set_retrying(self, destination, retry_interval):
# track that this destination is having problems and we should
# give it a chance to recover before trying it again
if retry_interval:
retry_interval *= 2
# plateau at hourly retries for now
if retry_interval >= 60 * 60 * 1000:
retry_interval = 60 * 60 * 1000
else:
retry_interval = 2000 # try again at first after 2 seconds
yield self.store.set_destination_retry_timings(
destination,
int(self._clock.time_msec()),
retry_interval
)

View File

@ -0,0 +1,359 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
import logging
logger = logging.getLogger(__name__)
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
a time for a given destination.
It batches pending PDUs into single transactions.
"""
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.transaction_actions = TransactionActions(self.store)
self.transport_layer = transport_layer
self._clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
# done
self.pending_transactions = {}
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self._clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
We can't send messages to ourselves. If we are running on localhost
then we can only federation with other servers running on localhost.
Otherwise we only federate with servers on a public domain.
Args:
destination(str): The server we are possibly trying to send to.
Returns:
bool: True if we can send to the server.
"""
if destination == self.server_name:
return False
if self.server_name.startswith("localhost"):
return destination.startswith("localhost")
else:
return not destination.startswith("localhost")
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
destinations = set(destinations)
destinations = set(
dest for dest in destinations if self.can_send_to(dest)
)
logger.debug("Sending to: %s", str(destinations))
if not destinations:
return
deferreds = []
for destination in destinations:
deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, deferred, order)
)
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(failure):
logger.warn("Failed to send pdu", failure.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
deferreds.append(deferred)
yield defer.DeferredList(deferreds, consumeErrors=True)
# NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if not self.can_send_to(destination):
return
deferred = defer.Deferred()
self.pending_edus_by_dest.setdefault(destination, []).append(
(edu, deferred)
)
def chain(failure):
if not deferred.called:
deferred.errback(failure)
def log_failure(failure):
logger.warn("Failed to send pdu", failure.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
return deferred
@defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
deferred = defer.Deferred()
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
).append(
(failure, deferred)
)
def chain(f):
if not deferred.called:
deferred.errback(f)
def log_failure(f):
logger.warn("Failed to send pdu", f.value)
deferred.addErrback(log_failure)
with PreserveLoggingContext():
self._attempt_new_transaction(destination).addErrback(chain)
yield deferred
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.info(
"TX [%s] Transaction already in progress",
destination
)
return
logger.info("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.info("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[2])
pdus = [x[0] for x in pending_pdus]
edus = [x[0] for x in pending_edus]
failures = [x[0].get_dict() for x in pending_failures]
deferreds = [
x[1]
for x in pending_pdus + pending_edus + pending_failures
]
try:
self.pending_transactions[destination] = 1
limiter = yield get_retry_limiter(
destination,
self._clock,
self.store,
)
logger.debug(
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
edus=edus,
pdu_failures=failures,
)
self._next_txn_id += 1
yield self.transaction_actions.prepare_to_send(transaction)
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
)
with limiter:
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self._clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self.transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
if response:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"Transaction returned error for %s: %s",
e_id, r,
)
except HttpResponseException as e:
code = e.code
response = e.response
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)
yield self.transaction_actions.delivered(
transaction, code, response
)
logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
deferred.callback(None)
else:
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
# deferred have fired
try:
yield deferred
except:
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
)
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
except Exception as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
logger.warn(
"TX [%s] Problem in _attempt_transaction: %s",
destination,
e,
)
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
finally:
# We want to be *very* sure we delete this after we stop processing
self.pending_transactions.pop(destination, None)
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)

View File

@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol.
from .server import TransportLayerServer from .server import TransportLayerServer
from .client import TransportLayerClient from .client import TransportLayerClient
from synapse.util.ratelimitutils import FederationRateLimiter
class TransportLayer(TransportLayerServer, TransportLayerClient): class TransportLayer(TransportLayerServer, TransportLayerClient):
"""This is a basic implementation of the transport layer that translates """This is a basic implementation of the transport layer that translates
@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient):
send requests send requests
""" """
self.keyring = homeserver.get_keyring() self.keyring = homeserver.get_keyring()
self.clock = homeserver.get_clock()
self.server_name = server_name self.server_name = server_name
self.server = server self.server = server
self.client = client self.client = client
self.request_handler = None self.request_handler = None
self.received_handler = None self.received_handler = None
self.ratelimiter = FederationRateLimiter(
self.clock,
window_size=homeserver.config.federation_rc_window_size,
sleep_limit=homeserver.config.federation_rc_sleep_limit,
sleep_msec=homeserver.config.federation_rc_sleep_delay,
reject_limit=homeserver.config.federation_rc_reject_limit,
concurrent_requests=homeserver.config.federation_rc_concurrent,
)

View File

@ -19,7 +19,6 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import logging import logging
import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -129,7 +128,7 @@ class TransportLayerClient(object):
# generated by the json_data_callback. # generated by the json_data_callback.
json_data = transaction.get_dict() json_data = transaction.get_dict()
code, response = yield self.client.put_json( response = yield self.client.put_json(
transaction.destination, transaction.destination,
path=PREFIX + "/send/%s/" % transaction.transaction_id, path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=json_data, data=json_data,
@ -137,79 +136,105 @@ class TransportLayerClient(object):
) )
logger.debug( logger.debug(
"send_data dest=%s, txid=%s, got response: %d", "send_data dest=%s, txid=%s, got response: 200",
transaction.destination, transaction.transaction_id, code transaction.destination, transaction.transaction_id,
) )
defer.returnValue((code, response)) defer.returnValue(response)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail): def make_query(self, destination, query_type, args, retry_on_dns_fail):
path = PREFIX + "/query/%s" % query_type path = PREFIX + "/query/%s" % query_type
response = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
args=args, args=args,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
) )
defer.returnValue(response) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True): def make_join(self, destination, room_id, user_id, retry_on_dns_fail=True):
path = PREFIX + "/make_join/%s/%s" % (room_id, user_id) path = PREFIX + "/make_join/%s/%s" % (room_id, user_id)
response = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
) )
defer.returnValue(response) defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_join(self, destination, room_id, event_id, content): def send_join(self, destination, room_id, event_id, content):
path = PREFIX + "/send_join/%s/%s" % (room_id, event_id) path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
code, content = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
) )
if not 200 <= code < 300: defer.returnValue(response)
raise RuntimeError("Got %d from send_join", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def send_invite(self, destination, room_id, event_id, content): def send_invite(self, destination, room_id, event_id, content):
path = PREFIX + "/invite/%s/%s" % (room_id, event_id) path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
code, content = yield self.client.put_json( response = yield self.client.put_json(
destination=destination, destination=destination,
path=path, path=path,
data=content, data=content,
) )
if not 200 <= code < 300: defer.returnValue(response)
raise RuntimeError("Got %d from send_invite", code)
defer.returnValue(json.loads(content))
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_event_auth(self, destination, room_id, event_id): def get_event_auth(self, destination, room_id, event_id):
path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id) path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
response = yield self.client.get_json( content = yield self.client.get_json(
destination=destination, destination=destination,
path=path, path=path,
) )
defer.returnValue(response) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
content = yield self.client.post_json(
destination=destination,
path=path,
data=content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
path = PREFIX + "/get_missing_events/%s" % (room_id,)
content = yield self.client.post_json(
destination=destination,
path=path,
data={
"limit": int(limit),
"min_depth": int(min_depth),
"earliest_events": earliest_events,
"latest_events": latest_events,
}
)
defer.returnValue(content)

View File

@ -20,7 +20,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
import logging import logging
import json import simplejson as json
import re import re
@ -42,7 +42,7 @@ class TransportLayerServer(object):
content = None content = None
origin = None origin = None
if request.method == "PUT": if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types? # TODO: Handle other method types? other content types?
try: try:
content_bytes = request.content.read() content_bytes = request.content.read()
@ -98,15 +98,23 @@ class TransportLayerServer(object):
def new_handler(request, *args, **kwargs): def new_handler(request, *args, **kwargs):
try: try:
(origin, content) = yield self._authenticate_request(request) (origin, content) = yield self._authenticate_request(request)
response = yield handler( with self.ratelimiter.ratelimit(origin) as d:
origin, content, request.args, *args, **kwargs yield d
) response = yield handler(
origin, content, request.args, *args, **kwargs
)
except: except:
logger.exception("_authenticate_request failed") logger.exception("_authenticate_request failed")
raise raise
defer.returnValue(response) defer.returnValue(response)
return new_handler return new_handler
def rate_limit_origin(self, handler):
def new_handler(origin, *args, **kwargs):
response = yield handler(origin, *args, **kwargs)
defer.returnValue(response)
return new_handler()
@log_function @log_function
def register_received_handler(self, handler): def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data. """ Register a handler that will be fired when we receive data.
@ -235,6 +243,28 @@ class TransportLayerServer(object):
) )
) )
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
self._with_authentication(
lambda origin, content, query, context, event_id:
self._on_query_auth_request(
origin, content, event_id,
)
)
)
self.server.register_path(
"POST",
re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
self._with_authentication(
lambda origin, content, query, room_id:
self._get_missing_events(
origin, content, room_id,
)
)
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _on_send_request(self, origin, content, query, transaction_id): def _on_send_request(self, origin, content, query, transaction_id):
@ -325,3 +355,31 @@ class TransportLayerServer(object):
) )
defer.returnValue((200, content)) defer.returnValue((200, content))
@defer.inlineCallbacks
@log_function
def _on_query_auth_request(self, origin, content, event_id):
new_content = yield self.request_handler.on_query_auth_request(
origin, content, event_id
)
defer.returnValue((200, new_content))
@defer.inlineCallbacks
@log_function
def _get_missing_events(self, origin, content, room_id):
limit = int(content.get("limit", 10))
min_depth = int(content.get("min_depth", 0))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
content = yield self.request_handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
min_depth=min_depth,
limit=limit,
)
defer.returnValue((200, content))

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler from .register import RegistrationHandler
from .room import ( from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler RoomCreationHandler, RoomMemberHandler, RoomListHandler
@ -26,6 +27,8 @@ from .presence import PresenceHandler
from .directory import DirectoryHandler from .directory import DirectoryHandler
from .typing import TypingNotificationHandler from .typing import TypingNotificationHandler
from .admin import AdminHandler from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
class Handlers(object): class Handlers(object):
@ -51,3 +54,7 @@ class Handlers(object):
self.directory_handler = DirectoryHandler(hs) self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs) self.admin_handler = AdminHandler(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, ApplicationServiceApi(hs)
)
self.sync_handler = SyncHandler(hs)

View File

@ -19,6 +19,7 @@ from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID
import logging import logging
@ -113,7 +114,7 @@ class BaseHandler(object):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE: if event.content["membership"] == Membership.INVITE:
invitee = self.hs.parse_userid(event.state_key) invitee = UserID.from_string(event.state_key)
if not self.hs.is_mine(invitee): if not self.hs.is_mine(invitee):
# TODO: Can we add signature from remote server in a nicer # TODO: Can we add signature from remote server in a nicer
# way? If we have been invited by a remote server, we need # way? If we have been invited by a remote server, we need
@ -134,7 +135,7 @@ class BaseHandler(object):
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(
self.hs.parse_userid(s.state_key).domain UserID.from_string(s.state_key).domain
) )
except SynapseError: except SynapseError:
logger.warn( logger.warn(

View File

@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService
from synapse.types import UserID
import synapse.util.stringutils as stringutils
import logging
logger = logging.getLogger(__name__)
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api):
self.store = hs.get_datastore()
self.hs = hs
self.appservice_api = appservice_api
@defer.inlineCallbacks
def register(self, app_service):
logger.info("Register -> %s", app_service)
# check the token is recognised
try:
stored_service = yield self.store.get_app_service_by_token(
app_service.token
)
if not stored_service:
raise StoreError(404, "Application service not found")
except StoreError:
raise SynapseError(
403, "Unrecognised application services token. "
"Consult the home server admin.",
errcode=Codes.FORBIDDEN
)
app_service.hs_token = self._generate_hs_token()
# create a sender for this application service which is used when
# creating rooms, etc..
account = yield self.hs.get_handlers().registration_handler.register()
app_service.sender = account[0]
yield self.store.update_app_service(app_service)
defer.returnValue(app_service)
@defer.inlineCallbacks
def unregister(self, token):
logger.info("Unregister as_token=%s", token)
yield self.store.unregister_app_service(token)
@defer.inlineCallbacks
def notify_interested_services(self, event):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
prolonged length of time.
Args:
event(Event): The event to push out to interested services.
"""
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
# Do we know this user exists? If not, poke the user query API for
# all services which match that user regex. This needs to block as these
# user queries need to be made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
# Fork off pushes to these services - XXX First cut, best effort
for service in services:
self.appservice_api.push(service, event)
@defer.inlineCallbacks
def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
Args:
user_id(str): The user to query if they exist on any AS.
Returns:
True if this user exists on at least one application service.
"""
user_query_services = yield self._get_services_for_user(
user_id=user_id
)
for user_service in user_query_services:
is_known_user = yield self.appservice_api.query_user(
user_service, user_id
)
if is_known_user:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def query_room_alias_exists(self, room_alias):
"""Check if an application service knows this room alias exists.
Args:
room_alias(RoomAlias): The room alias to query.
Returns:
namedtuple: with keys "room_id" and "servers" or None if no
association can be found.
"""
room_alias_str = room_alias.to_string()
alias_query_services = yield self._get_services_for_event(
event=None,
restrict_to=ApplicationService.NS_ALIASES,
alias_list=[room_alias_str]
)
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
alias_service, room_alias_str
)
if is_known_alias:
# the alias exists now so don't query more ASes.
result = yield self.store.get_association_from_room_alias(
room_alias
)
defer.returnValue(result)
@defer.inlineCallbacks
def _get_services_for_event(self, event, restrict_to="", alias_list=None):
"""Retrieve a list of application services interested in this event.
Args:
event(Event): The event to check. Can be None if alias_list is not.
restrict_to(str): The namespace to restrict regex tests to.
alias_list: A list of aliases to get services for. If None, this
list is obtained from the database.
Returns:
list<ApplicationService>: A list of services interested in this
event based on the service regex.
"""
member_list = None
if hasattr(event, "room_id"):
# We need to know the aliases associated with this event.room_id,
# if any.
if not alias_list:
alias_list = yield self.store.get_aliases_for_room(
event.room_id
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_room_members(
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested(event, restrict_to, alias_list, member_list)
)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _get_services_for_user(self, user_id):
services = yield self.store.get_app_services()
interested_list = [
s for s in services if (
s.is_interested_in_user(user_id)
)
]
defer.returnValue(interested_list)
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
defer.returnValue(False)
return
user_info = yield self.store.get_user_by_id(user_id)
defer.returnValue(len(user_info) == 0)
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
unknown_user = yield self._is_unknown_user(user_id)
if unknown_user:
exists = yield self.query_user_exists(user_id)
defer.returnValue(exists)
defer.returnValue(True)
def _generate_hs_token(self):
return stringutils.random_string(24)

View File

@ -19,6 +19,7 @@ from ._base import BaseHandler
from synapse.api.errors import SynapseError, Codes, CodeMessageException from synapse.api.errors import SynapseError, Codes, CodeMessageException
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.types import RoomAlias
import logging import logging
@ -36,18 +37,15 @@ class DirectoryHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def create_association(self, user_id, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services
# TODO(erikj): Do auth.
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this. # TODO(erikj): Change this.
# TODO(erikj): Add transactions. # TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association. # TODO(erikj): Check if there is a current association.
if not servers: if not servers:
servers = yield self.store.get_joined_hosts_for_room(room_id) servers = yield self.store.get_joined_hosts_for_room(room_id)
@ -60,23 +58,78 @@ class DirectoryHandler(BaseHandler):
servers servers
) )
@defer.inlineCallbacks
def create_association(self, user_id, room_alias, room_id, servers=None):
# association creation for human users
# TODO(erikj): Do user auth.
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
)
if not can_create:
raise SynapseError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks
def create_appservice_association(self, service, room_alias, room_id,
servers=None):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400, "This application service has not reserved"
" this kind of alias.", errcode=Codes.EXCLUSIVE
)
# association creation for app services
yield self._create_association(room_alias, room_id, servers)
@defer.inlineCallbacks @defer.inlineCallbacks
def delete_association(self, user_id, room_alias): def delete_association(self, user_id, room_alias):
# association deletion for human users
# TODO Check if server admin # TODO Check if server admin
can_delete = yield self.can_modify_alias(
room_alias,
user_id=user_id
)
if not can_delete:
raise SynapseError(
400, "This alias is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
yield self._delete_association(room_alias)
@defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
errcode=Codes.EXCLUSIVE
)
yield self._delete_association(room_alias)
@defer.inlineCallbacks
def _delete_association(self, room_alias):
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias) yield self.store.delete_room_alias(room_alias)
if room_id: # TODO - Looks like _update_room_alias_event has never been implemented
yield self._update_room_alias_events(user_id, room_id) # if room_id:
# yield self._update_room_alias_events(user_id, room_id)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_association(self, room_alias): def get_association(self, room_alias):
room_id = None room_id = None
if self.hs.is_mine(room_alias): if self.hs.is_mine(room_alias):
result = yield self.store.get_association_from_room_alias( result = yield self.get_association_from_room_alias(
room_alias room_alias
) )
@ -107,12 +160,21 @@ class DirectoryHandler(BaseHandler):
if not room_id: if not room_id:
raise SynapseError( raise SynapseError(
404, 404,
"Room alias %r not found" % (room_alias.to_string(),), "Room alias %s not found" % (room_alias.to_string(),),
Codes.NOT_FOUND Codes.NOT_FOUND
) )
extra_servers = yield self.store.get_joined_hosts_for_room(room_id) extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
servers = list(set(extra_servers) | set(servers)) servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
if self.server_name in servers:
servers = (
[self.server_name]
+ [s for s in servers if s != self.server_name]
)
else:
servers = list(servers)
defer.returnValue({ defer.returnValue({
"room_id": room_id, "room_id": room_id,
@ -122,13 +184,13 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_directory_query(self, args): def on_directory_query(self, args):
room_alias = self.hs.parse_roomalias(args["room_alias"]) room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError( raise SynapseError(
400, "Room Alias is not hosted on this Home Server" 400, "Room Alias is not hosted on this Home Server"
) )
result = yield self.store.get_association_from_room_alias( result = yield self.get_association_from_room_alias(
room_alias room_alias
) )
@ -156,3 +218,37 @@ class DirectoryHandler(BaseHandler):
"sender": user_id, "sender": user_id,
"content": {"aliases": aliases}, "content": {"aliases": aliases},
}, ratelimit=False) }, ratelimit=False)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
result = yield self.store.get_association_from_room_alias(
room_alias
)
if not result:
# Query AS to see if it exists
as_handler = self.hs.get_handlers().appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
defer.returnValue(result)
@defer.inlineCallbacks
def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
# non-exclusive locks on the alias (or there are no interested services)
services = yield self.store.get_app_services()
interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string())
]
for service in interested_services:
if user_id == service.sender:
# this user IS the app service so they can do whatever they like
defer.returnValue(True)
return
elif service.is_exclusive_alias(alias.to_string()):
# another service has an exclusive lock on this alias.
defer.returnValue(False)
return
# either no interested services, or no service with an exclusive lock
defer.returnValue(True)

View File

@ -17,10 +17,13 @@ from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
from ._base import BaseHandler from ._base import BaseHandler
import logging import logging
import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -47,38 +50,46 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True): as_client_event=True, affect_presence=True):
auth_user = self.hs.parse_userid(auth_user_id) auth_user = UserID.from_string(auth_user_id)
try: try:
if auth_user not in self._streams_per_user: if affect_presence:
self._streams_per_user[auth_user] = 0 if auth_user not in self._streams_per_user:
if auth_user in self._stop_timer_per_user: self._streams_per_user[auth_user] = 0
try: if auth_user in self._stop_timer_per_user:
self.clock.cancel_call_later( try:
self._stop_timer_per_user.pop(auth_user) self.clock.cancel_call_later(
self._stop_timer_per_user.pop(auth_user)
)
except:
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
) )
except: self._streams_per_user[auth_user] += 1
logger.exception("Failed to cancel event timer")
else:
yield self.distributor.fire(
"started_user_eventstream", auth_user
)
self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
rm_handler = self.hs.get_handlers().room_member_handler rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(auth_user) room_ids = yield rm_handler.get_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
timeout = max(timeout, 500)
# Add some randomness to this value to try and mitigate against
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
with PreserveLoggingContext(): with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout auth_user, room_ids, pagin_config, timeout
) )
time_now = self.clock.time_msec()
chunks = [ chunks = [
self.hs.serialize_event(e, as_client_event) for e in events serialize_event(e, time_now, as_client_event) for e in events
] ]
chunk = { chunk = {
@ -90,28 +101,29 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
finally: finally:
self._streams_per_user[auth_user] -= 1 if affect_presence:
if not self._streams_per_user[auth_user]: self._streams_per_user[auth_user] -= 1
del self._streams_per_user[auth_user] if not self._streams_per_user[auth_user]:
del self._streams_per_user[auth_user]
# 10 seconds of grace to allow the client to reconnect again # 10 seconds of grace to allow the client to reconnect again
# before we think they're gone # before we think they're gone
def _later(): def _later():
logger.debug( logger.debug(
"_later stopped_user_eventstream %s", auth_user "_later stopped_user_eventstream %s", auth_user
)
self._stop_timer_per_user.pop(auth_user, None)
return self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
) )
self._stop_timer_per_user.pop(auth_user, None)
yield self.distributor.fire(
"stopped_user_eventstream", auth_user
)
logger.debug("Scheduling _later: for %s", auth_user)
self._stop_timer_per_user[auth_user] = (
self.clock.call_later(30, _later)
)
class EventHandler(BaseHandler): class EventHandler(BaseHandler):

View File

@ -17,21 +17,21 @@
from ._base import BaseHandler from ._base import BaseHandler
from synapse.events.utils import prune_event
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, FederationError, SynapseError, StoreError, AuthError, FederationError, StoreError,
) )
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, check_event_content_hash, compute_event_signature, add_hashes_and_signatures,
add_hashes_and_signatures,
) )
from syutil.jsonutil import encode_canonical_json from synapse.types import UserID
from twisted.internet import defer from twisted.internet import defer
import itertools
import logging import logging
@ -112,33 +112,6 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id) logger.debug("Processing event: %s", event.event_id)
redacted_event = prune_event(event)
redacted_pdu_json = redacted_event.get_pdu_json()
try:
yield self.keyring.verify_json_for_server(
event.origin, redacted_pdu_json
)
except SynapseError as e:
logger.warn(
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
raise FederationError(
"ERROR",
e.code,
e.msg,
affected=event.event_id,
)
if not check_event_content_hash(event):
logger.warn(
"Event content has been tampered, redacting %s, %s",
event.event_id, encode_canonical_json(event.get_dict())
)
event = redacted_event
logger.debug("Event: %s", event) logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently # FIXME (erikj): Awful hack to make the case where we are not currently
@ -148,41 +121,38 @@ class FederationHandler(BaseHandler):
event.room_id, event.room_id,
self.server_name self.server_name
) )
if not is_in_room and not event.internal_metadata.outlier: if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.") logger.debug("Got event for room we're not in.")
replication = self.replication_layer
if not state:
state, auth_chain = yield replication.get_state_for_room(
origin, context=event.room_id, event_id=event.event_id,
)
if not auth_chain:
auth_chain = yield replication.get_event_auth(
origin,
context=event.room_id,
event_id=event.event_id,
)
for e in auth_chain:
e.internal_metadata.outlier = True
try:
yield self._handle_new_event(e, fetch_auth_from=origin)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
current_state = state current_state = state
event_ids = set()
if state: if state:
for e in state: event_ids |= {e.event_id for e in state}
logging.info("A :) %r", e) if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = set(
(yield self.store.have_events(event_ids)).keys()
)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
yield self._handle_new_event(e) auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
seen_ids.add(e.event_id)
except: except:
logger.exception( logger.exception(
"Failed to handle state event %s", "Failed to handle state event %s",
@ -191,6 +161,7 @@ class FederationHandler(BaseHandler):
try: try:
yield self._handle_new_event( yield self._handle_new_event(
origin,
event, event,
state=state, state=state,
backfilled=backfilled, backfilled=backfilled,
@ -227,7 +198,7 @@ class FederationHandler(BaseHandler):
extra_users = [] extra_users = []
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
target_user_id = event.state_key target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
@ -236,7 +207,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = UserID.from_string(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
@ -305,7 +276,7 @@ class FederationHandler(BaseHandler):
@log_function @log_function
@defer.inlineCallbacks @defer.inlineCallbacks
def do_invite_join(self, target_host, room_id, joinee, content, snapshot): def do_invite_join(self, target_hosts, room_id, joinee, content, snapshot):
""" Attempts to join the `joinee` to the room `room_id` via the """ Attempts to join the `joinee` to the room `room_id` via the
server `target_host`. server `target_host`.
@ -319,8 +290,8 @@ class FederationHandler(BaseHandler):
""" """
logger.debug("Joining %s to %s", joinee, room_id) logger.debug("Joining %s to %s", joinee, room_id)
pdu = yield self.replication_layer.make_join( origin, pdu = yield self.replication_layer.make_join(
target_host, target_hosts,
room_id, room_id,
joinee joinee
) )
@ -341,7 +312,7 @@ class FederationHandler(BaseHandler):
self.room_queues[room_id] = [] self.room_queues[room_id] = []
builder = self.event_builder_factory.new( builder = self.event_builder_factory.new(
event.get_pdu_json() unfreeze(event.get_pdu_json())
) )
handled_events = set() handled_events = set()
@ -362,11 +333,20 @@ class FederationHandler(BaseHandler):
new_event = builder.build() new_event = builder.build()
# Try the host we successfully got a response to /make_join/
# request first.
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
except ValueError:
pass
ret = yield self.replication_layer.send_join( ret = yield self.replication_layer.send_join(
target_host, target_hosts,
new_event new_event
) )
origin = ret["origin"]
state = ret["state"] state = ret["state"]
auth_chain = ret["auth_chain"] auth_chain = ret["auth_chain"]
auth_chain.sort(key=lambda e: e.depth) auth_chain.sort(key=lambda e: e.depth)
@ -392,8 +372,19 @@ class FederationHandler(BaseHandler):
for e in auth_chain: for e in auth_chain:
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
if e.event_id == event.event_id:
continue
try: try:
yield self._handle_new_event(e) auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
except: except:
logger.exception( logger.exception(
"Failed to handle auth event %s", "Failed to handle auth event %s",
@ -401,11 +392,18 @@ class FederationHandler(BaseHandler):
) )
for e in state: for e in state:
# FIXME: Auth these. if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
try: try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event( yield self._handle_new_event(
e, fetch_auth_from=target_host origin, e, auth_events=auth
) )
except: except:
logger.exception( logger.exception(
@ -413,10 +411,18 @@ class FederationHandler(BaseHandler):
e.event_id, e.event_id,
) )
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event( yield self._handle_new_event(
origin,
new_event, new_event,
state=state, state=state,
current_state=state, current_state=state,
auth_events=auth_events,
) )
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
@ -480,7 +486,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
context = yield self._handle_new_event(event) context = yield self._handle_new_event(origin, event)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@ -491,7 +497,7 @@ class FederationHandler(BaseHandler):
extra_users = [] extra_users = []
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
target_user_id = event.state_key target_user_id = event.state_key
target_user = self.hs.parse_userid(target_user_id) target_user = UserID.from_string(target_user_id)
extra_users.append(target_user) extra_users.append(target_user)
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
@ -500,7 +506,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = self.hs.parse_userid(event.state_key) user = UserID.from_string(event.state_key)
yield self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=event.room_id "user_joined_room", user=user, room_id=event.room_id
) )
@ -514,13 +520,15 @@ class FederationHandler(BaseHandler):
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(
self.hs.parse_userid(s.state_key).domain UserID.from_string(s.state_key).domain
) )
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
) )
destinations.discard(origin)
logger.debug( logger.debug(
"on_send_join_request: Sending event: %s, signatures: %s", "on_send_join_request: Sending event: %s, signatures: %s",
event.event_id, event.event_id,
@ -565,7 +573,7 @@ class FederationHandler(BaseHandler):
backfilled=False, backfilled=False,
) )
target_user = self.hs.parse_userid(event.state_key) target_user = UserID.from_string(event.state_key)
yield self.notifier.on_new_room_event( yield self.notifier.on_new_room_event(
event, extra_users=[target_user], event, extra_users=[target_user],
) )
@ -573,12 +581,13 @@ class FederationHandler(BaseHandler):
defer.returnValue(event) defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_for_pdu(self, origin, room_id, event_id): def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
yield run_on_reactor() yield run_on_reactor()
in_room = yield self.auth.check_host_in_room(room_id, origin) if do_auth:
if not in_room: in_room = yield self.auth.check_host_in_room(room_id, origin)
raise AuthError(403, "Host not in room.") if not in_room:
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups( state_groups = yield self.store.get_state_groups(
[event_id] [event_id]
@ -641,6 +650,7 @@ class FederationHandler(BaseHandler):
event = yield self.store.get_event( event = yield self.store.get_event(
event_id, event_id,
allow_none=True, allow_none=True,
allow_rejected=True,
) )
if event: if event:
@ -681,11 +691,12 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None) waiters.pop().callback(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def _handle_new_event(self, event, state=None, backfilled=False, @log_function
current_state=None, fetch_auth_from=None): def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
logger.debug( logger.debug(
"_handle_new_event: Before annotate: %s, sigs: %s", "_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures, event.event_id, event.signatures,
) )
@ -693,65 +704,46 @@ class FederationHandler(BaseHandler):
event, old_state=state event, old_state=state
) )
if not auth_events:
auth_events = context.auth_events
logger.debug( logger.debug(
"_handle_new_event: Before auth fetch: %s, sigs: %s", "_handle_new_event: %s, auth_events: %s",
event.event_id, event.signatures, event.event_id, auth_events,
) )
is_new_state = not event.internal_metadata.is_outlier() is_new_state = not event.internal_metadata.is_outlier()
known_ids = set( # This is a hack to fix some old rooms where the initial join event
[s.event_id for s in context.auth_events.values()] # didn't reference the create event in its auth events.
)
for e_id, _ in event.auth_events:
if e_id not in known_ids:
e = yield self.store.get_event(e_id, allow_none=True)
if not e and fetch_auth_from is not None:
# Grab the auth_chain over federation if we are missing
# auth events.
auth_chain = yield self.replication_layer.get_event_auth(
fetch_auth_from, event.event_id, event.room_id
)
for auth_event in auth_chain:
yield self._handle_new_event(auth_event)
e = yield self.store.get_event(e_id, allow_none=True)
if not e:
# TODO: Do some conflict res to make sure that we're
# not the ones who are wrong.
logger.info(
"Rejecting %s as %s not in db or %s",
event.event_id, e_id, known_ids,
)
# FIXME: How does raising AuthError work with federation?
raise AuthError(403, "Cannot find auth event")
context.auth_events[(e.type, e.state_key)] = e
logger.debug(
"_handle_new_event: Before hack: %s, sigs: %s",
event.event_id, event.signatures,
)
if event.type == EventTypes.Member and not event.auth_events: if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1: if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0]) c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create: if c.type == EventTypes.Create:
context.auth_events[(c.type, c.state_key)] = c auth_events[(c.type, c.state_key)] = c
logger.debug( try:
"_handle_new_event: Before auth check: %s, sigs: %s", yield self.do_auth(
event.event_id, event.signatures, origin, event, context, auth_events=auth_events
) )
except AuthError as e:
logger.warn(
"Rejecting %s because %s",
event.event_id, e.msg
)
self.auth.check(event, auth_events=context.auth_events) context.rejected = RejectedReason.AUTH_ERROR
logger.debug( # FIXME: Don't store as rejected with AUTH_ERROR if we haven't
"_handle_new_event: Before persist_event: %s, sigs: %s", # seen all the auth events.
event.event_id, event.signatures, yield self.store.persist_event(
) event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
yield self.store.persist_event( yield self.store.persist_event(
event, event,
@ -761,9 +753,388 @@ class FederationHandler(BaseHandler):
current_state=current_state, current_state=current_state,
) )
logger.debug( defer.returnValue(context)
"_handle_new_event: After persist_event: %s, sigs: %s",
event.event_id, event.signatures, @defer.inlineCallbacks
def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
missing):
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
yield self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain([event_id])
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
ret = yield self.construct_auth_difference(
local_auth_chain, remote_auth_chain
) )
defer.returnValue(context) for event in ret["auth_chain"]:
event.signatures.update(
compute_event_signature(
event,
self.hs.hostname,
self.hs.config.signing_key[0]
)
)
logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret)
@defer.inlineCallbacks
def on_get_missing_events(self, origin, room_id, earliest_events,
latest_events, limit, min_depth):
in_room = yield self.auth.check_host_in_room(
room_id,
origin
)
if not in_room:
raise AuthError(403, "Host not in room.")
limit = min(limit, 20)
min_depth = max(min_depth, 0)
missing_events = yield self.store.get_missing_events(
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
min_depth=min_depth,
)
defer.returnValue(missing_events)
@defer.inlineCallbacks
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
have_events = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events)
seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events
if missing_auth:
logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
try:
remote_auth_chain = yield self.replication_layer.get_event_auth(
origin, event.room_id, event.event_id
)
seen_remotes = yield self.store.have_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
if e.event_id in seen_remotes.keys():
continue
if e.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in remote_auth_chain
if e.event_id in auth_ids
}
e.internal_metadata.outlier = True
logger.debug(
"do_auth %s missing_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, e, auth_events=auth
)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
except AuthError:
pass
have_events = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
seen_events = set(have_events.keys())
except:
# FIXME:
logger.exception("Failed to get auth chain")
# FIXME: Assumes we have and stored all the state for all the
# prev_events
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
different_events = yield defer.gatherResults(
[
self.store.get_event(
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
],
consumeErrors=True
)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
(d.type, d.state_key): d for d in different_events
})
new_state, prev_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()],
event
)
auth_events.update(new_state)
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
context.current_state.update(auth_events)
context.state_group = None
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
# Only do auth resolution if we have something new to say.
# We can't rove an auth failure.
do_resolution = False
provable = [
RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
]
for e_id in different_auth:
if e_id in have_events:
if have_events[e_id] in provable:
do_resolution = True
break
if do_resolution:
# 1. Get what we think is the auth chain.
auth_ids = self.auth.compute_auth_events(
event, context.current_state
)
local_auth_chain = yield self.store.get_auth_chain(auth_ids)
try:
# 2. Get remote difference.
result = yield self.replication_layer.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
seen_remotes = yield self.store.have_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes.keys():
continue
if ev.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in ev.auth_events]
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
}
ev.internal_metadata.outlier = True
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, ev, auth_events=auth
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
except:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
context.current_state.update(auth_events)
context.state_group = None
try:
self.auth.check(event, auth_events=auth_events)
except AuthError:
raise
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:
local_auth (list)
remote_auth (list)
Returns:
dict
"""
logger.debug("construct_auth_difference Start!")
# TODO: Make sure we are OK with local_auth or remote_auth having more
# auth events in them than strictly necessary.
def sort_fun(ev):
return ev.depth, ev.event_id
logger.debug("construct_auth_difference after sort_fun!")
# We find the differences by starting at the "bottom" of each list
# and iterating up on both lists. The lists are ordered by depth and
# then event_id, we iterate up both lists until we find the event ids
# don't match. Then we look at depth/event_id to see which side is
# missing that event, and iterate only up that list. Repeat.
remote_list = list(remote_auth)
remote_list.sort(key=sort_fun)
local_list = list(local_auth)
local_list.sort(key=sort_fun)
local_iter = iter(local_list)
remote_iter = iter(remote_list)
logger.debug("construct_auth_difference before get_next!")
def get_next(it, opt=None):
try:
return it.next()
except:
return opt
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
logger.debug("construct_auth_difference before while")
missing_remotes = []
missing_locals = []
while current_local or current_remote:
if current_remote is None:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local is None:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
if current_local.event_id == current_remote.event_id:
current_local = get_next(local_iter)
current_remote = get_next(remote_iter)
continue
if current_local.depth < current_remote.depth:
missing_locals.append(current_local)
current_local = get_next(local_iter)
continue
if current_local.depth > current_remote.depth:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
# They have the same depth, so we fall back to the event_id order
if current_local.event_id < current_remote.event_id:
missing_locals.append(current_local)
current_local = get_next(local_iter)
if current_local.event_id > current_remote.event_id:
missing_remotes.append(current_remote)
current_remote = get_next(remote_iter)
continue
logger.debug("construct_auth_difference after while")
# missing locals should be sent to the server
# We should find why we are missing remotes, as they will have been
# rejected.
# Remove events from missing_remotes if they are referencing a missing
# remote. We only care about the "root" rejected ones.
missing_remote_ids = [e.event_id for e in missing_remotes]
base_remote_rejected = list(missing_remotes)
for e in missing_remotes:
for e_id, _ in e.auth_events:
if e_id in missing_remote_ids:
try:
base_remote_rejected.remove(e)
except ValueError:
pass
reason_map = {}
for e in base_remote_rejected:
reason = yield self.store.get_rejection_reason(e.event_id)
if reason is None:
# TODO: e is not in the current state, so we should
# construct some proof of that.
continue
reason_map[e.event_id] = reason
if reason == RejectedReason.AUTH_ERROR:
pass
elif reason == RejectedReason.REPLACED:
# TODO: Get proof
pass
elif reason == RejectedReason.NOT_ANCESTOR:
# TODO: Get proof.
pass
logger.debug("construct_auth_difference returning")
defer.returnValue({
"auth_chain": local_auth,
"rejects": {
e.event_id: {
"reason": reason_map[e.event_id],
"proof": None,
}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
})

View File

@ -16,12 +16,13 @@
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes from synapse.api.errors import LoginError, Codes, CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils import synapse.util.emailutils as emailutils
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -96,16 +97,20 @@ class LoginHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def _query_email(self, email): def _query_email(self, email):
httpCli = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
data = yield httpCli.get_json( try:
# TODO FIXME This should be configurable. data = yield http_client.get_json(
# XXX: ID servers need to use HTTPS # TODO FIXME This should be configurable.
"http://%s%s" % ( # XXX: ID servers need to use HTTPS
"matrix.org:8090", "/_matrix/identity/api/v1/lookup" "http://%s%s" % (
), "matrix.org:8090", "/_matrix/identity/api/v1/lookup"
{ ),
'medium': 'email', {
'address': email 'medium': 'email',
} 'address': email
) }
defer.returnValue(data) )
defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@ -16,10 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import RoomError from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from ._base import BaseHandler from ._base import BaseHandler
@ -33,6 +35,7 @@ class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(MessageHandler, self).__init__(hs) super(MessageHandler, self).__init__(hs)
self.hs = hs self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator() self.validator = EventValidator()
@ -89,7 +92,7 @@ class MessageHandler(BaseHandler):
yield self.hs.get_event_sources().get_current_token() yield self.hs.get_event_sources().get_current_token()
) )
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows( events, next_key = yield data_source.get_pagination_rows(
user, pagin_config.get_source_config("room"), room_id user, pagin_config.get_source_config("room"), room_id
@ -99,9 +102,11 @@ class MessageHandler(BaseHandler):
"room_key", next_key "room_key", next_key
) )
time_now = self.clock.time_msec()
chunk = { chunk = {
"chunk": [ "chunk": [
self.hs.serialize_event(e, as_client_event) for e in events serialize_event(e, time_now, as_client_event) for e in events
], ],
"start": pagin_config.from_token.to_string(), "start": pagin_config.from_token.to_string(),
"end": next_token.to_string(), "end": next_token.to_string(),
@ -110,7 +115,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks @defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True): def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event. """ Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events, Creates an FrozenEvent object, filling out auth_events, prev_events,
@ -130,13 +136,13 @@ class MessageHandler(BaseHandler):
if ratelimit: if ratelimit:
self.ratelimit(builder.user_id) self.ratelimit(builder.user_id)
# TODO(paul): Why does 'event' not have a 'user' object? # TODO(paul): Why does 'event' not have a 'user' object?
user = self.hs.parse_userid(builder.user_id) user = UserID.from_string(builder.user_id)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,) assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if builder.type == EventTypes.Member: if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None) membership = builder.content.get("membership", None)
if membership == Membership.JOIN: if membership == Membership.JOIN:
joinee = self.hs.parse_userid(builder.state_key) joinee = UserID.from_string(builder.state_key)
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield self.distributor.fire( yield self.distributor.fire(
"collect_presencelike_data", "collect_presencelike_data",
@ -144,6 +150,15 @@ class MessageHandler(BaseHandler):
builder.content builder.content
) )
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event( event, context = yield self._create_new_client_event(
builder=builder, builder=builder,
) )
@ -210,7 +225,10 @@ class MessageHandler(BaseHandler):
# TODO: This is duplicating logic from snapshot_all_rooms # TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id) current_state = yield self.state_handler.get_current_state(room_id)
defer.returnValue([self.hs.serialize_event(c) for c in current_state]) now = self.clock.time_msec()
defer.returnValue(
[serialize_event(c, now) for c in current_state.values()]
)
@defer.inlineCallbacks @defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, def snapshot_all_rooms(self, user_id=None, pagin_config=None,
@ -237,7 +255,7 @@ class MessageHandler(BaseHandler):
membership_list=[Membership.INVITE, Membership.JOIN] membership_list=[Membership.INVITE, Membership.JOIN]
) )
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
rooms_ret = [] rooms_ret = []
@ -282,10 +300,11 @@ class MessageHandler(BaseHandler):
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
d["messages"] = { d["messages"] = {
"chunk": [ "chunk": [
self.hs.serialize_event(m, as_client_event) serialize_event(m, time_now, as_client_event)
for m in messages for m in messages
], ],
"start": start_token.to_string(), "start": start_token.to_string(),
@ -296,7 +315,8 @@ class MessageHandler(BaseHandler):
event.room_id event.room_id
) )
d["state"] = [ d["state"] = [
self.hs.serialize_event(c) for c in current_state serialize_event(c, time_now, as_client_event)
for c in current_state.values()
] ]
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -312,20 +332,27 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None, def room_initial_sync(self, user_id, room_id, pagin_config=None,
feedback=False): feedback=False):
yield self.auth.check_joined_room(room_id, user_id) current_state = yield self.state.get_current_state(
room_id=room_id,
)
yield self.auth.check_joined_room(
room_id, user_id,
current_state=current_state
)
# TODO(paul): I wish I was called with user objects not user_id # TODO(paul): I wish I was called with user objects not user_id
# strings... # strings...
auth_user = self.hs.parse_userid(user_id) auth_user = UserID.from_string(user_id)
# TODO: These concurrently # TODO: These concurrently
state_tuples = yield self.state_handler.get_current_state(room_id) time_now = self.clock.time_msec()
state = [self.hs.serialize_event(x) for x in state_tuples] state = [
serialize_event(x, time_now)
for x in current_state.values()
]
member_event = (yield self.store.get_room_member( member_event = current_state.get((EventTypes.Member, user_id,))
user_id=user_id,
room_id=room_id
))
now_token = yield self.hs.get_event_sources().get_current_token() now_token = yield self.hs.get_event_sources().get_current_token()
@ -342,28 +369,34 @@ class MessageHandler(BaseHandler):
start_token = now_token.copy_and_replace("room_key", token[0]) start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1]) end_token = now_token.copy_and_replace("room_key", token[1])
room_members = yield self.store.get_room_members(room_id) room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_handlers().presence_handler presence_handler = self.hs.get_handlers().presence_handler
presence = [] presence = []
for m in room_members: for m in room_members:
try: try:
member_presence = yield presence_handler.get_state( member_presence = yield presence_handler.get_state(
target_user=self.hs.parse_userid(m.user_id), target_user=UserID.from_string(m.user_id),
auth_user=auth_user, auth_user=auth_user,
as_event=True, as_event=True,
) )
presence.append(member_presence) presence.append(member_presence)
except Exception: except SynapseError:
logger.exception( logger.exception(
"Failed to get member presence of %r", m.user_id "Failed to get member presence of %r", m.user_id
) )
time_now = self.clock.time_msec()
defer.returnValue({ defer.returnValue({
"membership": member_event.membership, "membership": member_event.membership,
"room_id": room_id, "room_id": room_id,
"messages": { "messages": {
"chunk": [self.hs.serialize_event(m) for m in messages], "chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(), "start": start_token.to_string(),
"end": end_token.to_string(), "end": end_token.to_string(),
}, },

View File

@ -20,6 +20,7 @@ from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from ._base import BaseHandler from ._base import BaseHandler
@ -86,6 +87,10 @@ class PresenceHandler(BaseHandler):
"changed_presencelike_data", self.changed_presencelike_data "changed_presencelike_data", self.changed_presencelike_data
) )
# outbound signal from the presence module to advertise when a user's
# presence has changed
distributor.declare("user_presence_changed")
self.distributor = distributor self.distributor = distributor
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
@ -96,22 +101,22 @@ class PresenceHandler(BaseHandler):
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence_invite", "m.presence_invite",
lambda origin, content: self.invite_presence( lambda origin, content: self.invite_presence(
observed_user=hs.parse_userid(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=hs.parse_userid(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence_accept", "m.presence_accept",
lambda origin, content: self.accept_presence( lambda origin, content: self.accept_presence(
observed_user=hs.parse_userid(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=hs.parse_userid(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.presence_deny", "m.presence_deny",
lambda origin, content: self.deny_presence( lambda origin, content: self.deny_presence(
observed_user=hs.parse_userid(content["observed_user"]), observed_user=UserID.from_string(content["observed_user"]),
observer_user=hs.parse_userid(content["observer_user"]), observer_user=UserID.from_string(content["observer_user"]),
) )
) )
@ -418,7 +423,7 @@ class PresenceHandler(BaseHandler):
) )
for p in presence: for p in presence:
observed_user = self.hs.parse_userid(p.pop("observed_user_id")) observed_user = UserID.from_string(p.pop("observed_user_id"))
p["observed_user"] = observed_user p["observed_user"] = observed_user
p.update(self._get_or_offline_usercache(observed_user).get_state()) p.update(self._get_or_offline_usercache(observed_user).get_state())
if "last_active" in p: if "last_active" in p:
@ -441,7 +446,7 @@ class PresenceHandler(BaseHandler):
user.localpart, accepted=True user.localpart, accepted=True
) )
target_users = set([ target_users = set([
self.hs.parse_userid(x["observed_user_id"]) for x in presence UserID.from_string(x["observed_user_id"]) for x in presence
]) ])
# Also include people in all my rooms # Also include people in all my rooms
@ -452,9 +457,9 @@ class PresenceHandler(BaseHandler):
if state is None: if state is None:
state = yield self.store.get_presence_state(user.localpart) state = yield self.store.get_presence_state(user.localpart)
else: else:
# statuscache = self._get_or_make_usercache(user) # statuscache = self._get_or_make_usercache(user)
# self._user_cachemap_latest_serial += 1 # self._user_cachemap_latest_serial += 1
# statuscache.update(state, self._user_cachemap_latest_serial) # statuscache.update(state, self._user_cachemap_latest_serial)
pass pass
yield self.push_update_to_local_and_remote( yield self.push_update_to_local_and_remote(
@ -487,7 +492,7 @@ class PresenceHandler(BaseHandler):
user, domain, remoteusers user, domain, remoteusers
)) ))
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user): def _start_polling_local(self, user, target_user):
target_localpart = target_user.localpart target_localpart = target_user.localpart
@ -543,7 +548,7 @@ class PresenceHandler(BaseHandler):
self._stop_polling_remote(user, domain, remoteusers) self._stop_polling_remote(user, domain, remoteusers)
) )
return defer.DeferredList(deferreds) return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user): def _stop_polling_local(self, user, target_user):
for localpart in self._local_pushmap.keys(): for localpart in self._local_pushmap.keys():
@ -603,6 +608,7 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids, room_ids=room_ids,
statuscache=statuscache, statuscache=statuscache,
) )
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks @defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None): def _push_presence_remote(self, user, destination, state=None):
@ -646,13 +652,15 @@ class PresenceHandler(BaseHandler):
deferreds = [] deferreds = []
for push in content.get("push", []): for push in content.get("push", []):
user = self.hs.parse_userid(push["user_id"]) user = UserID.from_string(push["user_id"])
logger.debug("Incoming presence update from %s", user) logger.debug("Incoming presence update from %s", user)
observers = set(self._remote_recvmap.get(user, set())) observers = set(self._remote_recvmap.get(user, set()))
if observers: if observers:
logger.debug(" | %d interested local observers %r", len(observers), observers) logger.debug(
" | %d interested local observers %r", len(observers), observers
)
rm_handler = self.homeserver.get_handlers().room_member_handler rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(user) room_ids = yield rm_handler.get_rooms_for_user(user)
@ -694,14 +702,14 @@ class PresenceHandler(BaseHandler):
del self._user_cachemap[user] del self._user_cachemap[user]
for poll in content.get("poll", []): for poll in content.get("poll", []):
user = self.hs.parse_userid(poll) user = UserID.from_string(poll)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
continue continue
# TODO(paul) permissions checks # TODO(paul) permissions checks
if not user in self._remote_sendmap: if user not in self._remote_sendmap:
self._remote_sendmap[user] = set() self._remote_sendmap[user] = set()
self._remote_sendmap[user].add(origin) self._remote_sendmap[user].add(origin)
@ -709,7 +717,7 @@ class PresenceHandler(BaseHandler):
deferreds.append(self._push_presence_remote(user, origin)) deferreds.append(self._push_presence_remote(user, origin))
for unpoll in content.get("unpoll", []): for unpoll in content.get("unpoll", []):
user = self.hs.parse_userid(unpoll) user = UserID.from_string(unpoll)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
continue continue
@ -721,7 +729,7 @@ class PresenceHandler(BaseHandler):
del self._remote_sendmap[user] del self._remote_sendmap[user]
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
@ -760,7 +768,7 @@ class PresenceHandler(BaseHandler):
) )
) )
yield defer.DeferredList(deferreds) yield defer.DeferredList(deferreds, consumeErrors=True)
defer.returnValue((localusers, remote_domains)) defer.returnValue((localusers, remote_domains))

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from ._base import BaseHandler from ._base import BaseHandler
@ -169,7 +170,7 @@ class ProfileHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_profile_query(self, args): def on_profile_query(self, args):
user = self.hs.parse_userid(args["user_id"]) user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -211,10 +212,16 @@ class ProfileHandler(BaseHandler):
) )
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_event({ try:
"type": EventTypes.Member, yield msg_handler.create_and_send_event({
"room_id": j.room_id, "type": EventTypes.Member,
"state_key": user.to_string(), "room_id": j.room_id,
"content": content, "state_key": user.to_string(),
"sender": user.to_string() "content": content,
}, ratelimit=False) "sender": user.to_string()
}, ratelimit=False)
except Exception as e:
logger.warn(
"Failed to update join event for room %s - %s",
j.room_id, str(e.message)
)

View File

@ -18,7 +18,8 @@ from twisted.internet import defer
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, RegistrationError, InvalidCaptchaError AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError,
CodeMessageException
) )
from ._base import BaseHandler from ._base import BaseHandler
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
@ -28,6 +29,7 @@ from synapse.http.client import CaptchaServerHttpClient
import base64 import base64
import bcrypt import bcrypt
import json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -64,6 +66,8 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
user_id=user_id, user_id=user_id,
@ -82,6 +86,7 @@ class RegistrationHandler(BaseHandler):
localpart = self._generate_user_id() localpart = self._generate_user_id()
user = UserID(localpart, self.hs.hostname) user = UserID(localpart, self.hs.hostname)
user_id = user.to_string() user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id) token = self._generate_token(user_id)
yield self.store.register( yield self.store.register(
@ -99,6 +104,47 @@ class RegistrationHandler(BaseHandler):
raise RegistrationError( raise RegistrationError(
500, "Cannot generate user ID.") 500, "Cannot generate user ID.")
# create a default avatar for the user
# XXX: ideally clients would explicitly specify one, but given they don't
# and we want consistent and pretty identicons for random users, we'll
# do it here.
try:
auth_user = UserID.from_string(user_id)
media_repository = self.hs.get_resource_for_media_repository()
identicon_resource = media_repository.getChildWithDefault("identicon", None)
upload_resource = media_repository.getChildWithDefault("upload", None)
identicon_bytes = identicon_resource.generate_identicon(user_id, 320, 320)
content_uri = yield upload_resource.create_content(
"image/png", None, identicon_bytes, len(identicon_bytes), auth_user
)
profile_handler = self.hs.get_handlers().profile_handler
profile_handler.set_avatar_url(
auth_user, auth_user, ("%s#auto" % (content_uri,))
)
except NotImplementedError:
pass # make tests pass without messing around creating default avatars
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = yield self.store.get_app_service_by_token(as_token)
if not service:
raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id):
raise SynapseError(
400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE
)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
password_hash=""
)
self.distributor.fire("registered_user", user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -147,6 +193,21 @@ class RegistrationHandler(BaseHandler):
# XXX: This should be a deferred list, shouldn't it? # XXX: This should be a deferred list, shouldn't it?
yield self._bind_threepid(c, user_id) yield self._bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_is_valid(self, user_id):
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = yield self.store.get_app_services()
interested_services = [
s for s in services if s.is_interested_in_user(user_id)
]
for service in interested_services:
if service.is_exclusive_user(user_id):
raise SynapseError(
400, "This user ID is reserved by an application service.",
errcode=Codes.EXCLUSIVE
)
def _generate_token(self, user_id): def _generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace # urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as # all =s with .s so http clients don't quote =s when it is used as
@ -161,21 +222,26 @@ class RegistrationHandler(BaseHandler):
def _threepid_from_creds(self, creds): def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for # TODO: get this from the homeserver rather than creating a new one for
# each request # each request
httpCli = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
trustedIdServers = ['matrix.org:8090'] trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers: if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' + logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer']) 'credentials', creds['idServer'])
defer.returnValue(None) defer.returnValue(None)
data = yield httpCli.get_json(
# XXX: This should be HTTPS data = {}
"http://%s%s" % ( try:
creds['idServer'], data = yield http_client.get_json(
"/_matrix/identity/api/v1/3pid/getValidated3pid" # XXX: This should be HTTPS
), "http://%s%s" % (
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']} creds['idServer'],
) "/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data: if 'medium' in data:
defer.returnValue(data) defer.returnValue(data)
@ -185,19 +251,23 @@ class RegistrationHandler(BaseHandler):
def _bind_threepid(self, creds, mxid): def _bind_threepid(self, creds, mxid):
yield yield
logger.debug("binding threepid") logger.debug("binding threepid")
httpCli = SimpleHttpClient(self.hs) http_client = SimpleHttpClient(self.hs)
data = yield httpCli.post_urlencoded_get_json( data = None
# XXX: Change when ID servers are all HTTPS try:
"http://%s%s" % ( data = yield http_client.post_urlencoded_get_json(
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" # XXX: Change when ID servers are all HTTPS
), "http://%s%s" % (
{ creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
'sid': creds['sid'], ),
'clientSecret': creds['clientSecret'], {
'mxid': mxid, 'sid': creds['sid'],
} 'clientSecret': creds['clientSecret'],
) 'mxid': mxid,
logger.debug("bound threepid") }
)
logger.debug("bound threepid")
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -16,12 +16,14 @@
"""Contains functions for performing events on rooms.""" """Contains functions for performing events on rooms."""
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import BaseHandler from synapse.events.utils import serialize_event
import logging import logging
@ -64,7 +66,7 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", []) invite_list = config.get("invite", [])
for i in invite_list: for i in invite_list:
try: try:
self.hs.parse_userid(i) UserID.from_string(i)
except: except:
raise SynapseError(400, "Invalid user_id: %s" % (i,)) raise SynapseError(400, "Invalid user_id: %s" % (i,))
@ -114,7 +116,7 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname], servers=[self.hs.hostname],
) )
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room( creation_events = self._create_events_for_new_room(
user, room_id, is_public=is_public user, room_id, is_public=is_public
) )
@ -246,11 +248,9 @@ class RoomMemberHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_members(self, room_id): def get_room_members(self, room_id):
hs = self.hs
users = yield self.store.get_users_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
defer.returnValue([hs.parse_userid(u) for u in users]) defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks @defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None, def fetch_room_distributions_into(self, room_id, localusers=None,
@ -295,8 +295,9 @@ class RoomMemberHandler(BaseHandler):
yield self.auth.check_joined_room(room_id, user_id) yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id) member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [ event_list = [
self.hs.serialize_event(entry) serialize_event(entry, time_now)
for entry in member_list for entry in member_list
] ]
chunk_data = { chunk_data = {
@ -368,7 +369,7 @@ class RoomMemberHandler(BaseHandler):
) )
if prev_state and prev_state.membership == Membership.JOIN: if prev_state and prev_state.membership == Membership.JOIN:
user = self.hs.parse_userid(event.user_id) user = UserID.from_string(event.user_id)
self.distributor.fire( self.distributor.fire(
"user_left_room", user=user, room_id=event.room_id "user_left_room", user=user, room_id=event.room_id
) )
@ -388,8 +389,6 @@ class RoomMemberHandler(BaseHandler):
if not hosts: if not hosts:
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
host = hosts[0]
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield self.distributor.fire( yield self.distributor.fire(
"collect_presencelike_data", joinee, content "collect_presencelike_data", joinee, content
@ -406,13 +405,13 @@ class RoomMemberHandler(BaseHandler):
}) })
event, context = yield self._create_new_client_event(builder) event, context = yield self._create_new_client_event(builder)
yield self._do_join(event, context, room_host=host, do_auth=True) yield self._do_join(event, context, room_hosts=hosts, do_auth=True)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_join(self, event, context, room_host=None, do_auth=True): def _do_join(self, event, context, room_hosts=None, do_auth=True):
joinee = self.hs.parse_userid(event.state_key) joinee = UserID.from_string(event.state_key)
# room_id = RoomID.from_string(event.room_id, self.hs) # room_id = RoomID.from_string(event.room_id, self.hs)
room_id = event.room_id room_id = event.room_id
@ -440,7 +439,7 @@ class RoomMemberHandler(BaseHandler):
if is_host_in_room: if is_host_in_room:
should_do_dance = False should_do_dance = False
elif room_host: # TODO: Shouldn't this be remote_room_host? elif room_hosts: # TODO: Shouldn't this be remote_room_host?
should_do_dance = True should_do_dance = True
else: else:
# TODO(markjh): get prev_state from snapshot # TODO(markjh): get prev_state from snapshot
@ -452,7 +451,7 @@ class RoomMemberHandler(BaseHandler):
inviter = UserID.from_string(prev_state.user_id) inviter = UserID.from_string(prev_state.user_id)
should_do_dance = not self.hs.is_mine(inviter) should_do_dance = not self.hs.is_mine(inviter)
room_host = inviter.domain room_hosts = [inviter.domain]
else: else:
# return the same error as join_room_alias does # return the same error as join_room_alias does
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
@ -460,10 +459,10 @@ class RoomMemberHandler(BaseHandler):
if should_do_dance: if should_do_dance:
handler = self.hs.get_handlers().federation_handler handler = self.hs.get_handlers().federation_handler
yield handler.do_invite_join( yield handler.do_invite_join(
room_host, room_hosts,
room_id, room_id,
event.user_id, event.user_id,
event.get_dict()["content"], # FIXME To get a non-frozen dict event.content, # FIXME To get a non-frozen dict
context context
) )
else: else:
@ -476,7 +475,7 @@ class RoomMemberHandler(BaseHandler):
do_auth=do_auth, do_auth=do_auth,
) )
user = self.hs.parse_userid(event.user_id) user = UserID.from_string(event.user_id)
yield self.distributor.fire( yield self.distributor.fire(
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )
@ -511,9 +510,16 @@ class RoomMemberHandler(BaseHandler):
def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]): def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given
membership states in.""" membership states in."""
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
) )
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user.to_string(), membership_list=membership_list
)
# For some reason the list of events contains duplicates # For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should # TODO(paul): work out why because I really don't think it should
@ -526,7 +532,7 @@ class RoomMemberHandler(BaseHandler):
do_auth): do_auth):
yield run_on_reactor() yield run_on_reactor()
target_user = self.hs.parse_userid(event.state_key) target_user = UserID.from_string(event.state_key)
yield self.handle_new_client_event( yield self.handle_new_client_event(
event, event,
@ -560,13 +566,24 @@ class RoomEventSource(object):
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
events, end_key = yield self.store.get_room_events_stream( app_service = yield self.store.get_app_service_by_user_id(
user_id=user.to_string(), user.to_string()
from_key=from_key,
to_key=to_key,
room_id=None,
limit=limit,
) )
if app_service:
events, end_key = yield self.store.get_appservice_room_stream(
service=app_service,
from_key=from_key,
to_key=to_key,
limit=limit,
)
else:
events, end_key = yield self.store.get_room_events_stream(
user_id=user.to_string(),
from_key=from_key,
to_key=to_key,
room_id=None,
limit=limit,
)
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))

439
synapse/handlers/sync.py Normal file
View File

@ -0,0 +1,439 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from twisted.internet import defer
import collections
import logging
logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"client_info",
"limit",
"gap",
"sort",
"backfill",
"filter",
])
class RoomSyncResult(collections.namedtuple("RoomSyncResult", [
"room_id",
"limited",
"published",
"events",
"state",
"prev_batch",
"ephemeral",
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.events or self.state or self.ephemeral)
class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync
"private_user_data", # List of private events for the user.
"public_user_data", # List of public events for all users.
"rooms", # RoomSyncResult for each room.
])):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if the notifier needs to wait for more events when polling for
events.
"""
return bool(
self.private_user_data or self.public_user_data or self.rooms
)
class SyncHandler(BaseHandler):
def __init__(self, hs):
super(SyncHandler, self).__init__(hs)
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
Returns:
A Deferred SyncResult.
"""
if timeout == 0 or since_token is None:
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback():
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
)
defer.returnValue(result)
def current_sync_for_user(self, sync_config, since_token=None):
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
"""
if since_token is None:
return self.initial_sync(sync_config)
else:
if sync_config.gap:
return self.incremental_sync_with_gap(sync_config, since_token)
else:
# TODO(mjark): Handle gapless sync
raise NotImplementedError()
@defer.inlineCallbacks
def initial_sync(self, sync_config):
"""Get a sync for a client which is starting without any state
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
presence_stream = self.event_sources.sources["presence"]
# TODO (mjark): This looks wrong, shouldn't we be getting the presence
# UP to the present rather than after the present?
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user=sync_config.user,
pagination_config=pagination_config.get_source_config("presence"),
key=None
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(),
membership_list=[Membership.INVITE, Membership.JOIN]
)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
rooms = []
for event in room_list:
room_sync = yield self.initial_sync_for_room(
event.room_id, sync_config, now_token, published_room_ids
)
rooms.append(room_sync)
defer.returnValue(SyncResult(
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def initial_sync_for_room(self, room_id, sync_config, now_token,
published_room_ids):
"""Sync a room for a client which is starting without any state
Returns:
A Deferred RoomSyncResult.
"""
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token,
)
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
defer.returnValue(RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=current_state_events,
limited=limited,
ephemeral=[],
))
@defer.inlineCallbacks
def incremental_sync_with_gap(self, sync_config, since_token):
""" Get the incremental delta needed to bring the client up to
date with the server.
Returns:
A Deferred SyncResult.
"""
if sync_config.sort == "timeline,desc":
# TODO(mjark): Handle going through events in reverse order?.
# What does "most recent events" mean when applying the limits mean
# in this case?
raise NotImplementedError()
now_token = yield self.event_sources.get_current_token()
presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.presence_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("presence_key", presence_key)
typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user(
user=sync_config.user,
from_key=since_token.typing_key,
limit=sync_config.limit,
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
typing_by_room = {event["room_id"]: [event] for event in typing}
for event in typing:
event.pop("room_id")
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
published_room_ids = set(r["room_id"] for r in published_rooms)
room_events, _ = yield self.store.get_room_events_stream(
sync_config.user.to_string(),
from_key=since_token.room_key,
to_key=now_token.room_key,
room_id=None,
limit=sync_config.limit + 1,
)
rooms = []
if len(room_events) <= sync_config.limit:
# There is no gap in any of the rooms. Therefore we can just
# partition the new events by room and return them.
events_by_room_id = {}
for event in room_events:
events_by_room_id.setdefault(event.room_id, []).append(event)
for room_id in room_ids:
recents = events_by_room_id.get(room_id, [])
state = [event for event in recents if event.is_state()]
if recents:
prev_batch = now_token.copy_and_replace(
"room_key", recents[0].internal_metadata.before
)
else:
prev_batch = now_token
state = yield self.check_joined_room(
sync_config, room_id, state
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch,
state=state,
limited=False,
ephemeral=typing_by_room.get(room_id, [])
)
if room_sync:
rooms.append(room_sync)
else:
for room_id in room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token,
published_room_ids, typing_by_room
)
if room_sync:
rooms.append(room_sync)
defer.returnValue(SyncResult(
public_user_data=presence,
private_user_data=[],
rooms=rooms,
next_batch=now_token,
))
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
limited = True
recents = []
filtering_factor = 2
load_limit = max(sync_config.limit * filtering_factor, 100)
max_repeat = 3 # Only try a few times per room, otherwise
room_key = now_token.room_key
end_key = room_key
while limited and len(recents) < sync_config.limit and max_repeat:
events, keys = yield self.store.get_recent_events_for_room(
room_id,
limit=load_limit + 1,
from_token=since_token.room_key if since_token else None,
end_token=end_key,
)
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:
limited = False
max_repeat -= 1
if len(recents) > sync_config.limit:
recents = recents[-sync_config.limit:]
room_key = recents[0].internal_metadata.before
prev_batch_token = now_token.copy_and_replace(
"room_key", room_key
)
defer.returnValue((recents, prev_batch_token, limited))
@defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token,
published_room_ids, typing_by_room):
""" Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to
state.
Returns:
A Deferred RoomSyncResult
"""
# TODO(mjark): Check for redactions we might have missed.
recents, prev_batch_token, limited = yield self.load_filtered_recents(
room_id, sync_config, now_token, since_token,
)
logging.debug("Recents %r", recents)
# TODO(mjark): This seems racy since this isn't being passed a
# token to indicate what point in the stream this is
current_state = yield self.state_handler.get_current_state(
room_id
)
current_state_events = current_state.values()
state_at_previous_sync = yield self.get_state_at_previous_sync(
room_id, since_token=since_token
)
state_events_delta = yield self.compute_state_delta(
since_token=since_token,
previous_state=state_at_previous_sync,
current_state=current_state_events,
)
state_events_delta = yield self.check_joined_room(
sync_config, room_id, state_events_delta
)
room_sync = RoomSyncResult(
room_id=room_id,
published=room_id in published_room_ids,
events=recents,
prev_batch=prev_batch_token,
state=state_events_delta,
limited=limited,
ephemeral=typing_by_room.get(room_id, [])
)
logging.debug("Room sync: %r", room_sync)
defer.returnValue(room_sync)
@defer.inlineCallbacks
def get_state_at_previous_sync(self, room_id, since_token):
""" Get the room state at the previous sync the client made.
Returns:
A Deferred list of Events.
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=since_token.room_key, limit=1,
)
if last_events:
last_event = last_events[0]
last_context = yield self.state_handler.compute_event_context(
last_event
)
if last_event.is_state():
state = [last_event] + last_context.current_state.values()
else:
state = last_context.current_state.values()
else:
state = ()
defer.returnValue(state)
def compute_state_delta(self, since_token, previous_state, current_state):
""" Works out the differnce in state between the current state and the
state the client got when it last performed a sync.
Returns:
A list of events.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
# updates even if they occured logically before the previous event.
# TODO(mjark) Check for new redactions in the state events.
previous_dict = {event.event_id: event for event in previous_state}
state_delta = []
for event in current_state:
if event.event_id not in previous_dict:
state_delta.append(event)
return state_delta
@defer.inlineCallbacks
def check_joined_room(self, sync_config, room_id, state_delta):
joined = False
for event in state_delta:
if (
event.type == EventTypes.Member
and event.state_key == sync_config.user.to_string()
):
if event.content["membership"] == Membership.JOIN:
joined = True
if joined:
res = yield self.state_handler.get_current_state(room_id)
state_delta = res.values()
defer.returnValue(state_delta)

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError from synapse.api.errors import SynapseError, AuthError
from synapse.types import UserID
import logging import logging
@ -180,12 +181,12 @@ class TypingNotificationHandler(BaseHandler):
}, },
)) ))
yield defer.DeferredList(deferreds, consumeErrors=False) yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def _recv_edu(self, origin, content): def _recv_edu(self, origin, content):
room_id = content["room_id"] room_id = content["room_id"]
user = self.homeserver.parse_userid(content["user_id"]) user = UserID.from_string(content["user_id"])
localusers = set() localusers = set()

View File

@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api.errors import CodeMessageException
from syutil.jsonutil import encode_canonical_json
from synapse.http.agent_name import AGENT_NAME
from twisted.internet import defer, reactor from twisted.internet import defer, reactor
from twisted.web.client import ( from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError Agent, readBody, FileBodyProducer, PartialDownloadError
@ -23,7 +24,7 @@ from twisted.web.http_headers import Headers
from StringIO import StringIO from StringIO import StringIO
import json import simplejson as json
import logging import logging
import urllib import urllib
@ -42,6 +43,7 @@ class SimpleHttpClient(object):
# BrowserLikePolicyForHTTPS which will do regular cert validation # BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser' # 'like a browser'
self.agent = Agent(reactor) self.agent = Agent(reactor)
self.version_string = hs.version_string
@defer.inlineCallbacks @defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}): def post_urlencoded_get_json(self, uri, args={}):
@ -53,7 +55,7 @@ class SimpleHttpClient(object):
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [AGENT_NAME], b"User-Agent": [self.version_string],
}), }),
bodyProducer=FileBodyProducer(StringIO(query_bytes)) bodyProducer=FileBodyProducer(StringIO(query_bytes))
) )
@ -62,9 +64,28 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json)
logger.info("HTTP POST %s -> %s", json_str, uri)
response = yield self.agent.request(
"POST",
uri.encode("ascii"),
headers=Headers({
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield readBody(response)
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, uri, args={}): def get_json(self, uri, args={}):
""" Get's some json from the given host and path """ Gets some json from the given URI.
Args: Args:
uri (str): The URI to request, not including query parameters uri (str): The URI to request, not including query parameters
@ -72,15 +93,13 @@ class SimpleHttpClient(object):
None. None.
**Note**: The value of each key is assumed to be an iterable **Note**: The value of each key is assumed to be an iterable
and *not* a string. and *not* a string.
Returns: Returns:
Deferred: Succeeds when we get *any* HTTP response. Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
The result of the deferred is a tuple of `(code, response)`, Raises:
where `response` is a dict representing the decoded JSON body. On a non-2xx HTTP response. The response body will be used as the
error message.
""" """
yield
if len(args): if len(args):
query_bytes = urllib.urlencode(args, True) query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes) uri = "%s?%s" % (uri, query_bytes)
@ -89,13 +108,62 @@ class SimpleHttpClient(object):
"GET", "GET",
uri.encode("ascii"), uri.encode("ascii"),
headers=Headers({ headers=Headers({
b"User-Agent": [AGENT_NAME], b"User-Agent": [self.version_string],
}) })
) )
body = yield readBody(response) body = yield readBody(response)
defer.returnValue(json.loads(body)) if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}):
""" Puts some json to the given URI.
Args:
uri (str): The URI to request, not including query parameters
json_body (dict): The JSON to put in the HTTP body,
args (dict): A dictionary used to create query strings, defaults to
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
HTTP body as JSON.
Raises:
On a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
json_str = encode_canonical_json(json_body)
response = yield self.agent.request(
"PUT",
uri.encode("ascii"),
headers=Headers({
b"User-Agent": [self.version_string],
"Content-Type": ["application/json"]
}),
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield readBody(response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
else:
# NB: This is explicitly not json.loads(body)'d because the contract
# of CodeMessageException is a *string* message. Callers can always
# load it into JSON if they want.
raise CodeMessageException(response.code, body)
class CaptchaServerHttpClient(SimpleHttpClient): class CaptchaServerHttpClient(SimpleHttpClient):
@ -114,7 +182,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
bodyProducer=FileBodyProducer(StringIO(query_bytes)), bodyProducer=FileBodyProducer(StringIO(query_bytes)),
headers=Headers({ headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"], b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [AGENT_NAME], b"User-Agent": [self.version_string],
}) })
) )

View File

@ -20,18 +20,19 @@ from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone from twisted.web._newclient import ResponseDone
from synapse.http.agent_name import AGENT_NAME
from synapse.http.endpoint import matrix_federation_endpoint from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from syutil.jsonutil import encode_canonical_json from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import CodeMessageException, SynapseError, Codes from synapse.api.errors import (
SynapseError, Codes, HttpResponseException,
)
from syutil.crypto.jsonsign import sign_json from syutil.crypto.jsonsign import sign_json
import json import simplejson as json
import logging import logging
import urllib import urllib
import urlparse import urlparse
@ -77,6 +78,8 @@ class MatrixFederationHttpClient(object):
self.signing_key = hs.config.signing_key[0] self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname self.server_name = hs.hostname
self.agent = MatrixFederationHttpAgent(reactor) self.agent = MatrixFederationHttpAgent(reactor)
self.clock = hs.get_clock()
self.version_string = hs.version_string
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
@ -84,7 +87,7 @@ class MatrixFederationHttpClient(object):
query_bytes=b"", retry_on_dns_fail=True): query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url """ Creates and sends a request to the given url
""" """
headers_dict[b"User-Agent"] = [AGENT_NAME] headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination] headers_dict[b"Host"] = [destination]
url_bytes = urlparse.urlunparse( url_bytes = urlparse.urlunparse(
@ -116,7 +119,7 @@ class MatrixFederationHttpClient(object):
try: try:
with PreserveLoggingContext(): with PreserveLoggingContext():
response = yield self.agent.request( request_deferred = self.agent.request(
destination, destination,
endpoint, endpoint,
method, method,
@ -127,6 +130,11 @@ class MatrixFederationHttpClient(object):
producer producer
) )
response = yield self.clock.time_bound_deferred(
request_deferred,
time_out=60,
)
logger.debug("Got response to %s", method) logger.debug("Got response to %s", method)
break break
except Exception as e: except Exception as e:
@ -136,16 +144,16 @@ class MatrixFederationHttpClient(object):
destination, destination,
e e
) )
raise SynapseError(400, "Domain specified not found.") raise
logger.warn( logger.warn(
"Sending request failed to %s: %s %s : %s", "Sending request failed to %s: %s %s: %s - %s",
destination, destination,
method, method,
url_bytes, url_bytes,
e type(e).__name__,
_flatten_response_never_received(e),
) )
_print_ex(e)
if retries_left: if retries_left:
yield sleep(2 ** (5 - retries_left)) yield sleep(2 ** (5 - retries_left))
@ -163,13 +171,13 @@ class MatrixFederationHttpClient(object):
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
pass pass
else: else:
# :'( # :'(
# Update transactions table? # Update transactions table?
raise CodeMessageException( body = yield readBody(response)
response.code, response.phrase raise HttpResponseException(
response.code, response.phrase, body
) )
defer.returnValue(response) defer.returnValue(response)
@ -238,11 +246,66 @@ class MatrixFederationHttpClient(object):
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
) )
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body") logger.debug("Getting resp body")
body = yield readBody(response) body = yield readBody(response)
logger.debug("Got resp body") logger.debug("Got resp body")
defer.returnValue((response.code, body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def post_json(self, destination, path, data={}):
""" Sends the specifed json data using POST
Args:
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body. On a 4xx or 5xx error response a
CodeMessageException is raised.
"""
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._create_request(
destination.encode("ascii"),
"POST",
path.encode("ascii"),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
)
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True): def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
@ -284,7 +347,18 @@ class MatrixFederationHttpClient(object):
retry_on_dns_fail=retry_on_dns_fail retry_on_dns_fail=retry_on_dns_fail
) )
if 200 <= response.code < 300:
# We need to update the transactions table to say it was sent?
c_type = response.headers.getRawHeaders("Content-Type")
if "application/json" not in c_type:
raise RuntimeError(
"Content-Type not application/json"
)
logger.debug("Getting resp body")
body = yield readBody(response) body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@ -373,14 +447,6 @@ def _readBodyToFile(response, stream, max_size):
return d return d
def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.warn(e)
class _JsonProducer(object): class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json """ Used by the twisted http client to create the HTTP body from json
""" """
@ -400,3 +466,13 @@ class _JsonProducer(object):
def stopProducing(self): def stopProducing(self):
pass pass
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
return ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
else:
return "%s: %s" % (type(e).__name__, e.message,)

View File

@ -14,9 +14,8 @@
# limitations under the License. # limitations under the License.
from synapse.http.agent_name import AGENT_NAME
from synapse.api.errors import ( from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
) )
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -69,10 +68,12 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"]) _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
def __init__(self): def __init__(self, hs):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.clock = hs.get_clock()
self.path_regexs = {} self.path_regexs = {}
self.version_string = hs.version_string
def register_path(self, method, path_pattern, callback): def register_path(self, method, path_pattern, callback):
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
@ -111,6 +112,8 @@ class JsonResource(HttpServer, resource.Resource):
This checks if anyone has registered a callback for that method and This checks if anyone has registered a callback for that method and
path. path.
""" """
code = None
start = self.clock.time_msec()
try: try:
# Just say yes to OPTIONS. # Just say yes to OPTIONS.
if request.method == "OPTIONS": if request.method == "OPTIONS":
@ -121,37 +124,42 @@ class JsonResource(HttpServer, resource.Resource):
# and path regex match # and path regex match
for path_entry in self.path_regexs.get(request.method, []): for path_entry in self.path_regexs.get(request.method, []):
m = path_entry.pattern.match(request.path) m = path_entry.pattern.match(request.path)
if m: if not m:
# We found a match! Trigger callback and then return the continue
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
args = [ # We found a match! Trigger callback and then return the
urllib.unquote(u).decode("UTF-8") for u in m.groups() # returned response. We pass both the request and any
] # matched groups from the regex to the callback.
code, response = yield path_entry.callback( args = [
request, urllib.unquote(u).decode("UTF-8") for u in m.groups()
*args ]
)
self._send_response(request, code, response) logger.info(
return "Received request: %s %s",
request.method, request.path
)
code, response = yield path_entry.callback(
request,
*args
)
self._send_response(request, code, response)
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
self._send_response( raise UnrecognizedRequestError()
request,
400,
{"error": "Unrecognized request"}
)
except CodeMessageException as e: except CodeMessageException as e:
if isinstance(e, SynapseError): if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
else: else:
logger.exception(e) logger.exception(e)
code = e.code
self._send_response( self._send_response(
request, request,
e.code, code,
cs_exception(e), cs_exception(e),
response_code_message=e.response_code_message response_code_message=e.response_code_message
) )
@ -162,6 +170,14 @@ class JsonResource(HttpServer, resource.Resource):
500, 500,
{"error": "Internal server error"} {"error": "Internal server error"}
) )
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
def _send_response(self, request, code, response_json_object, def _send_response(self, request, code, response_json_object,
response_code_message=None): response_code_message=None):
@ -175,9 +191,13 @@ class JsonResource(HttpServer, resource.Resource):
return return
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
respond_with_json(request, code, response_json_object, send_cors=True, respond_with_json(
response_code_message=response_code_message, request, code, response_json_object,
pretty_print=self._request_user_agent_is_curl) send_cors=True,
response_code_message=response_code_message,
pretty_print=self._request_user_agent_is_curl,
version_string=self.version_string,
)
@staticmethod @staticmethod
def _request_user_agent_is_curl(request): def _request_user_agent_is_curl(request):
@ -207,18 +227,23 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False, def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False): response_code_message=None, pretty_print=False,
version_string=""):
if not pretty_print: if not pretty_print:
json_bytes = encode_pretty_printed_json(json_object) json_bytes = encode_pretty_printed_json(json_object)
else: else:
json_bytes = encode_canonical_json(json_object) json_bytes = encode_canonical_json(json_object)
return respond_with_json_bytes(request, code, json_bytes, send_cors, return respond_with_json_bytes(
response_code_message=response_code_message) request, code, json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
version_string=version_string
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False, def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
response_code_message=None): version_string="", response_code_message=None):
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.
Args: Args:
@ -232,7 +257,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message) request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
request.setHeader(b"Server", AGENT_NAME) request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
if send_cors: if send_cors:

View File

@ -50,6 +50,7 @@ class LocalKey(Resource):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.version_string = hs.version_string
self.response_body = encode_canonical_json( self.response_body = encode_canonical_json(
self.response_json_object(hs.config) self.response_json_object(hs.config)
) )
@ -82,7 +83,10 @@ class LocalKey(Resource):
return json_object return json_object
def render_GET(self, request): def render_GET(self, request):
return respond_with_json_bytes(request, 200, self.response_body) return respond_with_json_bytes(
request, 200, self.response_body,
version_string=self.version_string
)
def getChild(self, name, request): def getChild(self, name, request):
if name == '': if name == '':

113
synapse/http/servlet.py Normal file
View File

@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
from synapse.api.errors import SynapseError
import logging
logger = logging.getLogger(__name__)
class RestServlet(object):
""" A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
To use this latter, the implementing class instead provides a `PATTERN`
class attribute containing a pre-compiled regular expression. The automatic
register method will then use this method to register any of the following
instance methods associated with the corresponding HTTP method:
on_GET
on_PUT
on_POST
on_DELETE
on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
"""
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")
@staticmethod
def parse_integer(request, name, default=None, required=False):
if name in request.args:
try:
return int(request.args[name][0])
except:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_boolean(request, name, default=None, required=False):
if name in request.args:
try:
return {
"true": True,
"false": False,
}[request.args[name][0]]
except:
message = (
"Boolean query parameter %r must be one of"
" ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
if required:
message = "Missing boolean query parameter %r" % (name,)
raise SynapseError(400, message)
else:
return default
@staticmethod
def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
if name in request.args:
value = request.args[name][0]
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
name, ", ".join(repr(v) for v in allowed_values)
)
raise SynapseError(message)
else:
return value
else:
if required:
message = "Missing %s query parameter %r" % (param_type, name)
raise SynapseError(400, message)
else:
return default

View File

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
import logging import logging
@ -35,8 +36,10 @@ class _NotificationListener(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, from_token, limit, timeout, deferred): def __init__(self, user, rooms, from_token, limit, timeout, deferred,
appservice=None):
self.user = user self.user = user
self.appservice = appservice
self.from_token = from_token self.from_token = from_token
self.limit = limit self.limit = limit
self.timeout = timeout self.timeout = timeout
@ -60,10 +63,14 @@ class _NotificationListener(object):
pass pass
for room in self.rooms: for room in self.rooms:
lst = notifier.rooms_to_listeners.get(room, set()) lst = notifier.room_to_listeners.get(room, set())
lst.discard(self) lst.discard(self)
notifier.user_to_listeners.get(self.user, set()).discard(self) notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice:
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
class Notifier(object): class Notifier(object):
@ -76,8 +83,9 @@ class Notifier(object):
def __init__(self, hs): def __init__(self, hs):
self.hs = hs self.hs = hs
self.rooms_to_listeners = {} self.room_to_listeners = {}
self.user_to_listeners = {} self.user_to_listeners = {}
self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources() self.event_sources = hs.get_event_sources()
@ -98,15 +106,32 @@ class Notifier(object):
`extra_users` param. `extra_users` param.
""" """
yield run_on_reactor() yield run_on_reactor()
# poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services(
event
)
room_id = event.room_id room_id = event.room_id
room_source = self.event_sources.sources["room"] room_source = self.event_sources.sources["room"]
listeners = self.rooms_to_listeners.get(room_id, set()).copy() listeners = self.room_to_listeners.get(room_id, set()).copy()
for user in extra_users: for user in extra_users:
listeners |= self.user_to_listeners.get(user, set()).copy() listeners |= self.user_to_listeners.get(user, set()).copy()
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_listeners set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event):
listeners |= self.appservice_to_listeners.get(
appservice, set()
).copy()
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the # TODO (erikj): Can we make this more efficient by hitting the
@ -134,7 +159,8 @@ class Notifier(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList( yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners] [notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -158,7 +184,7 @@ class Notifier(object):
listeners |= self.user_to_listeners.get(user, set()).copy() listeners |= self.user_to_listeners.get(user, set()).copy()
for room in rooms: for room in rooms:
listeners |= self.rooms_to_listeners.get(room, set()).copy() listeners |= self.room_to_listeners.get(room, set()).copy()
@defer.inlineCallbacks @defer.inlineCallbacks
def notify(listener): def notify(listener):
@ -202,9 +228,57 @@ class Notifier(object):
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList( yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners] [notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
) )
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
deferred = defer.Deferred()
from_token = StreamToken("s0", "0", "0")
listener = [_NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)]
if timeout:
self._register_with_keys(listener[0])
result = yield callback()
if timeout:
timed_out = [False]
def _timeout_listener():
timed_out[0] = True
listener[0].notify(self, [], from_token, from_token)
self.clock.call_later(timeout/1000., _timeout_listener)
while not result and not timed_out[0]:
yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)
self._register_with_keys(listener[0])
result = yield callback()
defer.returnValue(result)
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
@ -224,6 +298,10 @@ class Notifier(object):
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = yield self.event_sources.get_current_token()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
user.to_string()
)
listener = _NotificationListener( listener = _NotificationListener(
user, user,
rooms, rooms,
@ -231,6 +309,7 @@ class Notifier(object):
limit, limit,
timeout, timeout,
deferred, deferred,
appservice=appservice
) )
def _timeout_listener(): def _timeout_listener():
@ -258,11 +337,16 @@ class Notifier(object):
@log_function @log_function
def _register_with_keys(self, listener): def _register_with_keys(self, listener):
for room in listener.rooms: for room in listener.rooms:
s = self.rooms_to_listeners.setdefault(room, set()) s = self.room_to_listeners.setdefault(room, set())
s.add(listener) s.add(listener)
self.user_to_listeners.setdefault(listener.user, set()).add(listener) self.user_to_listeners.setdefault(listener.user, set()).add(listener)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _check_for_updates(self, listener): def _check_for_updates(self, listener):
@ -296,5 +380,5 @@ class Notifier(object):
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user, room_id):
new_listeners = self.user_to_listeners.get(user, set()) new_listeners = self.user_to_listeners.get(user, set())
listeners = self.rooms_to_listeners.setdefault(room_id, set()) listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners listeners |= new_listeners

427
synapse/push/__init__.py Normal file
View File

@ -0,0 +1,427 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
import synapse.util.async
import baserules
import logging
import simplejson as json
import re
logger = logging.getLogger(__name__)
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['dont-notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.profile_tag = profile_tag
self.user_name = user_name
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
self.pushkey = pushkey
self.pushkey_ts = pushkey_ts
self.data = data
self.last_token = last_token
self.last_success = last_success # not actually used
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since
self.alive = True
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@defer.inlineCallbacks
def _actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify'])
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
for r in rawrules:
r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions'])
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rawrules, user)
# get *our* member event for display name matching
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=None
)
my_display_name = None
room_member_count = 0
for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules:
if r['rule_id'] in enabled_map and not enabled_map[r['rule_id']]:
continue
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s" %
(r['rule_id'], r['user_name'])
)
continue
if matches:
defer.returnValue(actions)
defer.returnValue(Pusher.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges.
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r)
return r
def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % self._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % self._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == self.profile_tag
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = Pusher.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else:
return True
@defer.inlineCallbacks
def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases(
ev['room_id']
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=ev['user_id']
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)
@defer.inlineCallbacks
def start(self):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
# because we need the result to be reproduceable in case
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.last_token)
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
while self.alive:
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk
)
if not self.alive:
continue
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
def dispatch_push(self, p, tweaks):
"""
Overridden by implementing classes to actually deliver the notification
Args:
p: The event to notify for as a single event from the event stream
Returns: If the notification was delivered, an array containing any
pushkeys that were rejected by the push gateway.
False if the notification could not be delivered (ie.
should be retried).
"""
pass
def reset_badge_count(self):
pass
def presence_changed(self, state):
"""
We clear badge counts whenever a user's last_active time is bumped
This is by no means perfect but I think it's the best we can do
without read receipts.
"""
if 'last_active' in state.state:
last_active = state.state['last_active']
if last_active > self.last_last_active_time:
self.last_last_active_time = last_active
if self.has_unread:
logger.info("Resetting badge count for %s", self.user_name)
self.reset_badge_count()
self.has_unread = False
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val
def _tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
class PusherConfigException(Exception):
def __init__(self, msg):
super(PusherConfigException, self).__init__(msg)

209
synapse/push/baserules.py Normal file
View File

@ -0,0 +1,209 @@
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
def list_with_base_rules(rawrules, user_name):
ruleslist = []
# shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
for r in rawrules:
if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class:
ruleslist.extend(make_base_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
ruleslist.append(r)
while current_prio_class > 0:
ruleslist.extend(make_base_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
return ruleslist
def make_base_rules(user, kind):
rules = []
if kind == 'override':
rules = make_base_override_rules()
elif kind == 'underride':
rules = make_base_underride_rules(user)
elif kind == 'content':
rules = make_base_content_rules(user)
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
r['default'] = True # Deprecated, left for backwards compat
return rules
def make_base_content_rules(user):
return [
{
'rule_id': 'global/content/.m.rule.contains_user_name',
'conditions': [
{
'kind': 'event_match',
'key': 'content.body',
'pattern': user.localpart, # Matrix ID match
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default',
}, {
'set_tweak': 'highlight'
}
]
},
]
def make_base_override_rules():
return [
{
'rule_id': 'global/override/.m.rule.call',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.call.invite',
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'ring'
}, {
'set_tweak': 'highlight',
'value': 'false'
}
]
},
{
'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
{
'kind': 'event_match',
'key': 'content.msgtype',
'pattern': 'm.notice',
}
],
'actions': [
'dont_notify',
]
},
{
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight'
}
]
},
{
'rule_id': 'global/override/.m.rule.room_one_to_one',
'conditions': [
{
'kind': 'room_member_count',
'is': '2'
}
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': 'false'
}
]
}
]
def make_base_underride_rules(user):
return [
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern': user.to_string(),
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': 'false'
}
]
},
{
'rule_id': 'global/underride/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': 'false'
}
]
},
{
'rule_id': 'global/underride/.m.rule.message',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.message',
}
],
'actions': [
'notify', {
'set_tweak': 'highlight',
'value': 'false'
}
]
}
]

148
synapse/push/httppusher.py Normal file
View File

@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push import Pusher, PusherConfigException
from synapse.http.client import SimpleHttpClient
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
profile_tag,
user_name,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
data,
last_token,
last_success,
failing_since
)
if 'url' not in data:
raise PusherConfigException(
"'url' required in data for HTTP pusher"
)
self.url = data['url']
self.httpCli = SimpleHttpClient(self.hs)
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks):
# we probably do not want to push for every presence update
# (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones)
# Actually, presence events will not get this far now because we
# need to filter them out in the main Pusher code.
if 'event_id' not in event:
defer.returnValue(None)
ctx = yield self.get_context_for_event(event)
d = {
'notification': {
'id': event['event_id'],
'room_id': event['room_id'],
'type': event['type'],
'sender': event['user_id'],
'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing
# Just set the badge to 1 until we have read receipts
'unread': 1,
# 'missed_calls': 2
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
'tweaks': tweaks
}
]
}
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_name
if 'content' in event:
d['notification']['content'] = event['content']
if len(ctx['aliases']):
d['notification']['room_alias'] = ctx['aliases'][0]
if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
d['notification']['sender_display_name'] = ctx['sender_display_name']
if 'name' in ctx and len(ctx['name']) > 0:
d['notification']['room_name'] = ctx['name']
defer.returnValue(d)
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks):
notification_dict = yield self._build_notification_dict(event, tweaks)
if not notification_dict:
defer.returnValue([])
try:
resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
except:
logger.warn("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)
@defer.inlineCallbacks
def reset_badge_count(self):
d = {
'notification': {
'id': '',
'type': None,
'sender': '',
'counts': {
'unread': 0,
'missed_calls': 0
},
'devices': [
{
'app_id': self.app_id,
'pushkey': self.pushkey,
'pushkey_ts': long(self.pushkey_ts / 1000),
'data': self.data_minus_url,
}
]
}
}
try:
resp = yield self.httpCli.post_json_get_json(self.url, d)
except:
logger.exception("Failed to push %s ", self.url)
defer.returnValue(False)
rejected = []
if 'rejected' in resp:
rejected = resp['rejected']
defer.returnValue(rejected)

154
synapse/push/pusherpool.py Normal file
View File

@ -0,0 +1,154 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from httppusher import HttpPusher
from synapse.push import PusherConfigException
from syutil.jsonutil import encode_canonical_json
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
self.store = self.hs.get_datastore()
self.pushers = {}
self.last_pusher_started = -1
distributor = self.hs.get_distributor()
distributor.observe(
"user_presence_changed", self.user_presence_changed
)
@defer.inlineCallbacks
def user_presence_changed(self, user, state):
user_name = user.to_string()
# until we have read receipts, pushers use this to reset a user's
# badge counters to zero
for p in self.pushers.values():
if p.user_name == user_name:
yield p.presence_changed(state)
@defer.inlineCallbacks
def start(self):
pushers = yield self.store.get_all_pushers()
for p in pushers:
p['data'] = json.loads(p['data'])
self._start_pushers(pushers)
@defer.inlineCallbacks
def add_pusher(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, lang, data):
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
"user_name": user_name,
"kind": kind,
"profile_tag": profile_tag,
"app_id": app_id,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"pushkey": pushkey,
"pushkey_ts": self.hs.get_clock().time_msec(),
"lang": lang,
"data": data,
"last_token": None,
"last_success": None,
"failing_since": None
})
yield self._add_pusher_to_store(
user_name, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data
)
@defer.inlineCallbacks
def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id,
app_display_name, device_display_name,
pushkey, lang, data):
yield self.store.add_pusher(
user_name=user_name,
profile_tag=profile_tag,
kind=kind,
app_id=app_id,
app_display_name=app_display_name,
device_display_name=device_display_name,
pushkey=pushkey,
pushkey_ts=self.hs.get_clock().time_msec(),
lang=lang,
data=encode_canonical_json(data).decode("UTF-8"),
)
self._refresh_pusher((app_id, pushkey))
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
profile_tag=pusherdict['profile_tag'],
user_name=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['pushkey_ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
failing_since=pusherdict['failing_since']
)
else:
raise PusherConfigException(
"Unknown pusher type '%s' for user %s" %
(pusherdict['kind'], pusherdict['user_name'])
)
@defer.inlineCallbacks
def _refresh_pusher(self, app_id_pushkey):
p = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id_pushkey
)
p['data'] = json.loads(p['data'])
self._start_pushers([p])
def _start_pushers(self, pushers):
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
p = self._create_pusher(pusherdict)
if p:
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
p.start()
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey):
fullid = "%s:%s" % (app_id, pushkey)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)

View File

@ -0,0 +1,8 @@
PRIORITY_CLASS_MAP = {
'underride': 1,
'sender': 2,
'room': 3,
'content': 4,
'override': 5,
}
PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}

View File

@ -4,9 +4,9 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
REQUIREMENTS = { REQUIREMENTS = {
"syutil==0.0.2": ["syutil"], "syutil>=0.0.3": ["syutil"],
"matrix_angular_sdk==0.6.0": ["syweb==0.6.0"], "matrix_angular_sdk>=0.6.4": ["syweb>=0.6.4"],
"Twisted>=14.0.0": ["twisted>=14.0.0"], "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"], "service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"], "pyyaml": ["yaml"],
@ -16,9 +16,32 @@ REQUIREMENTS = {
"py-bcrypt": ["bcrypt"], "py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"], "frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"], "pillow": ["PIL"],
"pydenticon": ["pydenticon"],
} }
def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = [
github_link(
project="pyca/pynacl",
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
egg="pynacl-0.3.0",
),
github_link(
project="matrix-org/syutil",
version="v0.0.3",
egg="syutil-0.0.3",
),
github_link(
project="matrix-org/matrix-angular-sdk",
version="v0.6.4",
egg="matrix_angular_sdk-0.6.4",
),
]
class MissingRequirementError(Exception): class MissingRequirementError(Exception):
pass pass
@ -78,3 +101,24 @@ def check_requirements():
"Unexpected version of %r in %r. %r != %r" "Unexpected version of %r in %r. %r != %r"
% (dependency, file_path, version, required_version) % (dependency, file_path, version, required_version)
) )
def list_requirements():
result = []
linked = []
for link in DEPENDENCY_LINKS:
egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0])
result.append(link)
for requirement in REQUIREMENTS:
is_linked = False
for link in linked:
if requirement.replace('-', '_').startswith(link):
is_linked = True
if not is_linked:
result.append(requirement)
return result
if __name__ == "__main__":
import sys
sys.stdout.writelines(req + "\n" for req in list_requirements())

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd # Copyright 2015 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,36 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin,
)
class RestServletFactory(object):
""" A factory for creating REST servlets.
These REST servlets represent the entire client-server REST API. Generally
speaking, they serve as wrappers around events and the handlers that
process them.
See synapse.events for information on synapse events.
"""
def __init__(self, hs):
client_resource = hs.get_resource_for_client()
# TODO(erikj): There *must* be a better way of doing this.
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import register
from synapse.http.server import JsonResource
class AppServiceRestResource(JsonResource):
"""A resource for version 1 of the matrix application service API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(appservice_resource, hs):
register.register_servlets(hs, appservice_resource)

View File

@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import APP_SERVICE_PREFIX
import re
import logging
logger = logging.getLogger(__name__)
def as_path_pattern(path_regex):
"""Creates a regex compiled appservice path with the correct path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
class AppServiceRestServlet(RestServlet):
"""A base Synapse REST Servlet for the application services version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handler = hs.get_handlers().appservice_handler

View File

@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains REST servlets to do with registration: /register"""
from twisted.internet import defer
from base import AppServiceRestServlet, as_path_pattern
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.storage.appservice import ApplicationService
import json
import logging
logger = logging.getLogger(__name__)
class RegisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/register$")
@defer.inlineCallbacks
def on_POST(self, request):
params = _parse_json(request)
# sanity check required params
try:
as_token = params["as_token"]
as_url = params["url"]
if (not isinstance(as_token, basestring) or
not isinstance(as_url, basestring)):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(
400, "Missed required keys: as_token(str) / url(str)."
)
try:
app_service = ApplicationService(
as_token, as_url, params["namespaces"]
)
except ValueError as e:
raise SynapseError(400, e.message)
app_service = yield self.handler.register(app_service)
hs_token = app_service.hs_token
defer.returnValue((200, {
"hs_token": hs_token
}))
class UnregisterRestServlet(AppServiceRestServlet):
"""Handles AS registration with the home server.
"""
PATTERN = as_path_pattern("/unregister$")
def on_POST(self, request):
params = _parse_json(request)
try:
as_token = params["as_token"]
if not isinstance(as_token, basestring):
raise ValueError
except (KeyError, ValueError):
raise SynapseError(400, "Missing required key: as_token(str)")
yield self.handler.unregister(as_token)
raise CodeMessageException(500, "Not implemented")
def _parse_json(request):
try:
content = json.loads(request.content.read())
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
except ValueError:
raise SynapseError(400, "Content not JSON.")
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
UnregisterRestServlet(hs).register(http_server)

View File

@ -1,80 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
from synapse.api.urls import CLIENT_PREFIX
from synapse.rest.transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_PREFIX + path_regex)
class RestServlet(object):
""" A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.
To use this latter, the implementing class instead provides a `PATTERN`
class attribute containing a pre-compiled regular expression. The automatic
register method will then use this method to register any of the following
instance methods associated with the corresponding HTTP method:
on_GET
on_PUT
on_POST
on_DELETE
on_OPTIONS
Automatically handles turning CodeMessageExceptions thrown by these methods
into the appropriate HTTP response.
"""
def __init__(self, hs):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()
def register(self, http_server):
""" Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method)):
method_handler = getattr(self, "on_%s" % (method))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")

View File

@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin, pusher, push_rule
)
from synapse.http.server import JsonResource
class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)

View File

@ -16,20 +16,22 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from base import RestServlet, client_path_pattern from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class WhoisRestServlet(RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)") PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
target_user = self.hs.parse_userid(user_id) target_user = UserID.from_string(user_id)
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user) is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user: if not is_admin and target_user != auth_user:

View File

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base REST classes for constructing client v1 servlets.
"""
from synapse.http.servlet import RestServlet
from synapse.api.urls import CLIENT_PREFIX
from .transactions import HttpTransactionStore
import re
import logging
logger = logging.getLogger(__name__)
def client_path_pattern(path_regex):
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
as this will be prefixed.
Returns:
SRE_Pattern
"""
return re.compile("^" + CLIENT_PREFIX + path_regex)
class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
def __init__(self, hs):
self.hs = hs
self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionStore()

View File

@ -17,9 +17,10 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from base import RestServlet, client_path_pattern from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_pattern
import json import simplejson as json
import logging import logging
@ -30,12 +31,12 @@ def register_servlets(hs, http_server):
ClientDirectoryServer(hs).register(http_server) ClientDirectoryServer(hs).register(http_server)
class ClientDirectoryServer(RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$") PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):
room_alias = self.hs.parse_roomalias(room_alias) room_alias = RoomAlias.from_string(room_alias)
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias) res = yield dir_handler.get_association(room_alias)
@ -44,16 +45,14 @@ class ClientDirectoryServer(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, room_alias): def on_PUT(self, request, room_alias):
user = yield self.auth.get_user_by_req(request)
content = _parse_json(request) content = _parse_json(request)
if not "room_id" in content: if "room_id" not in content:
raise SynapseError(400, "Missing room_id key", raise SynapseError(400, "Missing room_id key",
errcode=Codes.BAD_JSON) errcode=Codes.BAD_JSON)
logger.debug("Got content: %s", content) logger.debug("Got content: %s", content)
room_alias = self.hs.parse_roomalias(room_alias) room_alias = RoomAlias.from_string(room_alias)
logger.debug("Got room name: %s", room_alias.to_string()) logger.debug("Got room name: %s", room_alias.to_string())
@ -69,34 +68,70 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler dir_handler = self.handlers.directory_handler
try: try:
user_id = user.to_string() # try to auth as a user
yield dir_handler.create_association( user, client = yield self.auth.get_user_by_req(request)
user_id, room_alias, room_id, servers try:
user_id = user.to_string()
yield dir_handler.create_association(
user_id, room_alias, room_id, servers
)
yield dir_handler.send_room_alias_update_event(user_id, room_id)
except SynapseError as e:
raise e
except:
logger.exception("Failed to create association")
raise
except AuthError:
# try to auth as an application service
service = yield self.auth.get_appservice_by_req(request)
yield dir_handler.create_appservice_association(
service, room_alias, room_id, servers
)
logger.info(
"Application service at %s created alias %s pointing to %s",
service.url,
room_alias.to_string(),
room_id
) )
yield dir_handler.send_room_alias_update_event(user_id, room_id)
except SynapseError as e:
raise e
except:
logger.exception("Failed to create association")
raise
defer.returnValue((200, {})) defer.returnValue((200, {}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_DELETE(self, request, room_alias): def on_DELETE(self, request, room_alias):
user = yield self.auth.get_user_by_req(request) dir_handler = self.handlers.directory_handler
try:
service = yield self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_appservice_association(
service, room_alias
)
logger.info(
"Application service at %s deleted alias %s",
service.url,
room_alias.to_string()
)
defer.returnValue((200, {}))
except AuthError:
# fallback to default user behaviour if they aren't an AS
pass
user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user) is_admin = yield self.auth.is_server_admin(user)
if not is_admin: if not is_admin:
raise AuthError(403, "You need to be a server admin") raise AuthError(403, "You need to be a server admin")
dir_handler = self.handlers.directory_handler room_alias = RoomAlias.from_string(room_alias)
room_alias = self.hs.parse_roomalias(room_alias)
yield dir_handler.delete_association( yield dir_handler.delete_association(
user.to_string(), room_alias user.to_string(), room_alias
) )
logger.info(
"User %s deleted alias %s",
user.to_string(),
room_alias.to_string()
)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -18,7 +18,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.rest.base import RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_pattern
from synapse.events.utils import serialize_event
import logging import logging
@ -26,14 +27,14 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class EventStreamRestServlet(RestServlet): class EventStreamRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events$") PATTERN = client_path_pattern("/events$")
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -61,17 +62,22 @@ class EventStreamRestServlet(RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events. # TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(RestServlet): class EventRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$")
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, event_id): def on_GET(self, request, event_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id) event = yield handler.get_event(auth_user, event_id)
time_now = self.clock.time_msec()
if event: if event:
defer.returnValue((200, self.hs.serialize_event(event))) defer.returnValue((200, serialize_event(event, time_now)))
else: else:
defer.returnValue((404, "Event not found.")) defer.returnValue((404, "Event not found."))

View File

@ -16,16 +16,16 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from base import RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/initialSync$") PATTERN = client_path_pattern("/initialSync$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
user = yield self.auth.get_user_by_req(request) user, client = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request) pagination_config = PaginationConfig.from_request(request)

View File

@ -17,12 +17,12 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID from synapse.types import UserID
from base import RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_pattern
import json import simplejson as json
class LoginRestServlet(RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
@ -64,7 +64,7 @@ class LoginRestServlet(RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
class LoginFallbackRestServlet(RestServlet): class LoginFallbackRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/fallback$") PATTERN = client_path_pattern("/login/fallback$")
def on_GET(self, request): def on_GET(self, request):
@ -73,7 +73,7 @@ class LoginFallbackRestServlet(RestServlet):
return (200, {}) return (200, {})
class PasswordResetRestServlet(RestServlet): class PasswordResetRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/reset") PATTERN = client_path_pattern("/login/reset")
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -18,21 +18,22 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from base import RestServlet, client_path_pattern from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_pattern
import json import simplejson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(RestServlet): class PresenceStatusRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state( state = yield self.handlers.presence_handler.get_state(
target_user=user, auth_user=auth_user) target_user=user, auth_user=auth_user)
@ -41,8 +42,8 @@ class PresenceStatusRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
state = {} state = {}
try: try:
@ -71,13 +72,13 @@ class PresenceStatusRestServlet(RestServlet):
return (200, {}) return (200, {})
class PresenceListRestServlet(RestServlet): class PresenceListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)") PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server") raise SynapseError(400, "User not hosted on this Home Server")
@ -96,8 +97,8 @@ class PresenceListRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, user_id): def on_POST(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server") raise SynapseError(400, "User not hosted on this Home Server")
@ -118,7 +119,7 @@ class PresenceListRestServlet(RestServlet):
raise SynapseError(400, "Bad invite value.") raise SynapseError(400, "Bad invite value.")
if len(u) == 0: if len(u) == 0:
continue continue
invited_user = self.hs.parse_userid(u) invited_user = UserID.from_string(u)
yield self.handlers.presence_handler.send_invite( yield self.handlers.presence_handler.send_invite(
observer_user=user, observed_user=invited_user observer_user=user, observed_user=invited_user
) )
@ -129,7 +130,7 @@ class PresenceListRestServlet(RestServlet):
raise SynapseError(400, "Bad drop value.") raise SynapseError(400, "Bad drop value.")
if len(u) == 0: if len(u) == 0:
continue continue
dropped_user = self.hs.parse_userid(u) dropped_user = UserID.from_string(u)
yield self.handlers.presence_handler.drop( yield self.handlers.presence_handler.drop(
observer_user=user, observed_user=dropped_user observer_user=user, observed_user=dropped_user
) )

View File

@ -16,17 +16,18 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer from twisted.internet import defer
from base import RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_pattern
from synapse.types import UserID
import json import simplejson as json
class ProfileDisplaynameRestServlet(RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.handlers.profile_handler.get_displayname(
user, user,
@ -36,8 +37,8 @@ class ProfileDisplaynameRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -54,12 +55,12 @@ class ProfileDisplaynameRestServlet(RestServlet):
return (200, {}) return (200, {})
class ProfileAvatarURLRestServlet(RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
avatar_url = yield self.handlers.profile_handler.get_avatar_url( avatar_url = yield self.handlers.profile_handler.get_avatar_url(
user, user,
@ -69,8 +70,8 @@ class ProfileAvatarURLRestServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
auth_user = yield self.auth.get_user_by_req(request) auth_user, client = yield self.auth.get_user_by_req(request)
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -87,12 +88,12 @@ class ProfileAvatarURLRestServlet(RestServlet):
return (200, {}) return (200, {})
class ProfileRestServlet(RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)") PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
user = self.hs.parse_userid(user_id) user = UserID.from_string(user_id)
displayname = yield self.handlers.profile_handler.get_displayname( displayname = yield self.handlers.profile_handler.get_displayname(
user, user,

View File

@ -0,0 +1,456 @@
# -*- coding: utf-8 -*-
# Copyright 2014 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
)
from .base import ClientV1RestServlet, client_path_pattern
from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException
)
import synapse.push.baserules as baserules
from synapse.push.rulekinds import (
PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
)
import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
@defer.inlineCallbacks
def on_PUT(self, request):
spec = _rule_spec_from_path(request.postpath)
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
user, _ = yield self.auth.get_user_by_req(request)
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
raise SynapseError(400, "rule_id may not contain slashes")
content = _parse_json(request)
if 'attr' in spec:
self.set_rule_attr(user.to_string(), spec, content)
defer.returnValue((200, {}))
try:
(conditions, actions) = _rule_tuple_from_request_object(
spec['template'],
spec['rule_id'],
content,
device=spec['device'] if 'device' in spec else None
)
except InvalidRuleException as e:
raise SynapseError(400, e.message)
before = request.args.get("before", None)
if before and len(before):
before = before[0]
after = request.args.get("after", None)
if after and len(after):
after = after[0]
try:
yield self.hs.get_datastore().add_push_rule(
user_name=user.to_string(),
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
after=after
)
except InconsistentRuleException as e:
raise SynapseError(400, e.message)
except RuleNotFoundException as e:
raise SynapseError(400, e.message)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_DELETE(self, request):
spec = _rule_spec_from_path(request.postpath)
user, _ = yield self.auth.get_user_by_req(request)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
yield self.hs.get_datastore().delete_push_rule(
user.to_string(), namespaced_rule_id
)
defer.returnValue((200, {}))
except StoreError as e:
if e.code == 404:
raise NotFoundError()
else:
raise
@defer.inlineCallbacks
def on_GET(self, request):
user, _ = yield self.auth.get_user_by_req(request)
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
user.to_string()
)
for r in rawrules:
r["conditions"] = json.loads(r["conditions"])
r["actions"] = json.loads(r["actions"])
ruleslist = baserules.list_with_base_rules(rawrules, user)
rules = {'global': {}, 'device': {}}
rules['global'] = _add_empty_priority_class_arrays(rules['global'])
enabled_map = yield self.hs.get_datastore().\
get_push_rules_enabled_for_user(user.to_string())
for r in ruleslist:
rulearray = None
template_name = _priority_class_to_template_name(r['priority_class'])
if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
# per-device rule
profile_tag = _profile_tag_from_conditions(r["conditions"])
r = _strip_device_condition(r)
if not profile_tag:
continue
if profile_tag not in rules['device']:
rules['device'][profile_tag] = {}
rules['device'][profile_tag] = (
_add_empty_priority_class_arrays(
rules['device'][profile_tag]
)
)
rulearray = rules['device'][profile_tag][template_name]
else:
rulearray = rules['global'][template_name]
template_rule = _rule_to_template(r)
if template_rule:
template_rule['enabled'] = True
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
rulearray.append(template_rule)
path = request.postpath[1:]
if path == []:
# we're a reference impl: pedantry is our job.
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules))
elif path[0] == 'global':
path = path[1:]
result = _filter_ruleset_with_path(rules['global'], path)
defer.returnValue((200, result))
elif path[0] == 'device':
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
defer.returnValue((200, rules['device']))
profile_tag = path[0]
path = path[1:]
if profile_tag not in rules['device']:
ret = {}
ret = _add_empty_priority_class_arrays(ret)
defer.returnValue((200, ret))
ruleset = rules['device'][profile_tag]
result = _filter_ruleset_with_path(ruleset, path)
defer.returnValue((200, result))
else:
raise UnrecognizedRequestError()
def on_OPTIONS(self, _):
return 200, {}
def set_rule_attr(self, user_name, spec, val):
if spec['attr'] == 'enabled':
if not isinstance(val, bool):
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled(
user_name, namespaced_rule_id, val
)
else:
raise UnrecognizedRequestError()
def get_rule_attr(self, user_name, namespaced_rule_id, attr):
if attr == 'enabled':
return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
user_name, namespaced_rule_id
)
else:
raise UnrecognizedRequestError()
def _rule_spec_from_path(path):
if len(path) < 2:
raise UnrecognizedRequestError()
if path[0] != 'pushrules':
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
if scope not in ['global', 'device']:
raise UnrecognizedRequestError()
device = None
if scope == 'device':
if len(path) == 0:
raise UnrecognizedRequestError()
device = path[0]
path = path[1:]
if len(path) == 0:
raise UnrecognizedRequestError()
template = path[0]
path = path[1:]
if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError()
rule_id = path[0]
spec = {
'scope': scope,
'template': template,
'rule_id': rule_id
}
if device:
spec['profile_tag'] = device
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
spec['attr'] = path[0]
return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None):
if rule_template in ['override', 'underride']:
if 'conditions' not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
conditions = req_obj['conditions']
for c in conditions:
if 'kind' not in c:
raise InvalidRuleException("Condition without 'kind'")
elif rule_template == 'room':
conditions = [{
'kind': 'event_match',
'key': 'room_id',
'pattern': rule_id
}]
elif rule_template == 'sender':
conditions = [{
'kind': 'event_match',
'key': 'user_id',
'pattern': rule_id
}]
elif rule_template == 'content':
if 'pattern' not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
pat = req_obj['pattern']
conditions = [{
'kind': 'event_match',
'key': 'content.body',
'pattern': pat
}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
if device:
conditions.append({
'kind': 'device',
'profile_tag': device
})
if 'actions' not in req_obj:
raise InvalidRuleException("No actions found")
actions = req_obj['actions']
for a in actions:
if a in ['notify', 'dont_notify', 'coalesce']:
pass
elif isinstance(a, dict) and 'set_tweak' in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
return conditions, actions
def _add_empty_priority_class_arrays(d):
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
def _profile_tag_from_conditions(conditions):
"""
Given a list of conditions, return the profile tag of the
device rule if there is one
"""
for c in conditions:
if c['kind'] == 'device':
return c['profile_tag']
return None
def _filter_ruleset_with_path(ruleset, path):
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset
template_kind = path[0]
if template_kind not in ruleset:
raise UnrecognizedRequestError()
path = path[1:]
if path == []:
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
if path[0] == '':
return ruleset[template_kind]
rule_id = path[0]
the_rule = None
for r in ruleset[template_kind]:
if r['rule_id'] == rule_id:
the_rule = r
if the_rule is None:
raise NotFoundError
path = path[1:]
if len(path) == 0:
return the_rule
attr = path[0]
if attr in the_rule:
return the_rule[attr]
else:
raise UnrecognizedRequestError()
def _priority_class_from_spec(spec):
if spec['template'] not in PRIORITY_CLASS_MAP.keys():
raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
pc = PRIORITY_CLASS_MAP[spec['template']]
if spec['scope'] == 'device':
pc += len(PRIORITY_CLASS_MAP)
return pc
def _priority_class_to_template_name(pc):
if pc > PRIORITY_CLASS_MAP['override']:
# per-device
prio_class_index = pc - len(PRIORITY_CLASS_MAP)
return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
else:
return PRIORITY_CLASS_INVERSE_MAP[pc]
def _rule_to_template(rule):
unscoped_rule_id = None
if 'rule_id' in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
template_name = _priority_class_to_template_name(rule['priority_class'])
if template_name in ['override', 'underride']:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
templaterule = {'actions': rule['actions']}
unscoped_rule_id = rule['conditions'][0]['pattern']
elif template_name == 'content':
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
templaterule = {'actions': rule['actions']}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
templaterule['rule_id'] = unscoped_rule_id
if 'default' in rule:
templaterule['default'] = rule['default']
return templaterule
def _strip_device_condition(rule):
for i, c in enumerate(rule['conditions']):
if c['kind'] == 'device':
del rule['conditions'][i]
return rule
def _namespaced_rule_id_from_spec(spec):
if spec['scope'] == 'global':
scope = 'global'
else:
scope = 'device/%s' % (spec['profile_tag'])
return "%s/%s/%s" % (scope, spec['template'], spec['rule_id'])
def _rule_id_from_namespaced(in_rule_id):
return in_rule_id.split('/')[-1]
class InvalidRuleException(Exception):
pass
# XXX: C+ped from rest/room.py - surely this should be common?
def _parse_json(request):
try:
content = json.loads(request.content.read())
return content
except ValueError:
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
def register_servlets(hs, http_server):
PushRuleRestServlet(hs).register(http_server)

Some files were not shown because too many files have changed in this diff Show More