Merge branch 'develop' into paul/tiny-fixes

This commit is contained in:
Paul "LeoNerd" Evans 2015-12-10 16:21:00 +00:00
commit d7ee7b589f
95 changed files with 2504 additions and 911 deletions

View File

@ -48,3 +48,6 @@ Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
Steven Hammerton <steven.hammerton at openmarket.com> Steven Hammerton <steven.hammerton at openmarket.com>
* Add CAS support for registration and login. * Add CAS support for registration and login.
Mads Robin Christensen <mads at v42 dot dk>
* CentOS 7 installation instructions.

View File

@ -1,3 +1,42 @@
Changes in synapse v0.11.1 (2015-11-20)
=======================================
* Add extra options to search API (PR #394)
* Fix bug where we did not correctly cap federation retry timers. This meant it
could take several hours for servers to start talking to ressurected servers,
even when they were receiving traffic from them (PR #393)
* Don't advertise login token flow unless CAS is enabled. This caused issues
where some clients would always use the fallback API if they did not
recognize all login flows (PR #391)
* Change /v2 sync API to rename ``private_user_data`` to ``account_data``
(PR #386)
* Change /v2 sync API to remove the ``event_map`` and rename keys in ``rooms``
object (PR #389)
Changes in synapse v0.11.0-r2 (2015-11-19)
==========================================
* Fix bug in database port script (PR #387)
Changes in synapse v0.11.0-r1 (2015-11-18)
==========================================
* Retry and fail federation requests more aggressively for requests that block
client side requests (PR #384)
Changes in synapse v0.11.0 (2015-11-17)
=======================================
* Change CAS login API (PR #349)
Changes in synapse v0.11.0-rc2 (2015-11-13)
===========================================
* Various changes to /sync API response format (PR #373)
* Fix regression when setting display name in newly joined room over
federation (PR #368)
* Fix problem where /search was slow when using SQLite (PR #366)
Changes in synapse v0.11.0-rc1 (2015-11-11) Changes in synapse v0.11.0-rc1 (2015-11-11)
=========================================== ===========================================

View File

@ -20,4 +20,6 @@ recursive-include synapse/static *.gif
recursive-include synapse/static *.html recursive-include synapse/static *.html
recursive-include synapse/static *.js recursive-include synapse/static *.js
exclude jenkins.sh
prune demo/etc prune demo/etc

View File

@ -111,6 +111,14 @@ Installing prerequisites on ArchLinux::
sudo pacman -S base-devel python2 python-pip \ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3 python-setuptools python-virtualenv sqlite3
Installing prerequisites on CentOS 7::
sudo yum install libtiff-devel libjpeg-devel libzip-devel freetype-devel \
lcms2-devel libwebp-devel tcl-devel tk-devel \
python-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools"
Installing prerequisites on Mac OS X:: Installing prerequisites on Mac OS X::
xcode-select --install xcode-select --install
@ -133,15 +141,21 @@ In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/. above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
Another alternative is to install via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with
https://github.com/matrix-org/matrix-js-sdk/).
To set up your homeserver, run (in your virtualenv, as before):: To set up your homeserver, run (in your virtualenv, as before)::
cd ~/.synapse cd ~/.synapse
python -m synapse.app.homeserver \ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \ --server-name machine.my.domain.name \
--config-path homeserver.yaml \ --config-path homeserver.yaml \
--generate-config --generate-config \
--report-stats=[yes|no]
Substituting your host and domain name as appropriate. ...substituting your host and domain name as appropriate.
This will generate you a config file that you can then customise, but it will This will generate you a config file that you can then customise, but it will
also generate a set of keys for you. These keys will allow your Home Server to also generate a set of keys for you. These keys will allow your Home Server to
@ -154,10 +168,11 @@ key in the <server name>.signing.key file (the second word, which by default is
By default, registration of new users is disabled. You can either enable By default, registration of new users is disabled. You can either enable
registration in the config by specifying ``enable_registration: true`` registration in the config by specifying ``enable_registration: true``
(it is then recommended to also set up CAPTCHA), or (it is then recommended to also set up CAPTCHA - see docs/CAPTCHA_SETUP), or
you can use the command line to register new users:: you can use the command line to register new users::
$ source ~/.synapse/bin/activate $ source ~/.synapse/bin/activate
$ synctl start # if not already running
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448 $ register_new_matrix_user -c homeserver.yaml https://localhost:8448
New user localpart: erikj New user localpart: erikj
Password: Password:
@ -167,6 +182,16 @@ you can use the command line to register new users::
For reliable VoIP calls to be routed via this homeserver, you MUST configure For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details. a TURN server. See docs/turn-howto.rst for details.
Running Synapse
===============
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. ``~/.synapse``), and::
cd ~/.synapse
source ./bin/activate
synctl start
Using PostgreSQL Using PostgreSQL
================ ================
@ -189,16 +214,6 @@ may have a few regressions relative to SQLite.
For information on how to install and use PostgreSQL, please see For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_. `docs/postgres.rst <docs/postgres.rst>`_.
Running Synapse
===============
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. ``~/.synapse``), and::
cd ~/.synapse
source ./bin/activate
synctl start
Platform Specific Instructions Platform Specific Instructions
============================== ==============================
@ -425,6 +440,10 @@ SRV record, as that is the name other machines will expect it to have::
python -m synapse.app.homeserver --config-path homeserver.yaml python -m synapse.app.homeserver --config-path homeserver.yaml
If you've already generated the config file, you need to edit the "server_name"
in you ```homeserver.yaml``` file. If you've already started Synapse and a
database has been created, you will have to recreate the database.
You may additionally want to pass one or more "-v" options, in order to You may additionally want to pass one or more "-v" options, in order to
increase the verbosity of logging output; at least for initial testing. increase the verbosity of logging output; at least for initial testing.

View File

@ -30,6 +30,19 @@ running:
python synapse/python_dependencies.py | xargs -n1 pip install python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.11.0
====================
This release includes the option to send anonymous usage stats to matrix.org,
and requires that administrators explictly opt in or out by setting the
``report_stats`` option to either ``true`` or ``false``.
We would really appreciate it if you could help our project out by reporting
anonymized usage statistics from your homeserver. Only very basic aggregate
data (e.g. number of users) will be reported, but it helps us to track the
growth of the Matrix community, and helps us to make Matrix a success, as well
as to convince other networks that they should peer with us.
Upgrading to v0.9.0 Upgrading to v0.9.0
=================== ===================

View File

@ -18,8 +18,8 @@ encoding use, e.g.::
This would create an appropriate database named ``synapse`` owned by the This would create an appropriate database named ``synapse`` owned by the
``synapse_user`` user (which must already exist). ``synapse_user`` user (which must already exist).
Set up client Set up client in Debian/Ubuntu
============= ===========================
Postgres support depends on the postgres python connector ``psycopg2``. In the Postgres support depends on the postgres python connector ``psycopg2``. In the
virtual env:: virtual env::
@ -27,6 +27,19 @@ virtual env::
sudo apt-get install libpq-dev sudo apt-get install libpq-dev
pip install psycopg2 pip install psycopg2
Set up client in RHEL/CentOs 7
==============================
Make sure you have the appropriate version of postgres-devel installed. For a
postgres 9.4, use the postgres 9.4 packages from
[here](https://wiki.postgresql.org/wiki/YUM_Installation).
As with Debian/Ubuntu, postgres support depends on the postgres python connector
``psycopg2``. In the virtual env::
sudo yum install postgresql-devel libpqxx-devel.x86_64
export PATH=/usr/pgsql-9.4/bin/:$PATH
pip install psycopg2
Synapse config Synapse config
============== ==============

70
jenkins.sh Executable file
View File

@ -0,0 +1,70 @@
#!/bin/bash -eu
export PYTHONDONTWRITEBYTECODE=yep
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Output coverage to coverage.xml
export DUMP_COVERAGE_COMMAND="coverage xml -o coverage.xml"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
tox
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
set +u
. .tox/py27/bin/activate
set -u
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8000}
echo >&2 "Running sytest with SQLite3";
./run-tests.pl -O tap --synapse-directory .. --all --port-base $PORT_BASE > results-sqlite3.tap
RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES=$RUN_POSTGRES:$port
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
database: synapse_jenkins_$port
EOF
fi
done
# Run if both postgresql databases exist
if test $RUN_POSTGRES = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
pip install psycopg2
./run-tests.pl -O tap --synapse-directory .. --all --port-base $PORT_BASE > results-postgresql.tap
else
echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
fi

View File

@ -79,16 +79,16 @@ def defined_names(prefix, defs, names):
defined_names(prefix + name + ".", funcs, names) defined_names(prefix + name + ".", funcs, names)
def used_names(prefix, defs, names): def used_names(prefix, item, defs, names):
for name, funcs in defs.get('def', {}).items(): for name, funcs in defs.get('def', {}).items():
used_names(prefix + name + ".", funcs, names) used_names(prefix + name + ".", name, funcs, names)
for name, funcs in defs.get('class', {}).items(): for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", funcs, names) used_names(prefix + name + ".", name, funcs, names)
for used in defs.get('uses', ()): for used in defs.get('uses', ()):
if used in names: if used in names:
names[used].setdefault('used', []).append(prefix.rstrip('.')) names[used].setdefault('used', {}).setdefault(item, []).append(prefix.rstrip('.'))
if __name__ == '__main__': if __name__ == '__main__':
@ -109,6 +109,14 @@ if __name__ == '__main__':
"directories", nargs='+', metavar="DIR", "directories", nargs='+', metavar="DIR",
help="Directories to search for definitions" help="Directories to search for definitions"
) )
parser.add_argument(
"--referrers", default=0, type=int,
help="Include referrers up to the given depth"
)
parser.add_argument(
"--format", default="yaml",
help="Output format, one of 'yaml' or 'dot'"
)
args = parser.parse_args() args = parser.parse_args()
definitions = {} definitions = {}
@ -124,7 +132,7 @@ if __name__ == '__main__':
defined_names(filepath + ":", defs, names) defined_names(filepath + ":", defs, names)
for filepath, defs in definitions.items(): for filepath, defs in definitions.items():
used_names(filepath + ":", defs, names) used_names(filepath + ":", None, defs, names)
patterns = [re.compile(pattern) for pattern in args.pattern or ()] patterns = [re.compile(pattern) for pattern in args.pattern or ()]
ignore = [re.compile(pattern) for pattern in args.ignore or ()] ignore = [re.compile(pattern) for pattern in args.ignore or ()]
@ -139,4 +147,29 @@ if __name__ == '__main__':
continue continue
result[name] = definition result[name] = definition
referrer_depth = args.referrers
referrers = set()
while referrer_depth:
referrer_depth -= 1
for entry in result.values():
for used_by in entry.get("used", ()):
referrers.add(used_by)
for name, definition in names.items():
if not name in referrers:
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
result[name] = definition
if args.format == 'yaml':
yaml.dump(result, sys.stdout, default_flow_style=False) yaml.dump(result, sys.stdout, default_flow_style=False)
elif args.format == 'dot':
print "digraph {"
for name, entry in result.items():
print name
for used_by in entry.get("used", ()):
if used_by in result:
print used_by, "->", name
print "}"
else:
raise ValueError("Unknown format %r" % (args.format))

View File

@ -68,6 +68,7 @@ APPEND_ONLY_TABLES = [
"state_groups_state", "state_groups_state",
"event_to_state_groups", "event_to_state_groups",
"rejections", "rejections",
"event_search",
] ]
@ -229,6 +230,38 @@ class Porter(object):
if rows: if rows:
next_chunk = rows[-1][0] + 1 next_chunk = rows[-1][0] + 1
if table == "event_search":
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert(txn):
sql = (
"INSERT INTO event_search (event_id, room_id, key, sender, vector)"
" VALUES (?,?,?,?,to_tsvector('english', ?))"
)
rows_dict = [
dict(zip(headers, row))
for row in rows
]
txn.executemany(sql, [
(
row["event_id"],
row["room_id"],
row["key"],
row["sender"],
row["value"],
)
for row in rows_dict
])
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
else:
self._convert_rows(table, headers, rows) self._convert_rows(table, headers, rows)
def insert(txn): def insert(txn):

View File

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server. """ This is a reference implementation of a Matrix home server.
""" """
__version__ = "0.11.0-rc1" __version__ = "0.11.1"

View File

@ -207,6 +207,13 @@ class Auth(object):
user_id, room_id user_id, room_id
)) ))
if membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if forgot:
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
defer.returnValue(member) defer.returnValue(member)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -587,7 +594,7 @@ class Auth(object):
def _get_user_from_macaroon(self, macaroon_str): def _get_user_from_macaroon(self, macaroon_str):
try: try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str) macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self._validate_macaroon(macaroon) self.validate_macaroon(macaroon, "access", False)
user_prefix = "user_id = " user_prefix = "user_id = "
user = None user = None
@ -635,13 +642,27 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
def _validate_macaroon(self, macaroon): def validate_macaroon(self, macaroon, type_string, verify_expiry):
"""
validate that a Macaroon is understood by and was signed by this server.
Args:
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token this is (e.g. "access", "refresh")
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
"""
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1") v.satisfy_exact("gen = 1")
v.satisfy_exact("type = access") v.satisfy_exact("type = " + type_string)
v.satisfy_general(lambda c: c.startswith("user_id = ")) v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry)
v.satisfy_exact("guest = true") v.satisfy_exact("guest = true")
if verify_expiry:
v.satisfy_general(self._verify_expiry)
else:
v.satisfy_general(lambda c: c.startswith("time < "))
v.verify(macaroon, self.hs.config.macaroon_secret_key) v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier() v = pymacaroons.Verifier()
@ -652,9 +673,6 @@ class Auth(object):
prefix = "time < " prefix = "time < "
if not caveat.startswith(prefix): if not caveat.startswith(prefix):
return False return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):]) expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
@ -842,7 +860,7 @@ class Auth(object):
redact_level = self._get_named_level(auth_events, "redact", 50) redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level > redact_level: if user_level >= redact_level:
return False return False
redacter_domain = EventID.from_string(event.event_id).domain redacter_domain = EventID.from_string(event.event_id).domain

View File

@ -50,11 +50,11 @@ class Filtering(object):
# many definitions. # many definitions.
top_level_definitions = [ top_level_definitions = [
"presence" "presence", "account_data"
] ]
room_level_definitions = [ room_level_definitions = [
"state", "timeline", "ephemeral", "private_user_data" "state", "timeline", "ephemeral", "account_data"
] ]
for key in top_level_definitions: for key in top_level_definitions:
@ -131,14 +131,22 @@ class FilterCollection(object):
self.filter_json.get("room", {}).get("ephemeral", {}) self.filter_json.get("room", {}).get("ephemeral", {})
) )
self.room_private_user_data = Filter( self.room_account_data = Filter(
self.filter_json.get("room", {}).get("private_user_data", {}) self.filter_json.get("room", {}).get("account_data", {})
) )
self.presence_filter = Filter( self.presence_filter = Filter(
self.filter_json.get("presence", {}) self.filter_json.get("presence", {})
) )
self.account_data = Filter(
self.filter_json.get("account_data", {})
)
self.include_leave = self.filter_json.get("room", {}).get(
"include_leave", False
)
def timeline_limit(self): def timeline_limit(self):
return self.room_timeline_filter.limit() return self.room_timeline_filter.limit()
@ -151,6 +159,9 @@ class FilterCollection(object):
def filter_presence(self, events): def filter_presence(self, events):
return self.presence_filter.filter(events) return self.presence_filter.filter(events)
def filter_account_data(self, events):
return self.account_data.filter(events)
def filter_room_state(self, events): def filter_room_state(self, events):
return self.room_state_filter.filter(events) return self.room_state_filter.filter(events)
@ -160,8 +171,8 @@ class FilterCollection(object):
def filter_room_ephemeral(self, events): def filter_room_ephemeral(self, events):
return self.room_ephemeral_filter.filter(events) return self.room_ephemeral_filter.filter(events)
def filter_room_private_user_data(self, events): def filter_room_account_data(self, events):
return self.room_private_user_data.filter(events) return self.room_account_data.filter(events)
class Filter(object): class Filter(object):

View File

@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
import sys import sys
from synapse.rest import ClientRestResource
sys.dont_write_bytecode = True sys.dont_write_bytecode = True
from synapse.python_dependencies import ( from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError check_requirements, DEPENDENCY_LINKS, MissingRequirementError
@ -53,15 +55,13 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import ( from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX, SERVER_KEY_PREFIX, MEDIA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_V2_PREFIX,
) )
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events from synapse import events
@ -92,11 +92,8 @@ class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_client_resource(self):
return ClientV1RestResource(self) return ClientRestResource(self)
def build_resource_for_client_v2_alpha(self):
return ClientV2AlphaRestResource(self)
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
@ -179,16 +176,15 @@ class SynapseHomeServer(HomeServer):
for res in listener_config["resources"]: for res in listener_config["resources"]:
for name in res["names"]: for name in res["names"]:
if name == "client": if name == "client":
client_resource = self.get_client_resource()
if res["compress"]: if res["compress"]:
client_v1 = gz_wrap(self.get_resource_for_client()) client_resource = gz_wrap(client_resource)
client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
else:
client_v1 = self.get_resource_for_client()
client_v2 = self.get_resource_for_client_v2_alpha()
resources.update({ resources.update({
CLIENT_PREFIX: client_v1, "/_matrix/client/api/v1": client_resource,
CLIENT_V2_ALPHA_PREFIX: client_v2, "/_matrix/client/r0": client_resource,
"/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource,
}) })
if name == "federation": if name == "federation":
@ -499,13 +495,28 @@ class SynapseRequest(Request):
self.start_time = int(time.time() * 1000) self.start_time = int(time.time() * 1000)
def finished_processing(self): def finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
db_txn_duration = context.db_txn_duration
except:
ru_utime, ru_stime = (0, 0)
db_txn_count, db_txn_duration = (0, 0)
self.site.access_logger.info( self.site.access_logger.info(
"%s - %s - {%s}" "%s - %s - {%s}"
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"", " Processed request: %dms (%dms, %dms) (%dms/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(), self.getClientIP(),
self.site.site_tag, self.site.site_tag,
self.authenticated_entity, self.authenticated_entity,
int(time.time() * 1000) - self.start_time, int(time.time() * 1000) - self.start_time,
int(ru_utime * 1000),
int(ru_stime * 1000),
int(db_txn_duration * 1000),
int(db_txn_count),
self.sentLength, self.sentLength,
self.code, self.code,
self.method, self.method,

View File

@ -25,18 +25,29 @@ class ConfigError(Exception):
pass pass
# We split these messages out to allow packages to override with package
# specific instructions.
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
Please opt in or out of reporting anonymized homeserver usage statistics, by
setting the `report_stats` key in your config file to either True or False.
"""
MISSING_REPORT_STATS_SPIEL = """\
We would really appreciate it if you could help our project out by reporting
anonymized usage statistics from your homeserver. Only very basic aggregate
data (e.g. number of users) will be reported, but it helps us to track the
growth of the Matrix community, and helps us to make Matrix a success, as well
as to convince other networks that they should peer with us.
Thank you.
"""
MISSING_SERVER_NAME = """\
Missing mandatory `server_name` config option.
"""
class Config(object): class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by"
" reporting anonymized usage statistics from your homeserver. Only very"
" basic aggregate data (e.g. number of users) will be reported, but it"
" helps us to track the growth of the Matrix community, and helps us to"
" make Matrix a success, as well as to convince other networks that they"
" should peer with us."
"\nThank you."
)
@staticmethod @staticmethod
def parse_size(value): def parse_size(value):
if isinstance(value, int) or isinstance(value, long): if isinstance(value, int) or isinstance(value, long):
@ -215,7 +226,7 @@ class Config(object):
if config_args.report_stats is None: if config_args.report_stats is None:
config_parser.error( config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" + "Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel MISSING_REPORT_STATS_SPIEL
) )
if not config_files: if not config_files:
config_parser.error( config_parser.error(
@ -290,6 +301,10 @@ class Config(object):
yaml_config = cls.read_config_file(config_file) yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config) specified_config.update(yaml_config)
if "server_name" not in specified_config:
sys.stderr.write("\n" + MISSING_SERVER_NAME + "\n")
sys.exit(1)
server_name = specified_config["server_name"] server_name = specified_config["server_name"]
_, config = obj.generate_config( _, config = obj.generate_config(
config_dir_path=config_dir_path, config_dir_path=config_dir_path,
@ -299,11 +314,8 @@ class Config(object):
config.update(specified_config) config.update(specified_config)
if "report_stats" not in config: if "report_stats" not in config:
sys.stderr.write( sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage " "\n" + MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
"statistics, by setting the report_stats key in your config file " MISSING_REPORT_STATS_SPIEL + "\n")
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1) sys.exit(1)
if generate_keys: if generate_keys:

View File

@ -27,10 +27,12 @@ class CasConfig(Config):
if cas_config: if cas_config:
self.cas_enabled = cas_config.get("enabled", True) self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"] self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
self.cas_required_attributes = cas_config.get("required_attributes", {}) self.cas_required_attributes = cas_config.get("required_attributes", {})
else: else:
self.cas_enabled = False self.cas_enabled = False
self.cas_server_url = None self.cas_server_url = None
self.cas_service_url = None
self.cas_required_attributes = {} self.cas_required_attributes = {}
def default_config(self, config_dir_path, server_name, **kwargs): def default_config(self, config_dir_path, server_name, **kwargs):
@ -39,6 +41,7 @@ class CasConfig(Config):
#cas_config: #cas_config:
# enabled: true # enabled: true
# server_url: "https://cas-server.com" # server_url: "https://cas-server.com"
# service_url: "https://homesever.domain.com:8448"
# #required_attributes: # #required_attributes:
# # name: value # # name: value
""" """

View File

@ -133,6 +133,7 @@ class ServerConfig(Config):
# The domain name of the server, with optional explicit port. # The domain name of the server, with optional explicit port.
# This is used by remote servers to connect to this server, # This is used by remote servers to connect to this server,
# e.g. matrix.org, localhost:8080, etc. # e.g. matrix.org, localhost:8080, etc.
# This is also the last part of your UserID.
server_name: "%(server_name)s" server_name: "%(server_name)s"
# When running as a daemon, the file to store the pid in # When running as a daemon, the file to store the pid in

View File

@ -381,11 +381,6 @@ class Keyring(object):
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids, def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name, perspective_name,
perspective_keys): perspective_keys):
limiter = yield get_retry_limiter(
perspective_name, self.clock, self.store
)
with limiter:
# TODO(mark): Set the minimum_valid_until_ts to that needed by # TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating # the events being validated or the current time if validating
# an incoming request. # an incoming request.
@ -402,6 +397,7 @@ class Keyring(object):
for server_name, key_ids in server_names_and_key_ids for server_name, key_ids in server_names_and_key_ids
} }
}, },
long_retries=True,
) )
keys = {} keys = {}

View File

@ -100,22 +100,20 @@ def format_event_raw(d):
def format_event_for_client_v1(d): def format_event_for_client_v1(d):
d["user_id"] = d.pop("sender", None) d = format_event_for_client_v2(d)
move_keys = ( sender = d.get("sender")
if sender is not None:
d["user_id"] = sender
copy_keys = (
"age", "redacted_because", "replaces_state", "prev_content", "age", "redacted_because", "replaces_state", "prev_content",
"invite_room_state", "invite_room_state",
) )
for key in move_keys: for key in copy_keys:
if key in d["unsigned"]: if key in d["unsigned"]:
d[key] = d["unsigned"][key] d[key] = d["unsigned"][key]
drop_keys = (
"auth_events", "prev_events", "hashes", "signatures", "depth",
"unsigned", "origin", "prev_state"
)
for key in drop_keys:
d.pop(key, None)
return d return d
@ -129,10 +127,9 @@ def format_event_for_client_v2(d):
return d return d
def format_event_for_client_v2_without_event_id(d): def format_event_for_client_v2_without_room_id(d):
d = format_event_for_client_v2(d) d = format_event_for_client_v2(d)
d.pop("room_id", None) d.pop("room_id", None)
d.pop("event_id", None)
return d return d

View File

@ -136,6 +136,7 @@ class TransportLayerClient(object):
path=PREFIX + "/send/%s/" % transaction.transaction_id, path=PREFIX + "/send/%s/" % transaction.transaction_id,
data=json_data, data=json_data,
json_data_callback=json_data_callback, json_data_callback=json_data_callback,
long_retries=True,
) )
logger.debug( logger.debug(

View File

@ -165,7 +165,7 @@ class BaseFederationServlet(object):
if code is None: if code is None:
continue continue
server.register_path(method, pattern, self._wrap(code)) server.register_paths(method, (pattern,), self._wrap(code))
class FederationSendServlet(BaseFederationServlet): class FederationSendServlet(BaseFederationServlet):

View File

@ -92,6 +92,14 @@ class BaseHandler(object):
membership_event = state.get((EventTypes.Member, user_id), None) membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event: if membership_event:
was_forgotten_at_event = yield self.store.was_forgotten_at(
membership_event.state_key,
membership_event.room_id,
membership_event.event_id
)
if was_forgotten_at_event:
membership = None
else:
membership = membership_event.membership membership = membership_event.membership
else: else:
membership = None membership = None

View File

@ -16,22 +16,23 @@
from twisted.internet import defer from twisted.internet import defer
class PrivateUserDataEventSource(object): class AccountDataEventSource(object):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):
return self.store.get_max_private_user_data_stream_id() return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs): def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key
current_stream_id = yield self.store.get_max_private_user_data_stream_id() current_stream_id = yield self.store.get_max_account_data_stream_id()
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
results = [] results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items(): for room_id, room_tags in tags.items():
results.append({ results.append({
"type": "m.tag", "type": "m.tag",
@ -39,6 +40,24 @@ class PrivateUserDataEventSource(object):
"room_id": room_id, "room_id": room_id,
}) })
account_data, room_account_data = (
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
)
for account_data_type, content in account_data.items():
results.append({
"type": account_data_type,
"content": content,
})
for room_id, account_data in room_account_data.items():
for account_data_type, content in account_data.items():
results.append({
"type": account_data_type,
"content": content,
"room_id": room_id,
})
defer.returnValue((results, current_stream_id)) defer.returnValue((results, current_stream_id))
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -30,34 +30,27 @@ class AdminHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_whois(self, user): def get_whois(self, user):
res = yield self.store.get_user_ip_and_agents(user) connections = []
d = {} sessions = yield self.store.get_user_ip_and_agents(user)
for r in res: for session in sessions:
# Note that device_id is always None connections.append({
device = d.setdefault(r["device_id"], {}) "ip": session["ip"],
session = device.setdefault(r["access_token"], []) "last_seen": session["last_seen"],
session.append({ "user_agent": session["user_agent"],
"ip": r["ip"],
"user_agent": r["user_agent"],
"last_seen": r["last_seen"],
}) })
ret = { ret = {
"user_id": user.to_string(), "user_id": user.to_string(),
"devices": [ "devices": {
{ "": {
"device_id": k,
"sessions": [ "sessions": [
{ {
# "access_token": x, TODO (erikj) "connections": connections,
"connections": y,
} }
for x, y in v.items()
] ]
} },
for k, v in d.items() },
],
} }
defer.returnValue(ret) defer.returnValue(ret)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.types import UserID from synapse.types import UserID
from synapse.api.errors import LoginError, Codes from synapse.api.errors import AuthError, LoginError, Codes
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError from twisted.web.client import PartialDownloadError
@ -46,6 +46,7 @@ class AuthHandler(BaseHandler):
} }
self.bcrypt_rounds = hs.config.bcrypt_rounds self.bcrypt_rounds = hs.config.bcrypt_rounds
self.sessions = {} self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
@defer.inlineCallbacks @defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip): def check_auth(self, flows, clientdict, clientip):
@ -297,10 +298,11 @@ class AuthHandler(BaseHandler):
defer.returnValue((user_id, access_token, refresh_token)) defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks @defer.inlineCallbacks
def login_with_cas_user_id(self, user_id): def get_login_tuple_for_user_id(self, user_id):
""" """
Authenticates the user with the given user ID, Gets login tuple for the user with the given user ID.
intended to have been captured from a CAS response The user is assumed to have been authenticated by some other
machanism (e.g. CAS)
Args: Args:
user_id (str): User ID user_id (str): User ID
@ -393,6 +395,23 @@ class AuthHandler(BaseHandler):
)) ))
return m.serialize() return m.serialize()
def generate_short_term_login_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + (2 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
auth_api = self.hs.get_auth()
auth_api.validate_macaroon(macaroon, "login", True)
return self._get_user_from_macaroon(macaroon)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
def _generate_base_macaroon(self, user_id): def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon( macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name, location=self.hs.config.server_name,
@ -402,6 +421,16 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon return macaroon
def _get_user_from_macaroon(self, macaroon):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix):]
raise AuthError(
self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
errcode=Codes.UNKNOWN_TOKEN
)
@defer.inlineCallbacks @defer.inlineCallbacks
def set_password(self, user_id, newpassword): def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword) password_hash = self.hash(newpassword)

View File

@ -28,6 +28,18 @@ import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def started_user_eventstream(distributor, user):
return distributor.fire("started_user_eventstream", user)
def stopped_user_eventstream(distributor, user):
return distributor.fire("stopped_user_eventstream", user)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user, room_id)
class EventStreamHandler(BaseHandler): class EventStreamHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -66,7 +78,7 @@ class EventStreamHandler(BaseHandler):
except: except:
logger.exception("Failed to cancel event timer") logger.exception("Failed to cancel event timer")
else: else:
yield self.distributor.fire("started_user_eventstream", user) yield started_user_eventstream(self.distributor, user)
self._streams_per_user[user] += 1 self._streams_per_user[user] += 1
@ -89,7 +101,7 @@ class EventStreamHandler(BaseHandler):
self._stop_timer_per_user.pop(user, None) self._stop_timer_per_user.pop(user, None)
return self.distributor.fire("stopped_user_eventstream", user) return stopped_user_eventstream(self.distributor, user)
logger.debug("Scheduling _later: for %s", user) logger.debug("Scheduling _later: for %s", user)
self._stop_timer_per_user[user] = ( self._stop_timer_per_user[user] = (
@ -120,9 +132,7 @@ class EventStreamHandler(BaseHandler):
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
if is_guest: if is_guest:
yield self.distributor.fire( yield user_joined_room(self.distributor, auth_user, room_id)
"user_joined_room", user=auth_user, room_id=room_id
)
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,

View File

@ -44,6 +44,10 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user, room_id)
class FederationHandler(BaseHandler): class FederationHandler(BaseHandler):
"""Handles events that originated from federation. """Handles events that originated from federation.
Responsible for: Responsible for:
@ -60,10 +64,7 @@ class FederationHandler(BaseHandler):
self.hs = hs self.hs = hs
self.distributor.observe( self.distributor.observe("user_joined_room", self.user_joined_room)
"user_joined_room",
self._on_user_joined
)
self.waiting_for_join_list = {} self.waiting_for_join_list = {}
@ -176,7 +177,7 @@ class FederationHandler(BaseHandler):
) )
try: try:
_, event_stream_id, max_stream_id = yield self._handle_new_event( context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, origin,
event, event,
state=state, state=state,
@ -233,10 +234,13 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.membership == Membership.JOIN: if event.membership == Membership.JOIN:
prev_state = context.current_state.get((event.type, event.state_key))
if not prev_state or prev_state.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally
# joined the room. Don't bother if the user is just
# changing their profile info.
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield self.distributor.fire( yield user_joined_room(self.distributor, user, event.room_id)
"user_joined_room", user=user, room_id=event.room_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events): def _filter_events_for_server(self, server_name, room_id, events):
@ -733,9 +737,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN: if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key) user = UserID.from_string(event.state_key)
yield self.distributor.fire( yield user_joined_room(self.distributor, user, event.room_id)
"user_joined_room", user=user, room_id=event.room_id
)
new_pdu = event new_pdu = event
@ -1082,7 +1084,7 @@ class FederationHandler(BaseHandler):
return self.store.get_min_depth(context) return self.store.get_min_depth(context)
@log_function @log_function
def _on_user_joined(self, user, room_id): def user_joined_room(self, user, room_id):
waiters = self.waiting_for_join_list.get( waiters = self.waiting_for_join_list.get(
(user.to_string(), room_id), (user.to_string(), room_id),
[] []

View File

@ -20,7 +20,6 @@ from synapse.api.errors import (
CodeMessageException CodeMessageException
) )
from ._base import BaseHandler from ._base import BaseHandler
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
@ -35,13 +34,12 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(IdentityHandler, self).__init__(hs) super(IdentityHandler, self).__init__(hs)
self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks @defer.inlineCallbacks
def threepid_from_creds(self, creds): def threepid_from_creds(self, creds):
yield run_on_reactor() yield run_on_reactor()
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable! # XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090'] # trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im'] trustedIdServers = ['matrix.org', 'vector.im']
@ -67,7 +65,7 @@ class IdentityHandler(BaseHandler):
data = {} data = {}
try: try:
data = yield http_client.get_json( data = yield self.http_client.get_json(
"https://%s%s" % ( "https://%s%s" % (
id_server, id_server,
"/_matrix/identity/api/v1/3pid/getValidated3pid" "/_matrix/identity/api/v1/3pid/getValidated3pid"
@ -85,7 +83,6 @@ class IdentityHandler(BaseHandler):
def bind_threepid(self, creds, mxid): def bind_threepid(self, creds, mxid):
yield run_on_reactor() yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid) logger.debug("binding threepid %r to %s", creds, mxid)
http_client = SimpleHttpClient(self.hs)
data = None data = None
if 'id_server' in creds: if 'id_server' in creds:
@ -103,7 +100,7 @@ class IdentityHandler(BaseHandler):
raise SynapseError(400, "No client_secret in creds") raise SynapseError(400, "No client_secret in creds")
try: try:
data = yield http_client.post_urlencoded_get_json( data = yield self.http_client.post_urlencoded_get_json(
"https://%s%s" % ( "https://%s%s" % (
id_server, "/_matrix/identity/api/v1/3pid/bind" id_server, "/_matrix/identity/api/v1/3pid/bind"
), ),
@ -121,7 +118,6 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor() yield run_on_reactor()
http_client = SimpleHttpClient(self.hs)
params = { params = {
'email': email, 'email': email,
@ -131,7 +127,7 @@ class IdentityHandler(BaseHandler):
params.update(kwargs) params.update(kwargs)
try: try:
data = yield http_client.post_urlencoded_get_json( data = yield self.http_client.post_urlencoded_get_json(
"https://%s%s" % ( "https://%s%s" % (
id_server, id_server,
"/_matrix/identity/api/v1/validate/email/requestToken" "/_matrix/identity/api/v1/validate/email/requestToken"

View File

@ -26,11 +26,17 @@ from synapse.types import UserID, RoomStreamToken, StreamToken
from ._base import BaseHandler from ._base import BaseHandler
from canonicaljson import encode_canonical_json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class MessageHandler(BaseHandler): class MessageHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -195,10 +201,8 @@ class MessageHandler(BaseHandler):
if membership == Membership.JOIN: if membership == Membership.JOIN:
joinee = UserID.from_string(builder.state_key) joinee = UserID.from_string(builder.state_key)
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield self.distributor.fire( yield collect_presencelike_data(
"collect_presencelike_data", self.distributor, joinee, builder.content
joinee,
builder.content
) )
if token_id is not None: if token_id is not None:
@ -211,6 +215,16 @@ class MessageHandler(BaseHandler):
builder=builder, builder=builder,
) )
if event.is_state():
prev_state = context.current_state.get((event.type, event.state_key))
if prev_state and event.user_id == prev_state.user_id:
prev_content = encode_canonical_json(prev_state.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
# Duplicate suppression for state updates with same sender
# and content.
defer.returnValue(prev_state)
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
member_handler = self.hs.get_handlers().room_member_handler member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.change_membership(event, context, is_guest=is_guest) yield member_handler.change_membership(event, context, is_guest=is_guest)
@ -359,6 +373,10 @@ class MessageHandler(BaseHandler):
tags_by_room = yield self.store.get_tags_for_user(user_id) tags_by_room = yield self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = (
yield self.store.get_account_data_for_user(user_id)
)
public_room_ids = yield self.store.get_public_room_ids() public_room_ids = yield self.store.get_public_room_ids()
limit = pagin_config.limit limit = pagin_config.limit
@ -436,14 +454,22 @@ class MessageHandler(BaseHandler):
for c in current_state.values() for c in current_state.values()
] ]
private_user_data = [] account_data_events = []
tags = tags_by_room.get(event.room_id) tags = tags_by_room.get(event.room_id)
if tags: if tags:
private_user_data.append({ account_data_events.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
d["private_user_data"] = private_user_data
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
d["account_data"] = account_data_events
except: except:
logger.exception("Failed to get snapshot") logger.exception("Failed to get snapshot")
@ -456,9 +482,17 @@ class MessageHandler(BaseHandler):
consumeErrors=True consumeErrors=True
).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError)
account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
ret = { ret = {
"rooms": rooms_ret, "rooms": rooms_ret,
"presence": presence, "presence": presence,
"account_data": account_data_events,
"receipts": receipt, "receipts": receipt,
"end": now_token.to_string(), "end": now_token.to_string(),
} }
@ -498,14 +532,22 @@ class MessageHandler(BaseHandler):
user_id, room_id, pagin_config, membership, member_event_id, is_guest user_id, room_id, pagin_config, membership, member_event_id, is_guest
) )
private_user_data = [] account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id) tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags: if tags:
private_user_data.append({ account_data_events.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
result["private_user_data"] = private_user_data
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
result["account_data"] = account_data_events
defer.returnValue(result) defer.returnValue(result)
@ -588,8 +630,6 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence(): def get_presence():
states = {}
if not is_guest:
states = yield presence_handler.get_states( states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members], target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user, auth_user=auth_user,
@ -599,12 +639,19 @@ class MessageHandler(BaseHandler):
defer.returnValue(states.values()) defer.returnValue(states.values())
@defer.inlineCallbacks
def get_receipts():
receipts_handler = self.hs.get_handlers().receipts_handler receipts_handler = self.hs.get_handlers().receipts_handler
receipts = yield receipts_handler.get_receipts_for_room(
room_id,
now_token.receipt_key
)
defer.returnValue(receipts)
presence, receipts, (messages, token) = yield defer.gatherResults( presence, receipts, (messages, token) = yield defer.gatherResults(
[ [
get_presence(), get_presence(),
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key), get_receipts(),
self.store.get_recent_events_for_room( self.store.get_recent_events_for_room(
room_id, room_id,
limit=limit, limit=limit,

View File

@ -62,6 +62,14 @@ def partitionbool(l, func):
return ret.get(True, []), ret.get(False, []) return ret.get(True, []), ret.get(False, [])
def user_presence_changed(distributor, user, statuscache):
return distributor.fire("user_presence_changed", user, statuscache)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class PresenceHandler(BaseHandler): class PresenceHandler(BaseHandler):
STATE_LEVELS = { STATE_LEVELS = {
@ -361,9 +369,7 @@ class PresenceHandler(BaseHandler):
yield self.store.set_presence_state( yield self.store.set_presence_state(
target_user.localpart, state_to_store target_user.localpart, state_to_store
) )
yield self.distributor.fire( yield collect_presencelike_data(self.distributor, target_user, state)
"collect_presencelike_data", target_user, state
)
if now_level > was_level: if now_level > was_level:
state["last_active"] = self.clock.time_msec() state["last_active"] = self.clock.time_msec()
@ -467,7 +473,7 @@ class PresenceHandler(BaseHandler):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, observer_user, observed_user): def send_presence_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user""" """Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -878,7 +884,7 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids, room_ids=room_ids,
statuscache=statuscache, statuscache=statuscache,
) )
yield self.distributor.fire("user_presence_changed", user, statuscache) yield user_presence_changed(self.distributor, user, statuscache)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
@ -1116,9 +1122,7 @@ class PresenceHandler(BaseHandler):
self._user_cachemap[user].get_state()["last_active"] self._user_cachemap[user].get_state()["last_active"]
) )
yield self.distributor.fire( yield collect_presencelike_data(self.distributor, user, state)
"collect_presencelike_data", user, state
)
if "last_active" in state: if "last_active" in state:
state = dict(state) state = dict(state)

View File

@ -28,6 +28,14 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def changed_presencelike_data(distributor, user, state):
return distributor.fire("changed_presencelike_data", user, state)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class ProfileHandler(BaseHandler): class ProfileHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -95,11 +103,9 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )
yield self.distributor.fire( yield changed_presencelike_data(self.distributor, target_user, {
"changed_presencelike_data", target_user, {
"displayname": new_displayname, "displayname": new_displayname,
} })
)
yield self._update_join_states(target_user) yield self._update_join_states(target_user)
@ -144,11 +150,9 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url target_user.localpart, new_avatar_url
) )
yield self.distributor.fire( yield changed_presencelike_data(self.distributor, target_user, {
"changed_presencelike_data", target_user, {
"avatar_url": new_avatar_url, "avatar_url": new_avatar_url,
} })
)
yield self._update_join_states(target_user) yield self._update_join_states(target_user)
@ -208,9 +212,7 @@ class ProfileHandler(BaseHandler):
"membership": Membership.JOIN, "membership": Membership.JOIN,
} }
yield self.distributor.fire( yield collect_presencelike_data(self.distributor, user, content)
"collect_presencelike_data", user, content
)
msg_handler = self.hs.get_handlers().message_handler msg_handler = self.hs.get_handlers().message_handler
try: try:

View File

@ -31,6 +31,10 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
class RegistrationHandler(BaseHandler): class RegistrationHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
@ -38,6 +42,7 @@ class RegistrationHandler(BaseHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("registered_user") self.distributor.declare("registered_user")
self.captch_client = CaptchaServerHttpClient(hs)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart): def check_username(self, localpart):
@ -98,7 +103,7 @@ class RegistrationHandler(BaseHandler):
password_hash=password_hash password_hash=password_hash
) )
yield self.distributor.fire("registered_user", user) yield registered_user(self.distributor, user)
else: else:
# autogen a random user ID # autogen a random user ID
attempts = 0 attempts = 0
@ -117,7 +122,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash=password_hash) password_hash=password_hash)
self.distributor.fire("registered_user", user) yield registered_user(self.distributor, user)
except SynapseError: except SynapseError:
# if user id is taken, just generate another # if user id is taken, just generate another
user_id = None user_id = None
@ -167,7 +172,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash="" password_hash=""
) )
self.distributor.fire("registered_user", user) registered_user(self.distributor, user)
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
@defer.inlineCallbacks @defer.inlineCallbacks
@ -215,7 +220,7 @@ class RegistrationHandler(BaseHandler):
token=token, token=token,
password_hash=None password_hash=None
) )
yield self.distributor.fire("registered_user", user) yield registered_user(self.distributor, user)
except Exception, e: except Exception, e:
yield self.store.add_access_token_to_user(user_id, token) yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors # Ignore Registration errors
@ -302,10 +307,7 @@ class RegistrationHandler(BaseHandler):
""" """
Used only by c/s api v1 Used only by c/s api v1
""" """
# TODO: get this from the homeserver rather than creating a new one for data = yield self.captcha_client.post_urlencoded_get_raw(
# each request
client = CaptchaServerHttpClient(self.hs)
data = yield client.post_urlencoded_get_raw(
"http://www.google.com:80/recaptcha/api/verify", "http://www.google.com:80/recaptcha/api/verify",
args={ args={
'privatekey': private_key, 'privatekey': private_key,

View File

@ -41,6 +41,18 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://" id_server_scheme = "https://"
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
def user_left_room(distributor, user, room_id):
return distributor.fire("user_left_room", user=user, room_id=room_id)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user=user, room_id=room_id)
class RoomCreationHandler(BaseHandler): class RoomCreationHandler(BaseHandler):
PRESETS_DICT = { PRESETS_DICT = {
@ -438,9 +450,7 @@ class RoomMemberHandler(BaseHandler):
if prev_state and prev_state.membership == Membership.JOIN: if prev_state and prev_state.membership == Membership.JOIN:
user = UserID.from_string(event.user_id) user = UserID.from_string(event.user_id)
self.distributor.fire( user_left_room(self.distributor, user, event.room_id)
"user_left_room", user=user, room_id=event.room_id
)
defer.returnValue({"room_id": room_id}) defer.returnValue({"room_id": room_id})
@ -458,9 +468,7 @@ class RoomMemberHandler(BaseHandler):
raise SynapseError(404, "No known servers") raise SynapseError(404, "No known servers")
# If event doesn't include a display name, add one. # If event doesn't include a display name, add one.
yield self.distributor.fire( yield collect_presencelike_data(self.distributor, joinee, content)
"collect_presencelike_data", joinee, content
)
content.update({"membership": Membership.JOIN}) content.update({"membership": Membership.JOIN})
builder = self.event_builder_factory.new({ builder = self.event_builder_factory.new({
@ -517,10 +525,13 @@ class RoomMemberHandler(BaseHandler):
do_auth=do_auth, do_auth=do_auth,
) )
prev_state = context.current_state.get((event.type, event.state_key))
if not prev_state or prev_state.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
user = UserID.from_string(event.user_id) user = UserID.from_string(event.user_id)
yield self.distributor.fire( yield user_joined_room(self.distributor, user, room_id)
"user_joined_room", user=user, room_id=room_id
)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_inviter(self, event): def get_inviter(self, event):
@ -743,6 +754,9 @@ class RoomMemberHandler(BaseHandler):
) )
defer.returnValue((token, public_key, key_validity_url, display_name)) defer.returnValue((token, public_key, key_validity_url, display_name))
def forget(self, user, room_id):
self.store.forget(user.to_string(), room_id)
class RoomListHandler(BaseHandler): class RoomListHandler(BaseHandler):

View File

@ -17,13 +17,14 @@ from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
from synapse.api.constants import Membership from synapse.api.constants import Membership, EventTypes
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
import itertools
import logging import logging
@ -79,6 +80,9 @@ class SearchHandler(BaseHandler):
# What to order results by (impacts whether pagination can be doen) # What to order results by (impacts whether pagination can be doen)
order_by = room_cat.get("order_by", "rank") order_by = room_cat.get("order_by", "rank")
# Return the current state of the rooms?
include_state = room_cat.get("include_state", False)
# Include context around each event? # Include context around each event?
event_context = room_cat.get( event_context = room_cat.get(
"event_context", None "event_context", None
@ -96,6 +100,10 @@ class SearchHandler(BaseHandler):
after_limit = int(event_context.get( after_limit = int(event_context.get(
"after_limit", 5 "after_limit", 5
)) ))
# Return the historic display name and avatar for the senders
# of the events?
include_profile = bool(event_context.get("include_profile", False))
except KeyError: except KeyError:
raise SynapseError(400, "Invalid search query") raise SynapseError(400, "Invalid search query")
@ -123,6 +131,17 @@ class SearchHandler(BaseHandler):
if batch_group == "room_id": if batch_group == "room_id":
room_ids.intersection_update({batch_group_key}) room_ids.intersection_update({batch_group_key})
if not room_ids:
defer.returnValue({
"search_categories": {
"room_events": {
"results": [],
"count": 0,
"highlights": [],
}
}
})
rank_map = {} # event_id -> rank of event rank_map = {} # event_id -> rank of event
allowed_events = [] allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable room_groups = {} # Holds result of grouping by room, if applicable
@ -131,11 +150,18 @@ class SearchHandler(BaseHandler):
# Holds the next_batch for the entire result set if one of those exists # Holds the next_batch for the entire result set if one of those exists
global_next_batch = None global_next_batch = None
highlights = set()
if order_by == "rank": if order_by == "rank":
results = yield self.store.search_msgs( search_result = yield self.store.search_msgs(
room_ids, search_term, keys room_ids, search_term, keys
) )
if search_result["highlights"]:
highlights.update(search_result["highlights"])
results = search_result["results"]
results_map = {r["event"].event_id: r for r in results} results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results}) rank_map.update({r["event"].event_id: r["rank"] for r in results})
@ -163,27 +189,26 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id) s["results"].append(e.event_id)
elif order_by == "recent": elif order_by == "recent":
# In this case we specifically loop through each room as the given
# limit applies to each room, rather than a global list.
# This is not necessarilly a good idea.
for room_id in room_ids:
room_events = [] room_events = []
if batch_group == "room_id" and batch_group_key == room_id:
pagination_token = batch_token
else:
pagination_token = None
i = 0 i = 0
pagination_token = batch_token
# We keep looping and we keep filtering until we reach the limit # We keep looping and we keep filtering until we reach the limit
# or we run out of things. # or we run out of things.
# But only go around 5 times since otherwise synapse will be sad. # But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5: while len(room_events) < search_filter.limit() and i < 5:
i += 1 i += 1
results = yield self.store.search_room( search_result = yield self.store.search_rooms(
room_id, search_term, keys, search_filter.limit() * 2, room_ids, search_term, keys, search_filter.limit() * 2,
pagination_token=pagination_token, pagination_token=pagination_token,
) )
if search_result["highlights"]:
highlights.update(search_result["highlights"])
results = search_result["results"]
results_map = {r["event"].event_id: r for r in results} results_map = {r["event"].event_id: r for r in results}
rank_map.update({r["event"].event_id: r["rank"] for r in results}) rank_map.update({r["event"].event_id: r["rank"] for r in results})
@ -205,39 +230,36 @@ class SearchHandler(BaseHandler):
else: else:
pagination_token = results[-1]["pagination_token"] pagination_token = results[-1]["pagination_token"]
if room_events: for event in room_events:
res = results_map[room_events[-1].event_id] group = room_groups.setdefault(event.room_id, {
pagination_token = res["pagination_token"] "results": [],
})
group["results"].append(event.event_id)
group = room_groups.setdefault(room_id, {}) if room_events and len(room_events) >= search_filter.limit():
if pagination_token: last_event_id = room_events[-1].event_id
next_batch = encode_base64("%s\n%s\n%s" % ( pagination_token = results_map[last_event_id]["pagination_token"]
# We want to respect the given batch group and group keys so
# that if people blindly use the top level `next_batch` token
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64("%s\n%s\n%s" % (
batch_group, batch_group_key, pagination_token
))
else:
global_next_batch = encode_base64("%s\n%s\n%s" % (
"all", "", pagination_token
))
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % (
"room_id", room_id, pagination_token "room_id", room_id, pagination_token
)) ))
group["next_batch"] = next_batch
if batch_token:
global_next_batch = next_batch
group["results"] = [e.event_id for e in room_events]
group["order"] = max(
e.origin_server_ts/1000 for e in room_events
if hasattr(e, "origin_server_ts")
)
allowed_events.extend(room_events) allowed_events.extend(room_events)
# Normalize the group orders
if room_groups:
if len(room_groups) > 1:
mx = max(g["order"] for g in room_groups.values())
mn = min(g["order"] for g in room_groups.values())
for g in room_groups.values():
g["order"] = (g["order"] - mn) * 1.0 / (mx - mn)
else:
room_groups.values()[0]["order"] = 1
else: else:
# We should never get here due to the guard earlier. # We should never get here due to the guard earlier.
raise NotImplementedError() raise NotImplementedError()
@ -269,6 +291,33 @@ class SearchHandler(BaseHandler):
"room_key", res["end"] "room_key", res["end"]
).to_string() ).to_string()
if include_profile:
senders = set(
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
)
if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
else:
last_event_id = event.event_id
state = yield self.store.get_state_for_event(
last_event_id,
types=[(EventTypes.Member, sender) for sender in senders]
)
res["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
}
for s in state.values()
if s.type == EventTypes.Member and s.state_key in senders
}
contexts[event.event_id] = res contexts[event.event_id] = res
else: else:
contexts = {} contexts = {}
@ -287,20 +336,37 @@ class SearchHandler(BaseHandler):
for e in context["events_after"] for e in context["events_after"]
] ]
results = { state_results = {}
e.event_id: { if include_state:
rooms = set(e.room_id for e in allowed_events)
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
state_results[room_id] = state.values()
state_results.values()
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
results = [
{
"rank": rank_map[e.event_id], "rank": rank_map[e.event_id],
"result": serialize_event(e, time_now), "result": serialize_event(e, time_now),
"context": contexts.get(e.event_id, {}), "context": contexts.get(e.event_id, {}),
} }
for e in allowed_events for e in allowed_events
} ]
logger.info("Found %d results", len(results))
rooms_cat_res = { rooms_cat_res = {
"results": results, "results": results,
"count": len(results) "count": len(results),
"highlights": list(highlights),
}
if state_results:
rooms_cat_res["state"] = {
room_id: [serialize_event(e, time_now) for e in state]
for room_id, state in state_results.items()
} }
if room_groups and "room_id" in group_keys: if room_groups and "room_id" in group_keys:

View File

@ -51,7 +51,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
"ephemeral", "ephemeral",
"private_user_data", "account_data",
])): ])):
__slots__ = [] __slots__ = []
@ -63,7 +63,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
self.timeline self.timeline
or self.state or self.state
or self.ephemeral or self.ephemeral
or self.private_user_data or self.account_data
) )
@ -71,7 +71,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
"room_id", # str "room_id", # str
"timeline", # TimelineBatch "timeline", # TimelineBatch
"state", # dict[(str, str), FrozenEvent] "state", # dict[(str, str), FrozenEvent]
"private_user_data", "account_data",
])): ])):
__slots__ = [] __slots__ = []
@ -82,7 +82,7 @@ class ArchivedSyncResult(collections.namedtuple("JoinedSyncResult", [
return bool( return bool(
self.timeline self.timeline
or self.state or self.state
or self.private_user_data or self.account_data
) )
@ -100,6 +100,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
class SyncResult(collections.namedtuple("SyncResult", [ class SyncResult(collections.namedtuple("SyncResult", [
"next_batch", # Token for the next sync "next_batch", # Token for the next sync
"presence", # List of presence events for the user. "presence", # List of presence events for the user.
"account_data", # List of account_data events for the user.
"joined", # JoinedSyncResult for each joined room. "joined", # JoinedSyncResult for each joined room.
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
@ -185,13 +186,19 @@ class SyncHandler(BaseHandler):
pagination_config=pagination_config.get_source_config("presence"), pagination_config=pagination_config.get_source_config("presence"),
key=None key=None
) )
membership_list = (Membership.INVITE, Membership.JOIN)
if sync_config.filter.include_leave:
membership_list += (Membership.LEAVE, Membership.BAN)
room_list = yield self.store.get_rooms_for_user_where_membership_is( room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=sync_config.user.to_string(), user_id=sync_config.user.to_string(),
membership_list=( membership_list=membership_list
Membership.INVITE, )
Membership.JOIN,
Membership.LEAVE, account_data, account_data_by_room = (
Membership.BAN yield self.store.get_account_data_for_user(
sync_config.user.to_string()
) )
) )
@ -211,6 +218,7 @@ class SyncHandler(BaseHandler):
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room, ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
) )
joined.append(room_sync) joined.append(room_sync)
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
@ -230,11 +238,13 @@ class SyncHandler(BaseHandler):
leave_token=leave_token, leave_token=leave_token,
timeline_since_token=timeline_since_token, timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room, tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
) )
archived.append(room_sync) archived.append(room_sync)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data),
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -244,7 +254,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_joined_room(self, room_id, sync_config, def full_state_sync_for_joined_room(self, room_id, sync_config,
now_token, timeline_since_token, now_token, timeline_since_token,
ephemeral_by_room, tags_by_room): ephemeral_by_room, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -261,20 +272,39 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=current_state, state=current_state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
)) ))
def private_user_data_for_room(self, room_id, tags_by_room): def account_data_for_user(self, account_data):
private_user_data = [] account_data_events = []
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
return account_data_events
def account_data_for_room(self, room_id, tags_by_room, account_data_by_room):
account_data_events = []
tags = tags_by_room.get(room_id) tags = tags_by_room.get(room_id)
if tags is not None: if tags is not None:
private_user_data.append({ account_data_events.append({
"type": "m.tag", "type": "m.tag",
"content": {"tags": tags}, "content": {"tags": tags},
}) })
return private_user_data
account_data = account_data_by_room.get(room_id, {})
for account_data_type, content in account_data.items():
account_data_events.append({
"type": account_data_type,
"content": content,
})
return account_data_events
@defer.inlineCallbacks @defer.inlineCallbacks
def ephemeral_by_room(self, sync_config, now_token, since_token=None): def ephemeral_by_room(self, sync_config, now_token, since_token=None):
@ -341,7 +371,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def full_state_sync_for_archived_room(self, room_id, sync_config, def full_state_sync_for_archived_room(self, room_id, sync_config,
leave_event_id, leave_token, leave_event_id, leave_token,
timeline_since_token, tags_by_room): timeline_since_token, tags_by_room,
account_data_by_room):
"""Sync a room for a client which is starting without any state """Sync a room for a client which is starting without any state
Returns: Returns:
A Deferred JoinedSyncResult. A Deferred JoinedSyncResult.
@ -357,8 +388,8 @@ class SyncHandler(BaseHandler):
room_id=room_id, room_id=room_id,
timeline=batch, timeline=batch,
state=leave_state, state=leave_state,
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
)) ))
@ -412,7 +443,14 @@ class SyncHandler(BaseHandler):
tags_by_room = yield self.store.get_updated_tags( tags_by_room = yield self.store.get_updated_tags(
sync_config.user.to_string(), sync_config.user.to_string(),
since_token.private_user_data_key, since_token.account_data_key,
)
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
sync_config.user.to_string(),
since_token.account_data_key,
)
) )
joined = [] joined = []
@ -468,8 +506,8 @@ class SyncHandler(BaseHandler):
), ),
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
) )
logger.debug("Result for room %s: %r", room_id, room_sync) logger.debug("Result for room %s: %r", room_id, room_sync)
@ -492,14 +530,15 @@ class SyncHandler(BaseHandler):
for room_id in joined_room_ids: for room_id in joined_room_ids:
room_sync = yield self.incremental_sync_with_gap_for_room( room_sync = yield self.incremental_sync_with_gap_for_room(
room_id, sync_config, since_token, now_token, room_id, sync_config, since_token, now_token,
ephemeral_by_room, tags_by_room ephemeral_by_room, tags_by_room, account_data_by_room
) )
if room_sync: if room_sync:
joined.append(room_sync) joined.append(room_sync)
for leave_event in leave_events: for leave_event in leave_events:
room_sync = yield self.incremental_sync_for_archived_room( room_sync = yield self.incremental_sync_for_archived_room(
sync_config, leave_event, since_token, tags_by_room sync_config, leave_event, since_token, tags_by_room,
account_data_by_room
) )
archived.append(room_sync) archived.append(room_sync)
@ -510,6 +549,7 @@ class SyncHandler(BaseHandler):
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=presence, presence=presence,
account_data=self.account_data_for_user(account_data),
joined=joined, joined=joined,
invited=invited, invited=invited,
archived=archived, archived=archived,
@ -566,7 +606,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_with_gap_for_room(self, room_id, sync_config, def incremental_sync_with_gap_for_room(self, room_id, sync_config,
since_token, now_token, since_token, now_token,
ephemeral_by_room, tags_by_room): ephemeral_by_room, tags_by_room,
account_data_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the room. Gives the client the most recent events and the changes to the room. Gives the client the most recent events and the changes to
state. state.
@ -605,8 +646,8 @@ class SyncHandler(BaseHandler):
timeline=batch, timeline=batch,
state=state, state=state,
ephemeral=ephemeral_by_room.get(room_id, []), ephemeral=ephemeral_by_room.get(room_id, []),
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
room_id, tags_by_room room_id, tags_by_room, account_data_by_room
), ),
) )
@ -616,7 +657,8 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def incremental_sync_for_archived_room(self, sync_config, leave_event, def incremental_sync_for_archived_room(self, sync_config, leave_event,
since_token, tags_by_room): since_token, tags_by_room,
account_data_by_room):
""" Get the incremental delta needed to bring the client up to date for """ Get the incremental delta needed to bring the client up to date for
the archived room. the archived room.
Returns: Returns:
@ -653,8 +695,8 @@ class SyncHandler(BaseHandler):
room_id=leave_event.room_id, room_id=leave_event.room_id,
timeline=batch, timeline=batch,
state=state_events_delta, state=state_events_delta,
private_user_data=self.private_user_data_for_room( account_data=self.account_data_for_room(
leave_event.room_id, tags_by_room leave_event.room_id, tags_by_room, account_data_by_room
), ),
) )

View File

@ -56,7 +56,8 @@ incoming_responses_counter = metrics.register_counter(
) )
MAX_RETRIES = 4 MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
class MatrixFederationEndpointFactory(object): class MatrixFederationEndpointFactory(object):
@ -103,7 +104,7 @@ class MatrixFederationHttpClient(object):
def _create_request(self, destination, method, path_bytes, def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"", body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True, query_bytes=b"", retry_on_dns_fail=True,
timeout=None): timeout=None, long_retries=False):
""" Creates and sends a request to the given url """ Creates and sends a request to the given url
""" """
headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"User-Agent"] = [self.version_string]
@ -123,7 +124,10 @@ class MatrixFederationHttpClient(object):
# XXX: Would be much nicer to retry only at the transaction-layer # XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place) # (once we have reliable transactions in place)
retries_left = MAX_RETRIES if long_retries:
retries_left = MAX_LONG_RETRIES
else:
retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse( http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "") ("", "", path_bytes, param_bytes, query_bytes, "")
@ -184,8 +188,15 @@ class MatrixFederationHttpClient(object):
) )
if retries_left and not timeout: if retries_left and not timeout:
delay = 5 ** (MAX_RETRIES + 1 - retries_left) if long_retries:
delay = 4 ** (MAX_LONG_RETRIES + 1 - retries_left)
delay = min(delay, 60)
delay *= random.uniform(0.8, 1.4) delay *= random.uniform(0.8, 1.4)
else:
delay = 0.5 * 2 ** (MAX_SHORT_RETRIES - retries_left)
delay = min(delay, 2)
delay *= random.uniform(0.8, 1.4)
yield sleep(delay) yield sleep(delay)
retries_left -= 1 retries_left -= 1
else: else:
@ -236,7 +247,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks @defer.inlineCallbacks
def put_json(self, destination, path, data={}, json_data_callback=None): def put_json(self, destination, path, data={}, json_data_callback=None,
long_retries=False):
""" Sends the specifed json data using PUT """ Sends the specifed json data using PUT
Args: Args:
@ -247,6 +259,8 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to json_data_callback (callable): A callable returning the dict to
use as the request body. use as the request body.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -272,6 +286,7 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"), path.encode("ascii"),
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=long_retries,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
@ -287,7 +302,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(json.loads(body)) defer.returnValue(json.loads(body))
@defer.inlineCallbacks @defer.inlineCallbacks
def post_json(self, destination, path, data={}): def post_json(self, destination, path, data={}, long_retries=True):
""" Sends the specifed json data using POST """ Sends the specifed json data using POST
Args: Args:
@ -296,6 +311,8 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path. path (str): The HTTP path.
data (dict): A dict containing the data that will be used as data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON. the request body. This will be encoded as JSON.
long_retries (bool): A boolean that indicates whether we should
retry for a short or long time.
Returns: Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result Deferred: Succeeds when we get a 2xx HTTP response. The result
@ -315,6 +332,7 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"), path.encode("ascii"),
body_callback=body_callback, body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]}, headers_dict={"Content-Type": ["application/json"]},
long_retries=True,
) )
if 200 <= response.code < 300: if 200 <= response.code < 300:
@ -490,6 +508,9 @@ class _JsonProducer(object):
def stopProducing(self): def stopProducing(self):
pass pass
def resumeProducing(self):
pass
def _flatten_response_never_received(e): def _flatten_response_never_received(e):
if hasattr(e, "reasons"): if hasattr(e, "reasons"):

View File

@ -53,6 +53,23 @@ response_timer = metrics.register_distribution(
labels=["method", "servlet"] labels=["method", "servlet"]
) )
response_ru_utime = metrics.register_distribution(
"response_ru_utime", labels=["method", "servlet"]
)
response_ru_stime = metrics.register_distribution(
"response_ru_stime", labels=["method", "servlet"]
)
response_db_txn_count = metrics.register_distribution(
"response_db_txn_count", labels=["method", "servlet"]
)
response_db_txn_duration = metrics.register_distribution(
"response_db_txn_duration", labels=["method", "servlet"]
)
_next_request_id = 0 _next_request_id = 0
@ -120,7 +137,7 @@ class HttpServer(object):
""" Interface for registering callbacks on a HTTP server """ Interface for registering callbacks on a HTTP server
""" """
def register_path(self, method, path_pattern, callback): def register_paths(self, method, path_patterns, callback):
""" Register a callback that gets fired if we receive a http request """ Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex. with the given method for a path that matches the given regex.
@ -129,7 +146,7 @@ class HttpServer(object):
Args: Args:
method (str): The method to listen to. method (str): The method to listen to.
path_pattern (str): The regex used to match requests. path_patterns (list<SRE_Pattern>): The regex used to match requests.
callback (function): The function to fire if we receive a matched callback (function): The function to fire if we receive a matched
request. The first argument will be the request object and request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex. subsequent arguments will be any matched groups from the regex.
@ -165,7 +182,8 @@ class JsonResource(HttpServer, resource.Resource):
self.version_string = hs.version_string self.version_string = hs.version_string
self.hs = hs self.hs = hs
def register_path(self, method, path_pattern, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns:
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback) self._PathEntry(path_pattern, callback)
) )
@ -220,6 +238,21 @@ class JsonResource(HttpServer, resource.Resource):
self.clock.time_msec() - start, request.method, servlet_classname self.clock.time_msec() - start, request.method, servlet_classname
) )
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
response_ru_utime.inc_by(ru_utime, request.method, servlet_classname)
response_ru_stime.inc_by(ru_stime, request.method, servlet_classname)
response_db_txn_count.inc_by(
context.db_txn_count, request.method, servlet_classname
)
response_db_txn_duration.inc_by(
context.db_txn_duration, request.method, servlet_classname
)
except:
pass
return return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.

View File

@ -19,7 +19,6 @@ from synapse.api.errors import SynapseError
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -102,12 +101,13 @@ class RestServlet(object):
def register(self, http_server): def register(self, http_server):
""" Register this servlet with the given HTTP server. """ """ Register this servlet with the given HTTP server. """
if hasattr(self, "PATTERN"): if hasattr(self, "PATTERNS"):
pattern = self.PATTERN patterns = self.PATTERNS
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method,)): if hasattr(self, "on_%s" % (method,)):
method_handler = getattr(self, "on_%s" % (method,)) method_handler = getattr(self, "on_%s" % (method,))
http_server.register_path(method, pattern, method_handler) http_server.register_paths(method, patterns, method_handler)
else: else:
raise NotImplementedError("RestServlet must register something.") raise NotImplementedError("RestServlet must register something.")

View File

@ -16,14 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID from synapse.types import StreamToken
import synapse.util.async import synapse.util.async
import baserules import push_rule_evaluator as push_rule_evaluator
import logging import logging
import simplejson as json
import re
import random import random
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,9 +31,6 @@ class Pusher(object):
INITIAL_BACKOFF = 1000 INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000 MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000 GIVE_UP_AFTER = 24 * 60 * 60 * 1000
DEFAULT_ACTIONS = ['dont_notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, _hs, profile_tag, user_name, app_id, def __init__(self, _hs, profile_tag, user_name, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts, app_display_name, device_display_name, pushkey, pushkey_ts,
@ -62,161 +57,6 @@ class Pusher(object):
self.last_last_active_time = 0 self.last_last_active_time = 0
self.has_unread = True self.has_unread = True
@defer.inlineCallbacks
def _actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify'])
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
rules = []
for rawrule in rawrules:
rule = dict(rawrule)
rule['conditions'] = json.loads(rawrule['conditions'])
rule['actions'] = json.loads(rawrule['actions'])
rules.append(rule)
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rules, user)
room_id = ev['room_id']
# get *our* member event for display name matching
my_display_name = None
our_member_event = yield self.store.get_current_state(
room_id=room_id,
event_type='m.room.member',
state_key=self.user_name,
)
if our_member_event:
my_display_name = our_member_event[0].content.get("displayname")
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
for r in rules:
if r['rule_id'] in enabled_map:
r['enabled'] = enabled_map[r['rule_id']]
elif 'enabled' not in r:
r['enabled'] = True
if not r['enabled']:
continue
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_name
)
continue
if matches:
logger.info(
"%s matches for user %s, event %s",
r['rule_id'], self.user_name, ev['event_id']
)
defer.returnValue(actions)
logger.info(
"No rules match for user %s, event %s",
self.user_name, ev['event_id']
)
defer.returnValue(Pusher.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges.
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r)
return r
def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % self._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % self._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == self.profile_tag
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = Pusher.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else:
return True
@defer.inlineCallbacks @defer.inlineCallbacks
def get_context_for_event(self, ev): def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases( name_aliases = yield self.store.get_room_name_and_aliases(
@ -308,8 +148,14 @@ class Pusher(object):
return return
processed = False processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions) rule_evaluator = yield \
push_rule_evaluator.evaluator_for_user_name_and_profile_tag(
self.user_name, self.profile_tag, single_event['room_id'], self.store
)
actions = yield rule_evaluator.actions_for_event(single_event)
tweaks = rule_evaluator.tweaks_for_actions(actions)
if len(actions) == 0: if len(actions) == 0:
logger.warn("Empty actions! Using default action.") logger.warn("Empty actions! Using default action.")
@ -448,27 +294,6 @@ class Pusher(object):
self.has_unread = False self.has_unread = False
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val
def _tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
class PusherConfigException(Exception): class PusherConfigException(Exception):
def __init__(self, msg): def __init__(self, msg):
super(PusherConfigException, self).__init__(msg) super(PusherConfigException, self).__init__(msg)

View File

@ -247,6 +247,7 @@ def make_base_append_underride_rules(user):
}, },
{ {
'rule_id': 'global/underride/.m.rule.message', 'rule_id': 'global/underride/.m.rule.message',
'enabled': False,
'conditions': [ 'conditions': [
{ {
'kind': 'event_match', 'kind': 'event_match',

View File

@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
from synapse.push import Pusher, PusherConfigException from synapse.push import Pusher, PusherConfigException
from synapse.http.client import SimpleHttpClient
from twisted.internet import defer from twisted.internet import defer
@ -46,7 +45,7 @@ class HttpPusher(Pusher):
"'url' required in data for HTTP pusher" "'url' required in data for HTTP pusher"
) )
self.url = data['url'] self.url = data['url']
self.httpCli = SimpleHttpClient(self.hs) self.http_client = _hs.get_simple_http_client()
self.data_minus_url = {} self.data_minus_url = {}
self.data_minus_url.update(self.data) self.data_minus_url.update(self.data)
del self.data_minus_url['url'] del self.data_minus_url['url']
@ -107,7 +106,7 @@ class HttpPusher(Pusher):
if not notification_dict: if not notification_dict:
defer.returnValue([]) defer.returnValue([])
try: try:
resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
except: except:
logger.warn("Failed to push %s ", self.url) logger.warn("Failed to push %s ", self.url)
defer.returnValue(False) defer.returnValue(False)
@ -138,7 +137,7 @@ class HttpPusher(Pusher):
} }
} }
try: try:
resp = yield self.httpCli.post_json_get_json(self.url, d) resp = yield self.http_client.post_json_get_json(self.url, d)
except: except:
logger.exception("Failed to push %s ", self.url) logger.exception("Failed to push %s ", self.url)
defer.returnValue(False) defer.returnValue(False)

View File

@ -0,0 +1,224 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import UserID
import baserules
import logging
import simplejson as json
import re
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def evaluator_for_user_name_and_profile_tag(user_name, profile_tag, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_name)
enabled_map = yield store.get_push_rules_enabled_for_user(user_name)
our_member_event = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
state_key=user_name,
)
defer.returnValue(PushRuleEvaluator(
user_name, profile_tag, rawrules, enabled_map,
room_id, our_member_event, store
))
class PushRuleEvaluator:
DEFAULT_ACTIONS = ['dont_notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
def __init__(self, user_name, profile_tag, raw_rules, enabled_map, room_id,
our_member_event, store):
self.user_name = user_name
self.profile_tag = profile_tag
self.room_id = room_id
self.our_member_event = our_member_event
self.store = store
rules = []
for raw_rule in raw_rules:
rule = dict(raw_rule)
rule['conditions'] = json.loads(raw_rule['conditions'])
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
user = UserID.from_string(self.user_name)
self.rules = baserules.list_with_base_rules(rules, user)
self.enabled_map = enabled_map
@staticmethod
def tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
@defer.inlineCallbacks
def actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_name:
# let's assume you probably know about messages you sent yourself
defer.returnValue(['dont_notify'])
room_id = ev['room_id']
# get *our* member event for display name matching
my_display_name = None
if self.our_member_event:
my_display_name = self.our_member_event[0].content.get("displayname")
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
for r in self.rules:
if r['rule_id'] in self.enabled_map:
r['enabled'] = self.enabled_map[r['rule_id']]
elif 'enabled' not in r:
r['enabled'] = True
if not r['enabled']:
continue
matches = True
conditions = r['conditions']
actions = r['actions']
for c in conditions:
matches &= self._event_fulfills_condition(
ev, c, display_name=my_display_name,
room_member_count=room_member_count
)
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_name
)
continue
if matches:
logger.info(
"%s matches for user %s, event %s",
r['rule_id'], self.user_name, ev['event_id']
)
defer.returnValue(actions)
logger.info(
"No rules match for user %s, event %s",
self.user_name, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
@staticmethod
def _glob_to_regexp(glob):
r = re.escape(glob)
r = re.sub(r'\\\*', r'.*?', r)
r = re.sub(r'\\\?', r'.', r)
# handle [abc], [a-z] and [!a-z] style ranges.
r = re.sub(r'\\\[(\\\!|)(.*)\\\]',
lambda x: ('[%s%s]' % (x.group(1) and '^' or '',
re.sub(r'\\\-', '-', x.group(2)))), r)
return r
def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
if condition['kind'] == 'event_match':
if 'pattern' not in condition:
logger.warn("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
if condition['key'] == 'content.body':
r = r'\b%s\b' % self._glob_to_regexp(condition['pattern'])
else:
r = r'^%s$' % self._glob_to_regexp(condition['pattern'])
val = _value_for_dotted_key(condition['key'], ev)
if val is None:
return False
return re.search(r, val, flags=re.IGNORECASE) is not None
elif condition['kind'] == 'device':
if 'profile_tag' not in condition:
return True
return condition['profile_tag'] == self.profile_tag
elif condition['kind'] == 'contains_display_name':
# This is special because display names can be different
# between rooms and so you can't really hard code it in a rule.
# Optimisation: we should cache these names and update them from
# the event stream.
if 'content' not in ev or 'body' not in ev['content']:
return False
if not display_name:
return False
return re.search(
r"\b%s\b" % re.escape(display_name), ev['content']['body'],
flags=re.IGNORECASE
) is not None
elif condition['kind'] == 'room_member_count':
if 'is' not in condition:
return False
m = PushRuleEvaluator.INEQUALITY_EXPR.match(condition['is'])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
rhs = int(rhs)
if ineq == '' or ineq == '==':
return room_member_count == rhs
elif ineq == '<':
return room_member_count < rhs
elif ineq == '>':
return room_member_count > rhs
elif ineq == '>=':
return room_member_count >= rhs
elif ineq == '<=':
return room_member_count <= rhs
else:
return False
else:
return True
def _value_for_dotted_key(dotted_key, event):
parts = dotted_key.split(".")
val = event
while len(parts) > 0:
if parts[0] not in val:
return None
val = val[parts[0]]
parts = parts[1:]
return val

View File

@ -1,5 +1,5 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd # Copyright 2014, 2015 OpenMarket Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,3 +12,69 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.rest.client.v1 import (
room,
events,
profile,
presence,
initial_sync,
directory,
voip,
admin,
pusher,
push_rule,
register as v1_register,
login as v1_login,
)
from synapse.rest.client.v2_alpha import (
sync,
filter,
account,
register,
auth,
receipts,
keys,
tokenrefresh,
tags,
account_data,
)
from synapse.http.server import JsonResource
class ClientRestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
# "v1"
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
v1_register.register_servlets(hs, client_resource)
v1_login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)
# "v2"
sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource)

View File

@ -12,33 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
voip, admin, pusher, push_rule
)
from synapse.http.server import JsonResource
class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
room.register_servlets(hs, client_resource)
events.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
login.register_servlets(hs, client_resource)
profile.register_servlets(hs, client_resource)
presence.register_servlets(hs, client_resource)
initial_sync.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import logging import logging
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_path_pattern(path_regex): def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -37,7 +37,14 @@ def client_path_pattern(path_regex):
Returns: Returns:
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_PREFIX + path_regex) patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
if include_in_unstable:
unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
class ClientV1RestServlet(RestServlet): class ClientV1RestServlet(RestServlet):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError, Codes from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.types import RoomAlias from synapse.types import RoomAlias
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -32,7 +32,7 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet): class ClientDirectoryServer(ClientV1RestServlet):
PATTERN = client_path_pattern("/directory/room/(?P<room_alias>[^/]*)$") PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_alias): def on_GET(self, request, room_alias):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
import logging import logging
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class EventStreamRestServlet(ClientV1RestServlet): class EventStreamRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events$") PATTERNS = client_path_patterns("/events$")
DEFAULT_LONGPOLL_TIME_MS = 30000 DEFAULT_LONGPOLL_TIME_MS = 30000
@ -72,7 +72,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events. # TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(ClientV1RestServlet): class EventRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/events/(?P<event_id>[^/]*)$") PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs):
super(EventRestServlet, self).__init__(hs) super(EventRestServlet, self).__init__(hs)

View File

@ -16,12 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
# TODO: Needs unit testing # TODO: Needs unit testing
class InitialSyncRestServlet(ClientV1RestServlet): class InitialSyncRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/initialSync$") PATTERNS = client_path_patterns("/initialSync$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -16,12 +16,12 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import SynapseError, LoginError, Codes from synapse.api.errors import SynapseError, LoginError, Codes
from synapse.http.client import SimpleHttpClient
from synapse.types import UserID from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import urllib import urllib
import urlparse
import logging import logging
from saml2 import BINDING_HTTP_POST from saml2 import BINDING_HTTP_POST
@ -35,10 +35,11 @@ logger = logging.getLogger(__name__)
class LoginRestServlet(ClientV1RestServlet): class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$") PATTERNS = client_path_patterns("/login$", releases=(), include_in_unstable=False)
PASS_TYPE = "m.login.password" PASS_TYPE = "m.login.password"
SAML2_TYPE = "m.login.saml2" SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas" CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
def __init__(self, hs): def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs) super(LoginRestServlet, self).__init__(hs)
@ -49,6 +50,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes self.cas_required_attributes = hs.config.cas_required_attributes
self.servername = hs.config.server_name self.servername = hs.config.server_name
self.http_client = hs.get_simple_http_client()
def on_GET(self, request): def on_GET(self, request):
flows = [] flows = []
@ -56,8 +58,18 @@ class LoginRestServlet(ClientV1RestServlet):
flows.append({"type": LoginRestServlet.SAML2_TYPE}) flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled: if self.cas_enabled:
flows.append({"type": LoginRestServlet.CAS_TYPE}) flows.append({"type": LoginRestServlet.CAS_TYPE})
# While its valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# CAS login flow.
# Generally we don't want to advertise login flows that clients
# don't know how to implement, since they (currently) will always
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.password_enabled: if self.password_enabled:
flows.append({"type": LoginRestServlet.PASS_TYPE}) flows.append({"type": LoginRestServlet.PASS_TYPE})
return (200, {"flows": flows}) return (200, {"flows": flows})
def on_OPTIONS(self, request): def on_OPTIONS(self, request):
@ -83,19 +95,20 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state) "uri": "%s%s" % (self.idp_redirect_url, relay_state)
} }
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] == elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE): LoginRestServlet.CAS_TYPE):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
uri = "%s/proxyValidate" % (self.cas_server_url,) uri = "%s/proxyValidate" % (self.cas_server_url,)
args = { args = {
"ticket": login_submission["ticket"], "ticket": login_submission["ticket"],
"service": login_submission["service"] "service": login_submission["service"]
} }
body = yield http_client.get_raw(uri, args) body = yield self.http_client.get_raw(uri, args)
result = yield self.do_cas_login(body) result = yield self.do_cas_login(body)
defer.returnValue(result) defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
else: else:
raise SynapseError(400, "Bad login type.") raise SynapseError(400, "Bad login type.")
except KeyError: except KeyError:
@ -131,6 +144,26 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
@defer.inlineCallbacks
def do_token_login(self, login_submission):
token = login_submission['token']
auth_handler = self.handlers.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
@defer.inlineCallbacks @defer.inlineCallbacks
def do_cas_login(self, cas_response_body): def do_cas_login(self, cas_response_body):
user, attributes = self.parse_cas_response(cas_response_body) user, attributes = self.parse_cas_response(cas_response_body)
@ -152,7 +185,7 @@ class LoginRestServlet(ClientV1RestServlet):
user_exists = yield auth_handler.does_user_exist(user_id) user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists: if user_exists:
user_id, access_token, refresh_token = ( user_id, access_token, refresh_token = (
yield auth_handler.login_with_cas_user_id(user_id) yield auth_handler.get_login_tuple_for_user_id(user_id)
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
@ -173,6 +206,7 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result)) defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body): def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body) root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"): if not root.tag.endswith("serviceResponse"):
@ -201,7 +235,7 @@ class LoginRestServlet(ClientV1RestServlet):
class SAML2RestServlet(ClientV1RestServlet): class SAML2RestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/saml2") PATTERNS = client_path_patterns("/login/saml2", releases=())
def __init__(self, hs): def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs) super(SAML2RestServlet, self).__init__(hs)
@ -243,8 +277,9 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"})) defer.returnValue((200, {"status": "not_authenticated"}))
# TODO Delete this after all CAS clients switch to token login instead
class CasRestServlet(ClientV1RestServlet): class CasRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login/cas") PATTERNS = client_path_patterns("/login/cas", releases=())
def __init__(self, hs): def __init__(self, hs):
super(CasRestServlet, self).__init__(hs) super(CasRestServlet, self).__init__(hs)
@ -254,6 +289,115 @@ class CasRestServlet(ClientV1RestServlet):
return (200, {"serverUrl": self.cas_server_url}) return (200, {"serverUrl": self.cas_server_url})
class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
def __init__(self, hs):
super(CasRedirectServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
def on_GET(self, request):
args = request.args
if "redirectUrl" not in args:
return (400, "Redirect URL not specified for CAS auth")
client_redirect_url_param = urllib.urlencode({
"redirectUrl": args["redirectUrl"][0]
})
hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
service_param = urllib.urlencode({
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
})
request.redirect("%s?%s" % (self.cas_server_url, service_param))
request.finish()
class CasTicketServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs)
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = request.args["redirectUrl"][0]
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": request.args["ticket"],
"service": self.cas_service_url
}
body = yield http_client.get_raw(uri, args)
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
@defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
if required_attribute not in attributes:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
# Also need to check value
if required_value is not None:
actual_value = attributes[required_attribute]
# If required attribute value does not match expected - Forbidden
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if not user_exists:
user_id, _ = (
yield self.handlers.registration_handler.register(localpart=user)
)
login_token = auth_handler.generate_short_term_login_token(user_id)
redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
login_token)
request.redirect(redirect_url)
request.finish()
def add_login_token_to_redirect_url(self, url, token):
url_parts = list(urlparse.urlparse(url))
query = dict(urlparse.parse_qsl(url_parts[4]))
query.update({"loginToken": token})
url_parts[4] = urllib.urlencode(query)
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not root[0].tag.endswith("authenticationSuccess"):
raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
if child.tag.endswith("attributes"):
attributes = {}
for attribute in child:
# ElementTree library expands the namespace in attribute tags
# to the full URL of the namespace.
# See (https://docs.python.org/2/library/xml.etree.elementtree.html)
# We don't care about namespace here and it will always be encased in
# curly braces, so we remove them.
if "}" in attribute.tag:
attributes[attribute.tag.split("}")[1]] = attribute.text
else:
attributes[attribute.tag] = attribute.text
if user is None or attributes is None:
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
return (user, attributes)
def _parse_json(request): def _parse_json(request):
try: try:
content = json.loads(request.content.read()) content = json.loads(request.content.read())
@ -269,5 +413,7 @@ def register_servlets(hs, http_server):
if hs.config.saml2_enabled: if hs.config.saml2_enabled:
SAML2RestServlet(hs).register(http_server) SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled: if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
CasRestServlet(hs).register(http_server) CasRestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server) # TODO PasswordResetRestServlet(hs).register(http_server)

View File

@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import UserID from synapse.types import UserID
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class PresenceStatusRestServlet(ClientV1RestServlet): class PresenceStatusRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/presence/(?P<user_id>[^/]*)/status") PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -73,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
class PresenceListRestServlet(ClientV1RestServlet): class PresenceListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/presence/list/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -120,7 +120,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if len(u) == 0: if len(u) == 0:
continue continue
invited_user = UserID.from_string(u) invited_user = UserID.from_string(u)
yield self.handlers.presence_handler.send_invite( yield self.handlers.presence_handler.send_presence_invite(
observer_user=user, observed_user=invited_user observer_user=user, observed_user=invited_user
) )

View File

@ -16,14 +16,14 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """ """ This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer from twisted.internet import defer
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
from synapse.types import UserID from synapse.types import UserID
import simplejson as json import simplejson as json
class ProfileDisplaynameRestServlet(ClientV1RestServlet): class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/displayname") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -56,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet): class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)/avatar_url") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):
@ -89,7 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet): class ProfileRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/profile/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id): def on_GET(self, request, user_id):

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
) )
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
from synapse.storage.push_rule import ( from synapse.storage.push_rule import (
InconsistentRuleException, RuleNotFoundException InconsistentRuleException, RuleNotFoundException
) )
@ -31,7 +31,7 @@ import simplejson as json
class PushRuleRestServlet(ClientV1RestServlet): class PushRuleRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushrules/.*$") PATTERNS = client_path_patterns("/pushrules/.*$")
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash") "Unrecognised request: You probably wanted a trailing slash")
@ -207,7 +207,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
def set_rule_attr(self, user_name, spec, val): def set_rule_attr(self, user_name, spec, val):
if spec['attr'] == 'enabled': if spec['attr'] == 'enabled':
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool): if not isinstance(val, bool):
# Legacy fallback
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean") raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec) namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
self.hs.get_datastore().set_push_rule_enabled( self.hs.get_datastore().set_push_rule_enabled(

View File

@ -17,13 +17,16 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.push import PusherConfigException from synapse.push import PusherConfigException
from .base import ClientV1RestServlet, client_path_pattern from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json import simplejson as json
import logging
logger = logging.getLogger(__name__)
class PusherRestServlet(ClientV1RestServlet): class PusherRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/pushers/set$") PATTERNS = client_path_patterns("/pushers/set$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request): def on_POST(self, request):
@ -51,6 +54,9 @@ class PusherRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing parameters: "+','.join(missing), raise SynapseError(400, "Missing parameters: "+','.join(missing),
errcode=Codes.MISSING_PARAM) errcode=Codes.MISSING_PARAM)
logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
logger.debug("Got pushers request with body: %r", content)
append = False append = False
if 'append' in content: if 'append' in content:
append = content['append'] append = content['append']

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import synapse.util.stringutils as stringutils import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
@ -48,7 +48,7 @@ class RegisterRestServlet(ClientV1RestServlet):
handler doesn't have a concept of multi-stages or sessions. handler doesn't have a concept of multi-stages or sessions.
""" """
PATTERN = client_path_pattern("/register$") PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False)
def __init__(self, hs): def __init__(self, hs):
super(RegisterRestServlet, self).__init__(hs) super(RegisterRestServlet, self).__init__(hs)

View File

@ -16,7 +16,7 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """ """ This module contains REST servlets to do with rooms: /rooms/<paths> """
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
from synapse.api.errors import SynapseError, Codes, AuthError from synapse.api.errors import SynapseError, Codes, AuthError
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -34,15 +34,15 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here # No PATTERN; we have custom dispatch rules here
def register(self, http_server): def register(self, http_server):
PATTERN = "/createRoom" PATTERNS = "/createRoom"
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity # define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_path("OPTIONS", http_server.register_paths("OPTIONS",
client_path_pattern("/rooms(?:/.*)?$"), client_path_patterns("/rooms(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
# define CORS for /createRoom[/txnid] # define CORS for /createRoom[/txnid]
http_server.register_path("OPTIONS", http_server.register_paths("OPTIONS",
client_path_pattern("/createRoom(?:/.*)?$"), client_path_patterns("/createRoom(?:/.*)?$"),
self.on_OPTIONS) self.on_OPTIONS)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -103,17 +103,17 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
state_key = ("/rooms/(?P<room_id>[^/]*)/state/" state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_path("GET", http_server.register_paths("GET",
client_path_pattern(state_key), client_path_patterns(state_key),
self.on_GET) self.on_GET)
http_server.register_path("PUT", http_server.register_paths("PUT",
client_path_pattern(state_key), client_path_patterns(state_key),
self.on_PUT) self.on_PUT)
http_server.register_path("GET", http_server.register_paths("GET",
client_path_pattern(no_state_key), client_path_patterns(no_state_key),
self.on_GET_no_state_key) self.on_GET_no_state_key)
http_server.register_path("PUT", http_server.register_paths("PUT",
client_path_pattern(no_state_key), client_path_patterns(no_state_key),
self.on_PUT_no_state_key) self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type): def on_GET_no_state_key(self, request, room_id, event_type):
@ -170,8 +170,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
register_txn_path(self, PATTERN, http_server, with_get=True) register_txn_path(self, PATTERNS, http_server, with_get=True)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_type, txn_id=None): def on_POST(self, request, room_id, event_type, txn_id=None):
@ -215,8 +215,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /join/$room_identifier[/$txn_id] # /join/$room_identifier[/$txn_id]
PATTERN = ("/join/(?P<room_identifier>[^/]*)") PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None): def on_POST(self, request, room_identifier, txn_id=None):
@ -280,7 +280,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class PublicRoomListRestServlet(ClientV1RestServlet): class PublicRoomListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/publicRooms$") PATTERNS = client_path_patterns("/publicRooms$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
@ -291,7 +291,7 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomMemberListRestServlet(ClientV1RestServlet): class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/members$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -328,7 +328,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing # TODO: Needs better unit testing
class RoomMessageListRestServlet(ClientV1RestServlet): class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/messages$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -351,7 +351,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomStateRestServlet(ClientV1RestServlet): class RoomStateRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/state$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -368,7 +368,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing # TODO: Needs unit testing
class RoomInitialSyncRestServlet(ClientV1RestServlet): class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/initialSync$") PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
@ -383,32 +383,8 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content)) defer.returnValue((200, content))
class RoomTriggerBackfill(ClientV1RestServlet):
PATTERN = client_path_pattern("/rooms/(?P<room_id>[^/]*)/backfill$")
def __init__(self, hs):
super(RoomTriggerBackfill, self).__init__(hs)
self.clock = hs.get_clock()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
remote_server = urllib.unquote(
request.args["remote"][0]
).decode("UTF-8")
limit = int(request.args["limit"][0])
handler = self.handlers.federation_handler
events = yield handler.backfill(remote_server, room_id, limit)
time_now = self.clock.time_msec()
res = [serialize_event(event, time_now) for event in events]
defer.returnValue((200, res))
class RoomEventContext(ClientV1RestServlet): class RoomEventContext(ClientV1RestServlet):
PATTERN = client_path_pattern( PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
) )
@ -447,9 +423,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/[invite|join|leave] # /rooms/$roomid/[invite|join|leave]
PATTERN = ("/rooms/(?P<room_id>[^/]*)/" PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
"(?P<membership_action>join|invite|leave|ban|kick)") "(?P<membership_action>join|invite|leave|ban|kick|forget)")
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None): def on_POST(self, request, room_id, membership_action, txn_id=None):
@ -458,6 +434,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
allow_guest=True allow_guest=True
) )
effective_membership_action = membership_action
if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}: if is_guest and membership_action not in {Membership.JOIN, Membership.LEAVE}:
raise AuthError(403, "Guest access not allowed") raise AuthError(403, "Guest access not allowed")
@ -488,11 +466,13 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
UserID.from_string(state_key) UserID.from_string(state_key)
if membership_action == "kick": if membership_action == "kick":
membership_action = "leave" effective_membership_action = "leave"
elif membership_action == "forget":
effective_membership_action = "leave"
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
content = {"membership": unicode(membership_action)} content = {"membership": unicode(effective_membership_action)}
if is_guest: if is_guest:
content["kind"] = "guest" content["kind"] = "guest"
@ -509,6 +489,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
is_guest=is_guest, is_guest=is_guest,
) )
if membership_action == "forget":
self.handlers.room_member_handler.forget(user, room_id)
defer.returnValue((200, {})) defer.returnValue((200, {}))
def _has_3pid_invite_keys(self, content): def _has_3pid_invite_keys(self, content):
@ -536,8 +519,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet): class RoomRedactEventRestServlet(ClientV1RestServlet):
def register(self, http_server): def register(self, http_server):
PATTERN = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERN, http_server) register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id, event_id, txn_id=None): def on_POST(self, request, room_id, event_id, txn_id=None):
@ -575,7 +558,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
class RoomTypingRestServlet(ClientV1RestServlet): class RoomTypingRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern( PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
) )
@ -608,7 +591,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
class SearchRestServlet(ClientV1RestServlet): class SearchRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern( PATTERNS = client_path_patterns(
"/search$" "/search$"
) )
@ -648,20 +631,20 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
http_server : The http_server to register paths with. http_server : The http_server to register paths with.
with_get: True to also register respective GET paths for the PUTs. with_get: True to also register respective GET paths for the PUTs.
""" """
http_server.register_path( http_server.register_paths(
"POST", "POST",
client_path_pattern(regex_string + "$"), client_path_patterns(regex_string + "$"),
servlet.on_POST servlet.on_POST
) )
http_server.register_path( http_server.register_paths(
"PUT", "PUT",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_PUT servlet.on_PUT
) )
if with_get: if with_get:
http_server.register_path( http_server.register_paths(
"GET", "GET",
client_path_pattern(regex_string + "/(?P<txn_id>[^/]*)$"), client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
servlet.on_GET servlet.on_GET
) )
@ -672,7 +655,6 @@ def register_servlets(hs, http_server):
RoomMemberListRestServlet(hs).register(http_server) RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server) RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server)
RoomTriggerBackfill(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server) RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server) RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server) PublicRoomListRestServlet(hs).register(http_server)

View File

@ -15,7 +15,7 @@
from twisted.internet import defer from twisted.internet import defer
from base import ClientV1RestServlet, client_path_pattern from base import ClientV1RestServlet, client_path_patterns
import hmac import hmac
@ -24,7 +24,7 @@ import base64
class VoipRestServlet(ClientV1RestServlet): class VoipRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/voip/turnServer$") PATTERNS = client_path_patterns("/voip/turnServer$")
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):

View File

@ -12,37 +12,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from . import (
sync,
filter,
account,
register,
auth,
receipts,
keys,
tokenrefresh,
tags,
)
from synapse.http.server import JsonResource
class ClientV2AlphaRestResource(JsonResource):
"""A resource for version 2 alpha of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
sync.register_servlets(hs, client_resource)
filter.register_servlets(hs, client_resource)
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)
receipts.register_servlets(hs, client_resource)
keys.register_servlets(hs, client_resource)
tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource)

View File

@ -27,7 +27,7 @@ import simplejson
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def client_v2_pattern(path_regex): def client_v2_patterns(path_regex, releases=(0,)):
"""Creates a regex compiled client path with the correct client path """Creates a regex compiled client path with the correct client path
prefix. prefix.
@ -37,7 +37,13 @@ def client_v2_pattern(path_regex):
Returns: Returns:
SRE_Pattern SRE_Pattern
""" """
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex) patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)]
unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
def parse_request_allow_empty(request): def parse_request_allow_empty(request):

View File

@ -20,7 +20,7 @@ from synapse.api.errors import LoginError, SynapseError, Codes
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.util.async import run_on_reactor from synapse.util.async import run_on_reactor
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_patterns, parse_json_dict_from_request
import logging import logging
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class PasswordRestServlet(RestServlet): class PasswordRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/password") PATTERNS = client_v2_patterns("/account/password")
def __init__(self, hs): def __init__(self, hs):
super(PasswordRestServlet, self).__init__() super(PasswordRestServlet, self).__init__()
@ -89,7 +89,7 @@ class PasswordRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet): class ThreepidRestServlet(RestServlet):
PATTERN = client_v2_pattern("/account/3pid") PATTERNS = client_v2_patterns("/account/3pid")
def __init__(self, hs): def __init__(self, hs):
super(ThreepidRestServlet, self).__init__() super(ThreepidRestServlet, self).__init__()

View File

@ -0,0 +1,111 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError
from twisted.internet import defer
import logging
import simplejson as json
logger = logging.getLogger(__name__)
class AccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
def __init__(self, hs):
super(AccountDataServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
max_id = yield self.store.add_account_data_for_user(
user_id, account_data_type, body
)
yield self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
class RoomAccountDataServlet(RestServlet):
"""
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"
)
def __init__(self, hs):
super(RoomAccountDataServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, room_id, account_data_type):
auth_user, _, _ = yield self.auth.get_user_by_req(request)
if user_id != auth_user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
try:
content_bytes = request.content.read()
body = json.loads(content_bytes)
except:
raise SynapseError(400, "Invalid JSON")
if not isinstance(body, dict):
raise ValueError("Expected a JSON object")
max_id = yield self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
yield self.notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
AccountDataServlet(hs).register(http_server)
RoomAccountDataServlet(hs).register(http_server)

View File

@ -20,7 +20,7 @@ from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern from ._base import client_v2_patterns
import logging import logging
@ -97,7 +97,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint). cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth. Current use is for web fallback auth.
""" """
PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web") PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs): def __init__(self, hs):
super(AuthRestServlet, self).__init__() super(AuthRestServlet, self).__init__()

View File

@ -19,7 +19,7 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import UserID from synapse.types import UserID
from ._base import client_v2_pattern from ._base import client_v2_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet): class GetFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs): def __init__(self, hs):
super(GetFilterRestServlet, self).__init__() super(GetFilterRestServlet, self).__init__()
@ -65,7 +65,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter") PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs): def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__() super(CreateFilterRestServlet, self).__init__()

View File

@ -21,7 +21,7 @@ from synapse.types import UserID
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from ._base import client_v2_pattern from ._base import client_v2_patterns
import simplejson as json import simplejson as json
import logging import logging
@ -54,7 +54,7 @@ class KeyUploadServlet(RestServlet):
}, },
} }
""" """
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)") PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=())
def __init__(self, hs): def __init__(self, hs):
super(KeyUploadServlet, self).__init__() super(KeyUploadServlet, self).__init__()
@ -154,12 +154,13 @@ class KeyQueryServlet(RestServlet):
} } } } } } } } } } } }
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/keys/query(?:" "/keys/query(?:"
"/(?P<user_id>[^/]*)(?:" "/(?P<user_id>[^/]*)(?:"
"/(?P<device_id>[^/]*)" "/(?P<device_id>[^/]*)"
")?" ")?"
")?" ")?",
releases=()
) )
def __init__(self, hs): def __init__(self, hs):
@ -245,10 +246,11 @@ class OneTimeKeyServlet(RestServlet):
} } } } } } } }
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/keys/claim(?:/?|(?:/" "/keys/claim(?:/?|(?:/"
"(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)" "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
")?)" ")?)",
releases=()
) )
def __init__(self, hs): def __init__(self, hs):

View File

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern from ._base import client_v2_patterns
import logging import logging
@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet): class ReceiptRestServlet(RestServlet):
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)$"

View File

@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_patterns, parse_json_dict_from_request
import logging import logging
import hmac import hmac
@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class RegisterRestServlet(RestServlet): class RegisterRestServlet(RestServlet):
PATTERN = client_v2_pattern("/register") PATTERNS = client_v2_patterns("/register")
def __init__(self, hs): def __init__(self, hs):
super(RegisterRestServlet, self).__init__() super(RegisterRestServlet, self).__init__()

View File

@ -22,14 +22,17 @@ from synapse.handlers.sync import SyncConfig
from synapse.types import StreamToken from synapse.types import StreamToken
from synapse.events import FrozenEvent from synapse.events import FrozenEvent
from synapse.events.utils import ( from synapse.events.utils import (
serialize_event, format_event_for_client_v2_without_event_id, serialize_event, format_event_for_client_v2_without_room_id,
) )
from synapse.api.filtering import FilterCollection from synapse.api.filtering import FilterCollection
from ._base import client_v2_pattern from synapse.api.errors import SynapseError
from ._base import client_v2_patterns
import copy import copy
import logging import logging
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -48,7 +51,7 @@ class SyncRestServlet(RestServlet):
"next_batch": // batch token for the next /sync "next_batch": // batch token for the next /sync
"presence": // presence data for the user. "presence": // presence data for the user.
"rooms": { "rooms": {
"joined": { // Joined rooms being updated. "join": { // Joined rooms being updated.
"${room_id}": { // Id of the room being updated "${room_id}": { // Id of the room being updated
"event_map": // Map of EventID -> event JSON. "event_map": // Map of EventID -> event JSON.
"timeline": { // The recent events in the room if gap is "true" "timeline": { // The recent events in the room if gap is "true"
@ -63,13 +66,13 @@ class SyncRestServlet(RestServlet):
"ephemeral": {"events": []} // list of event objects "ephemeral": {"events": []} // list of event objects
} }
}, },
"invited": {}, // Invited rooms being updated. "invite": {}, // Invited rooms being updated.
"archived": {} // Archived rooms being updated. "leave": {} // Archived rooms being updated.
} }
} }
""" """
PATTERN = client_v2_pattern("/sync$") PATTERNS = client_v2_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline"]) ALLOWED_PRESENCE = set(["online", "offline"])
def __init__(self, hs): def __init__(self, hs):
@ -100,6 +103,15 @@ class SyncRestServlet(RestServlet):
) )
) )
if filter_id and filter_id.startswith('{'):
logging.error("MJH %r", filter_id)
try:
filter_object = json.loads(filter_id)
except:
raise SynapseError(400, "Invalid filter JSON")
self.filtering._check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
else:
try: try:
filter = yield self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user.localpart, filter_id user.localpart, filter_id
@ -144,13 +156,16 @@ class SyncRestServlet(RestServlet):
) )
response_content = { response_content = {
"account_data": self.encode_account_data(
sync_result.account_data, filter, time_now
),
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, filter, time_now sync_result.presence, filter, time_now
), ),
"rooms": { "rooms": {
"joined": joined, "join": joined,
"invited": invited, "invite": invited,
"archived": archived, "leave": archived,
}, },
"next_batch": sync_result.next_batch.to_string(), "next_batch": sync_result.next_batch.to_string(),
} }
@ -165,6 +180,9 @@ class SyncRestServlet(RestServlet):
formatted.append(event) formatted.append(event)
return {"events": filter.filter_presence(formatted)} return {"events": filter.filter_presence(formatted)}
def encode_account_data(self, events, filter, time_now):
return {"events": filter.filter_account_data(events)}
def encode_joined(self, rooms, filter, time_now, token_id): def encode_joined(self, rooms, filter, time_now, token_id):
""" """
Encode the joined rooms in a sync result Encode the joined rooms in a sync result
@ -207,7 +225,7 @@ class SyncRestServlet(RestServlet):
for room in rooms: for room in rooms:
invite = serialize_event( invite = serialize_event(
room.invite, time_now, token_id=token_id, room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id, event_format=format_event_for_client_v2_without_room_id,
) )
invited_state = invite.get("unsigned", {}).pop("invite_room_state", []) invited_state = invite.get("unsigned", {}).pop("invite_room_state", [])
invited_state.append(invite) invited_state.append(invite)
@ -256,7 +274,13 @@ class SyncRestServlet(RestServlet):
:return: the room, encoded in our response format :return: the room, encoded in our response format
:rtype: dict[str, object] :rtype: dict[str, object]
""" """
event_map = {} def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.
return serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
)
state_dict = room.state state_dict = room.state
timeline_events = filter.filter_room_timeline(room.timeline.events) timeline_events = filter.filter_room_timeline(room.timeline.events)
@ -264,37 +288,22 @@ class SyncRestServlet(RestServlet):
state_dict, timeline_events) state_dict, timeline_events)
state_events = filter.filter_room_state(state_dict.values()) state_events = filter.filter_room_state(state_dict.values())
state_event_ids = []
for event in state_events:
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
state_event_ids.append(event.event_id)
timeline_event_ids = [] serialized_state = [serialize(e) for e in state_events]
for event in timeline_events: serialized_timeline = [serialize(e) for e in timeline_events]
# TODO(mjark): Respect formatting requirements in the filter.
event_map[event.event_id] = serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_event_id,
)
timeline_event_ids.append(event.event_id)
private_user_data = filter.filter_room_private_user_data( account_data = filter.filter_room_account_data(
room.private_user_data room.account_data
) )
result = { result = {
"event_map": event_map,
"timeline": { "timeline": {
"events": timeline_event_ids, "events": serialized_timeline,
"prev_batch": room.timeline.prev_batch.to_string(), "prev_batch": room.timeline.prev_batch.to_string(),
"limited": room.timeline.limited, "limited": room.timeline.limited,
}, },
"state": {"events": state_event_ids}, "state": {"events": serialized_state},
"private_user_data": {"events": private_user_data}, "account_data": {"events": account_data},
} }
if joined: if joined:

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import client_v2_pattern from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
@ -31,7 +31,7 @@ class TagListServlet(RestServlet):
""" """
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
) )
@ -56,7 +56,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
""" """
PATTERN = client_v2_pattern( PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
) )
@ -81,7 +81,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body) max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
yield self.notifier.on_new_event( yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -95,7 +95,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
yield self.notifier.on_new_event( yield self.notifier.on_new_event(
"private_user_data_key", max_id, users=[user_id] "account_data_key", max_id, users=[user_id]
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View File

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, StoreError, SynapseError from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from ._base import client_v2_pattern, parse_json_dict_from_request from ._base import client_v2_patterns, parse_json_dict_from_request
class TokenRefreshRestServlet(RestServlet): class TokenRefreshRestServlet(RestServlet):
@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh Exchanges refresh tokens for a pair of an access token and a new refresh
token. token.
""" """
PATTERN = client_v2_pattern("/tokenrefresh") PATTERNS = client_v2_patterns("/tokenrefresh")
def __init__(self, hs): def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__() super(TokenRefreshRestServlet, self).__init__()

View File

@ -71,8 +71,7 @@ class BaseHomeServer(object):
'state_handler', 'state_handler',
'notifier', 'notifier',
'distributor', 'distributor',
'resource_for_client', 'client_resource',
'resource_for_client_v2_alpha',
'resource_for_federation', 'resource_for_federation',
'resource_for_static_content', 'resource_for_static_content',
'resource_for_web_client', 'resource_for_web_client',

View File

@ -17,12 +17,11 @@ var submitPassword = function(user, pwd) {
}).error(errorFunc); }).error(errorFunc);
}; };
var submitCas = function(ticket, service) { var submitToken = function(loginToken) {
console.log("Logging in with cas..."); console.log("Logging in with login token...");
var data = { var data = {
type: "m.login.cas", type: "m.login.token",
ticket: ticket, token: loginToken
service: service,
}; };
$.post(matrixLogin.endpoint, JSON.stringify(data), function(response) { $.post(matrixLogin.endpoint, JSON.stringify(data), function(response) {
show_login(); show_login();
@ -41,23 +40,10 @@ var errorFunc = function(err) {
} }
}; };
var getCasURL = function(cb) {
$.get(matrixLogin.endpoint + "/cas", function(response) {
var cas_url = response.serverUrl;
cb(cas_url);
}).error(errorFunc);
};
var gotoCas = function() { var gotoCas = function() {
getCasURL(function(cas_url) {
var this_page = window.location.origin + window.location.pathname; var this_page = window.location.origin + window.location.pathname;
var redirect_url = matrixLogin.endpoint + "/cas/redirect?redirectUrl=" + encodeURIComponent(this_page);
var redirect_url = cas_url + "/login?service=" + encodeURIComponent(this_page);
window.location.replace(redirect_url); window.location.replace(redirect_url);
});
} }
var setFeedbackString = function(text) { var setFeedbackString = function(text) {
@ -111,7 +97,7 @@ var fetch_info = function(cb) {
matrixLogin.onLoad = function() { matrixLogin.onLoad = function() {
fetch_info(function() { fetch_info(function() {
if (!try_cas()) { if (!try_token()) {
show_login(); show_login();
} }
}); });
@ -148,20 +134,20 @@ var parseQsFromUrl = function(query) {
return result; return result;
}; };
var try_cas = function() { var try_token = function() {
var pos = window.location.href.indexOf("?"); var pos = window.location.href.indexOf("?");
if (pos == -1) { if (pos == -1) {
return false; return false;
} }
var qs = parseQsFromUrl(window.location.href.substr(pos+1)); var qs = parseQsFromUrl(window.location.href.substr(pos+1));
var ticket = qs.ticket; var loginToken = qs.loginToken;
if (!ticket) { if (!loginToken) {
return false; return false;
} }
submitCas(ticket, location.origin); submitToken(loginToken);
return true; return true;
}; };

View File

@ -42,6 +42,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .search import SearchStore from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore
import logging import logging
@ -73,6 +74,7 @@ class DataStore(RoomMemberStore, RoomStore,
EndToEndKeyStore, EndToEndKeyStore,
SearchStore, SearchStore,
TagsStore, TagsStore,
AccountDataStore,
): ):
def __init__(self, hs): def __init__(self, hs):

View File

@ -214,7 +214,8 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000) self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, func, *args, **kwargs): def _new_transaction(self, conn, desc, after_callbacks, logging_context,
func, *args, **kwargs):
start = time.time() * 1000 start = time.time() * 1000
txn_id = self._TXN_ID txn_id = self._TXN_ID
@ -277,6 +278,9 @@ class SQLBaseStore(object):
end = time.time() * 1000 end = time.time() * 1000
duration = end - start duration = end - start
if logging_context is not None:
logging_context.add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f", name, duration) transaction_logger.debug("[TXN END] {%s} %f", name, duration)
self._current_txn_total_time += duration self._current_txn_total_time += duration
@ -302,7 +306,8 @@ class SQLBaseStore(object):
current_context.copy_to(context) current_context.copy_to(context)
return self._new_transaction( return self._new_transaction(
conn, desc, after_callbacks, func, *args, **kwargs conn, desc, after_callbacks, current_context,
func, *args, **kwargs
) )
result = yield preserve_context_over_fn( result = yield preserve_context_over_fn(

View File

@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from twisted.internet import defer
import ujson as json
import logging
logger = logging.getLogger(__name__)
class AccountDataStore(SQLBaseStore):
def get_account_data_for_user(self, user_id):
"""Get all the client account_data for a user.
Args:
user_id(str): The user to get the account_data for.
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
"""
def get_account_data_for_user_txn(txn):
rows = self._simple_select_list_txn(
txn, "account_data", {"user_id": user_id},
["account_data_type", "content"]
)
global_account_data = {
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id},
["room_id", "account_data_type", "content"]
)
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = json.loads(row["content"])
return (global_account_data, by_room)
return self.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
def get_account_data_for_room(self, user_id, room_id):
"""Get all the client account_data for a user for a room.
Args:
user_id(str): The user to get the account_data for.
room_id(str): The room to get the account_data for.
Returns:
A deferred dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
["account_data_type", "content"]
)
return {
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
return self.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
"""Get all the client account_data for a that's changed.
Args:
user_id(str): The user to get the account_data for.
stream_id(int): The point in the stream since which to get updates
Returns:
A deferred pair of a dict of global account_data and a dict
mapping from room_id string to per room account_data dicts.
"""
def get_updated_account_data_for_user_txn(txn):
sql = (
"SELECT account_data_type, content FROM account_data"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
global_account_data = {
row[0]: json.loads(row[1]) for row in txn.fetchall()
}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
" WHERE user_id = ? AND stream_id > ?"
)
txn.execute(sql, (user_id, stream_id))
account_data_by_room = {}
for row in txn.fetchall():
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
return (global_account_data, account_data_by_room)
return self.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
"""Add some account_data to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the account_data has been added.
"""
content_json = json.dumps(content)
def add_account_data_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="room_account_data",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
}
)
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction(
"add_room_account_data", add_account_data_txn, next_id
)
result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result)
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
"""Add some account_data to a room for a user.
Args:
user_id(str): The user to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
Returns:
A deferred that completes once the account_data has been added.
"""
content_json = json.dumps(content)
def add_account_data_txn(txn, next_id):
self._simple_upsert_txn(
txn,
table="account_data",
keyvalues={
"user_id": user_id,
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
}
)
self._update_max_stream_id(txn, next_id)
with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction(
"add_user_account_data", add_account_data_txn, next_id
)
result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):
"""Update the max stream_id
Args:
txn: The database cursor
next_id(int): The the revision to advance to.
"""
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))

View File

@ -51,6 +51,14 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
def __init__(self, hs):
super(EventsStore, self).__init__(hs)
self.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
@defer.inlineCallbacks @defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False, def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True): is_new_state=True):
@ -365,6 +373,7 @@ class EventsStore(SQLBaseStore):
"processed": True, "processed": True,
"outlier": event.internal_metadata.is_outlier(), "outlier": event.internal_metadata.is_outlier(),
"content": encode_json(event.content).decode("UTF-8"), "content": encode_json(event.content).decode("UTF-8"),
"origin_server_ts": int(event.origin_server_ts),
} }
for event, _ in events_and_contexts for event, _ in events_and_contexts
], ],
@ -640,7 +649,7 @@ class EventsStore(SQLBaseStore):
] ]
rows = self._new_transaction( rows = self._new_transaction(
conn, "do_fetch", [], self._fetch_event_rows, event_ids conn, "do_fetch", [], None, self._fetch_event_rows, event_ids
) )
row_dict = { row_dict = {
@ -964,3 +973,71 @@ class EventsStore(SQLBaseStore):
ret = yield self.runInteraction("count_messages", _count_messages) ret = yield self.runInteraction("count_messages", _count_messages)
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
rows = txn.fetchall()
if not rows:
return 0
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
events = self._get_events_txn(txn, event_ids)
rows = []
for event in events:
try:
event_id = event.event_id
origin_server_ts = event.origin_server_ts
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
continue
rows.append((origin_server_ts, event_id))
sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
)
for index in range(0, len(rows), INSERT_CLUMP_SIZE):
clump = rows[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(rows)
}
self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows)
result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
defer.returnValue(result)

View File

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 25 SCHEMA_VERSION = 27
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@ -160,7 +160,7 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id, def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list): membership_list):
where_clause = "user_id = ? AND (%s)" % ( where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]), " OR ".join(["membership = ?" for _ in membership_list]),
) )
@ -269,3 +269,67 @@ class RoomMemberStore(SQLBaseStore):
ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0 ret = len(room_id_lists.pop(0).intersection(*room_id_lists)) > 0
defer.returnValue(ret) defer.returnValue(ret)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self.runInteraction("forget_membership", f)
@defer.inlineCallbacks
def did_forget(self, user_id, room_id):
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
def f(txn):
sql = (
"SELECT"
" COUNT(*)"
" FROM"
" room_memberships"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
" AND"
" forgotten = 0"
)
txn.execute(sql, (user_id, room_id))
rows = txn.fetchall()
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
@defer.inlineCallbacks
def was_forgotten_at(self, user_id, room_id, event_id):
"""Returns whether user_id has elected to discard history for room_id at event_id.
event_id must be a membership event."""
def f(txn):
sql = (
"SELECT"
" forgotten"
" FROM"
" room_memberships"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
" AND"
" event_id = ?"
)
txn.execute(sql, (user_id, room_id, event_id))
rows = txn.fetchall()
return rows[0][0]
forgot = yield self.runInteraction("did_forget_membership_at", f)
defer.returnValue(forgot == 1)

View File

@ -1,23 +1,22 @@
-- Drop, copy & recreate pushers table to change unique key -- Drop, copy & recreate pushers table to change unique key
-- Also add access_token column at the same time -- Also add access_token column at the same time
CREATE TABLE IF NOT EXISTS pushers2 ( CREATE TABLE IF NOT EXISTS pushers2 (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL, user_name TEXT NOT NULL,
access_token INTEGER DEFAULT NULL, access_token BIGINT DEFAULT NULL,
profile_tag varchar(32) NOT NULL, profile_tag VARCHAR(32) NOT NULL,
kind varchar(8) NOT NULL, kind VARCHAR(8) NOT NULL,
app_id varchar(64) NOT NULL, app_id VARCHAR(64) NOT NULL,
app_display_name varchar(64) NOT NULL, app_display_name VARCHAR(64) NOT NULL,
device_display_name varchar(128) NOT NULL, device_display_name VARCHAR(128) NOT NULL,
pushkey blob NOT NULL, pushkey bytea NOT NULL,
ts BIGINT NOT NULL, ts BIGINT NOT NULL,
lang varchar(8), lang VARCHAR(8),
data blob, data bytea,
last_token TEXT, last_token TEXT,
last_success BIGINT, last_success BIGINT,
failing_since BIGINT, failing_since BIGINT,
FOREIGN KEY(user_name) REFERENCES users(name), UNIQUE (app_id, pushkey)
UNIQUE (app_id, pushkey, user_name)
); );
INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since) INSERT INTO pushers2 (id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since)
SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers; SELECT id, user_name, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_token, last_success, failing_since FROM pushers;

View File

@ -38,7 +38,7 @@ CREATE INDEX event_search_ev_ridx ON event_search(room_id);
SQLITE_TABLE = ( SQLITE_TABLE = (
"CREATE VIRTUAL TABLE IF NOT EXISTS event_search" "CREATE VIRTUAL TABLE event_search"
" USING fts4 ( event_id, room_id, sender, key, value )" " USING fts4 ( event_id, room_id, sender, key, value )"
) )

View File

@ -0,0 +1,17 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE private_user_data_max_stream_id RENAME TO account_data_max_stream_id;

View File

@ -0,0 +1,36 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS account_data(
user_id TEXT NOT NULL,
account_data_type TEXT NOT NULL, -- The type of the account_data.
stream_id BIGINT NOT NULL, -- The version of the account_data.
content TEXT NOT NULL, -- The JSON content of the account_data
CONSTRAINT account_data_uniqueness UNIQUE (user_id, account_data_type)
);
CREATE TABLE IF NOT EXISTS room_account_data(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
account_data_type TEXT NOT NULL, -- The type of the account_data.
stream_id BIGINT NOT NULL, -- The version of the account_data.
content TEXT NOT NULL, -- The JSON content of the account_data
CONSTRAINT room_account_data_uniqueness UNIQUE (user_id, room_id, account_data_type)
);
CREATE INDEX account_data_stream_id on account_data(user_id, stream_id);
CREATE INDEX room_account_data_stream_id on room_account_data(user_id, stream_id);

View File

@ -0,0 +1,26 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Keeps track of what rooms users have left and don't want to be able to
* access again.
*
* If all users on this server have left a room, we can delete the room
* entirely.
*
* This column should always contain either 0 or 1.
*/
ALTER TABLE room_memberships ADD COLUMN forgotten INTEGER DEFAULT 0;

View File

@ -0,0 +1,57 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from synapse.storage.prepare_database import get_statements
import ujson
logger = logging.getLogger(__name__)
ALTER_TABLE = (
"ALTER TABLE events ADD COLUMN origin_server_ts BIGINT;"
"CREATE INDEX events_ts ON events(origin_server_ts, stream_ordering);"
)
def run_upgrade(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
cur.execute("SELECT MIN(stream_ordering) FROM events")
rows = cur.fetchall()
min_stream_id = rows[0][0]
cur.execute("SELECT MAX(stream_ordering) FROM events")
rows = cur.fetchall()
max_stream_id = rows[0][0]
if min_stream_id is not None and max_stream_id is not None:
progress = {
"target_min_stream_id_inclusive": min_stream_id,
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = ujson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
" VALUES (?, ?)"
)
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json))

View File

@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
import logging import logging
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -139,7 +140,10 @@ class SearchStore(BackgroundUpdateStore):
list of dicts list of dicts
""" """
clauses = [] clauses = []
args = []
search_query = search_query = _parse_query(self.database_engine, search_term)
args = [search_query]
# Make sure we don't explode because the person is in too many rooms. # Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless. # We filter the results below regardless.
@ -161,7 +165,7 @@ class SearchStore(BackgroundUpdateStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id" "SELECT ts_rank_cd(vector, query) AS rank, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search" " FROM to_tsquery('english', ?) as query, event_search"
" WHERE vector @@ query" " WHERE vector @@ query"
) )
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
@ -182,7 +186,7 @@ class SearchStore(BackgroundUpdateStore):
sql += " ORDER BY rank DESC LIMIT 500" sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute( results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *([search_term] + args) "search_msgs", self.cursor_to_dict, sql, *args
) )
results = filter(lambda row: row["room_id"] in room_ids, results) results = filter(lambda row: row["room_id"] in room_ids, results)
@ -194,21 +198,28 @@ class SearchStore(BackgroundUpdateStore):
for ev in events for ev in events
} }
defer.returnValue([ highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
defer.returnValue({
"results": [
{ {
"event": event_map[r["event_id"]], "event": event_map[r["event_id"]],
"rank": r["rank"], "rank": r["rank"],
} }
for r in results for r in results
if r["event_id"] in event_map if r["event_id"] in event_map
]) ],
"highlights": highlights,
})
@defer.inlineCallbacks @defer.inlineCallbacks
def search_room(self, room_id, search_term, keys, limit, pagination_token=None): def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
"""Performs a full text search over events with given keys. """Performs a full text search over events with given keys.
Args: Args:
room_id (str): The room_id to search in room_id (list): The room_ids to search in
search_term (str): Search term to search for search_term (str): Search term to search for
keys (list): List of keys to search in, currently supports keys (list): List of keys to search in, currently supports
"content.body", "content.name", "content.topic" "content.body", "content.name", "content.topic"
@ -218,7 +229,18 @@ class SearchStore(BackgroundUpdateStore):
list of dicts list of dicts
""" """
clauses = [] clauses = []
args = [search_term, room_id]
search_query = search_query = _parse_query(self.database_engine, search_term)
args = [search_query]
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
args.extend(room_ids)
local_clauses = [] local_clauses = []
for key in keys: for key in keys:
@ -231,25 +253,25 @@ class SearchStore(BackgroundUpdateStore):
if pagination_token: if pagination_token:
try: try:
topo, stream = pagination_token.split(",") origin_server_ts, stream = pagination_token.split(",")
topo = int(topo) origin_server_ts = int(origin_server_ts)
stream = int(stream) stream = int(stream)
except: except:
raise SynapseError(400, "Invalid pagination token") raise SynapseError(400, "Invalid pagination token")
clauses.append( clauses.append(
"(topological_ordering < ?" "(origin_server_ts < ?"
" OR (topological_ordering = ? AND stream_ordering < ?))" " OR (origin_server_ts = ? AND stream_ordering < ?))"
) )
args.extend([topo, topo, stream]) args.extend([origin_server_ts, origin_server_ts, stream])
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (
"SELECT ts_rank_cd(vector, query) as rank," "SELECT ts_rank_cd(vector, query) as rank,"
" topological_ordering, stream_ordering, room_id, event_id" " origin_server_ts, stream_ordering, room_id, event_id"
" FROM plainto_tsquery('english', ?) as query, event_search" " FROM to_tsquery('english', ?) as query, event_search"
" NATURAL JOIN events" " NATURAL JOIN events"
" WHERE vector @@ query AND room_id = ?" " WHERE vector @@ query AND "
) )
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
# We use CROSS JOIN here to ensure we use the right indexes. # We use CROSS JOIN here to ensure we use the right indexes.
@ -262,24 +284,23 @@ class SearchStore(BackgroundUpdateStore):
# MATCH unless it uses the full text search index # MATCH unless it uses the full text search index
sql = ( sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id," "SELECT rank(matchinfo) as rank, room_id, event_id,"
" topological_ordering, stream_ordering" " origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo" " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search" " FROM event_search"
" WHERE value MATCH ?" " WHERE value MATCH ?"
" )" " )"
" CROSS JOIN events USING (event_id)" " CROSS JOIN events USING (event_id)"
" WHERE room_id = ?" " WHERE "
) )
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")
for clause in clauses: sql += " AND ".join(clauses)
sql += " AND " + clause
# We add an arbitrary limit here to ensure we don't try to pull the # We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database. # entire table from the database.
sql += " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
args.append(limit) args.append(limit)
@ -287,6 +308,8 @@ class SearchStore(BackgroundUpdateStore):
"search_rooms", self.cursor_to_dict, sql, *args "search_rooms", self.cursor_to_dict, sql, *args
) )
results = filter(lambda row: row["room_id"] in room_ids, results)
events = yield self._get_events([r["event_id"] for r in results]) events = yield self._get_events([r["event_id"] for r in results])
event_map = { event_map = {
@ -294,14 +317,110 @@ class SearchStore(BackgroundUpdateStore):
for ev in events for ev in events
} }
defer.returnValue([ highlights = None
if isinstance(self.database_engine, PostgresEngine):
highlights = yield self._find_highlights_in_postgres(search_query, events)
defer.returnValue({
"results": [
{ {
"event": event_map[r["event_id"]], "event": event_map[r["event_id"]],
"rank": r["rank"], "rank": r["rank"],
"pagination_token": "%s,%s" % ( "pagination_token": "%s,%s" % (
r["topological_ordering"], r["stream_ordering"] r["origin_server_ts"], r["stream_ordering"]
), ),
} }
for r in results for r in results
if r["event_id"] in event_map if r["event_id"] in event_map
]) ],
"highlights": highlights,
})
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
that match from the content of the event.
This is used to give a list of words that clients can match against to
highlight the matching parts.
Args:
search_query (str)
events (list): A list of events
Returns:
deferred : A set of strings.
"""
def f(txn):
highlight_words = set()
for event in events:
# As a hack we simply join values of all possible keys. This is
# fine since we're only using them to find possible highlights.
values = []
for key in ("body", "name", "topic"):
v = event.content.get(key, None)
if v:
values.append(v)
if not values:
continue
value = " ".join(values)
# We need to find some values for StartSel and StopSel that
# aren't in the value so that we can pick results out.
start_sel = "<"
stop_sel = ">"
while start_sel in value:
start_sel += "<"
while stop_sel in value:
stop_sel += ">"
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
_to_postgres_options({
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
})
)
txn.execute(query, (value, search_query,))
headline, = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
# result.
matcher_regex = "%s(.*?)%s" % (
re.escape(start_sel),
re.escape(stop_sel),
)
res = re.findall(matcher_regex, headline)
highlight_words.update([r.lower() for r in res])
return highlight_words
return self.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
return "'%s'" % (
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
)
def _parse_query(database_engine, search_term):
"""Takes a plain unicode string from the user and converts it into a form
that can be passed to database.
We use this so that we can add prefix matching, which isn't something
that is supported by default.
"""
# Pull out the individual words, discarding any non-word characters.
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
if isinstance(database_engine, PostgresEngine):
return " & ".join(result + ":*" for result in results)
elif isinstance(database_engine, Sqlite3Engine):
return " & ".join(result + "*" for result in results)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")

View File

@ -28,17 +28,17 @@ class TagsStore(SQLBaseStore):
def __init__(self, hs): def __init__(self, hs):
super(TagsStore, self).__init__(hs) super(TagsStore, self).__init__(hs)
self._private_user_data_id_gen = StreamIdGenerator( self._account_data_id_gen = StreamIdGenerator(
"private_user_data_max_stream_id", "stream_id" "account_data_max_stream_id", "stream_id"
) )
def get_max_private_user_data_stream_id(self): def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream
Returns: Returns:
A deferred int. A deferred int.
""" """
return self._private_user_data_id_gen.get_max_token(self) return self._account_data_id_gen.get_max_token(self)
@cached() @cached()
def get_tags_for_user(self, user_id): def get_tags_for_user(self, user_id):
@ -48,8 +48,8 @@ class TagsStore(SQLBaseStore):
Args: Args:
user_id(str): The user to get the tags for. user_id(str): The user to get the tags for.
Returns: Returns:
A deferred dict mapping from room_id strings to lists of tag A deferred dict mapping from room_id strings to dicts mapping from
strings. tag strings to tag content.
""" """
deferred = self._simple_select_list( deferred = self._simple_select_list(
@ -144,12 +144,12 @@ class TagsStore(SQLBaseStore):
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id) yield self.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -166,12 +166,12 @@ class TagsStore(SQLBaseStore):
txn.execute(sql, (user_id, room_id, tag)) txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with (yield self._private_user_data_id_gen.get_next(self)) as next_id: with (yield self._account_data_id_gen.get_next(self)) as next_id:
yield self.runInteraction("remove_tag", remove_tag_txn, next_id) yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,)) self.get_tags_for_user.invalidate((user_id,))
result = yield self._private_user_data_id_gen.get_max_token(self) result = yield self._account_data_id_gen.get_max_token(self)
defer.returnValue(result) defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id): def _update_revision_txn(self, txn, user_id, room_id, next_id):
@ -185,7 +185,7 @@ class TagsStore(SQLBaseStore):
""" """
update_max_id_sql = ( update_max_id_sql = (
"UPDATE private_user_data_max_stream_id" "UPDATE account_data_max_stream_id"
" SET stream_id = ?" " SET stream_id = ?"
" WHERE stream_id < ?" " WHERE stream_id < ?"
) )

View File

@ -21,7 +21,7 @@ from synapse.handlers.presence import PresenceEventSource
from synapse.handlers.room import RoomEventSource from synapse.handlers.room import RoomEventSource
from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
from synapse.handlers.private_user_data import PrivateUserDataEventSource from synapse.handlers.account_data import AccountDataEventSource
class EventSources(object): class EventSources(object):
@ -30,7 +30,7 @@ class EventSources(object):
"presence": PresenceEventSource, "presence": PresenceEventSource,
"typing": TypingNotificationEventSource, "typing": TypingNotificationEventSource,
"receipt": ReceiptEventSource, "receipt": ReceiptEventSource,
"private_user_data": PrivateUserDataEventSource, "account_data": AccountDataEventSource,
} }
def __init__(self, hs): def __init__(self, hs):
@ -54,8 +54,8 @@ class EventSources(object):
receipt_key=( receipt_key=(
yield self.sources["receipt"].get_current_key() yield self.sources["receipt"].get_current_key()
), ),
private_user_data_key=( account_data_key=(
yield self.sources["private_user_data"].get_current_key() yield self.sources["account_data"].get_current_key()
), ),
) )
defer.returnValue(token) defer.returnValue(token)

View File

@ -103,7 +103,7 @@ class StreamToken(
"presence_key", "presence_key",
"typing_key", "typing_key",
"receipt_key", "receipt_key",
"private_user_data_key", "account_data_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -138,7 +138,7 @@ class StreamToken(
or (int(other.presence_key) < int(self.presence_key)) or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
or (int(other.private_user_data_key) < int(self.private_user_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View File

@ -64,8 +64,7 @@ class Clock(object):
current_context = LoggingContext.current_context() current_context = LoggingContext.current_context()
def wrapped_callback(*args, **kwargs): def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext(): with PreserveLoggingContext(current_context):
LoggingContext.thread_local.current_context = current_context
callback(*args, **kwargs) callback(*args, **kwargs)
with PreserveLoggingContext(): with PreserveLoggingContext():

View File

@ -30,8 +30,7 @@ def debug_deferreds():
context = LoggingContext.current_context() context = LoggingContext.current_context()
def restore_context_callback(x): def restore_context_callback(x):
with PreserveLoggingContext(): with PreserveLoggingContext(context):
LoggingContext.thread_local.current_context = context
return fn(x) return fn(x)
return restore_context_callback return restore_context_callback

View File

@ -19,6 +19,25 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try:
import resource
# Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
# to be 1 on linux so we hard code it.
RUSAGE_THREAD = 1
# If the system doesn't support RUSAGE_THREAD then this should throw an
# exception.
resource.getrusage(RUSAGE_THREAD)
def get_thread_resource_usage():
return resource.getrusage(RUSAGE_THREAD)
except:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage by returning None.
def get_thread_resource_usage():
return None
class LoggingContext(object): class LoggingContext(object):
"""Additional context for log formatting. Contexts are scoped within a """Additional context for log formatting. Contexts are scoped within a
@ -27,7 +46,9 @@ class LoggingContext(object):
name (str): Name for the context for debugging. name (str): Name for the context for debugging.
""" """
__slots__ = ["parent_context", "name", "__dict__"] __slots__ = [
"parent_context", "name", "usage_start", "usage_end", "main_thread", "__dict__"
]
thread_local = threading.local() thread_local = threading.local()
@ -42,11 +63,26 @@ class LoggingContext(object):
def copy_to(self, record): def copy_to(self, record):
pass pass
def start(self):
pass
def stop(self):
pass
def add_database_transaction(self, duration_ms):
pass
sentinel = Sentinel() sentinel = Sentinel()
def __init__(self, name=None): def __init__(self, name=None):
self.parent_context = None self.parent_context = None
self.name = name self.name = name
self.ru_stime = 0.
self.ru_utime = 0.
self.db_txn_count = 0
self.db_txn_duration = 0.
self.usage_start = None
self.main_thread = threading.current_thread()
def __str__(self): def __str__(self):
return "%s@%x" % (self.name, id(self)) return "%s@%x" % (self.name, id(self))
@ -56,12 +92,26 @@ class LoggingContext(object):
"""Get the current logging context from thread local storage""" """Get the current logging context from thread local storage"""
return getattr(cls.thread_local, "current_context", cls.sentinel) return getattr(cls.thread_local, "current_context", cls.sentinel)
@classmethod
def set_current_context(cls, context):
"""Set the current logging context in thread local storage
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
current = cls.current_context()
if current is not context:
current.stop()
cls.thread_local.current_context = context
context.start()
return current
def __enter__(self): def __enter__(self):
"""Enters this logging context into thread local storage""" """Enters this logging context into thread local storage"""
if self.parent_context is not None: if self.parent_context is not None:
raise Exception("Attempt to enter logging context multiple times") raise Exception("Attempt to enter logging context multiple times")
self.parent_context = self.current_context() self.parent_context = self.set_current_context(self)
self.thread_local.current_context = self
return self return self
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
@ -70,16 +120,16 @@ class LoggingContext(object):
Returns: Returns:
None to avoid suppressing any exeptions that were thrown. None to avoid suppressing any exeptions that were thrown.
""" """
if self.thread_local.current_context is not self: current = self.set_current_context(self.parent_context)
if self.thread_local.current_context is self.sentinel: if current is not self:
if current is self.sentinel:
logger.debug("Expected logging context %s has been lost", self) logger.debug("Expected logging context %s has been lost", self)
else: else:
logger.warn( logger.warn(
"Current logging context %s is not expected context %s", "Current logging context %s is not expected context %s",
self.thread_local.current_context, current,
self self
) )
self.thread_local.current_context = self.parent_context
self.parent_context = None self.parent_context = None
def __getattr__(self, name): def __getattr__(self, name):
@ -93,6 +143,43 @@ class LoggingContext(object):
for key, value in self.__dict__.items(): for key, value in self.__dict__.items():
setattr(record, key, value) setattr(record, key, value)
record.ru_utime, record.ru_stime = self.get_resource_usage()
def start(self):
if threading.current_thread() is not self.main_thread:
return
if self.usage_start and self.usage_end:
self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime
self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime
self.usage_start = None
self.usage_end = None
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self):
if threading.current_thread() is not self.main_thread:
return
if self.usage_start:
self.usage_end = get_thread_resource_usage()
def get_resource_usage(self):
ru_utime = self.ru_utime
ru_stime = self.ru_stime
if self.usage_start and threading.current_thread() is self.main_thread:
current = get_thread_resource_usage()
ru_utime += current.ru_utime - self.usage_start.ru_utime
ru_stime += current.ru_stime - self.usage_start.ru_stime
return ru_utime, ru_stime
def add_database_transaction(self, duration_ms):
self.db_txn_count += 1
self.db_txn_duration += duration_ms / 1000.
class LoggingContextFilter(logging.Filter): class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each """Logging filter that adds values from the current logging context to each
@ -121,17 +208,20 @@ class PreserveLoggingContext(object):
exited. Used to restore the context after a function using exited. Used to restore the context after a function using
@defer.inlineCallbacks is resumed by a callback from the reactor.""" @defer.inlineCallbacks is resumed by a callback from the reactor."""
__slots__ = ["current_context"] __slots__ = ["current_context", "new_context"]
def __init__(self, new_context=LoggingContext.sentinel):
self.new_context = new_context
def __enter__(self): def __enter__(self):
"""Captures the current logging context""" """Captures the current logging context"""
self.current_context = LoggingContext.current_context() self.current_context = LoggingContext.set_current_context(
LoggingContext.thread_local.current_context = LoggingContext.sentinel self.new_context
)
def __exit__(self, type, value, traceback): def __exit__(self, type, value, traceback):
"""Restores the current logging context""" """Restores the current logging context"""
LoggingContext.thread_local.current_context = self.current_context LoggingContext.set_current_context(self.current_context)
if self.current_context is not LoggingContext.sentinel: if self.current_context is not LoggingContext.sentinel:
if self.current_context.parent_context is None: if self.current_context.parent_context is None:
logger.warn( logger.warn(
@ -164,8 +254,7 @@ class _PreservingContextDeferred(defer.Deferred):
def _wrap_callback(self, f): def _wrap_callback(self, f):
def g(res, *args, **kwargs): def g(res, *args, **kwargs):
with PreserveLoggingContext(): with PreserveLoggingContext(self._log_context):
LoggingContext.thread_local.current_context = self._log_context
res = f(res, *args, **kwargs) res = f(res, *args, **kwargs)
return res return res
return g return g

View File

@ -197,6 +197,7 @@ class FederationTestCase(unittest.TestCase):
'pdu_failures': [], 'pdu_failures': [],
}, },
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks
@ -228,6 +229,7 @@ class FederationTestCase(unittest.TestCase):
'pdu_failures': [], 'pdu_failures': [],
}, },
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
) )
@defer.inlineCallbacks @defer.inlineCallbacks

View File

@ -365,7 +365,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
# TODO(paul): This test will likely break if/when real auth permissions # TODO(paul): This test will likely break if/when real auth permissions
# are added; for now the HS will always accept any invite # are added; for now the HS will always accept any invite
yield self.handler.send_invite( yield self.handler.send_presence_invite(
observer_user=self.u_apple, observed_user=self.u_banana) observer_user=self.u_apple, observed_user=self.u_banana)
self.assertEquals( self.assertEquals(
@ -384,7 +384,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_invite_local_nonexistant(self): def test_invite_local_nonexistant(self):
yield self.handler.send_invite( yield self.handler.send_presence_invite(
observer_user=self.u_apple, observed_user=self.u_durian) observer_user=self.u_apple, observed_user=self.u_durian)
self.assertEquals( self.assertEquals(
@ -409,11 +409,12 @@ class PresenceInvitesTestCase(PresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
yield self.handler.send_invite( yield self.handler.send_presence_invite(
observer_user=self.u_apple, observed_user=u_rocket) observer_user=self.u_apple, observed_user=u_rocket)
self.assertEquals( self.assertEquals(
@ -443,6 +444,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -483,6 +485,7 @@ class PresenceInvitesTestCase(PresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -827,6 +830,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -843,6 +847,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1033,6 +1038,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1048,6 +1054,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1078,6 +1085,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1184,6 +1192,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
}, },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1200,6 +1209,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
}, },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1232,6 +1242,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
}, },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1265,6 +1276,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
}, },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -1297,6 +1309,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
}, },
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View File

@ -218,6 +218,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )
@ -284,6 +285,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
} }
), ),
json_data_callback=ANY, json_data_callback=ANY,
long_retries=True,
), ),
defer.succeed((200, "OK")) defer.succeed((200, "OK"))
) )

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import uuid import uuid
from mock.mock import Mock from mock import Mock
from synapse.types import RoomID, UserID from synapse.types import RoomID, UserID
from tests import unittest from tests import unittest

View File

@ -168,7 +168,8 @@ class MockHttpResource(HttpServer):
raise KeyError("No event can handle %s" % path) raise KeyError("No event can handle %s" % path)
def register_path(self, method, path_pattern, callback): def register_paths(self, method, path_patterns, callback):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback)) self.callbacks.append((method, path_pattern, callback))

View File

@ -6,11 +6,13 @@ deps =
coverage coverage
Twisted>=15.1 Twisted>=15.1
mock mock
python-subunit
junitxml
setenv = setenv =
PYTHONDONTWRITEBYTECODE = no_byte_code PYTHONDONTWRITEBYTECODE = no_byte_code
commands = commands =
coverage run --source=synapse {envbindir}/trial {posargs:tests} /bin/bash -c "coverage run --source=synapse {envbindir}/trial {env:TRIAL_FLAGS:} {posargs:tests} {env:TOXSUFFIX:}"
coverage report -m {env:DUMP_COVERAGE_COMMAND:coverage report -m}
[testenv:packaging] [testenv:packaging]
deps = deps =
@ -23,4 +25,4 @@ skip_install = True
basepython = python2.7 basepython = python2.7
deps = deps =
flake8 flake8
commands = flake8 synapse commands = /bin/bash -c "flake8 synapse {env:PEP8SUFFIX:}"